mamba

Ikko Lv3

ssm与其它序列模型的比较

(1) RNN(循环神经网络)

  • 基本结构
    $$
    h_t = \sigma(W_h h_{t-1} + W_x x_t + b)
    $$

    • 通过隐藏状态 $ h_t $ 传递历史信息。
    • 使用 简单非线性激活(如 tanh) 更新状态。
  • 问题

    • 梯度消失/爆炸:难以学习长距离依赖。
    • 顺序计算:无法并行训练。

(2) LSTM(长短期记忆网络)

  • 改进点:引入 门控机制(输入门、遗忘门、输出门) 控制信息流:
    $$
    \begin{aligned}
    f_t &= \sigma(W_f [h_{t-1}, x_t] + b_f) \quad &\text{(遗忘门)} \
    i_t &= \sigma(W_i [h_{t-1}, x_t] + b_i) \quad &\text{(输入门)} \
    o_t &= \sigma(W_o [h_{t-1}, x_t] + b_o) \quad &\text{(输出门)} \
    \tilde{C}t &= \tanh(W_C [h{t-1}, x_t] + b_C) \quad &\text{(候选记忆)} \
    C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad &\text{(记忆更新)} \
    h_t &= o_t \odot \tanh(C_t) \quad &\text{(输出)}
    \end{aligned}
    $$

  • 优势

    • 比普通 RNN 更能捕捉长距离依赖。
    • 通过门控缓解梯度消失问题。
  • 局限

    • 仍依赖 顺序计算,训练速度慢。
    • 长序列下可能仍会丢失早期信息。

(3) SSM(结构化状态空间模型)

  • 核心思想

    • 受控于连续时间状态空间方程,离散化后处理序列:
      $$
      \begin{aligned}
      h_t &= \overline{A} h_{t-1} + \overline{B} x_t \
      y_t &= C h_t + D x_t
      \end{aligned}
      $$
    • 通过 结构化参数(如对角矩阵) 约束 $ A $,提升计算效率。
  • 优势

    • 线性复杂度(O(N)),适合超长序列。
    • 训练时可并行(通过卷积或并行扫描)。
    • 推理时可递推(类似RNN),内存占用低。

** 关键区别总结**

特性 RNN LSTM SSM(如S4/Mamba)
长序列建模能力 ❌ 梯度消失 ✅ 优于RNN ✅ 最优(理论无限上下文)
训练并行性 ❌ 顺序计算 ❌ 顺序计算 ✅ 卷积/并行扫描
计算复杂度 O(N) O(N) O(N)
推理模式 递推 递推 可递推或并行
参数效率 较高(门控机制) 高(结构化参数)
典型应用 短序列任务 文本、语音 超长序列(DNA、音频、时间序列)

ssm的并行性

SSM(状态空间模型)在并行计算中依赖快速傅里叶变换(FFT)。FFT是计算卷积的最优工具之一,尤其在处理长序列时。

1. SSM与卷积的等价性

在现代SSM实现中,状态空间模型可以通过数学推导重写为卷积形式。具体来说,SSM的状态更新方程:

  • $ h_t = A h_{t-1} + B x_t $
  • $ y_t = C h_t $
    可以展开为一个输入序列 $ x_t $ 与一个由 $ A $、$ B $、$ C $ 参数生成的卷积核 $ K $ 的卷积:
  • $ y_t = (K * x)_t $
    这里的 $ K $ 是一个由状态空间参数决定的核函数(例如,$ K_t = C A^t B $)。这意味着,SSM的输出可以看作输入序列与某个固定核的卷积结果。

2. 卷积的直接计算复杂度高

如果直接计算这个卷积,对于长度为 $ L $ 的序列,时间复杂度是 $ O(L^2) $,因为每个输出 $ y_t $ 需要对输入序列的 $ L $ 个元素与核 $ K $ 进行加权求和。当序列很长时(比如 $ L $ 在千或万级别),这种朴素计算方式在时间和计算资源上都非常昂贵,无法高效并行。

3. FFT加速卷积

快速傅里叶变换(FFT)提供了一种高效计算卷积的方法。根据卷积定理:

  • 时域中的卷积等价于频域中的点乘。
  • 即 $ y = K * x $ 可以通过 $ \mathcal{F}^{-1}(\mathcal{F}(K) \cdot \mathcal{F}(x)) $ 计算,其中 $ \mathcal{F} $ 表示傅里叶变换,$ \mathcal{F}^{-1} $ 表示逆傅里叶变换。

使用FFT的步骤如下:

  1. 将输入序列 $ x $ 和核 $ K $ 转换为频域:$ \mathcal{F}(x) $ 和 $ \mathcal{F}(K) $,复杂度为 $ O(L \log L) $。
  2. 在频域中进行逐元素乘法:$ \mathcal{F}(y) = \mathcal{F}(K) \cdot \mathcal{F}(x) $,复杂度为 $ O(L) $。
  3. 将结果转换回时域:$ y = \mathcal{F}^{-1}(\mathcal{F}(y)) $,复杂度为 $ O(L \log L) $。

总复杂度降为 $ O(L \log L) $,远低于直接卷积的 $ O(L^2) $。

4. 并行化的实现

  • 频域计算的独立性:在频域中,每个频率分量的乘法是独立的,可以在GPU等并行硬件上同时计算。
  • 批量处理:FFT本身是一个高度优化的算法,现代硬件和库(如cuFFT)能够一次性处理整个序列的变换,进一步提升并行效率。
  • 全局操作:通过FFT,SSM不再需要按时间步顺序递归计算,而是将整个序列的卷积一次性完成,消除了时间依赖性,天然适合并行化。

5. 评价

  • 长序列优化:对于短序列,直接卷积可能更快,但SSM的目标往往是高效处理长序列(如数千甚至数十万时间步),此时FFT的 $ O(L \log L) $ 优势显著。
  • 数值稳定性:直接展开状态转移(如反复计算 $ A^t $)可能导致数值不稳定(例如矩阵 $ A $ 的幂次放大误差),而频域计算通过核的傅里叶表示避免了这种问题。
  • 硬件支持:FFT算法在现代深度学习框架和硬件中有高度优化的实现(如NVIDIA的cuFFT),使其成为并行计算的理想选择。
  • Title: mamba
  • Author: Ikko
  • Created at : 2025-03-31 13:27:18
  • Updated at : 2025-03-31 14:14:57
  • Link: http://ikko-debug.github.io/2025/03/31/mamba/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments