LLM code 手撕
1. RoPE (Rotary Positional Embedding)
核心思想:通过旋转矩阵将绝对位置信息注入到 Q 和 K 中,且具有相对位置特性。 面试重点:复数乘法的实现方式(两两配对旋转)。 为了将位置信息注入模型中,我们将实现 旋转位置编码(RoPE)(Su 等,2021)。 对于一个 token 的 query 向量 $q^{(i)}$,它位于序列中的第 $i$ 个位置,维度为 $d$。 我们会对它施加一个成对旋转矩阵 $R^i$,得到:$q’^{(i)} = R^i q^{(i)} = R^i W_q x^{(i)}$ 也就是说,矩阵 $R^i$ 会将 query 向量的每两个元素看作一个 二维向量,并按角度 $\theta_{i,k}$ 进行旋转,其中:$\theta_{i,k} = \frac{i}{\Theta^{(2k-2)/d}} \quad (k = 1, \ldots, d/2)$ $\Theta$ 是一个常数(一般取 10000,与 Transformer 的位置编码一致)。 因此,矩阵 $R^i$ 可以看作一个 分块对角矩阵,每个 2×2 小块为:$R^i_k= \begin{bmatrix} \cos(\theta_{i,k}) & -\sin(\theta_{i,k}) \ \sin(\theta_{i,k}) & \cos(\theta_{i,k}) \end{bmatrix}$
所以完整的旋转矩阵 $R^i$ 是: \(R^i= \begin{bmatrix} R^i_1 & 0 & 0 & ... & 0 \\ 0 & R^i_2 & 0 & ... & 0 \\ 0 & 0 & R^i_3 & ... & 0 \\ ... & ... & ... & ... & ... \\ 0 & 0 & 0 & ... & R^i_{d/2} \end{bmatrix}\)
其中所有的 $0$ 都代表 2×2 的零矩阵。
虽然我们可以显式构造整个 $d \times d$ 的旋转矩阵,但更高效的做法是利用其特性直接对向量进行旋转。
而且,由于我们只关心一个序列内部 token 的相对旋转关系,所以所有层都可以共享同一套 cos 和 sin 表。 因此,这一层通常用 self.register_buffer(persistent=False) 来保存 cos 和 sin,而不会作为可训练参数。
最终,Q 和 K 都会用对应的$R^i$ 进行旋转。 注意:这个层 没有可学习参数
- 直接实现
import torch
from torch import nn
class Rope(nn.Module):
def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
"""
计算$$\theta_{i,k} = i \cdot \frac{\Theta}{10000^{2k/d}}$$
- 构造 RoPE 模块,并在需要时创建缓存(buffers)。
- `theta`: RoPE 中的 Θ 值(控制旋转角度的频率基底)。
- `d_k`: 查询(query)和键(key)向量的维度。
- `max_seq_len`: 输入序列的最大长度。
- `device`: 存储缓存张量的设备(`torch.device` 或 `None`)。
"""
super().__init__()
if d_k % 2 != 0:
raise ValueError("d_k must be even for RoPE")
self.theta = theta
self.d_k = d_k
self.max_seq_len = max_seq_len
self.device = device
f = 1.0 / (theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k))
# position
p = torch.arange(max_seq_len, device=device).float()
# sinusoids
s = torch.outer(p, f)
self.register_buffer("cos_cache", s.cos(), persistent=False)
self.register_buffer("sin_cache", s.sin(), persistent=False)
def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
"""
Args:
x: shape (*, seq_len, d_k), 输入张量(支持任意批维度)
token_positions: shape (*, seq_len), 每个 token 的绝对位置索引
Returns:
shape (*, seq_len, d_k), 应用旋转编码后的输出
"""
cos = self.cos_cache[token_positions]
sin = self.sin_cache[token_positions]
# 分割输入为偶数和奇数维度: x = [x_even, x_odd]
x_even = x[..., 0::2] # 偶数索引: 0, 2, 4, ...
x_odd = x[..., 1::2] # 奇数索引: 1, 3, 5, ...
# 应用旋转公式
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
# 交错合并: 将 (even, odd) 沿最后一维堆叠并展平
out = torch.stack([out_even, out_odd], dim=-1) # shape: (*, seq_len, d_k//2, 2)
out = out.flatten(-2) # shape: (*, seq_len, d_k)
return out
- 复数实现
import torch
import torch.nn as nn
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=4096, theta=10000.0):
super().__init__()
# 计算频率: theta ^ (-2i/d)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
# 生成时间步 t: [0, 1, ..., max_seq_len-1]
t = torch.arange(max_seq_len, device=freqs.device)
# 外积计算 args: (seq_len, dim/2)
freqs = torch.outer(t, freqs).float()
# 转为极坐标形式,方便后续利用复数性质计算# freqs_cis: (seq_len, dim/2) -> complex64
self.freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
def forward(self, x):# x shape: (batch, seq_len, n_heads, head_dim)# 将 x 重塑为复数形式: (batch, seq_len, n_heads, head_dim/2)
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
# 获取当前序列长度对应的频率,并利用广播机制
freqs_cis = self.freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
# 复数乘法即为旋转
x_rotated = x_complex * freqs_cis
# 变回实数并展平: (batch, seq_len, n_heads, head_dim)
x_out = torch.view_as_real(x_rotated).flatten(3)
return x_out.type_as(x)
2. Attention
import torch
import math
from torch import nn
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len, theta):
super.__init__()
f = 1.0 / theta ** (torch.arrange(0, dim, 2).float() / dim)
p = torch.arrange(0, max_seq_len).float()
s = torch.outer(p, f)
self.register_buffer("cos_cache", s.cos(), persistent=False)
self.register_buffer("sin_cache", s.sin(), persistent=False)
def forward(self, x, token_positions):
cos = self.cos_cache[token_positions]
sin = self.sin_cache[token_positions]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
out_even = cos * x_even - sin * x_odd
out_odd = sin * x_even + cos * x_odd
out = torch.stack([out_even, out_odd], dim=-1)
out = out.flatten(-2)
return out
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super.__init__()
def forward(self, Q, K, V, mask):
# Q, K, V shape (*, batch_size, seq_len, d_k)
d_k = Q.shape[-1]
scale = torch.sqrt(torch.tensor(d_k))
score = torch.matmul(Q, K.transpose(-2, -1)) / scale
attention_weight = torch.softmax(score)
output = torch.matmul(attention_weight, V)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
max_seq_len = 2048
causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
# diagonal=1 表示 从主对角线右上方开始 保留 True,其他位置置为 False。
self.register_buffer('causal_mask', causal_mask, persistent=False)
def forward(self, x):
Q = self.w_q(x) # (batch_size, seq_len, d_model)
K = self.w_k(x) # (batch_size, seq_len, d_model)
V = self.w_v(x) # (batch_size, seq_len, d_model)
batch_size, seq_len, _ = x.shape
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# (batch_size, num_heads, seq_len, head_dim)
scale = torch.sqrt(torch.tensor(self.head_dim))
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# (batch_size, num_heads, seq_len, head_dim) * (batch_size, num_heads, head_dim, seq_len)
# -> (batch_size, num_heads, seq_len, seq_len)
# if mask is not None:
# scores = scores.masked_fill(mask == 0, -1e9)
# 1. 根据当前输入的 seq_len 对预定义的 mask 进行切片
# self.causal_mask 的形状是 (max_len, max_len) -> 切片为 (seq_len, seq_len)
causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
# 2. 使用 masked_fill 填充负无穷
# 注意:这里 causal_mask 为 True (上三角) 的地方会被填充 -inf
scores = scores.masked_fill(causal_mask_slice, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# (batch_size, num_heads, seq_len, seq_len) * (batch_size, num_heads, seq_len, head_dim)
# -> (batch_size, num_heads, seq_len, head_dim)
context = context.transpose(1, 2).contiguous() # -> (batch_size, seq_len, head_dim, num_heads)
context = context.view(batch_size, seq_len, self.d_model)
output = self.w_o(context)
return output
import torch
import torch.nn as nn
import math
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# --- 区别 1: 线性层定义 ---
# Q: 保持多头,维度仍为 d_model (即 num_heads * head_dim)
self.w_q = nn.Linear(d_model, d_model)
# K, V: 变为单头 (Shared Head),输出维度仅为 head_dim
# 在 MQA 中,所有的 Query Head 共享同一个 Key 和 Value Head
self.w_k = nn.Linear(d_model, self.head_dim)
self.w_v = nn.Linear(d_model, self.head_dim)
self.w_o = nn.Linear(d_model, d_model)
max_seq_len = 2048
causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
self.register_buffer('causal_mask', causal_mask, persistent=False)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 1. 投影
Q = self.w_q(x) # (batch_size, seq_len, d_model) -> (B, L, H * D)
K = self.w_k(x) # (batch_size, seq_len, head_dim) -> (B, L, 1 * D)
V = self.w_v(x) # (batch_size, seq_len, head_dim) -> (B, L, 1 * D)
# 2. 变形 (Reshape & Transpose)
# Q: (B, L, H, D) -> (B, H, L, D)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# --- 区别 2: K, V 的维度处理 ---
# K, V: 这里的头数维度为 1
# (B, L, 1, D) -> (B, 1, L, D)
K = K.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
# 3. 计算 Scores (利用广播机制)
scale = math.sqrt(self.head_dim)
# Q: (B, H, L, D)
# K.transpose: (B, 1, D, L)
# PyTorch 会自动将 K 广播(Broadcast)成 (B, H, D, L) 以匹配 Q 的头数
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# Result: (B, H, L, L)
# 4. Masking
causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(causal_mask_slice, float('-inf'))
# 5. Softmax
attn_weights = torch.softmax(scores, dim=-1)
# 6. 计算 Context (利用广播机制)
# attn_weights: (B, H, L, L)
# V: (B, 1, L, D) -> 广播为 (B, H, L, D)
context = torch.matmul(attn_weights, V)
# Result: (B, H, L, D)
# 7. 输出投影
context = context.transpose(1, 2).contiguous() # (B, L, H, D)
context = context.view(batch_size, seq_len, self.d_model) # (B, L, d_model)
output = self.w_o(context)
return output
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_kv_heads):
"""
:param d_model: 模型维度
:param num_heads: Query 的头数 (例如 8)
:param num_kv_heads: Key/Value 的头数 (例如 2)。必须能被 num_heads 整除。
"""
super().__init__()
# 检查是否整除
assert num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = d_model // num_heads
# 计算每个 KV 头对应多少个 Q 头 (Group Size)
# 例如: 8个Q头, 2个KV头 ->每组 4个Q头 共享 1个KV头
self.num_rep = num_heads // num_kv_heads
# Q 保持原样: 输出 num_heads * head_dim
self.w_q = nn.Linear(d_model, num_heads * self.head_dim)
# K, V 减少头数: 输出 num_kv_heads * head_dim
self.w_k = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.w_v = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.w_o = nn.Linear(d_model, d_model)
max_seq_len = 2048
causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
self.register_buffer('causal_mask', causal_mask, persistent=False)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 1. 投影
Q = self.w_q(x) # (B, L, num_heads * D)
K = self.w_k(x) # (B, L, num_kv_heads * D)
V = self.w_v(x) # (B, L, num_kv_heads * D)
# 2. Reshape 分头
# Q: (B, L, num_heads, D)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# K, V: (B, L, num_kv_heads, D) -> (B, num_kv_heads, L, D)
K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# --- GQA 核心操作: 重复/复制 KV 头 ---
# 目标: 将 K, V 从 (B, num_kv_heads, L, D) 变成 (B, num_heads, L, D) 以便和 Q 进行计算
# 使用 repeat_interleave 在 dim=1 (头维度) 进行复制
# 例如: KV头为 [K1, K2], num_rep=2 -> [K1, K1, K2, K2]
K = K.repeat_interleave(self.num_rep, dim=1)
V = V.repeat_interleave(self.num_rep, dim=1)
# 此时 K, V 的形状变成了 (B, num_heads, L, D),与 MHA 计算逻辑一致了
# 3. 计算 Scores
scale = math.sqrt(self.head_dim)
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# (B, H, L, D) @ (B, H, D, L) -> (B, H, L, L)
# 4. Masking
causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(causal_mask_slice, float('-inf'))
# 5. Softmax & Context
attn_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V) # (B, H, L, D)
# 6. 输出
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)
output = self.w_o(context)
return output
Enjoy Reading This Article?
Here are some more articles you might like to read next: