MLP

MLP线性

在LlamaMLP类中,门控投影(gate projection)、上升投影(up projection)和下降投影(down projection)是前馈神经网络(FFN)中的三个关键步骤。这些步骤主要用于处理输入张量的变换,从而实现非线性特征提取。以下是对这三个步骤的详细解释:

门控投影(Gate Projection)

门控投影是一种线性变换,其作用是为后续的计算提供一组变换后的特征,这些特征将与上升投影的输出进行元素级别的相乘(门控操作)。

公式: [ \text{gate_proj}(x) = W_{\text{gate}} x + b_{\text{gate}} ]

其中:

  • (Wgate)( W_{\text{gate}} ) 是门控投影的权重矩阵。

  • (bgate)( b_{\text{gate}} ) 是门控投影的偏置项。

  • (x)( x ) 是输入张量。

上升投影(Up Projection)

上升投影也是一种线性变换,其主要作用是将输入张量的维度从隐藏层尺寸增加到中间层尺寸。这一步骤增加了模型处理的特征数量,为接下来的非线性变换提供了更多的特征表示。

公式: [ \text{up_proj}(x) = W_{\text{up}} x + b_{\text{up}} ]

其中:

  • (Wup)( W_{\text{up}} ) 是上升投影的权重矩阵。

  • (bup)( b_{\text{up}} ) 是上升投影的偏置项。

下降投影(Down Projection)

下降投影是最后一步线性变换,其作用是将之前上升到中间层尺寸的特征重新压缩回隐藏层尺寸。这一步骤使得输出的维度与输入的维度相匹配,从而可以继续进行下一层的计算。

公式: [ \text{down_proj}(x) = W_{\text{down}} x + b_{\text{down}} ]

其中:

  • (Wdown)( W_{\text{down}} ) 是下降投影的权重矩阵。

  • (bdown)( b_{\text{down}} ) 是下降投影的偏置项。

整体流程伪代码

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        
        # 定义线性层
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        
        # 选择激活函数
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # 线性变换:门控投影
        gate_proj_output = self.gate_proj(x)
        
        # 线性变换:上升投影
        up_proj_output = self.up_proj(x)
        
        # 元素级别相乘,并通过激活函数变换
        intermediate_output = self.act_fn(gate_proj_output) * up_proj_output
        
        # 线性变换:下降投影
        final_output = self.down_proj(intermediate_output)
        
        return final_output

解释

  1. 门控投影(Gate Projection):计算出一组特征,通过激活函数进行非线性变换。

  2. 上升投影(Up Projection):增加特征维度,为后续计算提供更多的信息。

  3. 元素级别相乘:将门控投影的结果与上升投影的结果相乘,实现门控机制。

  4. 下降投影(Down Projection):将特征维度压缩回原始隐藏层尺寸,准备进入下一层的计算。

通过这三个步骤的协同作用,MLP能够在保持输入-输出维度一致的前提下,进行复杂的特征映射,从而增强模型的表达能力。

最后更新于