HYMBA论文粗总结

Ikko Lv3

Hymba:小型语言模型的新标杆——融合SSM与Attention的混合架构

structure
状态空间模型(SSM,如Mamba)以线性复杂度崭露头角,但其回忆能力不足。2025年ICLR,NVIDIA团队提出了Hymba,一种小型语言模型(LM),通过创新的“混合头”(Hybrid-Head)架构,巧妙融合SSM和Attention。

一、SSM(状态空间模型)的结构

SSM是一种基于控制论的序列建模方法,Mamba(Gu & Dao, 2023)是其代表。它通过线性复杂度处理长序列,成为Transformer的有力竞争者。

1.1 SSM的基本原理

SSM将序列建模视为一个动态系统,用状态转移方程描述输入到输出的映射:

  • 输入:序列,每个
  • 状态:隐状态 (s是状态维度,如Hymba中的16)。
  • 输出

其离散形式为:

  • :状态转移矩阵。
  • :输入投影矩阵。
  • :输出投影矩阵。

1.2 Mamba的改进

Mamba在传统SSM基础上引入了数据依赖的参数和门控机制:

  • 动态参数
    • :固定但可学习。
    • :随输入动态变化。
    • :同上。
    • :时间步长,控制遗忘速度。
  • 门控:引入门控向量 ,输出为:
  • 计算
    • 递归形式:
    • 并行形式(推理优化):用卷积或矩阵运算展开:

      其中 是预计算的转移函数。

1.3 SSM的优劣

  • 优点:线性复杂度(),小缓存(Hymba中仅1.87MB)。
  • 缺点:低分辨率回忆,推理能力弱(Tab. 1显示Mamba在回忆任务仅19.23%)。

二、Hymba的结构

Hymba-1.5B是Hymba家族的主力型号,参数量1.52亿,专为小型设备设计。其架构以32层混合头模块为核心,融合了SSM和Attention,并引入优化策略。

2.1 总体架构

  • 层数:32层。
  • 隐藏维度:1600。
  • 输入:原始序列 前置128个Meta Tokens,形成
  • 输出:通过线性投影生成词汇概率。

参数配置

属性 Hymba-1.5B
层数 32
隐藏维度 1600
SSM状态维度 16
Attention头数 25
查询分组(GQA) 5
全局注意力层 3(1、16、32)
窗口大小 1024
MLP隐藏维度 5504

2.2 混合头模块(Hybrid-Head Module)

每层是一个混合头模块,同时包含Attention和SSM头,输入 被并行处理。

Attention头

  • 数量:25个头。
  • 计算
    • 头维度
  • 类型
    • 全局(3层):完整序列。
    • 局部(SWA,29层):窗口1024。
  • GQA:25头分5组,共享键值。

SSM头

  • 实现:基于Mamba。
  • 计算
    • :门控。
    • :输入特征。
    • 同Mamba。
  • 状态维度:16。

融合

  • 输出融合
    • :可学习缩放。
    • :平衡幅度差异(Fig. 9)。

2.3 Meta Tokens

  • 数量:128个
  • 作用:优化注意力分布,存储知识。
  • 实现:训练时优化,推理时预计算为KV初始化。

2.4 KV优化

  • 全局+局部:3层全局,29层SWA。
  • 跨层共享:32层分13组,每组共享KV。

三、Hymba如何融合SSM与Attention

Hymba的融合方式是其最大亮点,区别于Jamba的串联设计,它采用并行融合

3.1 并行处理

  • 同一层内:输入 同时送入Attention和SSM头。
  • 协同作用
    • Attention:提供“快照记忆”,高分辨率回忆(Tab. 1 recall从19.23%提升至51.79%)。
    • SSM:提供“渐退记忆”,高效总结上下文。
  • 对比串联:Jamba分层交替可能导致瓶颈,Hymba并行避免信息丢失(ERF更大,Fig. 8)。

3.2 输出融合

  • 加权平均:Attention和SSM输出通过归一化和缩放融合,确保稳定性。
  • 实验验证:Tab. 11显示并行优于拼接(concat),性能更高且参数更少。

3.3 Meta Tokens的桥梁作用

  • 注意力引导:减少BOS关注(Fig. 11),让Attention更专注关键token。
  • SSM增强:降低熵(Fig. 13),提升SSM的上下文提取能力。

3.4 KV优化的协同

  • SSM支持:全局上下文由SSM总结,允许Attention用局部SWA。
  • 结果:缓存减少11.67倍,吞吐量提升3.49倍(Tab. 2)。

四、前向传播伪代码

由原文伪代码理解得到
forward
理解如下
输入:X = [x_1, …, x_n]

  1. 前置Meta Tokens:\tilde{X}^0 = [R, X]
  2. 遍历32层:
    for i in [1, …, 32]:
    if i in [1, 16, 32]: # 全局注意力
    \tilde{X}^i = HymbaBlock-GA(\tilde{X}^{i-1})
    else: # 局部注意力
    if i是KV共享组首层:
    \tilde{X}^i, KV^i = HymbaBlock-SWA(\tilde{X}^{i-1})
    else:
    复用上一层KV^{i-1}
    \tilde{X}^i = HymbaBlock-SWA(\tilde{X}^{i-1}, KV^{i-1})
  3. 输出:最后一层结果经投影生成预测。

不过他给出的伪代码中并没有给出ssm的计算,于是在官方库推理代码中找到关键代码如下
解码器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

class HymbaDecoderLayer(nn.Module):
def __init__(self, config: HymbaConfig, num_experts: int, layer_idx: int, reuse_kv: bool = False):
super().__init__()

self.config = config
self.layer_idx = layer_idx
self.reuse_kv = reuse_kv

self.mamba = HymbaBlock(config=config, layer_idx=layer_idx, reuse_kv=reuse_kv)

self.input_layernorm = HymbaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.intermediate_size = config.intermediate_size
if self.intermediate_size > 0:
num_experts_per_tok = config.num_experts_per_tok if num_experts > 1 else 1

self.moe = HymbaSparseMoeBlock(config, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)

self.pre_moe_layernorm = HymbaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)


def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask_raw: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
kv_last_layer = None,
use_swa=False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

hidden_states, attn_key_value, present_key_value = self.mamba(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=position_ids,
kv_last_layer=kv_last_layer,
use_cache=use_cache,
use_swa=use_swa
)

bs, seqlen, _ = hidden_states.shape
past_seqlen = self._get_past_seqlen(past_key_value, seqlen)
num_attention_heads = self.mamba.config.num_attention_heads
self_attn_weights = torch.empty(bs, num_attention_heads, seqlen, past_seqlen, device="meta")

# residual connection after mamba
hidden_states = residual + hidden_states

if self.intermediate_size > 0:
residual = hidden_states
hidden_states = self.pre_moe_layernorm(hidden_states)
hidden_states, router_logits = self.moe(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

if output_router_logits:
outputs += (router_logits,)

outputs += (attn_key_value,)

return outputs

def _get_past_seqlen(self, past_key_value, seqlen):
if past_key_value is None:
return seqlen
past_seqlen = past_key_value.get_seq_length()

if past_seqlen == 0:
return seqlen

return past_seqlen

attention计算

1
2
3
4
5
6
7
8
9
10
if self.reuse_kv:
query_states, hidden_states = projected_states.tensor_split((self.attn_hidden_size,), dim=1)
query_states = query_states.transpose(1, 2)
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, kv_last_layer=kv_last_layer, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)
else:
query_states, key_states, value_states, hidden_states = projected_states.tensor_split((self.attn_hidden_size, self.attn_hidden_size + self.k_hidden_size, self.attn_hidden_size + self.k_hidden_size + self.v_hidden_size), dim=1)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_outputs, attn_key_value = self.self_attn(attention_mask=attention_mask, position_ids=position_ids, query_states=query_states, key_states=key_states, value_states=value_states, use_swa=use_swa, use_cache=use_cache, past_key_value=cache_params)

ssm计算

1
2
3
4
5
6
7
8
index = 0
ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
time_step, B, C = torch.split(ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1)
time_step, B, C = self._apply_layernorms(time_step, B, C)
discrete_time_step = self.dt_proj[index](time_step).transpose(1, 2)
A = -torch.exp(self.A_log[index].float())
scan_outputs = selective_scan_fn(hidden_states, discrete_time_step, A, B.transpose(1, 2), C.transpose(1, 2), self.D[index].float(), z=gate, delta_bias=time_proj_bias, delta_softplus=True, return_last_state=True)
scan_outputs = scan_outputs.transpose(1, 2)

融合

1
2
hidden_states = (self.pre_avg_layernorm1(attn_outputs) + self.pre_avg_layernorm2(scan_outputs)) / 2
contextualized_states = self.out_proj(hidden_states)
  • Title: HYMBA论文粗总结
  • Author: Ikko
  • Created at : 2025-03-11 14:51:10
  • Updated at : 2025-03-14 16:33:19
  • Link: http://ikko-debug.github.io/2025/03/11/HYMBA/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments