LLM code 手撕

1. RoPE (Rotary Positional Embedding)

核心思想:通过旋转矩阵将绝对位置信息注入到 Q 和 K 中,且具有相对位置特性。 面试重点:复数乘法的实现方式(两两配对旋转)。 为了将位置信息注入模型中,我们将实现 旋转位置编码(RoPE)(Su 等,2021)。 对于一个 token 的 query 向量 $q^{(i)}$,它位于序列中的第 $i$ 个位置,维度为 $d$。 我们会对它施加一个成对旋转矩阵 $R^i$,得到:$q’^{(i)} = R^i q^{(i)} = R^i W_q x^{(i)}$ 也就是说,矩阵 $R^i$ 会将 query 向量的每两个元素看作一个 二维向量,并按角度 $\theta_{i,k}$ 进行旋转,其中:$\theta_{i,k} = \frac{i}{\Theta^{(2k-2)/d}} \quad (k = 1, \ldots, d/2)$ $\Theta$ 是一个常数(一般取 10000,与 Transformer 的位置编码一致)。 因此,矩阵 $R^i$ 可以看作一个 分块对角矩阵,每个 2×2 小块为:$R^i_k= \begin{bmatrix} \cos(\theta_{i,k}) & -\sin(\theta_{i,k}) \ \sin(\theta_{i,k}) & \cos(\theta_{i,k}) \end{bmatrix}$

所以完整的旋转矩阵 $R^i$ 是: \(R^i= \begin{bmatrix} R^i_1 & 0 & 0 & ... & 0 \\ 0 & R^i_2 & 0 & ... & 0 \\ 0 & 0 & R^i_3 & ... & 0 \\ ... & ... & ... & ... & ... \\ 0 & 0 & 0 & ... & R^i_{d/2} \end{bmatrix}\)

其中所有的 $0$ 都代表 2×2 的零矩阵

虽然我们可以显式构造整个 $d \times d$ 的旋转矩阵,但更高效的做法是利用其特性直接对向量进行旋转。

而且,由于我们只关心一个序列内部 token 的相对旋转关系,所以所有层都可以共享同一套 cos 和 sin 表。 因此,这一层通常用 self.register_buffer(persistent=False) 来保存 cos 和 sin,而不会作为可训练参数。

最终,Q 和 K 都会用对应的$R^i$ 进行旋转。 注意:这个层 没有可学习参数

  • 直接实现
import torch
from torch import nn

