通过Unit Scaling进行简单的FP16和FP8训练
近年来,深度学习社区已经从FP32数字格式过渡到FP16和BFLOAT16格式,使得内存、带宽和计算要求大幅降低,这也是目前日趋增大的模型所需要的。
如今,支持FP8的硬件的发展(如C600 PCIe卡①中使用的Graphcore拟未IPU Bow处理器)使得进一步节省低精度的效率成为可能。然而,到目前为止,这些较小的低精度格式在实践中并不简单易用。对于FP8来说,有可能更加困难。
最重要的挑战是,这些较小的格式往往将用户限制在一个较窄的可表示值范围内。因此,问题出现了:我们如何确保我们的模型始终保持在较小格式的范围内?为了解决这个问题,Graphcore Research开发了一种新的方法,我们称之为Unit Scaling。
在不同尺度上,在FP16和FP8中定量的正态分布的信噪比(SNR)。对于较小的数字格式,信号在较窄的尺度范围内是较强的。
Unit Scaling是一种模型设计技术,在初始化时根据理想的缩放原则进行操作,即根据激活、权重和梯度的单元方差进行缩放。这是通过考虑模型中每个操作所带来的方差变化和引入固定的比例因子来抵消这种变化实现的。
由此产生的模型自动生成了对低精度数字格式具有良好缩放的张量,并且使用简单,并最大限度降低这些表示的缺点。与其他低精度训练方法不同,其成本和额外的复杂性是最小的。
我们的方法取得了突破性的成果:首次在FP16甚至FP8中准确地训练了BERT Base和BERT Large模型,并且没有因缩放产生性能损失。Unit Scaling是开箱即用的,训练时不需要额外的扫描或超参数。接下来,Unit Scaling模型可用于推理,无需额外的约束或修改。
对于关心效率并希望以FP16和FP8进行训练的使用者来说,Unit Scaling提供了一个直接的解决方案。IPU非常适合这些用例,拟未目前的Bow IPU处理器提供加速的FP16计算,而下一代IPU硬件则增加了加速的FP8计算。用户可以通过附带的Paperspace notebook亲自尝试Unit Scaling。
Unit Scaling:使用指南 ②
现有的FP16/FP8训练方法
FP16和FP8训练需要某种形式的缩放来保持数值在范围内。目前的方法如下:
(静态)损失缩放
缩小范围对于训练期间的后向传递来说特别具有挑战性,往往会导致梯度下溢。为了解决这个问题,一种方法是将损失乘以一个损失比例超参数来增加梯度的大小[1]。由于没有原则性的方法来提前选择损失比例,这个超参数可能需要被扫描到,通常需要多次完整的运行。
自动损失缩放
可以通过动态调整基于运行时梯度溢出(或直方图)的损失比例来避免超参数的扫描[2]。这也可以对抗训练期间张量分布的变化。然而,自动方案可能会增加开销和复杂性。
每张量的缩放
上述方法的另一个缺点是,它们只提供一个单一的全局损失比例。另一种解决方案是根据张量统计学在本地重新调整值[3]。这也是一个自动/运行时的方案,因此可能很复杂,从而难以有效实施。
低精度训练技术的比较。“∼”表示该方法在理想情况下不需要调整,但在实践中可能会引入需要扫描的超参数。
Unit Scaling在前向和后向传递中引入了局部缩放因子,以控制值的范围。然而,我们选择这些因素是基于对每个运算符如何影响值的比例的理论理解,而不是基于使用运行时分析得到的。
通过选择正确的缩放因子,每个操作都近似地保留了其输入的规模。通过将其应用至所有项目,将初始(单元)比例传播到整个模型,从而实现全局的Unit Scaling。
请注意,我们的分析是基于初始化时(即在训练开始之前)的值的比例。虽然在训练过程中比例会发生变化,但我们发现良好的初始缩放提供了足够的空间,无需再缩放(未来的工作将进一步研究这个方向,评估当我们转向更大的模型时,在更长的时间间隔内进行再缩放的可能性)。
我们的方法比自动缩放方案更简单,唯一的额外开销是应用缩放因子(一个标量乘法,可以融合到之前的操作中)。对于BERT Large而言,这在FLOPs中引入了一个可以忽略不计的0.2%的增加。
秘诀
一个模型可以通过应用以下方法进行Unit Scaling:
1. 用单元方差初始化非偏态参数
2. 计算所有操作的理想比例因子
3. 识别非切割边并约束消耗它们的操作,使之具有相同的缩放比例
4. 用加权加法取代加法。
我们将在下文中更详细地解释这些规则。
理想的缩放因子
我们可以对一些操作进行数学分析,以确定它们如何影响其输入的方差。
例如,一个基本的矩阵乘法XW(其中X是一个(b×m)矩阵,W是一个(m×n)矩阵)的输出方差为σ(X)²-σ(W)²-m。为了将这一操作Unit Scaling化,我们必须确保σ(X)² = σ(W)² = 1(通过缩放以前的操作),然后在输出中加入1/√m乘法。
以上属于正向传递的情况。反向传递则需要引入两个新的矩阵乘法,理想的缩放因子为1/√n和1/√b。其他操作也可以进行类似的分析,在输出方差不容易分析的情况下,可以用经验方法来寻找比例因子。
我们在arxiv论文③中提供了更详细的分析,同时还提供了常见操作及其理想比例因子的汇编。
切边
在正向和反向传递中直接应用这些理想的缩放因子会产生无效的梯度。为了避免这种情况,我们要求某些操作使用一个共享的缩放因子。
具体来说,我们采取正向计算图,并找到所有没有被切边(如果被移除,将把图分成两个不相连的小图的边)代表的变量。下图展示了一个transformer FFN层:
FFN层中的切边的可视化,以及相关的缩放因子。
在这种情况下,我们在权重、输入和输出变量上有切边。图中还显示了第二个matmul的反向传递所产生的梯度运算(注意:我们只考虑正向图的切边)。
因为x₃不是切边,我们可以限制∇x₃的matmul使用与向前传递相同的缩放因子。然而,由于w₂是切边,它被允许有自己的反向缩放因子。为了选择用于受限操作的共享缩放因子,我们取之前计算的理想缩放因子的几何平均值。
虽然这个切边规则听起来很复杂,但在实践中,它通常归结为一个简单的程序:给权重梯度以自己的缩放因子,以及模型中的任何编码器/解码器层。
加权加法
我们秘诀的最后一步是用加权加法取代加法操作。Unit Scaling的设计产生了具有相同尺度的变量,这意味着如果我们将两个张量相加,这两个张量实际上具有相等的权重。然而,在某些情况下,尤其是残差连接,我们可能需要一个不平衡的权重来达到良好的性能。
为了说明这一点,我们用一个加权(和Unit Scaling)的等价物来取代加法操作。对于残差连接,我们用它来推导出以下推荐方案:
Unit Scaling的残差连接方案。
实现
下面的代码显示了PyTorch中Unit Scaling的FFN层的实现。我们在代码库④和demo notebook⑤中提供了更多的实现示例。
我们首先定义了一些缩放原语,它允许我们创建基本操作的缩放版本,如scaled_projection:
这样,我们就可以创建全缩放的层。在这里,我们展示了一个标准的FFN和它Unit Scaling后的版本:
结果
实验结果表明,Unit Scaling在广泛的模型中都是有效的,而且开箱即用,不需要额外的超参数调整。
小规模试验
我们的第一组实验验证了Unit Scaling在不同模型架构中的广泛适用性。我们在FP32和FP16中训练了一大批具有和不具有Unit Scaling的小型字符级语言模型,并比较了结果。这些配置相当于进行了2092次运行扫描:
字符语言建模,显示了在广泛的模型中每个字符的验证。每个点代表以下的一个组合:{Conv, RNN, Attention}, {Pre, Post, No norm}, {Fixed, Running-mean residual}, {SGD, Adam}, {2, 8 Layers}。每个点都是在学习率扫描中最佳的终值。
我们的结果证明了以下几点:首先,使用FP16时需要某种形式的缩放(损耗或单元)。这是由于梯度下溢造成的,因为损失缩放的因子为2048,可以解决这个问题。其次,尽管Unit Scaling改变了模型的训练行为,不再只是数字的训练,但在几乎所有情况下都与基线性能相符,甚至略有改善。最后,当从FP32转换到FP16时,无需调优。
大规模试验
我们的第二组实验在一个更大、更现实的生产级模型BERT[4]上验证了Unit Scaling的有效性。我们对Unit Scaling模型进行了调整,使其与标准BERT实现保持一致,然后使用来自英文维基百科文章的文本对其进行训练。
我们对SQuAD v1.0和SQuAD v2.0评估任务的结果如下:
我们为每个模型-方法-格式组合预训练3个模型,然后为每个模型微调5个SQuAD v1.1和5个V2.0运行。显示的数值代表15次运行的平均值,±表示3个分组的平均分数的标准偏差。†来自Devlin等人(2019)。‡来自Noune等人(2022)。
Unit Scaling能够达到与标准(基线)模型相同的性能,而基线需要扫描损失规模。Unit Scaling在所有情况下都能开箱即用。基线和Unit Scaling模型并不完全等同,但其下游性能的偏差很小(Unit Scaling的BERT Base略低于基线,而BERT Large略高于基线)。
我们的FP8实现是基于拟未、AMD和高通最近提出的标准化格式。拟未的研究之前证明了在FP8中训练损失缩放的BERT没有退化[5],我们现在证明同样的情况可以通过Unit Scaling来实现。
如果要使FP8优于FP16,不需要额外的技术,我们只需将我们的matmul输入量化到FP8中,就能准确地进行训练(在FP8的E4变体中使用权重和激活,在E5中使用梯度)。这些结果代表了BERT Base或BERT Large首次在不需要损失缩放的情况下在FP16或FP8中进行训练。
低精度训练的未来
随着支持FP8硬件的 AI 社区越来越多,有效、直接且有原则的模型缩放方法的重要性也将越来越高。Unit Scaling适用于广泛的模型和优化器,并且计算成本最低。
下一代大型模型可能会广泛使用低精度格式,因此类似Unit Scaling的方法十分必要。希望我们的方法可以帮助这些应用,并为未来的缩放研究打下坚实的基础。低精度训练的效率优势是巨大的,Unit Scaling表明不会牺牲低精度训练的效率。
阅读论文⑥ | 代码⑦ | PyTorch demo notebook⑧
参考文献
[1] P. Micikevicius et al., Mixed precision training (2018). 6th International Conference on Learning Representations
[2] O. Kuchaiev et al., Mixed-precision training for nlp and speech recognition with openseq2seq (2018), arXiv preprint arXiv:1805.10387
[3] P. Micikevicius et al., FP8 formats for deep learning (2022). arXiv preprint arXiv:2209.05433
[4] J. Devlin et al., BERT: Pre-training of deep bidirectional transformers for language understanding (2019). NAACL-HLT
[5] B. Noune et al., 8-bit numerical formats for deep neural networks (2019). arXiv preprint arXiv:2206.02915
① https://medium.com/r/?url=https%3A%2F%2Fwww.graphcore.ai%2Fproducts%2Fc600
② https://ipu.dev/qXfm2a
③ https://ipu.dev/gP3Wng
④ https://ipu.dev/csvf7o
⑤ https://ipu.dev/qXfm2a
⑥ https://ipu.dev/gP3Wng
⑦ https://ipu.dev/csvf7o
⑧ https://ipu.dev/qXfm2a
获取更多Graphcore资讯,阅读深度技术文章,并与其他创新者们一起交流,请至中国官网graphcore.cn,以及关注Graphcore微信、微博和知乎创新社区。
Graphcore中国官网
Graphcore官方微信
Graphcore微博创新社区
Graphcore知乎创新社区
点击阅读原文,查看英文blog。
微信扫码关注该文公众号作者