haiku实现三角乘法模块
- 其他
- 2025-08-06 02:00:04

三角乘法(TriangleMultiplication)是作为一种更对称、更便宜的三角注意力(TriangleAttention)替代模块。
import jax import haiku import jax.numpy as jnp def _layer_norm(axis=-1, name='layer_norm'): return common_modules.LayerNorm( axis=axis, create_scale=True, create_offset=True, eps=1e-5, use_fast_variance=True, scale_init=hk.initializers.Constant(1.), offset_init=hk.initializers.Constant(0.), param_axis=axis, name=name) class TriangleMultiplication(hk.Module): """Triangle multiplication layer ("outgoing" or "incoming"). Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" """ def __init__(self, config, global_config, name='triangle_multiplication'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, left_act, left_mask, is_training=True): """Builds TriangleMultiplication module. Arguments: left_act: Pair activations, shape [N_res, N_res, c_z] left_mask: Pair mask, shape [N_res, N_res]. is_training: Whether the module is in training mode. Returns: Outputs, same shape/type as left_act. """ del is_training if self.config.fuse_projection_weights: return self._fused_triangle_multiplication(left_act, left_mask) else: return self._triangle_multiplication(left_act, left_mask) # @hk.transparent 是 Haiku 中的函数修饰器,用于标记函数为透明模式。 # 透明模式用于在神经网络模块内共享参数。 @hk.transparent def _triangle_multiplication(self, left_act, left_mask): """Implementation of TriangleMultiplication used in AF2 and AF-M<2.3.""" c = self.config gc = self.global_config mask = left_mask[..., None] act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True, name='layer_norm_input')(left_act) input_act = act left_projection = common_modules.Linear( c.num_intermediate_channel, name='left_projection') left_proj_act = mask * left_projection(act) right_projection = common_modules.Linear( c.num_intermediate_channel, name='right_projection') right_proj_act = mask * right_projection(act) left_gate_values = jax.nn.sigmoid(common_modules.Linear( c.num_intermediate_channel, bias_init=1., initializer=utils.final_init(gc), name='left_gate')(act)) right_gate_values = jax.nn.sigmoid(common_modules.Linear( c.num_intermediate_channel, bias_init=1., initializer=utils.final_init(gc), name='right_gate')(act)) left_proj_act *= left_gate_values right_proj_act *= right_gate_values # "Outgoing" edges equation: 'ikc,jkc->ijc' # "Incoming" edges equation: 'kjc,kic->ijc' # Note on the Suppl. Alg. 11 & 12 notation: # For the "outgoing" edges, a = left_proj_act and b = right_proj_act # For the "incoming" edges, it's swapped: # b = left_proj_act and a = right_proj_act act = jnp.einsum(c.equation, left_proj_act, right_proj_act) act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='center_layer_norm')( act) output_channel = int(input_act.shape[-1]) act = common_modules.Linear( output_channel, initializer=utils.final_init(gc), name='output_projection')(act) gate_values = jax.nn.sigmoid(common_modules.Linear( output_channel, bias_init=1., initializer=utils.final_init(gc), name='gating_linear')(input_act)) act *= gate_values return act @hk.transparent def _fused_triangle_multiplication(self, left_act, left_mask): """TriangleMultiplication with fused projection weights.""" mask = left_mask[..., None] c = self.config gc = self.global_config left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act) # Both left and right projections are fused into projection. projection = common_modules.Linear( 2*c.num_intermediate_channel, name='projection') proj_act = mask * projection(left_act) # Both left + right gate are fused into gate_values. gate_values = common_modules.Linear( 2 * c.num_intermediate_channel, name='gate', bias_init=1., initializer=utils.final_init(gc))(left_act) proj_act *= jax.nn.sigmoid(gate_values) left_proj_act = proj_act[:, :, :c.num_intermediate_channel] right_proj_act = proj_act[:, :, c.num_intermediate_channel:] act = jnp.einsum(c.equation, left_proj_act, right_proj_act) act = _layer_norm(axis=-1, name='center_norm')(act) output_channel = int(left_act.shape[-1]) act = common_modules.Linear( output_channel, initializer=utils.final_init(gc), name='output_projection')(act) gate_values = common_modules.Linear( output_channel, bias_init=1., initializer=utils.final_init(gc), name='gating_linear')(left_act) act *= jax.nn.sigmoid(gate_values) return act
haiku实现三角乘法模块由讯客互联其他栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“haiku实现三角乘法模块”