Rotary Embedding
LlamaRotaryEmbedding 是一种位置编码方法,通过对查询(query)和键(key)向量施加旋转变换,使得模型能够将相对位置信息融入到注意力机制中。相对于传统的位置编码,Rotary Embedding 在计算效率和效果上都有提升。
公式解释
LlamaRotaryEmbedding 的核心在于通过旋转矩阵对查询和键向量进行变换,使得位置编码直接作用于词向量。其公式如下:
旋转频率计算: [ \text{inv_freq}[i] = \frac{1}{\text{base}^{\frac{2i}{d}}} ] 其中 ( d ) 为嵌入维度的一半,(\text{base}) 是一个常数(通常为10000)。
位置编码计算: [ \text{pos_enc}_{t, i} = t \cdot \text{inv_freq}[i] ] 其中 ( t ) 是位置序号。
构造旋转矩阵: [ \text{cos_t} = \cos(\text{pos_enc}_t) \quad \text{sin_t} = \sin(\text{pos_enc}_t) ]
应用旋转位置编码: 对于查询和键向量 ( q ) 和 ( k ),其变换公式为: [ q' = q \cdot \cos_t + (-q_{\text{half}}) \cdot \sin_t ] [ k' = k \cdot \cos_t + (-k_{\text{half}}) \cdot \sin_t ] 其中,( q_{\text{half}} ) 表示拆分后的向量 ( q ) 的后一半。
伪代码
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# 计算逆频率
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
# 扩展逆频率和位置ID
inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# 计算位置编码
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def apply_rotary_pos_emb(q, k, cos, sin):
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置编码
q_embed = q * cos + rotate_half(q) * sin
k_embed = k * cos + rotate_half(k)
return q_embed, k_embed
解释
初始化:
self.inv_freq
是一个(1, dim//2)大小的张量,用于存储频率的倒数。self.scaling_factor
是一个缩放因子,用于位置缩放(可选)。
forward 方法:
inv_freq_expanded
扩展了频率倒数,以便与输入位置ID相乘。position_ids_expanded
将位置ID扩展为浮点数,并调整形状以进行矩阵乘法。通过矩阵乘法计算 (\text{freqs}),然后将其拼接为位置编码,并计算 cos 和 sin。
apply_rotary_pos_emb 函数:
定义了一个辅助函数
rotate_half
,用于旋转向量的一半。对查询和键向量施加位置旋转编码 ( q' ) 和 ( k' )。
通过伪代码和公式,你可以了解 LlamaRotaryEmbedding 是如何通过旋转变换将位置信息融入到查询和键向量中的。这种方法相对于传统位置编码具有更好的性能和计算效率。
最后更新于