Redian新闻
>
比标准Attention提速5-9倍,大模型都在用的FlashAttention v2来了

比标准Attention提速5-9倍,大模型都在用的FlashAttention v2来了

公众号新闻

机器之心报道

编辑:小舟、杜伟

一年时间,斯坦福大学提出的新型 Attention 算法 ——FlashAttention 完成了进化。这次在算法、并行化和工作分区等方面都有了显著改进,对大模型的适用性也更强了。


近来,几种长上下文语言模型陆续问世,包括 GPT-4(上下文长度为 32k)、MosaicML 的 MPT(上下文长度为 65k)Anthropic 的 Claude(上下文长度为 100k)。长文档查询和故事写作等新兴用例已经表明扩展语言模型上下文窗口是非常必要的。


然而,扩大 Transformer 的上下文长度是一个挑战,因为其核心的注意力层在时间复杂度和空间复杂度与输入序列长度的平方成正比。


一年前,来自斯坦福大学、纽约州立大学布法罗分校的研究者共同提出一种快速、内存高效的注意力算法 ——FlashAttention。该算法无需任何近似即可加速注意力并减少内存占用。现在,已经有许多机构和研究实验室采用 FlashAttention 来加速训练和推理


FlashAttention 示意图。


尽管 FlashAttention 的速度已经是优化基线的 2-4 倍,但它仍然有相当大的改进空间。FlashAttention 仍然不如优化过的矩阵乘法 (GEMM) 运算快,仅达到理论最大 FLOPs/s 的 25-40%。


现在,研究团队宣布推出 FlashAttention-2。FlashAttention-2 完全从头开始重写,使用 Nvidia 的 CUTLASS 3.x 及其核心库 CuTe 的原语(primitive)。


FlashAttention-2 开发者 Tri Dao。他是斯坦福大学博士生,还是 Together.AI 首席科学家,并将于 2024 年 9 月开始任职普林斯顿大学计算机科学助理教授。


FlashAttention-2 的速度是 FlashAttention 的 2 倍,在 A100 GPU 上达到 230 TFLOPs/s。在端到端训练 GPT 类语言模型时,FlashAttention-2 可让训练速度高达 225 TFLOPs/s(模型 FLOP 利用率为 72%)。


FlashAttention-2 将加速现有模型的训练、微调和推理。这意味着我们可以用相同成本训练 2 倍上下文长度的语言模型。这将有助于语言模型理解长篇书籍和报告、高分辨率图像、音频和视频。



  • 项目地址:https://github.com/Dao-AILab/flash-attention

  • 技术报告:https://tridao.me/publications/flash2/flash2.pdf


FlashAttention 是什么?


FlashAttention 是一种重新排序注意力计算的算法,它利用平铺、重计算等经典技术来显著提升计算速度,并将序列长度中的内存使用实现从二次到线性减少。其中平铺意味着将输入块从 HBM(GPU 内存)加载到 SRAM(快速缓存),并对该块执行注意力操作,更新 HBM 中的输出。


此外通过不将大型中间注意力矩阵写入 HBM,内存读写量减少,带来了 2-4 倍的时钟时间加速。


下图为 FlashAttention 的前向传递图:通过平铺和 softmax 重新缩放,研究者按块进行操作,避免从 HBM 中读取 / 写入,同时获得正确的输出,无需近似操作。



然而,FlashAttention 仍然存在一些低效率问题,原因在于不同线程块之间的工作分区不理想以及 GPU 上的 warp。这些导致低占用率或不必要的共享内存读写。


FlashAttention-2

更好的算法、并行化和工作分区


更少的非矩阵乘法 Flops


研究者调整了 FlashAttention 的算法,从而减少了非矩阵乘法(non-matmul)的 Flops 数量。这点很重要,因为现代 GPU 具有专门的计算单元(例如 Nvidia GPU 上的张量核心),使得矩阵乘法速度更快。


举例而言,A100 GPU 的 FP16/BF16 矩阵乘法的最大理论吞吐量为 312 TFLOPs/s,但非矩阵乘法 FP32 的理论吞吐量仅为 19.5 TFLOPs/s。


换一种思考方式,每个非矩阵乘法 FLOP 比矩阵乘法 FLOP 的代价高 16 倍。为了保持高吞吐量,研究者希望在矩阵乘法 FLOP 上花费尽可能多的时间。因此他们重写了 FlashAttention 中使用的在线 softmax 技巧,以减少重新缩放操作、边界检查和因果掩码操作的数量,而无需更改输出


更好的并行化


FlashAttention v1 在批大小和头(head)数量上进行并行化。研究者使用 1 个线程块来处理一个注意力头,总共有(批大小 * 头数量)个线程块。每个线程块都计划在流式多处理器(SM)上运行,例如 A100 GPU 上有 108 个这样的 SM。当这个数字非常大(如 >= 80)时,这种调度是有效的,这时可以高效地使用 GPU 上几乎所有计算资源。


在长序列的情况下(通常意味着小批量或少量头),为了更好地利用 GPU 上的多处理器,现在研究者在序列长度维数上额外地进行并行化,使该机制显著加速


更好的工作分区


即使在每个线程块内,研究者也必须决定如何在不同的 warp 之间划分工作(一组 32 个线程一起工作)。通常情况下,每个线程块使用 4 或 8 个 warp,分区方案如下图所述。 


研究者改进了 FlashAttention-2 中的这种分区,减少不同 warp 之间的同步和通信量,进而减少共享内存读写



对于每个块,FlashAttention 将 K 和 V 分割到 4 个 warp 上,同时保持 Q 可被所有 warp 访问。这被称为「sliced-K」方案。不过,这种方案是低效的,原因在于所有 warp 都需要将它们的中间结果写入共享内存,并同步,然后将中间结果相加。这些共享内存读写会减慢 FlashAttention 中的前向传递速度。


