type
status
date
slug
summary
tags
category
icon
password
😀
读论文Mamba: Linear-Time Sequence Modeling with Selective State Spaces
从 Mamba 的发展前世今生,学习全新架构 Mamba 相关知识点(适合快速入门)

1、现有架构的问题

序列建模的核心问题是:同时解决有效高效。有效是指能够选择性记忆历史信息,解决长距离依赖(Long-Range Dependencies,LRDs)问题;高效是指计算高效。
尽管传统的模型如循环神经网络(RNNs)、卷积神经网络(CNNs)和 Transformers 在处理长距离依赖方面有专门的变体,但它们在处理超过 10000 步的极长序列时仍然面临挑战。

1.1 Transformer 的优点与问题

Transformer 的一个主要优点是,无论它接收到多长的输入,它都使用序列中的所有 token 信息(无论序列有多长)来对输入数据进行处理。
但是为了获得全局信息,注意力机制在长序列上非常耗费显存。注意力创建一个矩阵,将每个 token 与之前的每个 token 进行比较。矩阵中的权重由 token 对之间的相关性决定。
在训练过程中,Attention 计算可以并行化,所以可以极大地加快训练速度。但是在推理过程中,生成新 token 时需要重新计算整个序列的注意力。
长度为 L 的序列生成 token 大约需要 L² 的计算量,如果序列长度增加,计算量会平方级增长。因此,需要重新计算整个序列是 Transformer 体系结构的主要瓶颈。Transformer 训练快、推理慢。

1.2、RNN 的优点与问题

在生成输出时,RNN 只需要考虑之前的隐藏状态和当前的输入。这样不会重新计算以前的隐藏状态,这正Transformer 不具备的。这种结构可以让 RNN 进行快速推理,并且理论上可以无限扩展上下文长度,因为每次推理只取一个隐藏状态和当前输入,内存占用非常稳定。
RNN 的每个隐藏状态都是之前所有隐藏状态的聚合。但是这里会有一个问题,在生成 token "Liang" 时,最后一个隐藏状态不再包含关于 token "Hello" 的信息。这会导致随着时间的推移,RNN 会忘记更久的信息,因为它只考虑前一个状态。
notion image
并且 RNN 的这种顺序性产生了另一个问题。训练不能并行进行,因为它需要按顺序完成每一步。与 Transformer 相比,RNN 的问题完全相反!它的推理速度非常快,但不能并行化导致训练很慢。RNN训练慢、推理快
notion image
人们一直在寻找一种既能像 Transformer 那样并行化训练,能够记住先前的信息,又能在推理时时间是随序列长度线性增长的模型,Mamba 就是这样应运而生的。解下来我们从 SSM 开始,逐步介绍 Mamba。

2、状态空间模型 SSM

状态空间模型将一维输入信号 映射到N维潜在状态 ,然后再映射到一维输出信号 。SSM主要包含两个部分:状态更新方程和输出方程。
notion image
其中, 是状态转移矩阵, 是输入到状态的矩阵, 是状态到输出的矩阵, 是直接从输入到输出的参数(很多时候)。

3、HiPPO架构

HiPPO 主要为了解决如何在有限的存储空间中有效地解决序列建模的长距离依赖问题。HiPPO 通过函数逼近产生状态矩阵 A 的最优解,有效的解决了长距离依赖问题。

3.1、HiPPO 架构如何解决长距离依赖(LRDs)问题:

LRDs 是序列建模中的一个关键挑战,因为它们涉及到在序列中跨越大量时间步的依赖关系。
作者指出,基本的 SSM 在实际应用中表现不佳,特别是在处理 LRDs 时。这是因为线性一阶常微分方程(ODEs)的解通常是指数函数,这可能导致梯度在序列长度上呈指数级增长,从而引发梯度消失或爆炸的问题。 为了解决这个问题,作者利用了 HiPPO 理论。HiPPO 理论指定了一类特殊的矩阵 A,当这些矩阵被纳入 SSM 的方程中时,可以使状态 x(t) 能够记住输入 u(t) 的历史信息。这些特殊矩阵被称为 HiPPO 矩阵,它们具有特定的数学形式,可以有效地捕捉长期依赖关系。 HiPPO 矩阵的一个关键特性是它们允许 SSM 在数学和实证上捕捉 LRDs。例如,通过将随机矩阵 A 替换为 HiPPO 矩阵,可以在序列 MNIST 基准测试上显著提高 SSM 的性能。

