主页 > 软件开发  > 

transformer(4):FFN编码器块

transformer(4):FFN编码器块

文章目录 FFN1.目的2.代码 编码器块

FFN 1.目的

注意力机制捕捉了序列里面不同位置的相关关系,并没有加强非线性表达能力,所以添加FFN用于增强非线性表达能力。

2.代码 class Pos_FFN(nn.Module): def __init__(self, *args, **kwargs) -> None: super(Pos_FFN, self).__init__(*args, **kwargs) self.lin_1 = nn.Linear(num_hiddens, 1024, bias=False) self.relu1 = nn.ReLU() self.lin_2 = nn.Linear(1024, num_hiddens, bias=False) self.relu2 = nn.ReLU() def forward(self, X): X = self.lin_1(X) X = self.relu1(X) X = self.lin_2(X) X = self.relu2(X) # 可写可不写 return X 编码器块

class Encoder_block(nn.Module): def __init__(self, *args, **kwargs) -> None: super(Encoder_block, self).__init__(*args, **kwargs) self.attention = Attention_block() self.add_norm_1 = AddNorm() self.FFN = Pos_FFN() self.add_norm_2 = AddNorm() def forward(self, X, I_m): I_m = I_m.unsqueeze(-2) X_1 = self.attention(X, I_m) X = self.add_norm_1(X, X_1) X_1 = self.FFN(X) X = self.add_norm_2(X, X_1) return X
标签:

transformer(4):FFN编码器块由讯客互联软件开发栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“transformer(4):FFN编码器块