RMSNorm模块
- IT业界
- 2025-09-10 23:27:01

目录 代码代码解释1. 初始化方法 `__init__`2. 前向传播方法 `forward`3. 总结4. 使用场景 可视化 代码 class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return self.weight * ( x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) ).type_as(x) 代码解释
这段代码定义了一个自定义的PyTorch模块 RMSNorm,用于实现Root Mean Square Normalization (RMSNorm)。RMSNorm是一种归一化技术,类似于Layer Normalization,但它只对输入进行缩放,而不进行平移(即没有偏置项)。下面是代码的详细解释:
1. 初始化方法 __init__ def __init__(self, dim: int, eps: float): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) dim: int: 输入特征的维度。eps: float: 一个小常数,用于数值稳定性,避免除以零的情况。self.weight: 一个可学习的参数,形状为 (dim,),初始化为全1的张量。这个参数用于对归一化后的输入进行缩放。 2. 前向传播方法 forward def forward(self, x): return self.weight * ( x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) ).type_as(x) x: 输入张量,形状通常为 (batch_size, ..., dim)。x.pow(2): 对输入 x 的每个元素求平方。x.pow(2).mean(-1, keepdim=True): 沿着最后一个维度(即特征维度 dim)计算平方的均值,并保持维度不变。结果形状为 (batch_size, ..., 1)。torch.rsqrt(...): 计算均方根的倒数(即1除以平方根),用于归一化。x.float() * torch.rsqrt(...): 将输入 x 转换为浮点数后,乘以均方根的倒数,得到归一化后的结果。.type_as(x): 将结果转换回与输入 x 相同的数据类型。self.weight * (...): 最后,将归一化后的结果乘以可学习的权重 self.weight,进行缩放。 3. 总结 RMSNorm 通过对输入进行归一化,使得每个特征的均方根值为1,然后通过可学习的权重进行缩放。与LayerNorm不同,RMSNorm没有偏置项,只进行缩放操作。eps 用于防止除以零的情况,增加数值稳定性。 4. 使用场景RMSNorm通常用于深度学习模型中,特别是在Transformer架构中,作为LayerNorm的替代方案。它可以加速训练并提高模型的稳定性。
可视化 dim = 64 eps = 1e-5 m = RMSNorm(dim, eps) x = torch.randn(32, 10, dim) # 示例输入 (batch_size, seq_len, dim) f = "rms_norm.onnx" # 导出的 ONNX 文件名 torch.onnx.export(m, x, f) # 模型 # 示例输入在 netron.app/ 上打开 rms_norm.onnx