4、S4 架构(Structured State Space Model)

S4 是 HiPPO 的后续工作,S4 的主要工作是将HiPPO中的矩阵 A(称为 HiPPO 矩阵)转换为正规矩阵(正规矩阵可以分解为对角矩阵)和低秩矩阵的和,以此提高计算效率。 S4 通过这种分解,将计算复杂度降低到了O(N+L),其中 N 是 HiPPO 矩阵的维度,L 是序列长度。

4.1、为什么对角化可以减少 SSM 计算复杂度

对角化是一种线性代数技术,它可以将一个矩阵转换为对角形式,从而简化矩阵的乘法和其他运算。在 SSM 的上下文中,对角化可以显著减少计算复杂度,因为对角矩阵的幂运算(如在递归方程中出现的)可以通过简单的元素指数运算来完成。
下面我们解释下,为什么对角化可以减少 SSM 计算复杂度。
首先我们引入论文中的定理 3.1
(Lemma 3.1): 共轭是SSM中的等价关系,即:
也就是将矩阵变为,最后得到的输出 保持不变。那么如果矩阵是对角矩阵,则输出 的计算复杂度将从变成。只要Lemma 3.1 成立,我们就可以使用对角化技术,降低计算复杂度。
状态空间模型可用于建模文本序列,但仍有一系列我们想要避免的缺点。

5、Mamba 架构(S6)

Mamba 的两大主要贡献:
  • 一种选择性扫描算法,该算法允许模型过滤(不)相关信息;
  • 一种硬件感知算法,该算法允许通过并行扫描、内核融合和重新计算来高效存储(中间)结果。

5.1、通过选择机制改进 SSM

SSM 在某些对语言建模和生成至关重要的任务上的糟糕表现说明了时间不变 SSM 的潜在问题,即矩阵 A、B 和 C 的静态性质导致内容感知方面的问题。
为了解决上面的问题,作者提出了一种新的选择性 SSM(Selective State Space Models,简称 S6 或 Mamba)。这种模型通过让 SSM 的矩阵 A、B、C 依赖于输入数据,从而实现了选择性。这意味着模型可以根据当前的输入动态地调整其状态,选择性地传播或忽略信息。
Mamba 集成了 S4 和 Transformer 的精华,一个更加高效(S4),一个更加强大(Transformer)。
Mamba 通过将输入序列的长度和批次大小结合起来,使矩阵 B 和 C,甚至步长 ∆ 都依赖于输入。这意味着对于每个输入 token,我们现在有不同的 B 和 C 矩阵。备注:这里矩阵 A 保持不变,因为我们希望状态本身保持静态,但影响它的方式 (通过 B 和 C) 是动态的。它们一起选择性地决定在隐藏状态中保留什么和忽略什么,因为它们现在依赖于输入。
在 SSM 中,通过调整 ∆,模型可以控制对当前输入的关注度,从而实现类似于 RNN 门控的效果。例如,当 ∆ 较大时,模型倾向于关注当前输入并忽略之前的信息;而当∆较小时,模型则倾向于保留更多的历史信息。
S4 和 选择性 SSM 的核心区别在于,它们将几个关键参数(∆, B, C)设定为输入的函数,并且伴随着整个 tensor 形状的相关变化。特别是,这些参数现在具有一个长度维度 L,这意味着模型已经从时间不变(time-invariant)转变为时间变化(time-varying)。
notion image

5.2、选择性 SSM 和门控之间的关系

  • 时间步∆:时间步∆和 RNN 的门控有很强的关联,依赖输入的∆跟 RNN 的遗忘门的功能类似。
  • 矩阵 B 和 C:在 SSM 中,修改 B 和 C 以使其具有选择性,允许模型更精细地控制是否让输入进入状态h或状态进入输出y,所以B和C类似于RNN中的输入门和输出门。
  • 矩阵 A:A有点类似起到多尺度/细粒度门控的作用。虽然∆已经有点遗忘门的作用,但注意到对于每个输入维度来说,∆只是一个标量,而,也就是说对应这个维度的SSM来说,A在每个hidden state维度上的作用可以不相同,起到细粒度门控的作用,这也是LSTM网络里面用element-wise product的原因(LSTM中遗忘门是跟隐藏层维度相同的一个向量,而不仅仅是一个标量)。

5.3、Mamba高效实现

因为现在的参数ABC都是输入相关了,所以不再是线性时间不变系统,也就失去了卷积的性质,不能用 FFT来进行高效训练了。
作者采用了一种硬件感知的算法,实际上就是用三种经典技术来解决这个问题:内核融合(kernel fusion)、并行扫描(parallel scan)和重计算(recomputation)。
 
Mamba 的实现比其它方法实现快很多倍,scan 在输入长度 2k 的时候就开始比 FlashAttention 快了,之后越长越快。同时 scan 也比 Convolution 快。
注:
  • 矩阵 A 在状态空间模型(SSM)中表示状态转移矩阵,它定义了状态如何从一个时间步转移到下一个时间步。
  • 矩阵 来表示经过某种变换或离散化后的状态转移矩阵,以适应离散时间模型。
  • 矩阵 B 表示输入矩阵(Input Matrix),用于将输入信号 x(t) 映射到状态空间,控制输入信号如何影响状态。
  • 矩阵 C 表示输出矩阵(Output Matrix),用于将状态 h(t) 映射到输出 y(t) ,控制状态如何影响输出。
  • 参数 D 表示前馈矩阵(Feedforward Matrix),用于直接将输入 x(t) 映射到输出 y(t) ,直接控制输入对输出的影响。
  • 时间步 Δ 表示SSM中相邻时间步之间的时间间隔。它是离散化过程中用于将连续时间系统转换为离散时间系统的关键参数。在选择性SSMs中,Δ 可以被参数化为输入的函数,以实现选择性机制。

5.4、Mamba架构

notion image
之前的 SSM 模型要 work,都会加上 output gating,之后再过个线性层 channel mixing,如上图的最左边所示。这两个部分跟 Gated MLP(上图中间)右边的支路和最上面的 channel mixing 是一样的。所以 SSM 层如果跟Gated MLP 合并的话,难免会感觉有点冗余,所以作者干脆把两个合二为一,把 token mixing 层和 channel mixing。
它首先进行线性投影以扩展输入 embedding。然后,在应用选择性 SSM 之前进行卷积。选择性 SSM 具有以下属性:
  • 通过离散化创建递归 SSM;
  • 对矩阵 A 进行 HiPPO 初始化,以捕获远程依赖关系;
  • 选择性扫描算法,有选择地压缩信息;
  • 硬件感知算法,加速计算。
下面是一个Mamba 架构端到端(输入到输出)的例子:
notion image
Mamba和Transformer以及RNN对比:
notion image
总结:融合 SSM 和 LSTM,将 LSTM 选择性的思想融入 SSM 中,全方位的实现优化,使得 Mamba 既具备 Transformer 高效训练的特点,又具备 S4 中支持长文本的优点,同时具备 LSTM 一样选择性记忆的特点。

🤗 总结归纳

本文的初衷是记录博主入门Mamba时的一些知识点,,方便后续复习,适合快速入门者实用,更深层次的探讨请参考论文原文。

📎 参考文章

💡
有关Mamba架构的相关问题,欢迎您在底部评论区留言,一起交流~
 
数据结构课程设计——EXT2文件系统(本人原版项目书)OpenCV-Python学习笔记
  • Giscus
  • Cusdis
  • Utterance
Naipings
Naipings
一个普通的大学生,分享自己学习的“有趣”知识
Announcement
type
status
date
slug
summary
tags
category
icon
password
🎉 感谢您的支持 🎉
-- 点击收藏不迷路 ---
👏欢迎更新体验👏