手撕attention

Ikko Lv4

手撕 Attention

这篇用两个版本把 Scaled Dot-Product Attention 从零实现一遍:

  • 纯 Python 版:方便看清楚每一步数学细节。
  • PyTorch 版:直接可用于训练和推理。

先回顾公式(单头):

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V
$$

其中:

  • $Q, K, V$ 分别是 Query、Key、Value。
  • $d_k$ 是 key 向量维度,用于缩放,避免点积过大。
  • $M$ 是可选 mask(例如因果 mask)。

1. 纯 Python 版本(无第三方依赖)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import math


def matmul(a, b):
"""Matrix multiplication for 2D lists."""
rows_a, cols_a = len(a), len(a[0])
rows_b, cols_b = len(b), len(b[0])
if cols_a != rows_b:
raise ValueError("shape mismatch in matmul")
out = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)]
for i in range(rows_a):
for k in range(cols_a):
aik = a[i][k]
for j in range(cols_b):
out[i][j] += aik * b[k][j]
return out


def transpose(x):
return [list(col) for col in zip(*x)]


def row_softmax(x):
"""Numerically stable softmax on each row."""
out = []
for row in x:
m = max(row)
exps = [math.exp(v - m) for v in row]
s = sum(exps)
out.append([e / s for e in exps])
return out


def add_mask(scores, mask):
"""mask=0 保留,mask=1 屏蔽(加上极小值)"""
if mask is None:
return scores
n = len(scores)
m = len(scores[0])
out = [[0.0 for _ in range(m)] for _ in range(n)]
for i in range(n):
for j in range(m):
out[i][j] = scores[i][j] if mask[i][j] == 0 else -1e9
return out


def attention_python(q, k, v, mask=None):
"""
q: [Lq, Dk]
k: [Lk, Dk]
v: [Lk, Dv]
mask: [Lq, Lk], 0 表示可见,1 表示不可见
"""
d_k = len(q[0])
kt = transpose(k) # [Dk, Lk]
scores = matmul(q, kt) # [Lq, Lk]
scale = 1.0 / math.sqrt(d_k)
scores = [[x * scale for x in row] for row in scores]
scores = add_mask(scores, mask)
probs = row_softmax(scores) # [Lq, Lk]
out = matmul(probs, v) # [Lq, Dv]
return out, probs


if __name__ == "__main__":
q = [[1.0, 0.0], [0.0, 1.0]]
k = [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]
v = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]

# 例子:第二个 query 看不到最后一个 key
mask = [
[0, 0, 0],
[0, 0, 1],
]

out, probs = attention_python(q, k, v, mask)
print("attention probs:", probs)
print("output:", out)

代码说明

  • 第一步 QK^T:每个 query 和所有 key 做点积,得到相关性分数。
  • 第二步 1/sqrt(d_k) 缩放:减小分数方差,训练更稳定。
  • 第三步 mask:把不允许看到的位置置为极小值(近似 $-\infty$)。
  • 第四步 softmax:将分数变成概率分布。
  • 第五步和 V 相乘:按概率对 value 做加权求和。

2. PyTorch 版本(工程可用)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import math
import torch
import torch.nn as nn


def scaled_dot_product_attention(q, k, v, mask=None):
"""
q: [B, H, Lq, Dk]
k: [B, H, Lk, Dk]
v: [B, H, Lk, Dv]
mask: [B, 1|H, Lq, Lk],True 表示要屏蔽
"""
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # [B,H,Lq,Lk]

if mask is not None:
scores = scores.masked_fill(mask, float("-inf"))

attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v) # [B,H,Lq,Dv]
return out, attn


class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.0):
super().__init__()
if d_model % num_heads != 0:
raise ValueError("d_model must be divisible by num_heads")
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads

self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)

def _split_heads(self, x):
# x: [B, L, D] -> [B, H, L, Dh]
bsz, seq_len, _ = x.shape
x = x.view(bsz, seq_len, self.num_heads, self.head_dim)
return x.transpose(1, 2)

def _merge_heads(self, x):
# x: [B, H, L, Dh] -> [B, L, D]
bsz, _, seq_len, _ = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(bsz, seq_len, self.d_model)

def forward(self, x, attn_mask=None):
# x: [B, L, D]
q = self._split_heads(self.w_q(x))
k = self._split_heads(self.w_k(x))
v = self._split_heads(self.w_v(x))

