Transformer升级之路:长度外推性与位置鲁棒性
诚然,目前语言模型的诸多指标看来局部注意力的思路确实能解决长度外推问题,但这种“强行截断”的做法也许会不符合某些读者的审美,因为人工雕琢痕迹太强,缺乏了自然感,同时也让人质疑它们在非语言模型任务上的有效性。
本文我们从模型对位置编码的鲁棒性角度来重新审视长度外推性这个问题,此思路可以在基本不对注意力进行修改的前提下改进 Transformer 的长度外推效果,并且还适用多种位置编码,总体来说方法更为优雅自然,而且还适用于非语言模型任务。
其中,第 2 点说的是更多的 token 会导致注意力更加分散(或者说注意力的熵变大),从而导致的训练和预测不一致问题,其实我们在《从熵不变性看Attention的Scale操作》已经初步讨论并解决了它,答案是将 Attention 从:
第 1 点不一致性,即“预测的时候用到了没训练过的位置编码”,那么为了解决它,就应该做到“训练阶段把预测所用的位置编码也训练一下”。一篇 ACL22 还在匿名评审的论文《Randomized Positional Encodings Boost Length Generalization of Transformers》[1] 首次从这个角度考虑了该问题,并且提出了解决方案。
1def random_position_ids(N, L=2048):
2 """从[0, L)中随机不重复挑N个整数,并从小到大排列
3 """
4 return np.sort(np.random.permutation(L)[:N])
预测阶段,也可以同样的方式随机采样位置序列,也可以直接均匀采样位置序列(个人的实验效果显示均匀采样的效果一般好些),这就解决了预测阶段的位置编码没有被训练过的问题。
很多相关工作,包括上一篇文章提到的各种 Local Attention 及其变体的方案,都以语言模型任务构建评测指标,但不管是单向 GPT 还是双向的 MLM,它们都高度依赖局部信息(局域性),所以之前的方案很可能只是因为语言模型的局域性才有良好的外推表现,假如换一个非局域性的任务,效果可能就变差了。
也许正因为如此,这篇论文的评测并非是常规的语言模型任务,而是 Google 去年在论文《Neural Networks and the Chomsky Hierarchy》[2] 专门提出的一个长度外泛化基准(下面简称该测试基准为“CHE 基准”,即“Chomsky Hierarchy Evaluation Benchmark”),这给我们提供了理解长度外推的一个新视角。
这个基准包含多个任务,分为 R(Regular)、DCF(Deterministic Context-Free)、CS(Context-Sensitive)三个级别,每个级别的难度依次递增,每个任务的简介如下:
Even Pairs,难度 R,给定二元序列,如“aabba”,判断 2-gram 中 ab 和 ba 的总数是否为偶数,该例子中 2-gram 有 aa、ab、bb、ba,其中 ab 和 ba 共有 2 个,即输出“是”,该题也等价于判断二元序列的首尾字符是否相同。
Parity Check,难度 R,给定二元序列,如“aaabba”,判断 b 的数目是否为偶数,该例子中 b 的数目为 2,那么输出“是”。
Cycle Navigation,难度 R,给定三元序列,其中每个元分别代表 +0、+1、-1 之一,输出从 0 出发该序列最终的运算结果模 5 的值,比如 0,1,2 分别代表+0,+1,-1,那么 010211 代表 0 + 0 + 1 + 0 − 1 + 1 + 1 = 2,模 5 后输出 2。
Reverse String,难度 DCF,给定二元序列,如“aabba”,输出其反转序列,该例子中应该输出“abbaa”。
Stack Manipulation,难度 DCF,给定二元序列,如“abbaa”,以及由“POP/PUSH a/PUSH b”三个动作组成的堆栈操作序列,如“POP / PUSH a / POP”,输出最后的堆栈结果,该例子中应该输出“abba”。
Binary Addition,难度 CS,给定两个二进制数,输出它们的和的二进制表示,如输入 10010 和 101,输出 10111,注意,这需要都在字符层面而不是数值层面输入到模型中进行训练和预测,并且两个数字是串行而不是并行对齐地提供的(可以理解为输入的是字符串 10010+101)。
Binary Multiplication,难度 CS,给定两个二进制数,输出它们的和的二进制表示,如输入 100 和 10110,输出 1011000,同 Binary Addition 一样,这需要都在字符层面而不是数值层面输入到模型中进行训练和预测,并且两个数字是串行而不是并行对齐地提供的(可以理解为输入的是字符串 100 × 10110)。
Duplicate String,难度 CS,给定一个二元序列,如“abaab”,输出重复一次后的序列,该例子应该输出“abaababaab”,这个简单的任务看上去是难度 R,但实际上是 CS,大家可以想想为什么。
Missing Duplicate,难度 CS,给定一个带有缺失值的二元序列,如“ab_aba”,并且已知原始的完整序列是一个重复序列(上一个任务的 Duplicate String),预测确实值,该例子应该输出 a。
细思之下,“随机位置训练”会很让人困惑。简单起见,我们不妨设 L=2048, N=64, M=512,这样一来,训练阶段所用的平均位置序列大致为 [0, 32, 64, ···, 2016],预测阶段所用的平均位置序列是 [0, 8, 16, ···, 2040],训练阶段和预测阶段的相邻位置差不一样,这也可以说是某种不一致性,但它表现依然良好,这是为什么呢?
我们可以从“序”的角度去理解它。由于训练阶段的位置 id 是随机采样的,那么相邻位置差也是随机的,所以不管是相对位置还是绝对位置,模型不大可能通过精确的位置 id 来获取位置信息,取而代之是一个模糊的位置信号,更准确地说,是通过位置序列的“序”来编码位置而不是通过位置 id 本身来编码位置。
比如,位置序列 [1,3,5] 跟 [2,4,8] 是等价的,因为它们都是从小到大排列的一个序列,随机位置训练“迫使”模型学会了一个等价类,即所有从小到大排列的位置序列都是等价的,都可以相互替换,这是位置鲁棒性的真正含义。
然而,笔者自己在 MLM 上做的实验结果显示,这个“等价类”的学习对模型还是有一定的困难的,更理想的方法是训练阶段依然使用随机位置,使得预测阶段的位置编码也被训练过,但是预测阶段的位置序列前面部分应该跟随机位置的平均结果一致。
于是,笔者考虑了如下思路:
参考代码为:
1def random_position_ids(N):
2 """先随机采样n,然后从[0, n]均匀取N个点
3 """
4 n = sample_from_xxx()
5 return np.linspace(0, 1, N) * n
参考文献
[1] https://openreview.net/forum?id=nMYj4argap
[2] https://arxiv.org/abs/2207.02098
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
微信扫码关注该文公众号作者