MLP
MLP线性
在LlamaMLP类中,门控投影(gate projection)、上升投影(up projection)和下降投影(down projection)是前馈神经网络(FFN)中的三个关键步骤。这些步骤主要用于处理输入张量的变换,从而实现非线性特征提取。以下是对这三个步骤的详细解释:
门控投影(Gate Projection)
门控投影是一种线性变换,其作用是为后续的计算提供一组变换后的特征,这些特征将与上升投影的输出进行元素级别的相乘(门控操作)。
公式:
其中:
是门控投影的权重矩阵。
是门控投影的偏置项。
是输入张量。
上升投影(Up Projection)
上升投影也是一种线性变换,其主要作用是将输入张量的维度从隐藏层尺寸增加到中间层尺寸。这一步骤增加了模型处理的特征数量,为接下来的非线性变换提供了更多的特征表示。
公式:
其中:
是上升投影的权重矩阵。
是上升投影的偏置项。
下降投影(Down Projection)
下降投影是最后一步线性变换,其作用是将之前上升到中间层尺寸的特征重新压缩回隐藏层尺寸。这一步骤使得输出的维度与输入的维度相匹配,从而可以继续进行下一层的计算。
公式:
其中:
是下降投影的权重矩阵。
是下降投影的偏置项。
整体流程伪代码
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# 定义线性层
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
# 选择激活函数
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
# 线性变换:门控投影
gate_proj_output = self.gate_proj(x)
# 线性变换:上升投影
up_proj_output = self.up_proj(x)
# 元素级别相乘,并通过激活函数变换
intermediate_output = self.act_fn(gate_proj_output) * up_proj_output
# 线性变换:下降投影
final_output = self.down_proj(intermediate_output)
return final_output
解释
门控投影(Gate Projection):计算出一组特征,通过激活函数进行非线性变换。
上升投影(Up Projection):增加特征维度,为后续计算提供更多的信息。
元素级别相乘:将门控投影的结果与上升投影的结果相乘,实现门控机制。
下降投影(Down Projection):将特征维度压缩回原始隐藏层尺寸,准备进入下一层的计算。
通过这三个步骤的协同作用,MLP能够在保持输入-输出维度一致的前提下,进行复杂的特征映射,从而增强模型的表达能力。
最后更新于