斯坦福提出大模型最强架构TTT,超越Transformers
夕小瑶科技说 原创
作者 | 谢年年
在Transformer被提出以前,以LSTMs为代表的RNNs网络结构由于计算简单、擅长处理序列数据常被用作NLP领域的基础架构。但受其结构限制,RNNs容易出现梯度消失和梯度爆炸问题,也无法像Transformer那样进行缩放或有效地利用长上下文。而自注意力机制则擅长处理长文本,但它计算起来有些复杂,复杂度跟数据长度的平方成正比。
最近,来自Stanford的团队设计了一种新的序列建模层——测试时训练(Learn at Test Time)(TTT)层。这个层既保持了线性复杂度的好处,又让隐藏状态变得更加强大和灵活。TTT受自监督学习启发,把隐藏状态本身变成一个小型的机器学习模型,然后每次处理新数据时,都用自监督学习的方式来更新这个模型。这样,隐藏状态就能不断学习和进步,就像我们人类在学习新知识一样。
论文标题:
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
论文链接:
https://arxiv.org/pdf/2407.04620
作者测试了从125M到1.3B不同参数规模的TTT层,发现它们的表现和目前最先进的Transformer模型以及现代RNN模型Mamba相比,毫不逊色,甚至在某些情况下还更好。特别是,当文本很长时,TTT层还能继续降低困惑度,而Mamba在文本超过一定长度后就做不到了。
经过优化,TTT-Linear在处理8K上下文时,速度已经比Transformer还要快了,而且和Mamba相当。让我们一起看看是如何做到的吧~
前言:RNNs陷入效率提升与长文本处理的僵局
在2020年,OpenAI的研究缩放定律发现,传统的LSTM模型在处理大量数据和长文本时,很难像Transformer那样进行缩放,也无法有效地利用长上下文。
不过,随着技术的进步,一种更现代的RNN模型——Mamba被提出。从下图中可以看到Mamba在规模和性能上已经跟上了强大的Transformer,比以前的LSTM进步了很多。
但是,当我们仔细观察时,还是发现了和LSTM类似的问题。理论上,处于序列靠后的Token应该更容易被预测,因为它们可以以来更多的上下文。对于Transformer来说,确实是这样的,随着处理的文本越来越长,困惑度逐步下降。但是,对于Mamba在16K后指标就趋于平稳了。
这揭示了一个尴尬的现实:虽然RNN(包括Mamba)在计算上比Transformer更有效率(因为它们的计算复杂度是线性的,而Transformer是二次的),但这种优势只有在处理非常长的文本时才真正显现。如下图所示:在长下文长度超过8K时,对于Transformer,每个token的转发时间随着上下文长度的增加而线性增长,但对于其他两种方法则大致保持不变。
但是就目前来看,随着文本长度增长,RNN难以充分利用这些额外的信息。与自注意力机制不同,RNN层必须将上下文压缩到固定大小的隐藏状态中。作为一种压缩启发式方法,更新规则需要发现数千甚至数百万个token之间的底层结构和关系,这对于RNN来说是困难的。
因此要利用RNNs的高效率,就必须先解决难以处理长上下文的挑战。
作者发现:自监督学习可以将大量训练集压缩到LLM的权重中,这些模型通常能够深入理解其训练数据中的语义连接,这正是RNNs需要学习的。
因此作者设计了一类新的序列模型层,其中隐藏状态是一个模型,更新规则是自监督学习的一步。因为在测试序列上更新隐藏状态的过程等同于测试时训练模型,这类新层被称为测试时训练 (TTT) 层。
作者在这一类中引入了两种简单的实例化:TTT-Linear 和 TTT-MLP,其中隐藏状态分别是线性模型和两层MLP。 TTT层可以集成到任何网络架构中,并且像RNN层和自注意力一样进行端到端优化。
方法
TTT作为更新隐藏状态
语言模型本身就是压缩知识的优秀例子。通过自我监督任务进行下一个Token预测的训练,它们的权重可以被视为对互联网上现有知识的压缩形式存储。通过查询语言模型,可以从它们的权重中提取知识。更重要的是,语言模型经常展现出对现有知识之间语义联系的深刻理解,在新的推理任务上表现出色。
受此启发,本文的关键思路是使用使用自我监督学习将历史上下文压缩成隐状态,将上下文视为无标签数据集,状态视为模型。具体来说,隐状态现在等同于模型 的权重,可以是线性模型、一个小型神经网络或其他任何形式。输出规则也很简单:
直观上,输出Token只是 使用更新后的权重对的预测。更新规则是对某个自监督损失进行梯度下降的一步:
从压缩的角度来看,每种启发式都需要决定记住或遗忘哪些输入。让记住产生大梯度的输入,此时学到更多的输入。
损失的另一种选择是重建本身,将处理成一个损坏的输入,然后进行优化:
与去噪自编码器相似, 需要发现 的维度之间的相关性,以便从部分信息中重构它。但此时梯度下降可以减小,但不能减小到零,如下图所示:
训练具有TTT层的网络
TTT最重要的部分是自监督任务,因为它决定了从测试序列中学到的特征类型。为了让 在语言建模上表现良好,本文选择直接为下一个token的预测目标优化自监督任务,而不是手工制作任务。
具体来说,将自监督任务作为外层循环的一部分来学习,从简单的重建任务开始,添加可学习的外循环参数。通过设计低秩投影 作为训练视图,并用 作为标签视图,其中 和 都是可学习的矩阵。最终,自监督损失定义如下:
在上述公式中 和各种 一起出现,但它们本质上是不同的。在内循环中,只有 被优化,因此它是 的一个参数;而 是该损失函数的超参数。在外循环中,、、 与 一起被优化,而 只是一个隐藏状态,而不是参数。下图的代码说明了这一区别,其中 和 是TTT层的参数,类似于自注意力机制中的Key和Value参数。
最后,由于训练视图 的维度比 小,需要重新建立输出规则。最简单的解决方案是创建一个测试视图 ,并将输出规则更改为:
这样训练和标签视图压缩了 中的信息到 并随时间传播。测试视图则指定了映射到当前输出 并通过网络层传播的不同信息,从而增加了自监督任务的灵活性。
所有可能的 、、 选择构成了一系列多视图重建任务,外循环则选择其中一个任务。为了简化,本文将所有视图设计为线性投影。
小批量TTT的并行优化
目前的简单TTT层在浮点运算(FLOPs)上已经很高效,但其更新规则 无法并行化,因为 在减号前和 内部都依赖于 。由于 包含大部分计算,接下来重点并行化这部分。
在线梯度下降
梯度下降(GD)有许多变体,其更新规则可以表示为:
其中 是下降方向。一旦计算了 对于 ,可以上述公式的第二部分累加和得到所有的 。简单的在线梯度下降使用 。
批量梯度下降
为了的进行并行化,可以将他们全部对 进行计算,这个变体使用 ,称为批量梯度下降,因为 与 作为一个批次相对于 的梯度是相同的。然而,批量GD中 实际上只离 一步之遥,而在线GD中 离 有 步之遥,因此批量GD的有效搜索空间较小,影响语言建模性能。
小批量梯度下降
TTT小批量梯度下降数据的高级计算图如下图所示,其中节点代表变量,边代表计算。蓝色节点是输入变量,黄色节点是输出。由于没有相互连接,它们之间不存在顺序依赖关系,因此可以并行计算。
用 表示TTT批量大小,使用 ,其中 是前一个小批量的最后一个时间步(对于第一个小批量,),这样可以同时并行计算 个梯度。
批量大小 的消融实验
本文对TTT小批量大小 进行了消融实验,如下图所示,其中 是在线梯度下降(GD), 是批量GD。较小的 提高了困惑度(perplexity),因为进行的GD步数更多。当 时,困惑度为11.09,这表明 控制着速度和质量的权衡
总结,有两个通道可以将信息从 传播到 ():累加和和梯度操作。累加和总是活跃的,但梯度通道只有在 来自前一个小批量时才活跃。不同的梯度下降变体只影响梯度通道,即下降方向 ,特别是相对于哪个 计算梯度。然而,下降步 总是从 开始,因为更新规则的自回归性质,这与 的选择是正交的。
对偶形式
前文的并行化是必要的,但还不足以在wall-clock时间上提高效率。尤其是在现代加速器如NVIDIA A100 GPU上,它们专门优化了矩阵乘法(matmuls)操作,但TTT层的现有实现中仍然存在matmuls利用率不足的问题。
以最简单的损失函数为例,其中,且仅考虑大小为的第一个TTT小批量。此外,将视为线性模型。在时间的损失是:
如上一节讨论,我们可以并行计算:
但是仅通过上述并行化我们不能简单地通过单个matmul计算所有,而是需要逐个计算个外积。更进一步,对于每个,的维度是,这会带来较大的内存和I/O开销。
为解决这些问题,作者发现:我们并不需要实际化,只要能在小批量结束时计算出和输出token即可。
现在,以简化的TTT-Linear情况展示这些计算过程。设,则:
因此,可以方便地通过matmul计算。对于计算,可以知道:
制定,以及矩阵,可以推导出:
其中是下三角掩码(类似于注意力掩码,但用零代替无穷大),项可以从计算时重复利用。现在,也可以通过matmul方便地计算。将代入公式,我们得到。
这种计算过程称为对偶形式,与之前需要显式计算和的原始形式相对。这两种形式在输出上是等效的。
对偶形式 VS 原始形式
在TTT小批量内,原始形式的时间复杂度为。对偶形式在计算时时间复杂度为,计算的额外时间复杂度为。对比原始形式,对偶形式在理论复杂性上有所牺牲,但能更有效地利用硬件。实际应用中,通常为几百,本文选择为16。因此,计算的壁钟时间相对较小,如下图所示:
理论等价:TTT层等价于线性注意力
前文提到 可以是一个线性模型或一个神经网络。还讨论了三种更新规则的变体:在线梯度下降、批量梯度下降和小批量梯度下降。这2 × 3组合中的每一个会引发TTT层的不同实例化,如下图所示:
在上图中,Parametric learners需要定义两个属性:model和optimizer(左侧),每个learner都唯一诱导了一个TTT层(右侧)。本文提出了两种诱导的TTT层:TTT-Linear和TTT-MLP。作者通过详尽的证明过程从理论上证明了具有线性模型和批量梯度下降的TTT层相当于线性注意力。
实验
作者通过与两个基准模型——Transformer和Mamba(一种现代RNN)进行比较来评估TTT-Linear和TTT-MLP。TTT-Linear和TTT-MLP始终使用Mamba骨干架构,除非另有说明。当一幅图同时包含Transformer骨干和Mamba骨干时,分别用(T)和(M)表示。
简短的上下文
通过在Pile上进行了2k和8k上下文长度的标准实验,得到一些有趣结论:
在2k上下文长度下,TTT-Linear(M)、Mamba和Transformer表现相当,因为它们的曲线大部分重叠。TTT-MLP(M)在大的FLOP预算下表现稍逊。尽管TTT-MLP在每个模型尺寸下的困惑度更佳,但额外的FLOP成本抵消了这一优势。 在8k上下文长度下,TTT-Linear(M)和TTT-MLP(M)都显著优于Mamba,与2k的观察相比截然不同。甚至使用Transformer骨干的TTT-MLP(T)在大约13亿时表现略优于Mamba。
随着上下文长度的增加,TTT层相比于Mamba的优势扩大。
作者从两个方面分析这一现象的原因:
骨干的影响
当尝试将TTT层从Mamba骨干迁移到Transformer骨干时,观察到两个关键变化。首先,Mamba骨干下的TTT层在评估中表现更佳。其次,尽管在Mamba中TTT-MLP与TTT-Linear性能相当,但在Transformer骨干中,TTT-MLP显著优于TTT-Linear。
这表明,Mamba的时间卷积对于提升较弱隐藏状态的表现尤为重要,而Transformer则为TTT-MLP提供了更大的发挥空间。因此,TTT-Linear因结构简单,在Mamba中更受益于时间卷积;而TTT-MLP在Transformer中则能更充分地展现其优势。
缺少线性拟合
Chinchilla论文中的一个重要观察是,通过他们的方法得到的计算最优模型在特定性能指标(如FLOPs与困惑度)的对数-对数图上呈现出一条清晰的线性关系,这通常被视为尺度定律的一个典型表现。
然而本文却并未能观察到类似的清晰线性拟合,即便是针对Transformer模型也是如此。这其实不奇怪,因为数据集、文本长短、分词方法和模型结构都不一样,肯定会影响结果。
所以本文选择了将数据点直接相连,而不是强行通过线性回归来拟合它们。通过直接观察数据点的分布趋势,可以更加直观地理解不同因素对模型性能和计算复杂度的影响。
长上下文
为了测试模型处理长文本的能力,选用Pile的子集Book3:
在2k上下文中,和之前Pile 2k的结果差不多,但这次Mamba稍微超过了TTT-Linear。 在32k上下文中,TTT-Linear(M)和TTT-MLP(M)的表现都优于Mamba,这与Pile 8k的观察相似。即使是使用Transformer骨干的TTT-MLP(T),在32k上下文中也略优于Mamba。 在1.3B规模下,TTT-MLP(T)仅略逊于TTT-MLP(M)。正如所讨论的,由于缺乏清晰的线性拟合,很难得出经验性缩放法则。但TTT-MLP(T)的强劲表现说明,Transformer骨干可能更适合更大的模型和更长的文本处理。
Transformer微调
虽然前文的实验按照Mamba论文从头开始训练Transformer,但在实际中这种方法很少用于长上下文。标准做法是先短后长地训练Transformer,即先短文本训练,再长文本微调。为此,本文在4K以上文本上引入了TF微调基线,它基于Books 2k训练的模型,并针对特定长度微调。
上下文长度作为超参数。
虽然输入序列的长度由用户决定,但语言模型处理输入的上下文长度由工程师作为设计选择来确定。因此,上下文长度是一个可以选择的超参数。对于线性模型选困惑度最低的,因计算量相近;Transformer则考虑计算量与性能的平衡,寻找最优边界点。实验结果如下图所示:
TTT-Linear和TTT-MLP的表现最佳,线条几乎完全重叠。 FLOPs之后,Mamba 和 TF微调 的曲线也基本重叠。 TF微调因利用长文本优势而不增额外训练成本,优于直接长文本训练。 所有从头训练模型在文本过长时困惑度上升。
Wall-clock time
LLM的训练和推理过程简单分为三个部分:前向、后向和生成。推理时提示处理(也叫预填充)跟训练时的前向一样,但不用保存中间结果来反向传播。因为前向和后向操作能同时做,作者使用对偶方式来加速。而生成tokens(解码)是顺序的,则保持原样处理。
由于资源限制,本文实验使用 JAX 编写并在 TPUs 上运行,此时**TTT-Linear方法不用额外优化就比标准的Transformer快了10%**,然而,Mamba(在 PyTorch、Triton 和 CUDA 中实现)只能在 GPUs 上运行,为了公平比较,作者重写了该方法以在GPUs 上运行。
上图左显示,随着文本变长,标准Transformer处理每个token的时间变长,但本文的方法几乎不变。
图右显示对于生成(解码)过程,TTT-Linear和Mamba具有几乎相同的延迟,明显小于Transformer和TTT-MLP。
结语
过去,我们用机器学习来模仿人类学习,但那种方式有点像是在一堆打乱的数据里找规律,训练时用的是一套数据,测试时又是另一套,跟人类真实的学习过程不太一样。
人类学习可不是这样“死板”的。我们不会把生活分成“训练时间”和“测试时间”,也不会说某个时刻的数据只能用来学习,另一个时刻的数据只能用来检验。我们的学习是连续的,每时每刻都在进行,而且我们学习的数据往往都是有时间关联的,比如我们学习一门语言,就是在不断地听、说、读、写中慢慢提高的。
本文提出的框架TTT就像是模拟了这种连续、灵活的学习方式。它认为,数据不应该被简单地分成训练集和测试集,而是应该看作一个长长的、有时间顺序的序列。在这个序列里,每一个数据点都可以既是学习的材料,也是检验我们学习成果的“考题”。这种学习方式更接近人类真实的学习过程,也让我们看到了AI发展的一个新方向,为网络架构设计提供了新的思考维度,还开启了探索更灵活、适应性更强学习系统的广阔空间。
微信扫码关注该文公众号作者