RMSNorm
RMSNorm
公式解释
伪代码
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
初始化
参数:
- hidden_size: 特征大小,即张量最后一个维度的大小
- eps: 防止除零错误的常数,默认值为1e-6
"""
super().__init__()
# 初始化可学习参数,初始值为1
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
"""
前向传播
参数:
- hidden_states: 输入张量
返回:
- 归一化后的张量
"""
# 保存输入的数据类型
input_dtype = hidden_states.dtype
# 将输入张量转换为浮点数类型以进行计算
hidden_states = hidden_states.to(torch.float32)
# 计算每个特征维度的平方均值
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# 归一化,并乘以可学习参数
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# 返回归一化后的张量,并将其转换回原来的数据类型
return self.weight * hidden_states.to(input_dtype)解释
最后更新于