主页 > 其他  > 

haiku实现三角乘法模块


三角乘法(TriangleMultiplication)是作为一种更对称、更便宜的三角注意力(TriangleAttention)替代模块。

import jax import haiku import jax.numpy as jnp def _layer_norm(axis=-1, name='layer_norm'):   return common_modules.LayerNorm(       axis=axis,       create_scale=True,       create_offset=True,       eps=1e-5,       use_fast_variance=True,       scale_init=hk.initializers.Constant(1.),       offset_init=hk.initializers.Constant(0.),       param_axis=axis,       name=name) class TriangleMultiplication(hk.Module):   """Triangle multiplication layer ("outgoing" or "incoming").   Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"   Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming"   """   def __init__(self, config, global_config, name='triangle_multiplication'):     super().__init__(name=name)     self.config = config     self.global_config = global_config   def __call__(self, left_act, left_mask, is_training=True):     """Builds TriangleMultiplication module.     Arguments:       left_act: Pair activations, shape [N_res, N_res, c_z]       left_mask: Pair mask, shape [N_res, N_res].       is_training: Whether the module is in training mode.     Returns:       Outputs, same shape/type as left_act.     """     del is_training     if self.config.fuse_projection_weights:       return self._fused_triangle_multiplication(left_act, left_mask)     else:       return self._triangle_multiplication(left_act, left_mask)   # @hk.transparent 是 Haiku 中的函数修饰器,用于标记函数为透明模式。   # 透明模式用于在神经网络模块内共享参数。   @hk.transparent   def _triangle_multiplication(self, left_act, left_mask):     """Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""     c = self.config     gc = self.global_config     mask = left_mask[..., None]     act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,                        name='layer_norm_input')(left_act)     input_act = act     left_projection = common_modules.Linear(         c.num_intermediate_channel,         name='left_projection')     left_proj_act = mask * left_projection(act)     right_projection = common_modules.Linear(         c.num_intermediate_channel,         name='right_projection')     right_proj_act = mask * right_projection(act)     left_gate_values = jax.nn.sigmoid(common_modules.Linear(         c.num_intermediate_channel,         bias_init=1.,         initializer=utils.final_init(gc),         name='left_gate')(act))     right_gate_values = jax.nn.sigmoid(common_modules.Linear(         c.num_intermediate_channel,         bias_init=1.,         initializer=utils.final_init(gc),         name='right_gate')(act))     left_proj_act *= left_gate_values     right_proj_act *= right_gate_values     # "Outgoing" edges equation: 'ikc,jkc->ijc'     # "Incoming" edges equation: 'kjc,kic->ijc'     # Note on the Suppl. Alg. 11 & 12 notation:     # For the "outgoing" edges, a = left_proj_act and b = right_proj_act     # For the "incoming" edges, it's swapped:     #   b = left_proj_act and a = right_proj_act     act = jnp.einsum(c.equation, left_proj_act, right_proj_act)     act = common_modules.LayerNorm(         axis=[-1],         create_scale=True,         create_offset=True,         name='center_layer_norm')(             act)     output_channel = int(input_act.shape[-1])     act = common_modules.Linear(         output_channel,         initializer=utils.final_init(gc),         name='output_projection')(act)     gate_values = jax.nn.sigmoid(common_modules.Linear(         output_channel,         bias_init=1.,         initializer=utils.final_init(gc),         name='gating_linear')(input_act))     act *= gate_values     return act   @hk.transparent   def _fused_triangle_multiplication(self, left_act, left_mask):     """TriangleMultiplication with fused projection weights."""     mask = left_mask[..., None]     c = self.config     gc = self.global_config     left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)     # Both left and right projections are fused into projection.     projection = common_modules.Linear(         2*c.num_intermediate_channel, name='projection')     proj_act = mask * projection(left_act)     # Both left + right gate are fused into gate_values.     gate_values = common_modules.Linear(         2 * c.num_intermediate_channel,         name='gate',         bias_init=1.,         initializer=utils.final_init(gc))(left_act)     proj_act *= jax.nn.sigmoid(gate_values)     left_proj_act = proj_act[:, :, :c.num_intermediate_channel]     right_proj_act = proj_act[:, :, c.num_intermediate_channel:]     act = jnp.einsum(c.equation, left_proj_act, right_proj_act)     act = _layer_norm(axis=-1, name='center_norm')(act)     output_channel = int(left_act.shape[-1])     act = common_modules.Linear(         output_channel,         initializer=utils.final_init(gc),         name='output_projection')(act)     gate_values = common_modules.Linear(         output_channel,         bias_init=1.,         initializer=utils.final_init(gc),         name='gating_linear')(left_act)     act *= jax.nn.sigmoid(gate_values)     return act

 

标签:

haiku实现三角乘法模块由讯客互联其他栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“haiku实现三角乘法模块