class Rope(nn.Module):
    
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
        """
        计算$$\theta_{i,k} = i \cdot \frac{\Theta}{10000^{2k/d}}$$  
        - 构造 RoPE 模块,并在需要时创建缓存(buffers)。  
        - `theta`: RoPE 中的 Θ 值(控制旋转角度的频率基底)。  
        - `d_k`: 查询(query)和键(key)向量的维度。  
        - `max_seq_len`: 输入序列的最大长度。  
        - `device`: 存储缓存张量的设备(`torch.device` 或 `None`)。  
        """
        super().__init__()
        if d_k % 2 != 0:
            raise ValueError("d_k must be even for RoPE")
        
        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        self.device = device

        f = 1.0 / (theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k))

        # position
        p = torch.arange(max_seq_len, device=device).float()
        # sinusoids
        s = torch.outer(p, f)

        self.register_buffer("cos_cache", s.cos(), persistent=False)
        self.register_buffer("sin_cache", s.sin(), persistent=False)

        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: shape (*, seq_len, d_k), 输入张量(支持任意批维度)
            token_positions: shape (*, seq_len), 每个 token 的绝对位置索引

        Returns:
            shape (*, seq_len, d_k), 应用旋转编码后的输出
        """

        cos = self.cos_cache[token_positions]
        sin = self.sin_cache[token_positions]

        # 分割输入为偶数和奇数维度: x = [x_even, x_odd]
        x_even = x[..., 0::2]  # 偶数索引: 0, 2, 4, ...
        x_odd  = x[..., 1::2]  # 奇数索引: 1, 3, 5, ...

        # 应用旋转公式
        out_even = x_even * cos - x_odd * sin
        out_odd  = x_even * sin + x_odd * cos

        # 交错合并: 将 (even, odd) 沿最后一维堆叠并展平
        out = torch.stack([out_even, out_odd], dim=-1)  # shape: (*, seq_len, d_k//2, 2)
        out = out.flatten(-2)  # shape: (*, seq_len, d_k)

        return out
        
  • 复数实现
import torch
import torch.nn as nn

class RoPE(nn.Module):
    def __init__(self, dim, max_seq_len=4096, theta=10000.0):
        super().__init__()
        # 计算频率: theta ^ (-2i/d)
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        # 生成时间步 t: [0, 1, ..., max_seq_len-1]
        t = torch.arange(max_seq_len, device=freqs.device)
        # 外积计算 args: (seq_len, dim/2)
        freqs = torch.outer(t, freqs).float()
        # 转为极坐标形式,方便后续利用复数性质计算# freqs_cis: (seq_len, dim/2) -> complex64
        self.freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 

    def forward(self, x):# x shape: (batch, seq_len, n_heads, head_dim)# 将 x 重塑为复数形式: (batch, seq_len, n_heads, head_dim/2)
        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        
        # 获取当前序列长度对应的频率,并利用广播机制
        freqs_cis = self.freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        
        # 复数乘法即为旋转
        x_rotated = x_complex * freqs_cis
        
        # 变回实数并展平: (batch, seq_len, n_heads, head_dim)
        x_out = torch.view_as_real(x_rotated).flatten(3)
        return x_out.type_as(x)

2. Attention

import torch
import math
from torch import nn

class RoPE(nn.Module):
    def __init__(self, dim, max_seq_len, theta):

        super.__init__()
        
        f = 1.0 / theta ** (torch.arrange(0, dim, 2).float() / dim)
        p = torch.arrange(0, max_seq_len).float()
        s = torch.outer(p, f)

        self.register_buffer("cos_cache", s.cos(), persistent=False)
        self.register_buffer("sin_cache", s.sin(), persistent=False)

    def forward(self, x, token_positions):
        cos = self.cos_cache[token_positions]
        sin = self.sin_cache[token_positions]

        x_even = x[..., 0::2]
        x_odd = x[..., 1::2]
        
        out_even = cos * x_even - sin * x_odd
        out_odd = sin * x_even + cos * x_odd

        out = torch.stack([out_even, out_odd], dim=-1)
        out = out.flatten(-2)

        return out

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super.__init__()

    def forward(self, Q, K, V, mask):

        # Q, K, V shape (*, batch_size, seq_len, d_k)

        d_k = Q.shape[-1]
        
        scale = torch.sqrt(torch.tensor(d_k))
        score = torch.matmul(Q, K.transpose(-2, -1)) / scale

        attention_weight = torch.softmax(score)

        output = torch.matmul(attention_weight, V)

        return output
    

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        max_seq_len = 2048
        causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
        # diagonal=1 表示 从主对角线右上方开始 保留 True,其他位置置为 False。
        self.register_buffer('causal_mask', causal_mask, persistent=False)

    def forward(self, x):
        Q = self.w_q(x) # (batch_size, seq_len, d_model)
        K = self.w_k(x) # (batch_size, seq_len, d_model)
        V = self.w_v(x) # (batch_size, seq_len, d_model)

        batch_size, seq_len, _ = x.shape
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # (batch_size, num_heads, seq_len, head_dim)

        scale = torch.sqrt(torch.tensor(self.head_dim))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        # (batch_size, num_heads, seq_len, head_dim) * (batch_size, num_heads, head_dim, seq_len)
        # -> (batch_size, num_heads, seq_len, seq_len)

        # if mask is not None:
        #     scores = scores.masked_fill(mask == 0, -1e9)

        # 1. 根据当前输入的 seq_len 对预定义的 mask 进行切片
        #    self.causal_mask 的形状是 (max_len, max_len) -> 切片为 (seq_len, seq_len)
        causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
        
        # 2. 使用 masked_fill 填充负无穷
        #    注意:这里 causal_mask 为 True (上三角) 的地方会被填充 -inf
        scores = scores.masked_fill(causal_mask_slice, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)

        context = torch.matmul(attn_weights, V) 
        # (batch_size, num_heads, seq_len, seq_len) * (batch_size, num_heads, seq_len, head_dim)
        # -> (batch_size, num_heads, seq_len, head_dim)

        context = context.transpose(1, 2).contiguous() # -> (batch_size, seq_len, head_dim, num_heads)
        context = context.view(batch_size, seq_len, self.d_model)

        output = self.w_o(context)

        return output

import torch
import torch.nn as nn
import math

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # --- 区别 1: 线性层定义 ---
        # Q: 保持多头,维度仍为 d_model (即 num_heads * head_dim)
        self.w_q = nn.Linear(d_model, d_model)
        
        # K, V: 变为单头 (Shared Head),输出维度仅为 head_dim
        # 在 MQA 中,所有的 Query Head 共享同一个 Key 和 Value Head
        self.w_k = nn.Linear(d_model, self.head_dim)
        self.w_v = nn.Linear(d_model, self.head_dim)
        
        self.w_o = nn.Linear(d_model, d_model)

        max_seq_len = 2048
        causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
        self.register_buffer('causal_mask', causal_mask, persistent=False)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # 1. 投影
        Q = self.w_q(x) # (batch_size, seq_len, d_model) -> (B, L, H * D)
        K = self.w_k(x) # (batch_size, seq_len, head_dim) -> (B, L, 1 * D)
        V = self.w_v(x) # (batch_size, seq_len, head_dim) -> (B, L, 1 * D)

        # 2. 变形 (Reshape & Transpose)
        # Q: (B, L, H, D) -> (B, H, L, D)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # --- 区别 2: K, V 的维度处理 ---
        # K, V: 这里的头数维度为 1
        # (B, L, 1, D) -> (B, 1, L, D)
        K = K.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)

        # 3. 计算 Scores (利用广播机制)
        scale = math.sqrt(self.head_dim)
        
        # Q: (B, H, L, D)
        # K.transpose: (B, 1, D, L)
        # PyTorch 会自动将 K 广播(Broadcast)成 (B, H, D, L) 以匹配 Q 的头数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        # Result: (B, H, L, L)

        # 4. Masking
        causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
        scores = scores.masked_fill(causal_mask_slice, float('-inf'))

        # 5. Softmax
        attn_weights = torch.softmax(scores, dim=-1)

        # 6. 计算 Context (利用广播机制)
        # attn_weights: (B, H, L, L)
        # V: (B, 1, L, D) -> 广播为 (B, H, L, D)
        context = torch.matmul(attn_weights, V)
        # Result: (B, H, L, D)

        # 7. 输出投影
        context = context.transpose(1, 2).contiguous() # (B, L, H, D)
        context = context.view(batch_size, seq_len, self.d_model) # (B, L, d_model)

        output = self.w_o(context)

        return output

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads):
        """
        :param d_model: 模型维度
        :param num_heads: Query 的头数 (例如 8)
        :param num_kv_heads: Key/Value 的头数 (例如 2)。必须能被 num_heads 整除。
        """
        super().__init__()
        
        # 检查是否整除
        assert num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = d_model // num_heads
        
        # 计算每个 KV 头对应多少个 Q 头 (Group Size)
        # 例如: 8个Q头, 2个KV头 ->每组 4个Q头 共享 1个KV头
        self.num_rep = num_heads // num_kv_heads

        # Q 保持原样: 输出 num_heads * head_dim
        self.w_q = nn.Linear(d_model, num_heads * self.head_dim)
        
        # K, V 减少头数: 输出 num_kv_heads * head_dim
        self.w_k = nn.Linear(d_model, num_kv_heads * self.head_dim)
        self.w_v = nn.Linear(d_model, num_kv_heads * self.head_dim)
        
        self.w_o = nn.Linear(d_model, d_model)

        max_seq_len = 2048
        causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
        self.register_buffer('causal_mask', causal_mask, persistent=False)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # 1. 投影
        Q = self.w_q(x) # (B, L, num_heads * D)
        K = self.w_k(x) # (B, L, num_kv_heads * D)
        V = self.w_v(x) # (B, L, num_kv_heads * D)

        # 2. Reshape 分头
        # Q: (B, L, num_heads, D)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # K, V: (B, L, num_kv_heads, D) -> (B, num_kv_heads, L, D)
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # --- GQA 核心操作: 重复/复制 KV 头 ---
        # 目标: 将 K, V 从 (B, num_kv_heads, L, D) 变成 (B, num_heads, L, D) 以便和 Q 进行计算
        # 使用 repeat_interleave 在 dim=1 (头维度) 进行复制
        # 例如: KV头为 [K1, K2], num_rep=2 -> [K1, K1, K2, K2]
        K = K.repeat_interleave(self.num_rep, dim=1) 
        V = V.repeat_interleave(self.num_rep, dim=1)
        
        # 此时 K, V 的形状变成了 (B, num_heads, L, D),与 MHA 计算逻辑一致了

        # 3. 计算 Scores
        scale = math.sqrt(self.head_dim)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale 
        # (B, H, L, D) @ (B, H, D, L) -> (B, H, L, L)

        # 4. Masking
        causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
        scores = scores.masked_fill(causal_mask_slice, float('-inf'))

        # 5. Softmax & Context
        attn_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V) # (B, H, L, D)

        # 6. 输出
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, seq_len, self.d_model)

        output = self.w_o(context)

        return output



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Google Gemini updates: Flash 1.5, Gemma 2 and Project Astra
  • Displaying External Posts on Your al-folio Blog
  • a post with plotly.js
  • a post with image galleries
  • a post with tabs