在 FlashAttention-2 中,研究者将 Q 分割在 4 个 warp 上,同时保持 K 和 V 可被所有的 warp 访问。每个 warp 执行矩阵乘法以获得 Q K^T 的切片,然后只需与 V 的共享切片相乘就能获得相应的输出切片。warp 之间不需要通信。共享内存读写的减少也可以提升速度。


新特性:头维数高达 256、多查询注意力


我们知道,FlashAttention 仅支持最高 128 的头维数,这适用于大多数模型,但有一些模型被遗漏了。


因此,FlashAttention-2 支持了高达 256 的头维数,这意味着 GPT-J、CodeGen 和 CodeGen2、StableDiffusion 1.x 等模型可以使用 FlashAttention-2 来获得加速和节省内存


此外,FlashAttention-2 还支持了多查询注意力(multi-query attention, MQA)以及分组查询注意力(grouped-query attention, GQA)。它们是注意力的变体,其中多个查询头关注相同的键和值头,以减少推理过程中 KV 缓存的大小,并可以显著提高推理吞吐量。


注意力基准结果


研究者在 A100 80GB SXM4 GPU 上,测量不同设置(无 / 有因果掩码、头维数 64 或 128)下不同注意力方法的运行时。 


结果发现, FlashAttention-2 的速度是 FlashAttention(以及 xformers 库和 Triton 中的其他实现)的 2 倍。与 PyTorch 中的标准注意力实现相比,FlashAttention-2 的速度最高是它们的 9 倍。


A100 GPU 上的注意力前向 + 后向速度。


此外只需要在 H100 GPU 上 运行相同的实现(不使用特殊指令来利用 TMA 和第四代 Tensor Core 等新硬件功能),研究者最高获得了 335 TFLOPs/s。


H100 GPU 上的注意力前向 + 后向速度。


当用于端到端 GPT 类模型训练时,FlashAttention-2 有助于在 A100 GPU 上实现最高 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。与优化良好的 FlashAttention 模型相比,端到端实现 1.3 倍加速。



这里的基线是不使用 FlashAttention 的 Megatron-LM,它现在也可以选择使用 FlashAttention 了。不久的将来,FlashAttention-2 也将集成到 Megatron-LM 中


研究团队表示:下一步将针对 H100 GPU 优化 FlashAttention-2,以使用新的硬件功能。


参考链接:

https://princeton-nlp.github.io/flash-atttention-2/



© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:[email protected]

微信扫码关注该文公众号作者

戳这里提交新闻线索和高质量文章给我们。
相关阅读
出事的陈师兄16亿人在用的TikTok,为什么超越不了10亿人用的抖音?LLM推理提速2.8倍,CMU清华姚班校友提出「投机式推理」引擎SpecInfer,小模型撬动大模型高效推理百度华为阿里等入选大模型“国家队”;盘古大模型3.0发布;阿里云推AI绘画大模型丨AIGC大事日报ChatGPT最强竞品Claude2来了:代码、GRE成绩超越GPT-4,免费可用最先被GPT革掉命的,大概率是你每天都在用的验证码Jupyter推出免费AI助手,不只会写代码,多种大模型都能调用开源大模型FLM-101B:训练成本最低的超100B参数大模型【城事】巴黎市长将重修Châtelet 广场以方便行人让Attention提速9倍!FlashAttention燃爆显存,Transformer上下文长度史诗级提升巴黎市长将重修Châtelet 广场以方便行人大模型速度狂飙2.39倍!清华联手微软首提SoT,让LLM思考更像人类浙商大佬将收获第三家上市公司,估值两年涨9倍,投资人爆赚Logstash、Fluentd、Fluent Bit 和 Vector,谁才是开源日志收集最强王者?斯坦福博士一己之力让Attention提速9倍!FlashAttention燃爆显存,Transformer上下文长度史诗级提升Erklärung zur ZusammenarbeitU设计周大谈AI时代的设计,不懂点大模型都落伍了世上“永不沉没”的Friendship,要从Lisbon Maru Ship说起……Jupyter大升级:各种大模型都能连,聊天就能生成代码、错误修改2023回国 农家乐一日游(多图)首开!Deloitte (US) 开放2026 Winter Internship世界人工智能大会上的大模型都在这了,让你一次看个够等了5年,亚马逊HQ2来了...零售行业都在用哪些软件|36氪企服点评在用趋势中科院提出FastSAM快速分割一切模型!比Meta原版提速50倍!5074 血壮山河之武汉会战 黄广战役 6让注意力提速9倍!FlashAttention燃爆显存,Transformer上下文长度史诗级提升!四大卷王 | Deloitte 率先开启2026 Winter Internship一文读懂领先的餐饮连锁企业都在用什么软件|36氪企服点评在用趋势火星乐园第三部《灰界》第十八章 信心价值数十家企业参编中国大模型标准;大模型创企获2.5亿美元投资;微软签署数十亿美元AI算力协议丨AIGC大事日报大妈是一种威武的存在只给大模型LeetCode编号,也能解题!大模型表现好是源于对训练数据的记忆吗?请不要迷信大模型膳食纤维是红薯19倍,清甜解渴,补水滋润,这才是夏天不能缺的汤水!中科院版「分割一切」模型来了,比Meta原版提速50倍 | GitHub 2.4K+星
logo
联系我们隐私协议©2024 redian.news
Redian新闻
Redian.news刊载任何文章,不代表同意其说法或描述,仅为提供更多信息,也不构成任何建议。文章信息的合法性及真实性由其作者负责,与Redian.news及其运营公司无关。欢迎投稿,如发现稿件侵权,或作者不愿在本网发表文章,请版权拥有者通知本网处理。