new-about-transformer

Ikko Lv4

发现自己对 Transformer 的理解有些偏颇,开此文章记录一下。

输入

设句子有 $T$ 个 token,embedding 维度为 $d$。

输入 embedding 矩阵可以表示为一个大小为 $T \times d$ 的矩阵:

1
2
3
4
5
6
7
X = [
x_1 -> token 1
x_2 -> token 2
x_3 -> token 3
...
x_T -> token T
] shape: (T, d)

这里每一行表示一个 token 的向量。Transformer 的核心工作,就是让每个 token 在上下文中不断更新自己的表示。

Attention 计算行为

对于第 $i$ 个 token,它会“看”其它 token,并计算自己对它们的关注程度。

整个注意力矩阵是一个 $T \times T$ 的矩阵,每一行都对应一个 token 的视角。

在自回归生成场景下,通常会使用 causal mask,也就是只保留对自己以及对自己之前 token 的关注,屏蔽未来 token。

为什么 Transformer 采用多头注意力

Transformer 使用多头注意力(Multi-Head Attention)的原因,不只是“头更多”,而是让模型能够在同一层里,从不同表示子空间中同时捕捉不同层次的信息。

1. 并行捕捉多样化特征

每个头都有自己独立的权重矩阵 $W^Q, W^K, W^V$。在计算时,不同头可以关注不同的信息:

  • 有的头专注于语法关系,比如主谓一致。
  • 有的头专注于代词指代。
  • 有的头专注于上下文语义关联。

如果只有单头,模型就必须把这些信息压缩进同一个注意力池里,容易出现信息相互抵消、表征变模糊的问题。

2. 增加容错性和稳定性

多头机制很像集成学习。多个头共同参与决策,可以减少单个头计算偏差带来的影响,让训练过程更稳定,泛化能力也更好。

3. 更有利于表达复杂映射

把高维向量切分成多个低维头来计算,本质上是在保持参数总量大致不变的前提下,增加了并行的非线性变换路径。

这会让模型更容易学习到复杂函数映射,而不是只靠一个大而粗的注意力通道去硬拟合。

LLM 里 Transformer 的层数与参数配置

“LLM 一般是多少个 Transformer” 这个问题,通常要拆成两个维度来看:

  • 层数(Layers / Blocks):Transformer 堆叠了多少层。
  • 头数(Attention Heads):每层里有多少个注意力头。

现代大模型大多数都基于 Decoder-only 架构堆叠而成,整体趋势是“变深”和“变宽”同时进行。

下面是一些典型的规模参考:

模型规模 (Parameter Scale) 层数 (Layers / Blocks) 注意力头数 (Attention Heads) 隐藏层维度 ($d_{model}$) 典型代表
1B - 3B 16 - 24 16 - 32 2048 Llama-3.2-3B, Gemma-2B
7B - 8B 30 - 32 32 4096 Llama-3-8B, Mistral-7B
14B - 32B 40 - 60 32 - 40 5120 Qwen-2.5-32B
70B - 175B 80 - 96 64 - 128 8192 Llama-3-70B, GPT-3

关键规律

  1. 头数与维度通常满足关系:

    $$d_{model} = \text{Heads} \times \text{Head Dim}$$

    实际工程里,每个头的维度常常固定在 128 左右,例如 Llama 系列。

  2. GQA(Grouped Query Attention)是现在 LLM 里常见的优化方式。为了节省显存和推理开销,Query 头数量往往多于 Key/Value 头数量。例如 Llama-3-8B 有 32 个 Query Heads,但只有 8 个 KV Heads。

  3. 模型变大时,通常会同步增加层数和隐藏层维度,而不是只增加某一个维度。

KV Cache demo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn.functional as F


class SimpleKVCache:
def __init__(self, num_layers, batch_size, num_heads, max_seq_len, head_dim, dtype=torch.float16):
self.k_cache = torch.empty(
num_layers, batch_size, num_heads, max_seq_len, head_dim,
device="cuda",
dtype=dtype,
)
self.v_cache = torch.empty_like(self.k_cache)

def append(self, layer_id, pos, k_new, v_new):
"""
k_new/v_new: [batch, heads, 1, head_dim]
"""
self.k_cache[layer_id, :, :, pos:pos+1, :] = k_new
self.v_cache[layer_id, :, :, pos:pos+1, :] = v_new

def get_kv(self, layer_id, end_pos):
"""
end_pos 表示当前有效长度
"""
return (
self.k_cache[layer_id, :, :, :end_pos, :],
self.v_cache[layer_id, :, :, :end_pos, :],
)


def decode_attention(q, k, v):
head_dim = q.shape[-1]

scores = torch.matmul(q, k.transpose(-1, -2)) / (head_dim ** 0.5)
probs = F.softmax(scores, dim=-1)
out = torch.matmul(probs, v)

return out


def main():
torch.manual_seed(0)

num_layers = 2
batch_size = 1
num_heads = 4
max_seq_len = 16
head_dim = 8
steps = 8

cache = SimpleKVCache(
num_layers=num_layers,
batch_size=batch_size,
num_heads=num_heads,
max_seq_len=max_seq_len,
head_dim=head_dim,
)

for pos in range(steps):
print(f"\ndecode step {pos}")

for layer_id in range(num_layers):
q = torch.randn(batch_size, num_heads, 1, head_dim, device="cuda", dtype=torch.float16)
k_new = torch.randn(batch_size, num_heads, 1, head_dim, device="cuda", dtype=torch.float16)
v_new = torch.randn(batch_size, num_heads, 1, head_dim, device="cuda", dtype=torch.float16)

cache.append(layer_id, pos, k_new, v_new)

k, v = cache.get_kv(layer_id, pos + 1)
out = decode_attention(q, k, v)

print(f"layer={layer_id}, seq_len={pos+1}, out={tuple(out.shape)}")


if __name__ == "__main__":
main()

本质是在 GPU global memory 中预分配 [layer, batch, head, max_seq_len, head_dim] 的 K/V tensor。decode 每生成一个 token,就把当前 token 的 K/V append 到对应位置;下一步 attention 会读取从 0 到当前长度的全部历史 KV。

这个 demo 能说明 decode 阶段为什么容易 memory-bound:每一步只生成一个 token,Q 很小,但是需要读取越来越长的历史 K/V。随着 seq_len 增长,读 KV cache 的显存流量线性增长,而计算并行度又不高,所以小 batch decode 很难把 GPU 吃满。

Paged KV Cache demo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch


class BlockAllocator:
# 物理块分配器:维护所有可用的 GPU 物理块 ID。
# 可以把它理解成一个简化版的物理内存池,负责“申请”和“回收”块。
def __init__(self, num_blocks):
# 初始时所有块都空闲,编号从 0 到 num_blocks - 1。
self.free_blocks = list(range(num_blocks))

def allocate(self):
# 当新请求到来或序列变长时,分配一个空闲物理块。
# 如果没有空闲块,就说明显存池耗尽,直接报 OOM。
if not self.free_blocks:
raise RuntimeError("No free KV blocks")
return self.free_blocks.pop()

def free(self, block_id):
# 请求结束后,把物理块重新放回空闲池。
self.free_blocks.append(block_id)


class PagedKVCache:
# 分页式 KV Cache 核心管理器。
# 它预先在 GPU 上分配好连续的 KV blocks,
# 再通过 request_id -> 逻辑块列表 -> 物理块 ID 的映射关系,
# 把逻辑 token 序列和物理显存解耦。
def __init__(self, num_blocks, block_size, num_heads, head_dim, dtype=torch.float16):
self.num_blocks = num_blocks
self.block_size = block_size
self.num_heads = num_heads
self.head_dim = head_dim

# 物理显存:预先分配固定数量的 KV 块。
# 形状 [num_blocks, num_heads, block_size, head_dim],每个 block 可容纳固定数量 token。
self.k_blocks = torch.empty(
num_blocks, num_heads, block_size, head_dim,
device="cuda",
dtype=dtype,
)
self.v_blocks = torch.empty_like(self.k_blocks)

# 维护物理块分配情况的分配器。
self.allocator = BlockAllocator(num_blocks)

# block_tables: request_id -> [physical_block_id, physical_block_id, ...]
# 逻辑块 ID 由列表索引表示,列表中的值是对应的物理块 ID。
self.block_tables = {}

def _ensure_block(self, request_id, logical_block_id):
# 在写入新 token 之前,先确认该请求的逻辑块映射是否已经足够。
# 如果逻辑块表长度不够,就继续向 allocator 申请新的物理块并追加进去。
if request_id not in self.block_tables:
self.block_tables[request_id] = []

table = self.block_tables[request_id]

while len(table) <= logical_block_id:
table.append(self.allocator.allocate())

def append(self, request_id, pos, k_new, v_new):
"""
k_new/v_new: [heads, 1, head_dim]
"""
# 将全局 token 位置 pos 拆成两部分:
# logical_block_id:属于第几个逻辑块。
# offset:当前 token 在块内的相对位置。
logical_block_id = pos // self.block_size
offset = pos % self.block_size

self._ensure_block(request_id, logical_block_id)

# 通过 block table 找到该逻辑块对应的物理块,然后把 K/V 精确写到对应位置。
physical_block_id = self.block_tables[request_id][logical_block_id]

self.k_blocks[physical_block_id, :, offset:offset+1, :] = k_new
self.v_blocks[physical_block_id, :, offset:offset+1, :] = v_new

def get_kv(self, request_id, end_pos):
"""
这是一个验证 / 调试接口:把离散的 paged blocks gather 成连续 tensor。
真实 vLLM 不会在推理路径里做这个拼接,因为 gather 会带来很大的显存带宽开销。
真正的 PagedAttention 做法,是把 block_table 直接传给 CUDA Kernel,
在 Kernel 内部按映射关系直接寻址读取离散数据。
"""
table = self.block_tables[request_id]

k_list = []
v_list = []

for logical_block_id, physical_block_id in enumerate(table):
# 当前逻辑块对应的全局起点。
start = logical_block_id * self.block_size
remain = end_pos - start

# 如果已经超过有效长度,就停止拼接。
if remain <= 0:
break

# 只保留当前序列最后一个块里真正有效的部分。
valid_len = min(self.block_size, remain)

k_list.append(self.k_blocks[physical_block_id, :, :valid_len, :])
v_list.append(self.v_blocks[physical_block_id, :, :valid_len, :])

# 将离散块拼成连续 tensor,方便调试或做单元测试。
k = torch.cat(k_list, dim=1) # [heads, seq_len, head_dim]
v = torch.cat(v_list, dim=1)

return k, v

def free_request(self, request_id):
# 请求结束后,先取出该请求占用的逻辑块表,再把所有物理块还给 allocator。
table = self.block_tables.pop(request_id, [])
for block_id in table:
self.allocator.free(block_id)

这段代码实现的是一个简化版的 Paged KV Cache(分页键值缓存)。它借鉴了操作系统虚拟内存分页的思路:把连续的 token 序列切成固定大小的块(block),再离散地存放到 GPU 的物理显存中,从而降低显存碎片和大块连续分配的压力。

逐模块说明

1. BlockAllocator:物理块分配器

这个类负责管理系统里所有可用的物理内存块,作用类似一个简化版的物理内存池。

  • __init__:初始化一个包含所有空闲块 ID 的列表,从 0num_blocks - 1
  • allocate:当有新请求到来,或者某个序列变长需要新块时,从空闲列表里弹出一个物理块 ID。如果没有空闲块,就直接报 OOM。
  • free:当请求结束后,把已经使用过的物理块 ID 重新放回空闲列表。

2. PagedKVCache:分页缓存核心管理

这个类负责预分配 GPU 显存,并维护每个请求的“逻辑块”到“物理块”的映射表(block table)。

  • __init__:在 GPU 上一次性开辟固定大小的连续张量 k_blocksv_blocks,形状为 [num_blocks, num_heads, block_size, head_dim]。其中 block_size 表示一个块能装下多少个 token。
  • block_tables:一个字典,键是 request_id,值是一个列表。这个列表相当于页表,列表索引表示逻辑块 ID,列表中的值表示对应的物理块 ID。
  • _ensure_block:在写入新 token 前,检查当前请求的逻辑块数量是否足够覆盖即将写入的位置。如果目标 logical_block_id 超出当前页表长度,就向 allocator 申请新的物理块并追加到映射表中。
  • append:把全局绝对位置 pos 拆成 logical_block_id = pos // block_sizeoffset = pos % block_size,再通过映射表找到物理块,把 k_newv_new 写入预分配张量对应的位置。
  • get_kv:用于验证和调试。它会遍历映射表,把离散的物理块逐个切片提取出来,并通过 torch.cat 拼成连续的 [heads, seq_len, head_dim] 张量。真实推理时不会做这个 gather,因为会带来很大的显存带宽开销。
  • free_request:当请求结束时,把该请求的映射表从字典里移除,并把占用过的所有物理块回收到 allocator 中。

