qwen部署
环境与部署
1 | conda create -n vllm-qwen python=3.11 -y |
这个模型必须通过两卡 + pipeline parallel + 降并发的方式部署,否则会因为显存不足而报错。
1 | CUDA_VISIBLE_DEVICES=0,1 vllm serve /home/share/HDstorage/xyc/qwen/models/qwen14b \ |
部署成功后,可以通过以下方式测试:
1 | curl http://127.0.0.1:8000/v1/chat/completions \ |
关闭服务:
1 | pkill -f "vllm serve /home/share/HDstorage/xyc/qwen/models/qwen14b" |
任务目标
这次工作分三段:
- 基于 vLLM 的 BitsAndBytes 量化支持,跑通小型 4-bit 模型推理。
- 分析 NF4 权重恢复在量化推理链路中的位置与开销。
- 基于 vLLM 的自定义量化扩展机制,设计实验性的 NF4 恢复/融合路径,并在 decode-like 小 batch 场景下评估其收益。
结论先说
这轮实验已经回答了几个关键问题:
- vLLM 跑 pre-quantized BitsAndBytes 4-bit 模型时,普通 dense 层并不是“先把 NF4 权重整层反量化到显存再计算”,而是尽量直接消费 packed weight 和
QuantState。 - 显式反量化成整块 dense bf16 权重的路径,主要瓶颈不是 LUT 查表本身,而是中间 dense weight 的显存写回和后续再读取。
- 如果把 bitsandbytes 小 batch fast path 的 GEMV/GEMM 组织方式迁到自定义 kernel,再把其中的 NF4 解码部分换成自己的版本,确实可以在 decode-like 小 batch 上得到端到端正收益。
- 只替换
q_proj时,真实 vLLM 服务端到端吞吐大约提升1.6%;扩到q/k/v/o后,提升大约2.5%;单独替down_proj基本没有端到端收益。
我是怎么做的
1. 跑通 vLLM + BitsAndBytes 4-bit 推理
实际运行环境在服务器 deep。模型为 unsloth/Qwen2.5-14B-Instruct-bnb-4bit 的本地副本。服务通过两卡、pipeline-parallel-size=2、max-num-seqs=1 成功启动并返回结果。
这里需要特别说明:真正跑通的是 PP=2,不是 TP=2。vLLM 日志里显示:
- rank 0:
PP rank 0, TP rank 0 - rank 1:
PP rank 1, TP rank 0
这是因为 vLLM 对 pre-quantized BitsAndBytes checkpoint 不支持 tensor parallel,但允许 pipeline parallel。
2. 梳理 NF4 在 vLLM 链路中的位置
我直接读了 vLLM 安装目录里的 BitsAndBytes 相关代码,重点看了:
vllm/model_executor/model_loader/bitsandbytes_loader.pyvllm/model_executor/layers/quantization/bitsandbytes.py
结论如下:
- loader 会先把 checkpoint 中的量化元数据组装成
QuantState。 - pre-quantized checkpoint 的
weight.absmax、weight.quant_map、weight.nested_absmax、weight.nested_quant_map等元数据,会在加载阶段被收集并重建。 _bind_quant_states_to_params()会把 double-quant 的 scale 元数据恢复出来,也就是把 nested 的absmax展开成 float32。- 但 packed 的 4-bit 权重本体仍然保留为
uint8,不会在普通 dense 路径中提前整层展开成 bf16/fp16 常驻显存。 - dense 线性层热路径最终走的是
bitsandbytes.matmul_4bit(x, packed_weight.T, quant_state)。 - MoE 那条路径更接近“先
dequantize_4bit(...),再交给后续算子”。
所以更准确地说,NF4 “恢复”发生在两个层面:
- 加载阶段:恢复
QuantState和 double-quant 的 scale 元数据。 - 推理阶段:普通 dense 层尽量直接消费 packed weight;如果走显式
dequantize_4bit路径,才会生成完整 dense weight。
3. 先做服务 benchmark
我写了一个 20 请求的流式 benchmark,统计:
- 首 token 延迟 TTFT
- 平均端到端耗时
- completion 吞吐
- 不同 prompt 长度下的变化
基线结果是:
- overall avg TTFT:
0.0858 s - p50/p95 TTFT:
0.0791 s / 0.1154 s - overall completion throughput:
70.76 tok/s
分组后大致如下:
- short, 约
88prompt tokens:0.0793 s,70.24 tok/s - medium, 约
363prompt tokens:0.0861 s,71.47 tok/s - long, 约
963prompt tokens:0.0986 s,69.70 tok/s
结论是:低并发条件下,prompt 变长会推高 TTFT,但 decode 吞吐整体比较稳定。
显式反量化路径的实验
实验设计
我没有直接改 bitsandbytes 本体,而是在自己的仓库里做了一套实验性扩展:
fs_plugins/custom_ops/nf4_ikko.cppfs_plugins/custom_ops/nf4_ikko.cufs_plugins/custom_ops/nf4_ikko.pybenchmark_nf4_ikko.py
第一版的目标不是替生产实现,而是把成本拆开量化。具体做法是:
- 从 checkpoint 读取 packed 4-bit 权重和量化元数据。
- 自己实现 NF4 restore kernel,把 packed weight 恢复成 dense bf16 权重。
- 再调用
F.linear。 - 与默认
bitsandbytes.matmul_4bit做单层对照。
恢复 kernel 中迁入的优化
后续我又把 mainla.cu 里的关键优化迁到了 restore kernel:
- shared LUT
- warp 内广播 scale
- pair write
- tail 处理
单层结论
我选了两个代表层做对照:
model.layers.0.self_attn.q_proj.weightmodel.layers.0.mlp.down_proj.weight
结果很一致:restore kernel 本身可以通过优化做快,但“先恢复整层 dense 权重再算”这条路径整体仍然打不过默认 bnb。
down_proj 上,优化后:
- restore+decode:
0.2600 ms -> 0.2424 ms - decode only:
0.2407 ms -> 0.2120 ms
但 restore+linear 总成本仍高于 bnb:
- batch 1: bnb
0.0314 ms,显式 restore0.3227 ms - batch 8: bnb
0.3399 ms,显式 restore0.3591 ms
q_proj 上,优化后:
- restore+decode:
0.1087 ms -> 0.0746 ms - decode only:
0.0749 ms -> 0.0415 ms
但总路径依然落后于 bnb:
- batch 1: bnb
0.0283 ms,显式 restore0.0596 ms - batch 8: bnb
0.0612 ms,显式 restore0.0686 ms
这里真正慢在哪
结论很明确:显式反量化路径的主要损失不在 LUT 或 scale 恢复,而在:
- 把整块 dense bf16 权重写回显存。
- 后续 GEMM 又要把这块 dense weight 从显存读一遍。
也就是说,显式 restore 的结构性问题是中间写回开销,而不是 NF4 解码逻辑本身。
从显式 restore 转向 fused decode + matmul
为什么转向 fused
既然问题出在“恢复成整块 dense 再写回显存”,那合理方向就是不再落地 dense weight,而是在 kernel 内边解码边做乘法。
第一版 fused 原型
我先做了一个朴素 fused 原型:
- 输入
x - 输入 packed
uint8weight - 输入
quant_state.absmax - 在 kernel 里边解码边累加
结果说明方向是对的,但朴素 per-output 累加写法太慢,打不过 bnb。
借 bitsandbytes kernel 组织方式做第二轮迭代
后面我参考了 bitsandbytes 源码里的 kernels.cu,特别是 kgemm_4bit_inference_naive 这类小 batch fast path 的组织方式,把 warp/GEMV 风格迁到了自己的 kernel 里,再把其中的 NF4 解码部分改成自己的实现。
这个版本的核心点是:
- 仍然直接消费 packed weight。
- 不生成整块 dense weight。
- 改成更接近 bnb fast path 的 warp 级输出组织。
修正 packed weight 行偏移错误之后,这条 fused 路径稳定跑通。
单层 fused 对比
对 q_proj 的 microbenchmark,结果如下:
- batch 1
- bnb:
0.0159 ms - 显式 restore:
0.0595 ms - 新 fused:
0.0129 ms
- bnb:
- batch 2
- bnb:
0.0601 ms - 显式 restore:
0.0668 ms - 新 fused:
0.0227 ms
- bnb:
- batch 4
- bnb:
0.0597 ms - 显式 restore:
0.0670 ms - 新 fused:
0.0413 ms
- bnb:
- batch 8
- bnb:
0.0620 ms - 显式 restore:
0.0689 ms - 新 fused:
0.0787 ms
- bnb:
这说明 fused 方向在 decode-like 小 batch 场景上是成立的:batch 1/2/4 可以赢,batch 8 开始退化。
嵌入 vLLM 真实推理流程
接入方式
我没有直接改 vLLM 安装包,而是用了 monkeypatch:
sitecustomize.pyvllm_ikko_sitecustomize.py
启动 vLLM 时通过环境变量控制:
PYTHONPATH=/home/xyc/PCFG-NATVLLM_IKKO_ENABLE=1VLLM_IKKO_MODE=<mode>
patch 的是 BitsAndBytesLinearMethod._apply_4bit_weight,即只替换指定层的 4-bit 线性层计算路径,其他层仍走默认 bnb。
支持的模式
目前已经做了这几类:
qproj: 只替换shape == (5120, 5120)的q_projattn: 替换 attention 线相关投影(5120, 5120):q_proj/o_proj(1024, 5120):k_proj/v_proj
down_proj: 只替换shape == (5120, 13824)的down_proj
端到端结果
所有端到端对比都使用同一组服务参数:
- 两卡
PP=2 --gpu-memory-utilization 0.6--max-num-seqs 1- 20 个串行流式请求
- 相同 benchmark 脚本和 prompt 组
1. 基线服务
默认 bitsandbytes 路径:
- avg TTFT:
0.0994 s - avg e2e:
0.5774 s - throughput:
68.58 tok/s
2. 只替 q_proj
- avg TTFT:
0.0896 s - avg e2e:
0.5684 s - throughput:
69.67 tok/s
相对基线:
- throughput 提升约
+1.6% - avg e2e 降低约
-1.6%
3. 扩到 q/k/v/o 四个 attention 投影
- avg TTFT:
0.0884 s - avg e2e:
0.5634 s - throughput:
70.28 tok/s
相对基线:
- throughput 提升约
+2.5% - avg e2e 降低约
-2.4%
这说明 attention 线整体替换比只替 q_proj 更有效。
4. 只替 down_proj
我又额外跑了一轮 down_proj 的端到端对比:
- avg TTFT:
0.0894 s - avg e2e:
0.5681 s - throughput:
69.70 tok/s
相对基线:
- throughput 提升约
+1.6% - avg e2e 降低约
-1.6%
这个数值和只替 q_proj 很接近,但没有超过 attention 全替换版本。结合前面的单层实验,可以更稳妥地理解为:
down_proj单层上并不是最适合 decode-like 优化的层。- 端到端里出现的小幅改善,更多是“局部替换仍然能成立”,但收益上限不如 attention 线。
- 真正最值得继续扩的仍然是
q/k/v/o这类更贴近 decode 热路径的层。
补充 Benchmark:固定 Prompt + 5 Repeats
前面的端到端结果是单轮 benchmark,能说明方向,但还不够干净。为了收尾,我又补了一轮更严格的对照:
- 固定 20 条请求,不再在不同模式之间重新生成 prompt。
- 只保留三组最有代表性的模式:
- baseline
q_projattn(q/k/v/o)
- 每组都跑 5 轮,统计
mean / std / min / max。
这轮 benchmark 使用的是同一份固定 prompt 集 fixed_qwen_prompts.json,每一轮都按完全相同的请求顺序执行。
结果汇总
| mode | avg TTFT mean ± std (s) | avg e2e mean ± std (s) | throughput mean ± std (tok/s) | throughput min/max |
|---|---|---|---|---|
| baseline | 0.0800 ± 0.0004 |
0.5605 ± 0.0005 |
71.01 ± 0.07 |
70.91 / 71.08 |
| q_proj | 0.0811 ± 0.0040 |
0.5620 ± 0.0029 |
70.75 ± 0.52 |
69.84 / 71.04 |
| attn | 0.0810 ± 0.0038 |
0.5619 ± 0.0039 |
70.76 ± 0.64 |
69.62 / 71.07 |
这一轮更干净 benchmark 的结论
这组结果和前面的单轮结果相比,更适合拿来做最终结论。因为它把 prompt 集和重复波动都控制住了。
从这轮数据看:
- baseline 的平均吞吐反而略高于
q_proj和attn两个自定义路径。 - 三组模式之间的差距已经缩到
0.2~0.3 tok/s量级,远小于q_proj/attn自身跨轮波动。 q_proj和attn的std明显高于 baseline,说明自定义路径当前还没有展现出更稳定的端到端收益。
所以如果只看这轮“固定 prompt + 5 repeats”的正式 benchmark,应该更保守地下结论:
- 自定义 fused 路径已经可以正确嵌入 vLLM 真实推理流程。
- 它在单层 microbenchmark 上能在 decode-like 小 batch 场景取得优势。
- 但在当前实现水平下,这个优势还没有稳定转化成端到端、统计上更有说服力的收益。
换句话说,前面那组 +1.6% / +2.5% 更适合被理解成“单轮观测到的正向信号”,而不是已经被重复实验充分证实的稳定收益。
Profiling 与结果解释
为了把现象解释清楚,我又补了一轮 profiling。原本想直接用 ncu / nsys 拿硬件计数器,但这台服务器上的 CUDA driver 和 Nsight CLI 版本有兼容性问题,硬件计数器和标准 report 导出都不稳定,所以最后采用的是 torch.profiler 做算子级 CUDA 时间统计。这个方法拿不到 occupancy、L2 hit rate 这类硬件计数器,但足够回答下面三个问题:
batch 8的退化是不是发生在 fused kernel 本体。- 退化更像 warp 利用率问题,还是更像访存/数据复用问题。
- attention 线为什么比
down_proj更值得继续优化。
1. batch 8 为什么开始退化
先看 fused kernel 本体的 CUDA 时间。下面的数字是 torch.profiler 对 20 次调用统计出的平均单次 CUDA 时间:
q_proj, batch 1:12.354 usq_proj, batch 8:77.877 usdown_proj, batch 1:29.637 usdown_proj, batch 8:205.781 us
这个结果说明两点:
- 退化确实集中在 fused kernel 本身,不是 Python 调度、
aten::t、cudaLaunchKernel之类的外围开销。profiler 里最主要的 CUDA 时间全部落在nf4_fused_matmul_absmax_*上。 - 当前 kernel 在
batch 1/2/4的工作点上是合适的,但到batch 8已经开始从“decode-like 小 batch GEMV”向“小 GEMM”过渡了。现在这版实现仍然沿用偏小 batch 的 warp/GEMV 组织,所以当M增大时,对输入激活x的复用不够好,batch 维上的工作没有像真正 tiled GEMM 那样被高效摊开。
更直白一点说:batch 8 慢下来,不是 NF4 解码突然出了问题,而是当前 fused kernel 的最优工作区间本来就偏 batch 1/2/4。当 batch 增长后,kernel mapping 开始偏离最优点。
2. 更像 warp 利用率问题,还是寄存器/访存平衡问题
从这轮 profiling 和现有 kernel 结构看,主因更像“访存/数据复用不足”,其次才可能是寄存器压力,而不是单纯的 warp 利用率问题。
依据主要有三条:
- 当前 fused kernel 在
q_proj的batch 1/2/4上已经能赢 bitsandbytes,这说明 warp 级输出组织本身不是根本错误。如果 warp 利用率一开始就很差,它不会在这些工作点上取得正收益。 batch 8的退化是随着M增大逐步出现的,更符合“输入x和量化权重/scale 的流式读取量增加,而 shared-memory / tile 复用不够”的模式。- profiler 看到的主要增长来自 kernel CUDA 时间本体,而不是 launch 次数或额外算子堆积,这说明问题在 kernel 内部的数据流,而不是外层调度。
可以把当前 fused kernel 理解成:它已经完成了“边解码边乘”的第一步,但还没有做到真正 GEMM-style 的 tile 化。于是它的瓶颈更像:
- packed weight 和
absmax仍然偏 streaming 访问 x在 batch 维增大后没有被充分复用- 算术强度不够高,更多表现为内存流量随 batch 增大而放大
所以这部分更适合写成:
- 现阶段更像“访存/数据复用与 kernel mapping 问题”
- 不是先把锅甩给寄存器或 occupancy
当然,要最终把这个判断坐实,后面还是应该补一轮真正的硬件 profiling,目标指标包括:
- achieved occupancy
- registers per thread
- local memory spill
- dram throughput
- L2 hit rate
- warp stall reason
3. 为什么 attention 线比 down_proj 更值得优化
这部分既能从端到端结果看,也能从 profiling 里看。
先看端到端:
- 只替
q_proj:68.58 -> 69.67 tok/s - 替
q/k/v/o:68.58 -> 70.28 tok/s - 只替
down_proj:68.58 -> 69.70 tok/s
attention 线整体替换的收益最高。
再看 profiler 中 fused kernel 的单次 CUDA 时间:
q_proj, batch 1:12.354 usq_proj, batch 8:77.877 usdown_proj, batch 1:29.637 usdown_proj, batch 8:205.781 us
当前 kernel 对 attention 线更友好,主要有三个原因:
attention 投影更贴近 decode 热路径
在自回归 decode 里,每一步都会频繁经过q/k/v/o这些投影,而且输入 batch 通常很小,这正好落在当前 fused kernel 最擅长的工作区间。attention 投影的形状更匹配当前 warp/GEMV 风格实现
当前已经验证过的 attention 形状主要是:(5120, 5120):q_proj/o_proj(1024, 5120):k_proj/v_proj
这些形状更接近当前 kernel 的小 batch 目标场景。
down_proj更宽,更容易暴露当前 kernel 的访存问题down_proj的形状是(5120, 13824),输出更宽,意味着每次调用要流过更多 packed weight 和 scale 数据。由于当前 fused kernel 还没有做真正的 shared-memory tile 复用,这类更宽的矩阵更容易变成内存流量主导。
所以当前阶段最合理的结论不是“down_proj 没价值”,而是:
down_proj也可以接,而且局部路径是能跑通的。- 但它对 kernel 组织的要求更接近通用 GEMM,而不是 decode-like GEMV。
- 在现有实现水平下,attention 线更容易把 fused 路径的优势转成真实端到端收益。
最终判断
NF4 恢复到底发生在哪
- 加载阶段恢复的是
QuantState和 double-quant scale 元数据。 - 普通 dense 路径不会先把整层 NF4 权重恢复成 dense bf16 常驻显存。
- 如果走显式
dequantize_4bit路径,才会发生整层 dense weight 的中间写回。
显式反量化值不值得
对 decode-like 小 batch 场景,不值得直接采用“restore -> dense -> linear”这条路径。即使 restore kernel 局部优化有效,中间写回开销仍然会吞掉大部分收益。
什么方向值得继续
当前最值得继续的方向不是进一步优化显式 restore,而是:
- 继续做直接消费 packed weight 的 fused decode + matmul。
- 优先替 attention 线上的更多层。
- 重点服务 decode-like 小 batch,而不是追求一开始就做通用大 GEMM。
这次改了哪些东西
本地仓库中,本轮核心修改集中在这些文件:
bench_vllm_qwen.pybenchmark_nf4_ikko.pyfs_plugins/custom_ops/nf4_ikko.cppfs_plugins/custom_ops/nf4_ikko.cufs_plugins/custom_ops/nf4_ikko.pysitecustomize.pyvllm_ikko_sitecustomize.pyvllm_bnb_benchmark_notes.mdoutputs/reports/nf4_ikko_experiments.md
其中:
nf4_ikko.cu是 restore kernel 和 fused kernel 的主体。nf4_ikko.py负责 PyTorch 扩展加载与接口封装。vllm_ikko_sitecustomize.py负责把自定义 fused 路径嵌到 vLLM 的 BitsAndBytes 线性层计算中。
小结
这轮实验最大的收获不是“证明 NF4 restore kernel 能更快”,而是把问题切清楚了:
- 默认 bnb 路径真正占优的关键,不只是 LUT 或解码实现,而是它避免了中间 dense weight 写回。
- 只优化 restore 本身,端到端收益有限。
- 直接消费 packed weight 的 fused 路径,在真实 vLLM 服务里已经能拿到小幅但稳定的正收益。
- 目前最值得继续扩展的是 attention 线,而不是单纯围绕显式反量化做更多微优化。
真实命中点排查
在继续做端到端对比之前,我专门做了一轮“真实调用点排查”。原因是前面我尝试在 BitsAndBytesLinearMethod._apply_4bit_weight 这一层做 timing hook,但 hook 没有命中真实 serve 热路径,所以那一层不适合作为最终 profiling 入口。
排查顺序按下面三层往下走:
- vLLM 自己的 BitsAndBytes 包装层
apply_bnb_4bitbitsandbytes.matmul_4bit
源码链路
从 vLLM 安装目录的源码可以直接确认这条链:
_apply_4bit_weightapply_bnb_4bit_apply_bnb_4bitbitsandbytes.matmul_4bit
其中:
apply_bnb_4bit不是普通 Python 函数,而是torch.ops.vllm.apply_bnb_4bit- 它是由 vLLM 把
_apply_bnb_4bit注册成 custom op 得到的
真实请求命中结果
我直接在安装包源码里给这两层加了最粗粒度探针,只记录:
- 有没有被调用
- 输入 shape
- quantized weight 的 shard shape
- 在
matmul_4bit里最终走的是gemv_4bit还是MatMul4Bit.apply
真实请求之后,两个日志文件都稳定命中:
/home/xyc/vllm_apply_bnb_4bit_hits.log/home/xyc/bnb_matmul_4bit_hits.log
关键结论如下:
apply_bnb_4bit确实是vllm serve真实 decode 流程中的热路径。- decode 阶段的
A=(1, 5120)/A=(1, 13824)最终落到了bitsandbytes.matmul_4bit。 - 对低 batch decode,
matmul_4bit最终走的是gemv_4bitfast path,不是MatMul4Bit.apply慢路径。 - prefill 大 shape(例如
A=(2048, 5120))更接近MatMul4Bit.apply这条通用路径。
也就是说,真实 decode 热路径并不是:
- “先
dequantize_4bit再通用linear”
而是:
apply_bnb_4bit -> bitsandbytes.matmul_4bit -> gemv_4bit
q/k/v/o 的真实 shard 顺序
进一步把 _apply_bnb_4bit 里 quant_states[i] 的真实循环顺序打出来之后,可以确认 attention 线在 decode 阶段的三分支顺序是:
q_projk_projv_proj
对应的实际 shard shape 是:
q_proj:(5120, 5120)k_proj:(1024, 5120)v_proj:(1024, 5120)
而独立的 o_proj 则是单独的:
o_proj:(5120, 5120)
这一步很重要,因为它说明后续如果要做 q/k/v/o 的逐项 timing,就应该把上下文从 _apply_bnb_4bit 传到 bitsandbytes.matmul_4bit,而不是继续在更上层的错误入口做统计。
我是如何定位、替换并计时的
这一段单独总结一下方法。整个过程我实际上是按三步做的:
- 先找真实调用链
- 再选替换点
- 最后做计时
1. 我是怎么找到 vLLM 调 bnb 的
最开始先读 vLLM 里 BitsAndBytes 相关代码,入口是:
vllm/model_executor/layers/quantization/bitsandbytes.py
从这里可以先看到高层结构:
BitsAndBytesLinearMethod._apply_4bit_weight(...)- 它会调
apply_bnb_4bit(...) - 这个名字虽然像 Python 函数,但实际注册成了
torch.ops.vllm.apply_bnb_4bit
然后继续顺着看 bitsandbytes 本身,在:
bitsandbytes/autograd/_functions.py
里可以确认:
bitsandbytes.matmul_4bit(...)- decode-like 小 batch 时会优先走
gemv_4bit - 大 shape 才更容易落到
MatMul4Bit.apply
但光读代码还不够,因为 vllm serve 的真实 worker 路径不一定和最初猜的 Python 层入口一致。所以接下来做的不是先计时,而是先打最粗粒度的“命中探针”。
我在远端安装包里直接加了最简单的日志,只记录:
- 有没有被调用
- 输入 shape
- weight shard shape
- 最终走的是
gemv_4bit还是MatMul4Bit.apply
具体加在这两个位置:
- vLLM 的
bitsandbytes.py里,给apply_bnb_4bit/_apply_bnb_4bit打日志 - bitsandbytes 的
_functions.py里,给matmul_4bit打日志
真实请求跑完后,日志文件稳定命中:
/home/xyc/vllm_apply_bnb_4bit_hits.log/home/xyc/bnb_matmul_4bit_hits.log
这样就能确认真实 decode 热路径是:
BitsAndBytesLinearMethod._apply_4bit_weightapply_bnb_4bit_apply_bnb_4bitbitsandbytes.matmul_4bitgemv_4bit
也就是说,真实低 batch decode 不是走通用的 dequantize_4bit + linear,而是走 bnb 的 gemv_4bit fast path。
2. 我是怎么替换的
替换我做过两种方式。
第一种是早期的猴补丁:
sitecustomize.pyvllm_ikko_sitecustomize.py
这条路适合快速试验,但后来验证发现,它不能稳定命中 vllm serve 的真实 worker 热路径,所以后面的正式分析没有继续依赖它。
第二种是直接改远端已安装包,这才是后面真正用于确认和实验的主线。
我在远端这两个文件里做了替换逻辑:
- vLLM 的
bitsandbytes.py - bitsandbytes 的
_functions.py
替换策略不是全局替换,而是只在目标形状上切换:
(5120, 5120)对应q_proj/o_proj(1024, 5120)对应k_proj/v_proj
真正的自定义实现仍然放在本地仓库:
fs_plugins/custom_ops/nf4_ikko.cufs_plugins/custom_ops/nf4_ikko.cppfs_plugins/custom_ops/nf4_ikko.py
然后在远端 matmul_4bit 的真实命中点上做条件分支:
- baseline:继续走原始
gemv_4bit - patch:如果是 attention 投影的目标 shape,就切到
matmul_4bit_ikko - 其他层全部保持原样
这样做的好处是:
- 调用链没变
- 只替换真实热路径里的局部算子
- baseline 和 patch 的系统环境几乎一致,更便于对比
3. 我是怎么计时的
计时也分三层。
第一层是命中确认,不计时间。
这是为了避免在错误入口上做 profiling。先只记录:
- 有没有被调用
xshapeqweightshape- shard 顺序
靠这个先确认:
q/k/v三个 shard 在真实_apply_bnb_4bit里的顺序o_proj是单独一个(5120, 5120)调用- decode 阶段最终落在
gemv_4bit
第二层是在真实命中点做聚合计时。
命中点确认之后,我在 bitsandbytes 的 _functions.py 里给 matmul_4bit 加了 CUDA event 计时。做法是:
- 进入
matmul_4bit前记录torch.cuda.Event - 返回后再记录一个 event
torch.cuda.synchronize()后取 elapsed time- 按 layer kind 聚合
为了让 matmul_4bit 知道当前是哪一个 projection,我先在 vLLM 那层把当前 shard 标成:
q_projk_projv_projo_proj
然后把这个上下文传给 bitsandbytes 层。这样 matmul_4bit 内部就知道这次调用属于哪个 projection,最后每个 worker 进程把统计写到 JSON,例如:
/home/xyc/proj_baseline_eager.*.json/home/xyc/proj_attn_eager.*.json
再把多个 worker 的 JSON 聚合,得到:
q_projk_projv_projo_proj- 合成后的
per-token projection time
第三层是和端到端 benchmark 对齐。
端到端我单独用 bench_vllm_qwen.py 跑固定 prompt、重复多轮。这样最终手里有两组数据:
- 局部真实命中点时间
- projection time 到底降了多少
- 端到端系统时间
- throughput / e2e 到底有没有跟着降
靠这两组数据,才能最后判断:
- kernel 接入是否真的生效
- 局部优化有没有传导到系统收益
- 如果没有,问题是在 kernel 本身还是系统其他部分
一句话总结
整个过程可以压成四步:
- 先从 vLLM 的
bitsandbytes.py和 bitsandbytes 的_functions.py读调用链 - 再在远端真实安装包里加命中日志,确认
vllm servedecode 真正走到matmul_4bit -> gemv_4bit - 然后只在这个真实命中点上做条件替换,把 attention 投影切到自定义 fused kernel
- 最后在这个真实命中点上做 CUDA event 聚合计时,再和固定 prompt 的端到端 benchmark 对比
Attention Projection 时间对比
在确认真实命中点之后,我又做了一轮更直接的实验:
- 把
q/k/v/o的上下文从 vLLM 的_apply_bnb_4bit传到bitsandbytes.matmul_4bit - 在
matmul_4bit这个真实热路径上聚合 decode 阶段的 CUDA 时间 - 做两组对比:
- baseline
- attn patch
为什么这里用了辅助配置
这里有一个必须说明的点。
如果直接在生产态配置下给 matmul_4bit 内联插入 CUDA event timing,vLLM 的 torch.compile / cudagraph / capture 流程会在 warmup 阶段报错。因此:
- 调用链的确认,是在生产态配置下完成的
- 时间聚合,则放到辅助观测配置下完成:
--enforce-eager
这个配置不适合拿来代表最终 e2e 性能,但非常适合回答一个更具体的问题:
- “真实命中点上的 attention projection 时间,到底降了没有?”
baseline vs attn patch
两边都使用同一份固定 prompt 集,并只统计 decode 阶段命中的 q/k/v/o 投影时间。
结果如下,单位都是 ms_per_token_row:
| projection | baseline | attn patch |
|---|---|---|
q_proj |
0.03298 ms |
0.01560 ms |
k_proj |
0.01964 ms |
0.01088 ms |
v_proj |
0.02114 ms |
0.01024 ms |
o_proj |
0.03620 ms |
0.01151 ms |
把四项加总得到:
- baseline per-token projection time:
0.10997 ms - attn patch per-token projection time:
0.04824 ms
也就是:
- 每 token 的 attention projection 总时间下降了
0.06173 ms - 降幅约
56.1%
这说明什么
这个结果非常关键,因为它回答了一个此前没有分清的问题:
- attention projection 这一项本身,确实明显下降了。
- 所以问题不在“kernel 根本没接进去”。
- 如果 e2e benchmark 仍然没有稳定改善,那么更大的问题就在系统别处,而不是 attention projection 本身。
换句话说,当前结论应该分两层:
kernel / projection 子系统层面:优化是成立的,而且降幅很明显serve 端到端层面:收益还没有稳定穿透整条系统链路
这也把后续方向收窄了:
- 继续提升 attention 覆盖率是合理的,但不是唯一问题。
- 需要进一步量化非 attention 层、框架调度、prefill 路径以及其他 decode 固定开销。
- 如果 projection 已经降了 50% 以上而 e2e 没明显跟上,那真正拖慢整体的部分已经不再是 attention projection 本身。
理论收益上限 vs 实测收益
在确认了 per-token projection time 的真实下降之后,可以进一步算一个很关键的量:
- 即便 attention projection 优化完全兑现到 decode 阶段,理论上最多能给端到端 completion 吞吐带来多大提升?
1. baseline 的 decode per-token 总时间
用前面“固定 prompt + 5 repeats”的 baseline completion 吞吐:
- baseline throughput:
71.01 tok/s
把它换算成每个输出 token 的平均时间:
- baseline decode per-token total time =
1000 / 71.01 ≈ 14.08 ms/token
也就是说,从系统视角看,baseline 每生成一个 token,大约要花 14.08 ms。
2. projection 优化带来的理论可回收时间
前面的真实命中点 timing 已经算出:
- baseline per-token projection time:
0.10997 ms - attn patch per-token projection time:
0.04824 ms
所以 attention projection 这部分每个 token 实际减少了:
0.10997 - 0.04824 = 0.06173 ms/token
3. 理论收益上限
如果假设:
- 这
0.06173 ms/token能无损地全部体现在 decode 总时间上 - 其他所有部分都完全不变
- 没有新增任何框架开销、同步开销或集成成本
那么新的理论最优 decode 时间应为:
14.08 - 0.06173 = 14.02 ms/token
对应的理论最优吞吐约为:
1000 / 14.02 ≈ 71.32 tok/s
相对 baseline 的理论上限提升约为:
(14.08 - 14.02) / 14.08 ≈ 0.44%
也可以直接理解成:
- 即便 attention projection 这一项已经降了
56.1% - 但它在整个 decode token 时间里原本只占
0.10997 / 14.08 ≈ 0.78% - 所以它能带来的端到端理论上限,本来就只有大约
0.44%
4. 和实际 benchmark 的对照
更干净的 5 repeats benchmark 结果是:
- baseline:
71.01 ± 0.07 tok/s - attn:
70.76 ± 0.64 tok/s
所以实测上:
- 实际没有出现稳定正收益
- 反而比 baseline 低了大约
0.25 tok/s - 约为
-0.35%
把这个结果和理论上限放在一起看,就很容易理解了:
- attention projection 优化本身是成立的,而且在真实命中点上降幅很大。
- 但它在 decode 总时间中所占比例本来就很小,因此理论收益上限只有
0.44%左右。 - 这个量级已经接近甚至低于系统级 benchmark 的自然波动和额外集成开销。
- 所以最终没有在 e2e benchmark 里形成稳定提升,是完全合理的。
5. 这一步真正说明了什么
这一步最大的价值不是证明“优化失败”,而是把问题定量化了:
- attention projection 不是没有优化成功,而是它对整体 decode token 时间的贡献太小。
- 即使局部减少了一半以上,这一项对全局吞吐的理论影响仍然不到
0.5%。 - 这意味着如果想在端到端上看到稳定、明显的收益,下一步必须优化更大头的部分,而不是继续只盯着 attention projection。
理论上限与实际收益
既然已经确认了 attention projection 总项确实下降,下一步自然要回答:
- baseline 的 decode per-token 总时间是多少?
- 如果只优化了 attention projection,这个优化理论上最多能带来多大收益?
- 实验里最终测到的实际收益是多少?
1. baseline 的 decode per-token 总时间
用固定 prompt、5 次重复 benchmark 里的 baseline completion 吞吐:
- throughput:
71.01 tok/s
可以直接换算得到 decode 每生成 1 个 token 的平均时间:
- baseline decode per-token total time =
1 / 71.01 s - 约等于
14.08 ms/token
这就是 baseline 的整体 decode 时间预算。
2. attention projection 优化的理论收益上限
前面已经量到,在真实命中点上:
- baseline per-token projection time:
0.10997 ms - attn patch per-token projection time:
0.04824 ms
所以 attention projection 这一项实际节省的是:
0.10997 - 0.04824 = 0.06173 ms/token
把这个节省量放到 baseline 的总 decode 时间预算里:
- theoretical max speedup =
0.06173 / 14.08 - 约等于
0.44%
也就是说,即使把 attention projection 这一项的优化完全、无损地传导到端到端 decode,理论收益上限也只有大约:
0.44%
这个数字很关键,因为它解释了为什么前面虽然看到 projection 本身降了很多,但 e2e benchmark 仍然不明显。
3. 实验里测到的实际收益
固定 prompt、5 次重复 benchmark 的端到端结果是:
- baseline throughput:
71.01 ± 0.07 tok/s - attn throughput:
70.76 ± 0.64 tok/s
换成相对 baseline 的实际收益:
- throughput 变化约为
-0.35%
对应的平均端到端时延:
- baseline avg e2e:
0.5605 ± 0.0005 s - attn avg e2e:
0.5619 ± 0.0039 s
相对变化约为:
- avg e2e 变化约为
+0.25%
4. 这一组数字的含义
把三组数字放在一起看:
- baseline decode total:
14.08 ms/token - projection 节省量:
0.06173 ms/token - 理论上限:
0.44% - 实际测得:
-0.35%throughput,基本落在“没有稳定收益”的区间
所以现在可以更清楚地下结论:
- attention projection 确实优化成功了。
- 但它在整个 decode 总时间里只占很小一部分。
- 因此,即使这一项下降了
56%,折算到整体 decode 上,理论收益上限也只有不到0.5%。 - 这已经足以解释为什么端到端 benchmark 看不到稳定提升。
换句话说,当前不是“优化没生效”,而是:
- 优化生效了
- 但覆盖到的系统时间占比还不够大
- 所以总体收益天花板本来就很低
Decode 时间拆分:projection vs non-projection
为了把最后一个问题闭环,还需要再做一个简单但关键的拆分:
- total decode per-token time
- projection time
- non-projection time
其中:
non-projection time = total decode per-token time - projection time
这样就能直接回答:
- 除了“attention projection 预算太小”之外,是否还有额外系统开销把收益吃掉了?
1. baseline 的拆分
前面已经有两组 baseline 数据:
- baseline total decode per-token time:
14.08 ms/token - baseline per-token projection time:
0.10997 ms/token
所以 baseline 的 non-projection time 是:
14.08 - 0.10997 = 13.97 ms/token
更精确一点,按前文同样的换算口径:
- baseline total decode per-token time:
1000 / 71.01 ≈ 14.0825 ms/token - baseline non-projection time:
14.0825 - 0.10997 ≈ 13.9725 ms/token
2. attn patch 的拆分
attn patch 对应的数据是:
- attn throughput:
70.76 tok/s - attn total decode per-token time:
1000 / 70.76 ≈ 14.1323 ms/token - attn per-token projection time:
0.04824 ms/token
所以 attn patch 的 non-projection time 是:
14.1323 - 0.04824 ≈ 14.0841 ms/token
3. 拆分结果对照
把两边并排放在一起:
| 模式 | total decode | projection | non-projection |
|---|---|---|---|
| baseline | 14.0825 ms/token |
0.10997 ms/token |
13.9725 ms/token |
| attn patch | 14.1323 ms/token |
0.04824 ms/token |
14.0841 ms/token |
对应变化量:
- projection 下降:
0.10997 - 0.04824 = 0.06173 ms/token - non-projection 上升:
14.0841 - 13.9725 = 0.1116 ms/token - total 反而增加:
14.1323 - 14.0825 = 0.0498 ms/token
4. 这组数字说明什么
这组拆分把整件事说得比较完整了:
- attention projection 这一项本身确实明显下降了,节省了
0.06173 ms/token。 - 但与此同时,non-projection 部分反而增加了约
0.1116 ms/token。 - 因此在总账上,节省掉的 projection 时间不仅被完全吃掉了,还额外多出了一点系统开销。
也就是说,当前结果并不只是“projection 预算太小”,还包括:
- 接入自定义 kernel 之后,系统的非 projection 部分出现了少量额外成本
这类额外成本可能来自:
- 不同 kernel dispatch 带来的额外调度成本
- 与原始 bnb
gemv_4bitfast path 不同的 runtime 行为 - 框架层集成、同步、launch、capture 兼容性差异
- 更细粒度算子替换之后,原有重叠与流水被轻微打破
5. 最终闭环
因此,这轮实验的闭环结论可以写成:
- attention projection 优化本身是有效的,而且在真实命中点上下降了约
56.1%。 - 但这一项在 baseline decode 总时间里本来只占不到
1%,理论收益上限只有约0.44%。 - 同时,接入自定义 kernel 后,non-projection 部分又增加了大约
0.11 ms/token的额外系统开销。 - 于是最终端到端收益被完全抵消,clean benchmark 里看不到稳定提升。
需要说明的是,这里的 projection time 来自辅助观测配置下的真实命中点 timing,而 total decode time 来自固定 prompt、5 次重复的 clean benchmark;两者口径不完全相同,因此这个拆分更适合用来做系统解释,而不是把每一位小数都当成绝对精确值。
额外开销排查:到底可疑在哪
在做完上面的时间拆分之后,还需要进一步问一句:
- 这部分看起来像
+0.11 ms/token的 non-projection 开销,到底最可疑的是哪一层?
这里我先做了一轮代码级排查,而不是继续猜。
1. 已经基本排除的部分
有一类开销虽然存在,但不太可能解释 clean benchmark 里的端到端结果:
- profiling 用的
torch.cuda.Event end.synchronize()- JSON dump / atexit 写文件
这些逻辑确实存在于远端安装包里的 bitsandbytes _functions.py 改动中,但它们都挂在显式环境变量后面:
VLLM_IKKO_PROFILE_PROJ=1
只有做 projection timing 时才会启用。也就是说,它们会影响辅助观测配置下的 timing,但不应该进入最终那轮 clean benchmark。
所以 clean benchmark 里的额外系统成本,主因大概率不是 profiling 插桩残留。
2. 目前最可疑的真实来源
现在最可疑的一点,其实在自定义 op 的 Python 包装层里。
fused_matmul_4bit_ikko(...) 这一层当前的实现是:
1 | packed_weight = _as_cuda_contiguous(packed_weight_t.t().contiguous(), torch.uint8) |
也就是说,在每次调用时,它都会对输入的 packed weight 做:
.t().contiguous()_as_cuda_contiguous(...)
如果 packed_weight_t 在 vLLM / bitsandbytes 的真实调用点上本来不是目标布局,那么这里就可能在每次 projection 调用时引入一次额外的 packed-weight 重排或拷贝。
这件事的性质和前面分析的“projection 算子内部时间”不一样:
- projection timing 统计的是
matmul_4bit这一调用本身的时间 - 但如果包装层为了适配自定义 kernel 做了额外 layout 处理,它可能带来额外的显存流量和 runtime 成本
- 这正是最像“收益被系统侧吃掉”的那类问题
因此,当前最强的怀疑对象不是 NF4 decode 本身,而是:
- 自定义 fused kernel 仍然没有直接吃到 vLLM / bnb 真实路径上的原生 packed layout
- 中间为了适配接口,发生了额外的 transpose / contiguous / copy
3. 第二可疑来源
除了 weight layout 适配之外,第二可疑来源是接入方式本身:
- 现在 patch 是在 bitsandbytes 的 Python
matmul_4bit包装层做条件分支 - 命中 attention shape 后再转到自定义扩展
这意味着即使 projection 算子本体更快,仍然可能引入:
- 额外的 Python dispatch 分支
- 不同于原始
gemv_4bitfast path 的 runtime 行为 - 与 cudagraph / compile / capture 协同不如原生路径
这类成本通常不大,但在理论收益上限只有 0.44% 的前提下,已经足够把端到端收益吃光。
4. 当前阶段最合理的判断
基于现有代码排查,可以先下一个相对稳妥的判断:
- 额外开销的主因不像 profiling 插桩残留。
- 最可疑的真实来源,是自定义 op 包装层里对 packed weight 的
transpose + contiguous适配。 - 第二可疑来源,是在 Python 包装层切换路径,而不是像原生 bnb 那样直接落到更紧的 native fast path。
也就是说,这部分额外系统成本更像是:
- layout 适配成本
- 接入层 runtime 成本
而不是:
- NF4 恢复 kernel 本身太慢
5. 如果要继续验证,最值得做的一件事
下一步最有价值的单点验证不是继续改 kernel,而是直接改接口:
- 避免
packed_weight_t.t().contiguous() - 直接让自定义 kernel 消费 vLLM / bitsandbytes 当前真实命中的 packed layout
如果这样改完之后,端到端表现明显改善,那么就能基本确认:
- 被吃掉的收益,主要就是 layout 适配和接入层成本
6. 一个最小验证:直接吃当前 packed layout
针对上面的第一嫌疑,我又专门做了一轮最小验证:
- 不改算法
- 不加新优化
- 只把 kernel 接口改成直接消费当前
packed_weight_t的布局 - 目标是验证:去掉
t().contiguous()这一步之后,额外系统成本能不能明显下降
这轮修改新增了一条 direct-layout 路径:
fused_matmul_4bit_ikko_direct(...)- 对应 CUDA 实现:
nf4_fused_matmul_absmax_transposed_kernel
它和原来的 copy 版 fused path 的区别只有一个:
- 原版:先
packed_weight_t.t().contiguous(),再交给 fused kernel - direct 版:直接吃当前
packed_weight_t的原始布局
也就是说,这轮实验只验证“去掉 layout 适配”这一件事,没有引入新的算法变化。
在 q_proj 上的单层 microbenchmark 结果如下:
| batch | bnb | fused(copy) | fused(direct layout) |
|---|---|---|---|
| 1 | 0.01759 ms |
0.01308 ms |
0.04872 ms |
| 2 | 0.06081 ms |
0.02292 ms |
0.08468 ms |
| 4 | 0.06072 ms |
0.04186 ms |
0.15874 ms |
| 8 | 0.06302 ms |
0.07926 ms |
0.30421 ms |
这组结果说明:
- “直接吃当前 packed layout” 这件事本身并没有自动减少开销。
- 相反,这个 direct-layout 版本比原来的 copy 版 fused kernel 明显更慢。
- 这意味着此前怀疑的
t().contiguous()虽然可疑,但它至少不是一个可以靠“简单去掉”就收回收益的成本点。
更准确地说,问题不只是“有没有 copy”,还包括:
- 当前自定义 kernel 的访存组织,本身就是围绕转置后的 row-major packed layout 写的
- 一旦直接吃原始
packed_weight_t布局,访存模式会明显变差 - 所以虽然少了一步显式布局转换,但 kernel 内部的 global memory 访问更不友好,反而总体更慢
因此,这轮最小验证给出的结论是:
- 额外系统成本不能靠“直接去掉
t().contiguous()”这一招简单消除 - 如果真要把 layout 适配成本收回来,必须连 kernel 的访存组织一起改,而不是只改接口
7. 对“第二可疑来源”的验证:Python 包装层分支是不是主因
前面还有一个第二嫌疑:
- patch 现在是在 bitsandbytes 的 Python
matmul_4bit包装层里做条件分支 - 再从这里转到自定义 extension
这个路径确实比原生 gemv_4bit fast path 多了一层 runtime / dispatch,所以需要单独验证:
- 额外开销到底是不是主要来自这层 Python 分支?
为了把这个问题拆干净,我没有继续用 vllm serve 做整机实验,而是做了一个单层 wrapper-mode benchmark,只测 q_proj 的 decode-like 小 batch。这个实验固定同一个输入和同一个 matmul_4bit 调用点,只改三种模式:
baseline- 原始 bnb 路径
- 不传 projection 上下文
- 不走任何 patch
shim- 走同样的 projection 上下文和 Python 条件判断
- 但最终仍然落回原始 bnb
gemv_4bit - 也就是说,它只保留“接入层分支”,不切换到自定义 kernel
patch- 走同样的 projection 上下文和 Python 条件判断
- 命中后切到自定义 fused kernel
如果 shim 明显慢于 baseline,就说明 Python 包装层这点分支本身就已经很贵。
如果 shim ≈ baseline,而 patch 才和它们拉开,那就说明真正的成本主要在自定义 op 路径,而不是那几行 Python 判断。
q_proj 的结果如下:
| batch | baseline | shim | patch |
|---|---|---|---|
| 1 | 0.02289 ms |
0.01973 ms |
0.03404 ms |
| 2 | 0.06114 ms |
0.06057 ms |
0.06058 ms |
| 4 | 0.06084 ms |
0.06091 ms |
0.06086 ms |
| 8 | 0.06250 ms |
0.06253 ms |
0.06256 ms |
更关键的是差值:
- batch 1:
shim - baseline = -0.00317 mspatch - baseline = +0.01114 ms
- batch 2:
shim - baseline = -0.00057 mspatch - baseline = -0.00056 ms
- batch 4:
shim - baseline = +0.00008 mspatch - baseline = +0.00003 ms
- batch 8:
shim - baseline = +0.00003 mspatch - baseline = +0.00006 ms
这组结果可以很直接地说明:
shim几乎贴着baseline,说明 Python 包装层那点 projection 分支本身不是主要成本来源。- 至少在单层
matmul_4bit这个粒度上,“多一层 Python 条件判断”带来的成本基本可以忽略。 - 因此,第二可疑来源虽然在系统层面仍可能影响 capture / graph 协同,但它不是当前额外开销的主要解释。
换句话说,这轮验证把“第二嫌疑”明显降级了:
- Python 包装层分支不是主要问题
- 更值得继续怀疑的,仍然是自定义 kernel 路径本身的 layout / 访存 / runtime 组织
- Title: qwen部署
- Author: Ikko
- Created at : 2026-03-29 17:23:18
- Updated at : 2026-04-25 22:00:05
- Link: http://ikko-debug.github.io/2026/03/29/qwen/
- License: This work is licensed under CC BY-NC-SA 4.0.