Rotary Embedding
LlamaRotaryEmbedding 是一种位置编码方法,通过对查询(query)和键(key)向量施加旋转变换,使得模型能够将相对位置信息融入到注意力机制中。相对于传统的位置编码,Rotary Embedding 在计算效率和效果上都有提升。
公式解释
LlamaRotaryEmbedding 的核心在于通过旋转矩阵对查询和键向量进行变换,使得位置编码直接作用于词向量。其公式如下:
旋转频率计算: [ \text{inv_freq}[i] = \frac{1}{\text{base}^{\frac{2i}{d}}} ] 其中 ( d ) 为嵌入维度的一半,(\text{base}) 是一个常数(通常为10000)。
位置编码计算: [ \text{pos_enc}_{t, i} = t \cdot \text{inv_freq}[i] ] 其中 ( t ) 是位置序号。
构造旋转矩阵: [ \text{cos_t} = \cos(\text{pos_enc}_t) \quad \text{sin_t} = \sin(\text{pos_enc}_t) ]
应用旋转位置编码: 对于查询和键向量 ( q ) 和 ( k ),其变换公式为: [ q' = q \cdot \cos_t + (-q_{\text{half}}) \cdot \sin_t ] [ k' = k \cdot \cos_t + (-k_{\text{half}}) \cdot \sin_t ] 其中,( q_{\text{half}} ) 表示拆分后的向量 ( q ) 的后一半。
伪代码
解释
初始化:
self.inv_freq
是一个(1, dim//2)大小的张量,用于存储频率的倒数。self.scaling_factor
是一个缩放因子,用于位置缩放(可选)。
forward 方法:
inv_freq_expanded
扩展了频率倒数,以便与输入位置ID相乘。position_ids_expanded
将位置ID扩展为浮点数,并调整形状以进行矩阵乘法。通过矩阵乘法计算 (\text{freqs}),然后将其拼接为位置编码,并计算 cos 和 sin。
apply_rotary_pos_emb 函数:
定义了一个辅助函数
rotate_half
,用于旋转向量的一半。对查询和键向量施加位置旋转编码 ( q' ) 和 ( k' )。
通过伪代码和公式,你可以了解 LlamaRotaryEmbedding 是如何通过旋转变换将位置信息融入到查询和键向量中的。这种方法相对于传统位置编码具有更好的性能和计算效率。
最后更新于