总结

这段代码的核心,是把原本随着序列长度动态增长、且需要连续存储的 KV Cache,转换成按固定大小分块分配、离散存储的模式。通过维护 request_id -> 逻辑块列表 -> 物理块 ID 的映射关系,逻辑序列和物理内存空间被解耦,进而缓解了变长序列推理中的显存碎片化和显存浪费问题。

KV Cache 为什么有 num_layers 维度

在推理系统里,KV Cache 通常会设计成包含 num_layers 维度的张量,因为 Transformer 的每一层都有独立的 Self-Attention 模块,而这些模块拥有不同的 $W^K$ 和 $W^V$。

也就是说,每一层在计算时,都会生成只属于这一层的 K 和 V,用来表征当前层级的特征抽象。所以每一层都需要单独缓存。

时间轴上的逐层写入

在自回归生成(Autoregressive Generation)阶段,假设模型正在生成第 $t$ 个 token:

  1. 输入先流经第 0 层,Layer 0 的 Attention 计算出当前 token 的 K 和 V,并写入缓存的第一个切片,例如 self.K[0, :, t, :]
  2. Layer 0 计算完毕后输出结果,作为 Layer 1 的输入。
  3. Layer 1 根据新的输入,计算自己的 K 和 V,并写入第二个切片,例如 self.K[1, :, t, :]
  4. 依次类推,直到第 num_layers - 1 层写入完毕。

在这个过程中,self.Kself.V 里的 num_layers 维度是一步步、逐层更新的,而不是同一时刻被并行填满。

为什么要设计成一个大张量

从 AI Infrastructure 和 CUDA 的角度看,把 KV Cache 统一初始化成 [num_layers, num_heads, max_seq_len, head_dim] 这样的连续大张量,有几个明显好处:

  • 避免显存碎片化:如果每一层都在自己的类里单独 torch.zeros() 分配缓存,会在显存中产生很多不连续的小块内存。一次性分配连续大张量,对显存管理更友好。
  • 指针偏移更简单:在底层 CUDA Kernel 里,计算到第 $i$ 层时,只需要拿到基础地址并加上层偏移即可,传参和调度都更直接。
  • 状态统一管理:在多轮对话、Prompt Caching、模型保存与恢复等场景里,把所有 KV Cache 集中在一个变量中,更方便做复制、保存和加载。

所以,物理存储上它是一个包含所有层信息的高维大数组,但在执行流里,访问和更新仍然是沿着层顺序逐层切片进行的。

各种kvcache优化策略

优缺点

在工程里,KV Cache 相关优化可以分成三个明确维度:模型架构层系统工程层算法与压缩层。这样区分后,哪些方法需要改模型、哪些只改推理引擎、哪些可以在推理阶段动态应用,就会清楚很多。

1. 模型架构层

这一层的策略在训练阶段就已经固定了,推理端只能按模型原本的结构去适配,不能在不改权重和结构的前提下,给标准 MHA 强行“切换”成 MQA 或 MLA。

策略 核心思路 优点 缺点 典型说明
MQA (Multi-Query Attention) 多个 Query 头共享同一组 Key/Value 头 KV Cache 体积显著下降,带宽压力更小 表达能力弱于标准 MHA,可能有质量损失 需要在预训练阶段就采用这种结构
GQA (Grouped-Query Attention) 将 Query 头分组,每组共享少量 KV 头 在质量和显存之间折中,推理友好 仍然不是完全独立的 KV,压缩幅度有限于 MQA 现代主流 LLM 常见,如 Llama 系列
MLA (Multi-head Latent Attention) 先把 KV 压到低维 latent,再恢复/投影 KV Cache 更小,适合超长上下文 结构和实现复杂,通常要从模型设计开始介入 更像专门为长上下文和推理效率设计的架构
Cross-Layer KV Sharing 相邻或部分层共享同一份 KV Cache 理论上可进一步压缩 KV 规模 会损伤层级表达能力,必须在预训练中介入 适合模型结构级别的实验性优化

2. 系统工程层