out, attn = scaled_dot_product_attention(q, k, v, attn_mask)
out = self._merge_heads(out)
out = self.w_o(self.dropout(out))
return out, attn


if __name__ == "__main__":
torch.manual_seed(0)
bsz, seq_len, d_model, nhead = 2, 4, 8, 2
x = torch.randn(bsz, seq_len, d_model)

# 因果 mask:上三角为 True(屏蔽未来信息)
causal = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
causal = causal.unsqueeze(0).unsqueeze(1) # [1,1,L,L]

mha = MultiHeadSelfAttention(d_model=d_model, num_heads=nhead, dropout=0.1)
y, attn = mha(x, attn_mask=causal)

print("y shape:", y.shape) # [B, L, D]
print("attn shape:", attn.shape) # [B, H, L, L]

代码说明

  • 线性层 w_q/w_k/w_v 把输入投影到多头子空间。
  • split_headsmerge_heads 完成 [B, L, D] <-> [B, H, L, Dh] 的维度变换。
  • scaled_dot_product_attention 是核心算子,完全对应数学公式。
  • 因果 mask 适用于自回归场景,保证当前位置不能看未来 token。

from online

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch import nn

class MultiHeadAttention(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads

self.q_linear = nn.Linear(hidden_size, hidden_size)
self.k_linear = nn.Linear(hidden_size, hidden_size)
self.v_linear = nn.Linear(hidden_size, hidden_size)
self.o_linear = nn.Linear(hidden_size, hidden_size)

def forward(self, hidden_state, causal_mask=None, past_key_value=None, use_cache=False):
batch_size = hidden_state.size(0)

# 计算 Q、K、V,注意此时只有一个 Token
query = self.q_linear(hidden_state) # (batch_size, 1, hidden_size)
key = self.k_linear(hidden_state)
value = self.v_linear(hidden_state)

# 分割多头,得到形状:(batch_size, num_heads, 1, head_dim)
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

# 若存在缓存,拼接当前 K、V
if past_key_value is not None:
past_key, past_value = past_key_value
key = torch.cat([past_key, key], dim=2) # (batch_size, num_heads, seq_len, head_dim)
value = torch.cat([past_value, value], dim=2)

# 保存新的缓存
new_past_key_value = (key, value) if use_cache else None

# 计算注意力分数,attention_scores 形状: (batch_size, num_heads, 1, seq_len)
attention_scores = torch.matmul(query, key.transpose(-1, -2)) \
/ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

# 应用因果掩码(若需要)
if causal_mask is not None:
attention_scores += causal_mask * -1e9

# 计算注意力输出
attention_probs = torch.softmax(attention_scores, dim=-1)
output = torch.matmul(attention_probs, value)

# 合并多头并线性变换
output = output.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
output = self.o_linear(output)

return (output, new_past_key_value) if use_cache else output

def test_MHA_with_cache():
batch_size = 2
seq_len = 5
hidden_size = 64
num_heads = 4

# 构造输入,模拟整个序列
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

# 刻意分步处理,使用 KV 缓存
past_key_value = None
outputs = []
for i in range(seq_len):
# 当前步骤的输入(单个 token)
current_input = hidden_state[:, i:i+1, :]
# 生成当前步骤的因果掩码(仅允许关注到当前位置及之前的)
current_causal_mask = causal_mask[i:i+1, :i+1] # (1, i+1)
# 前向传播
output_step, past_key_value = mha(
current_input,
causal_mask=current_causal_mask,
past_key_value=past_key_value,
use_cache=True
)
outputs.append(output_step)

# 合并分布输出
output = torch.cat(outputs, dim=1)

print("Input shape:", hidden_state.shape)
print("Output shape:", output.shape)

if __name__ == "__main__":
test_MHA_with_cache()

3. 纯 Python 和 PyTorch 的关系

  • 纯 Python 版强调可读性,适合手推公式和面试讲解。
  • PyTorch 版强调张量并行和可训练性,适合直接接入模型。
  • 两者核心步骤完全一致,差别主要在张量组织方式和计算后端。

4. CUDA 版本说明

CUDA 版本见算子实例:source/_posts/suanzi.md

其中已包含 attention 相关 kernel(含原生 attention 与优化版本)和完整 CUDA 工程式写法,本文不重复展开。

  • Title: 手撕attention
  • Author: Ikko
  • Created at : 2026-03-31 14:22:08
  • Updated at : 2026-03-31 15:39:01
  • Link: http://ikko-debug.github.io/2026/03/31/手撕attention/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments