手撕attention
手撕 Attention
这篇用两个版本把 Scaled Dot-Product Attention 从零实现一遍:
- 纯 Python 版:方便看清楚每一步数学细节。
- PyTorch 版:直接可用于训练和推理。
先回顾公式(单头):
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V
$$
其中:
- $Q, K, V$ 分别是 Query、Key、Value。
- $d_k$ 是 key 向量维度,用于缩放,避免点积过大。
- $M$ 是可选 mask(例如因果 mask)。
1. 纯 Python 版本(无第三方依赖)
1 | import math |
代码说明
- 第一步
QK^T:每个 query 和所有 key 做点积,得到相关性分数。 - 第二步
1/sqrt(d_k)缩放:减小分数方差,训练更稳定。 - 第三步 mask:把不允许看到的位置置为极小值(近似 $-\infty$)。
- 第四步 softmax:将分数变成概率分布。
- 第五步和
V相乘:按概率对 value 做加权求和。
2. PyTorch 版本(工程可用)
1 | import math |
代码说明
- 线性层
w_q/w_k/w_v把输入投影到多头子空间。 split_heads和merge_heads完成[B, L, D] <-> [B, H, L, Dh]的维度变换。scaled_dot_product_attention是核心算子,完全对应数学公式。- 因果 mask 适用于自回归场景,保证当前位置不能看未来 token。
from online
1 | import torch |
3. 纯 Python 和 PyTorch 的关系
- 纯 Python 版强调可读性,适合手推公式和面试讲解。
- PyTorch 版强调张量并行和可训练性,适合直接接入模型。
- 两者核心步骤完全一致,差别主要在张量组织方式和计算后端。
4. CUDA 版本说明
CUDA 版本见算子实例:source/_posts/suanzi.md。
其中已包含 attention 相关 kernel(含原生 attention 与优化版本)和完整 CUDA 工程式写法,本文不重复展开。
- Title: 手撕attention
- Author: Ikko
- Created at : 2026-03-31 14:22:08
- Updated at : 2026-03-31 15:39:01
- Link: http://ikko-debug.github.io/2026/03/31/手撕attention/
- License: This work is licensed under CC BY-NC-SA 4.0.
Comments