这一层主要由推理引擎、内存管理器和调度器决定,重点是怎么把缓存放好、怎么复用、怎么切分、怎么搬运。

策略 核心思路 优点 缺点 适用场景
静态预分配 Cache 一次性申请 [num_layers, num_heads, max_seq_len, head_dim] 大张量 实现简单,访问连续,Kernel 传参方便 占用显存固定,长上下文时浪费大 小型推理服务、单机原型
PagedAttention / 分页式 Cache 把 KV 按固定块分页管理,类似虚拟内存 减少碎片化,支持更高并发,更适合长上下文 调度和索引更复杂,工程实现成本高 vLLM、在线推理服务
Prefix Cache / Prompt Cache 复用相同前缀的 KV,避免重复计算 对多轮对话、系统提示词复用非常有效 只对重复前缀生效,命中率依赖业务模式 Chat、Agent、模板化请求
CPU Offload 把不活跃 KV 或部分层缓存放到 CPU 降低 GPU 显存压力 PCIe 传输开销大,延迟通常更高 显存紧张、吞吐要求不极端的场景
Chunked Prefill 将长 Prompt 分块预填充,而不是一次性算完 降低长输入的显存峰值,缓解 TTFT 瓶颈 调度逻辑更复杂,前后阶段衔接成本更高 长文本输入、在线批处理推理

3. 算法与压缩层

这一层可以在推理阶段动态应用,重点是少存、少算、少搬运,通常不要求彻底改变模型主干结构。

策略 核心思路 优点 缺点 适用场景
KV Quantization 将 KV 从 FP16/BF16 压到 INT8、FP8 等低比特 显著降低显存占用,提高带宽利用率 可能带来精度损失,量化/反量化有额外开销 大模型低成本部署
Sliding Window Cache 只保留最近窗口内的 KV,丢弃更早历史 显存占用可控,适合超长序列 会损失远距离依赖信息 长文档流式推理、局部上下文任务;典型如 Mistral-7B、Qwen-1.5 的部分版本
Attention Sinks / StreamingLLM 永久保留开头少量 token 作为 sink,再配合滑动窗口 比纯滑窗更稳,可支持更长甚至近似无限的流式生成 仍会丢失中间较远的细节,设计依赖具体任务 流式对话、长时间在线生成
动态/稀疏 KV 驱逐 根据 attention 分数或重要性动态保留高价值 token,丢弃不重要 KV 比纯时间窗口保留更多全局信息 需要额外重要性评估逻辑,底层 Kernel 适配更难 H2O、SnapKV 一类方法
Speculative Decoding 先用小模型草拟多个 token,再由大模型批量验证 显著减少大模型解码步数,降低端到端延迟 需要额外 draft model,通常还会增加一点显存占用 追求低延迟的在线生成

KV Cache 相关优化,本质上是在 显存占用、推理延迟、实现复杂度 之间做权衡。不同策略解决的问题不一样,实际工程里通常会组合使用。

其中有两个修正点值得单独强调:

  • KV Quantization 不只包括 INT 格式:在 Hopper 架构上,FP8 在实际工程里非常常见,动态范围和硬件支持通常更好,也更容易和 Tensor Core 结合。很多时候它比 INT8 更接近“几乎无损”的量化路径。
  • Speculative Decoding 不一定省显存:它的核心收益是减少大模型的解码步数,用计算量和少量额外显存去换更低的延迟;由于还要维护 draft model 及其缓存,整体显存占用通常不会比基础解码更低。

补充理解

这些方案并不是互斥的,而且也不完全属于同一层面的优化。比如一个系统可以同时使用:

  • PagedAttention 来管理物理内存布局;
  • Prefix Cache 来复用重复 prompt;
  • KV Quantization 来压缩每个 block 的存储成本;
  • Sliding Window Cache + Attention Sinks 来做流式长上下文生成。

所以,KV Cache 优化更像是一个系统设计问题,而不是单一算法问题。最终目标都是:尽量少算、少占、少搬运

如果你还想继续往下补,下一块通常会接 FlashAttention / FlashAttention-2,它主要优化 attention 计算时的显存访问和中间结果存储。

  • Title: new-about-transformer
  • Author: Ikko
  • Created at : 2025-07-21 14:33:12
  • Updated at : 2026-05-04 13:51:54
  • Link: http://ikko-debug.github.io/2025/07/21/new-about-transformer/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments