mamba

ssm与其它序列模型的比较
(1) RNN(循环神经网络)
基本结构:
$$
h_t = \sigma(W_h h_{t-1} + W_x x_t + b)
$$- 通过隐藏状态 $ h_t $ 传递历史信息。
- 使用 简单非线性激活(如 tanh) 更新状态。
问题:
- 梯度消失/爆炸:难以学习长距离依赖。
- 顺序计算:无法并行训练。
(2) LSTM(长短期记忆网络)
改进点:引入 门控机制(输入门、遗忘门、输出门) 控制信息流:
$$
\begin{aligned}
f_t &= \sigma(W_f [h_{t-1}, x_t] + b_f) \quad &\text{(遗忘门)} \
i_t &= \sigma(W_i [h_{t-1}, x_t] + b_i) \quad &\text{(输入门)} \
o_t &= \sigma(W_o [h_{t-1}, x_t] + b_o) \quad &\text{(输出门)} \
\tilde{C}t &= \tanh(W_C [h{t-1}, x_t] + b_C) \quad &\text{(候选记忆)} \
C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad &\text{(记忆更新)} \
h_t &= o_t \odot \tanh(C_t) \quad &\text{(输出)}
\end{aligned}
$$优势:
- 比普通 RNN 更能捕捉长距离依赖。
- 通过门控缓解梯度消失问题。
局限:
- 仍依赖 顺序计算,训练速度慢。
- 长序列下可能仍会丢失早期信息。
(3) SSM(结构化状态空间模型)
核心思想:
- 受控于连续时间状态空间方程,离散化后处理序列:
$$
\begin{aligned}
h_t &= \overline{A} h_{t-1} + \overline{B} x_t \
y_t &= C h_t + D x_t
\end{aligned}
$$ - 通过 结构化参数(如对角矩阵) 约束 $ A $,提升计算效率。
- 受控于连续时间状态空间方程,离散化后处理序列:
优势:
- 线性复杂度(O(N)),适合超长序列。
- 训练时可并行(通过卷积或并行扫描)。
- 推理时可递推(类似RNN),内存占用低。
** 关键区别总结**
特性 | RNN | LSTM | SSM(如S4/Mamba) |
---|---|---|---|
长序列建模能力 | ❌ 梯度消失 | ✅ 优于RNN | ✅ 最优(理论无限上下文) |
训练并行性 | ❌ 顺序计算 | ❌ 顺序计算 | ✅ 卷积/并行扫描 |
计算复杂度 | O(N) | O(N) | O(N) |
推理模式 | 递推 | 递推 | 可递推或并行 |
参数效率 | 低 | 较高(门控机制) | 高(结构化参数) |
典型应用 | 短序列任务 | 文本、语音 | 超长序列(DNA、音频、时间序列) |
ssm的并行性
SSM(状态空间模型)在并行计算中依赖快速傅里叶变换(FFT)。FFT是计算卷积的最优工具之一,尤其在处理长序列时。
1. SSM与卷积的等价性
在现代SSM实现中,状态空间模型可以通过数学推导重写为卷积形式。具体来说,SSM的状态更新方程:
- $ h_t = A h_{t-1} + B x_t $
- $ y_t = C h_t $
可以展开为一个输入序列 $ x_t $ 与一个由 $ A $、$ B $、$ C $ 参数生成的卷积核 $ K $ 的卷积: - $ y_t = (K * x)_t $
这里的 $ K $ 是一个由状态空间参数决定的核函数(例如,$ K_t = C A^t B $)。这意味着,SSM的输出可以看作输入序列与某个固定核的卷积结果。
2. 卷积的直接计算复杂度高
如果直接计算这个卷积,对于长度为 $ L $ 的序列,时间复杂度是 $ O(L^2) $,因为每个输出 $ y_t $ 需要对输入序列的 $ L $ 个元素与核 $ K $ 进行加权求和。当序列很长时(比如 $ L $ 在千或万级别),这种朴素计算方式在时间和计算资源上都非常昂贵,无法高效并行。
3. FFT加速卷积
快速傅里叶变换(FFT)提供了一种高效计算卷积的方法。根据卷积定理:
- 时域中的卷积等价于频域中的点乘。
- 即 $ y = K * x $ 可以通过 $ \mathcal{F}^{-1}(\mathcal{F}(K) \cdot \mathcal{F}(x)) $ 计算,其中 $ \mathcal{F} $ 表示傅里叶变换,$ \mathcal{F}^{-1} $ 表示逆傅里叶变换。
使用FFT的步骤如下:
- 将输入序列 $ x $ 和核 $ K $ 转换为频域:$ \mathcal{F}(x) $ 和 $ \mathcal{F}(K) $,复杂度为 $ O(L \log L) $。
- 在频域中进行逐元素乘法:$ \mathcal{F}(y) = \mathcal{F}(K) \cdot \mathcal{F}(x) $,复杂度为 $ O(L) $。
- 将结果转换回时域:$ y = \mathcal{F}^{-1}(\mathcal{F}(y)) $,复杂度为 $ O(L \log L) $。
总复杂度降为 $ O(L \log L) $,远低于直接卷积的 $ O(L^2) $。
4. 并行化的实现
- 频域计算的独立性:在频域中,每个频率分量的乘法是独立的,可以在GPU等并行硬件上同时计算。
- 批量处理:FFT本身是一个高度优化的算法,现代硬件和库(如cuFFT)能够一次性处理整个序列的变换,进一步提升并行效率。
- 全局操作:通过FFT,SSM不再需要按时间步顺序递归计算,而是将整个序列的卷积一次性完成,消除了时间依赖性,天然适合并行化。
5. 评价
- 长序列优化:对于短序列,直接卷积可能更快,但SSM的目标往往是高效处理长序列(如数千甚至数十万时间步),此时FFT的 $ O(L \log L) $ 优势显著。
- 数值稳定性:直接展开状态转移(如反复计算 $ A^t $)可能导致数值不稳定(例如矩阵 $ A $ 的幂次放大误差),而频域计算通过核的傅里叶表示避免了这种问题。
- 硬件支持:FFT算法在现代深度学习框架和硬件中有高度优化的实现(如NVIDIA的cuFFT),使其成为并行计算的理想选择。
- Title: mamba
- Author: Ikko
- Created at : 2025-03-31 13:27:18
- Updated at : 2025-03-31 14:14:57
- Link: http://ikko-debug.github.io/2025/03/31/mamba/
- License: This work is licensed under CC BY-NC-SA 4.0.