谁动了我的显存?——深度学习训练过程显存占用分析及优化
深度学习自然语言处理 分享
作者:游凯超 (知乎,已授权)
编辑:西柚
在大语言模型时代,不仅语言模型变得越来越大,而且几乎所有的模型都想变得越来越大,试图在模型变大之后观察到一些涌现出来的能力。
模型变大之后,最突出的问题就是显存不够用了。本文对深度学习训练过程中的显存占用问题进行一些具体分析,加深我对训练过程的理解,能够进行一些简单的显存优化操作。如果读者们有更多的相关资料、优化技巧,欢迎在评论区补充。
进NLP群—>加入NLP交流群
显存占用概述
深度学习训练过程中的显存占用,大致可以分为三部分:
框架占用,例如pytorch框架的cuda context会占用大约几百MB显存.
模型参数相关的占用,比如7B的模型以FP16格式要占用14GB显存。此处还包括优化器、梯度相关的参数占用,全量微调的情况下,梯度与参数一样大,优化器状态是梯度的1~2倍(SGD为1倍,Adam为2倍)。如果使用DDP进行多卡训练,则还需要乘以显卡数量;如果使用FSDP进行多卡训练,显存占用与显卡数无关,但是会增加通信开销。
特征相关的占用,这部分显存占用是最复杂的,因为它与模型的具体计算流程有关。很多地方只会笼统地说这类占用与batchsize成正比,但是具体的比例系数很难分析。
本文希望详细解析特征相关的显存占用到底是多少。
统计方法
我们用一个样例程序,来使用不同的方法、在不同的情况下计算这样一个简单的函数。具体程序为:
import torch
# Create two tensors with 1GB memory footprint each, initialized randomly, in fp16 format
# For a tensor of float16 (2 bytes), 1GB of memory can hold 1GB / 2B = 500M elements
tensor_size = 512 * 1024 * 1024
x = torch.randn(tensor_size, dtype=torch.float16, device='cuda')
y = torch.randn(tensor_size, dtype=torch.float16, device='cuda')
# Record current memory footprint, and reset max memory counter
current_memory = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
def compute(x, y):
return (x + 1) * (y + 1)
z = compute(x, y)
# Record the additional memory (both peak memory and persistent memory) after calculating the resulting tensor
additional_memory = torch.cuda.memory_allocated() - (current_memory + 1e9)
peak_memory = torch.cuda.max_memory_allocated()
additional_peak_memory = peak_memory - (current_memory + 1e9)
print(f"Additional memory used: {additional_memory / (1024 ** 3)} GB")
print(f"Additional peak memory used: {additional_peak_memory / (1024 ** 3)} GB")
在这个函数计算过程中,输入、,输出不可避免地要占用显存。我们希望在不同情况下、改变不同的计算方式,观察/理解为了计算这个函数所需要的额外显存。
这里需要区分两个概念:峰值显存占用 与 持续显存占用 。在计算一个函数的过程中,我们可能创建了很多中间结果,它们需要临时占用显存;但是当函数计算完成之后,只有一部分结果需要持续存在(直到反向传播结束),另一部分可以被释放。上述示例小脚本,会分别输出持续显存占用和峰值显存占用。
一:不需要计算梯度的情况
上述示例脚本,直接运行的结果是:
Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 2.0686774253845215 GB
也就是说,函数运行期间需要大约2GB的显存占用,运行结束之后几乎不占显存。
具体来说,函数计算过程中需要创建和两个临时变量,乘积结果放在中。因此大约需要2GB的显存来存储临时变量,它们在计算结束后会被释放。
至于为什么持续显存占用不严格为0、峰值显存占用不严格为2GB,这就与pytorch的具体显存管理策略、对象的显存布局有关,我们暂时不关心这部分内容。
二:需要计算梯度的情况
我们把计算函数改写为:
def compute(x, y):
x.requires_grad_(True)
y.requires_grad_(True)
return (x + 1) * (y + 1)
得到的结果为:
Additional memory used: 2.0686774253845215 GB
Additional peak memory used: 2.0686774253845215 GB
也就是说,需要计算梯度时,计算过程中的临时变量并不会被释放,反而会持续存在于显存中,等待后续用于反向传播计算。
这个问题可以变得更复杂一些,如果我们让一个输入要求梯度、一个参数不要求梯度,会发生什么呢?
def compute(x, y):
x.requires_grad_(True)
return (x + 1) * (y + 1)
得到的结果是:
Additional memory used: 1.0686774253845215 GB
Additional peak memory used: 2.0686774253845215 GB
可以看到,计算完成后释放了一个临时变量,还有一个临时变量持续存在。这是因为我们只要求能计算梯度,不用计算梯度。
有意思的是,大部分人看到这里,都觉得既然不需要计算梯度,那么肯定是这个临时变量被释放了。然而,事实上是这个临时变量被释放掉了。
为了说清楚这个问题,我们用具体的值来区分和,这里的值是1,的值是2,于是临时变量的值是2,的值是3。通过计算结果记录的中间变量的值,我们可以区分具体记录了哪个中间结果。
def compute(x, y):
x.zero_()
y.zero_()
x += 1
y += 2
x.requires_grad_(True)
z = (x + 1) * (y + 1)
print(z.grad_fn._saved_other.mean().item())
return z
这段代码的运行结果是:
3.0
Additional memory used: 1.0686774253845215 GB
Additional peak memory used: 2.0686774253845215 GB
可以看到,虽然是要求梯度,但是在计算过程中保留的变量却是。
为了从原理上理解这个现象,我们来看看反向传播的本质:梯度求导。
考虑神经网络中的某个函数,输入为和两个参数,输出为。将继续参与后续运算,得到损失函数。反向传播的任务,就是在已知的情况下,计算和。
不失一般性而言,是和的函数。于是,为了反向传播,我们需要完整记录和。这是最简单粗暴的方法。
实际上,对于很多简单函数来说,偏导数的表达式并不复杂。以本文的小脚本为例,,于是只和有关。也就是说,为了计算的反向传播,只需要记录的值。
于是,我们就能理解,为什么需要梯度(对应地也需要梯度)时,反向传播记录的是。
三:不使用自动微分计算梯度
有什么办法能够绕开自动微分的限制,使得显存开销更低吗?
有的,答案就是pytorch提供的torch.autograd.Function
。
我们把计算部分的代码替换成Function的实现,直接用一个算子实现的功能:
from torch.autograd import Function
class AddMulFunction(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return (x + 1) * (y + 1)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
grad_x = grad_output * (y + 1)
grad_y = grad_output * (x + 1)
return grad_x, grad_y
func = AddMulFunction.apply
def compute(x, y):
x.requires_grad_(True)
y.requires_grad_(True)
return func(x, y)
输出结果为:
Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 2.0686774253845215 GB
这个算子也能够进行反向传播,而且计算结束之后并不会占用显存。这是因为我们在它的backward函数里手动计算了这个算子的梯度,使得它不用记录临时变量和也能进行反向传播。
从这个算子的实现中,我们能清晰地看到ctx.save_for_backward函数,它为反向传播过程记录必要的参数。
关于torch.autograd.Function
,有一个细节值得注意:torch.autograd.Function
设计的初衷就是为了让高级用户绕开自动微分的限制,因此torch.autograd.Function
的forward和backward函数执行过程中,并不会记录梯度操作。大致可以理解为:torch.autograd.Function
的forward和backward函数执行过程被包裹在 with torch.no_grad()
环境中。
例如,我们把计算代码改成:
from torch.autograd import Function
class AddMulFunction(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
z = (x + 1) * (y + 1)
print(z.requires_grad)
print(z.grad_fn)
return z
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
grad_x = grad_output * (y + 1)
grad_y = grad_output * (x + 1)
return grad_x, grad_y
func = AddMulFunction.apply
def compute(x, y):
x.requires_grad_(True)
y.requires_grad_(True)
return func(x, y)
z = compute(x, y)
print(z.requires_grad)
print(z.grad_fn)
输出结果为:
False
None
True
<torch.autograd.function.AddMulFunctionBackward object at 0x7fea4304a5e0>
Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 2.0686774253845215 GB
即使和是需要梯度的,在Function的forward函数中,是不需要梯度的。然而,当走出forward函数之后,pytorch会为它加上需要梯度的标志,并且通过grad_fn属性记录其反向传播需要执行的函数。
通过这一细节,我们可以理解,为什么定义了AddMulFunction之后,不能直接使用AddMulFunction.forward
函数,而必须用func = AddMulFunction.apply
。
以上涉及的内容,其实就是“算子融合”,通过手动计算反向传播过程,节约不必要的显存开销。上述算子还可以进一步优化,把峰值显存占用也降下来。感兴趣的朋友可以试试。
我们日常使用的很多算子,都是融合过的。
以sigmoid算子为例,如果我们自己来实现:
def compute(x):
x.requires_grad_(True)
z = 1 / (1 + torch.exp(-x))
return z
z = compute(x)
输出结果为:
Additional memory used: 2.0686774253845215 GB
Additional peak memory used: 3.0686774253845215 GB
峰值显存占用为3GB,持续显存占用为2GB。
如果改为pytorch自带的已经融合过的算子:
def compute(x):
x.requires_grad_(True)
z = torch.nn.Sigmoid()(x)
return z
z = compute(x)
输出结果为:
Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 0.06867742538452148 GB
峰值显存占用与持续显存占用几乎都是0!
这是怎么做到的呢?
sigmoid函数是element-wise的函数,只需要申请一次显存,把所有的操作都变成in-place,再把这块显存作为输出内容,就不用申请临时空间了。 sigmoid函数 的导数是,为了计算反向传播,只需要记录输出。而在我们的示例程序中,z原本就会保留,因此sigmoid函数的反向传播记录的就不用额外占用空间。
算子显存占用分析中的记账问题
上述分析中,关于sigmoid算子显存占用为0的结论并不严谨。它占用的显存刚好是我们的输出,因此没有算在它的显存开销中。
为了更准确地反映这一问题,我们让它多计算几次:
def compute(x):
x.requires_grad_(True)
for i in range(5):
x = torch.nn.Sigmoid()(x)
return x
z = compute(x)
计算5次,额外占用显存为4GB:
Additional memory used: 4.0686774253845215 GB
Additional peak memory used: 4.0686774253845215 GB
大体上来说,一个算子持续占用的显存,就是它在前向传播过程中保存下来的变量所占的显存。但一个程序占用的显存总量,并不能用全部算子占用的显存数进行求和,因为这些变量之间可能有重复(正如我们的示例中的输入变量、输出变量那样)。
总结
本文介绍了深度学习训练过程中的显存占用分析方法、自动求导与手动算子融合、优化等技术原理。算子融合是深度学习编译器等技术的核心,而算子优化目前还需要人工设计。对算子优化感兴趣的朋友,可以看看FlashAttention论文(参见《Flashattention: Fast and memory-efficient exact attention with io-awareness》),它是一个十分优雅的算子优化的例子。
注:
如何查看pytorch自带算子为反向传播保存的变量?可以通过输出的grad_fn
属性的dir(var.grad_fn)
看到,里面的_saved_xxx
就是为了反向传播保存的变量。
对于乘法,这个属性是_saved_other
,因为乘法的梯度是另一个变量;对于sigmoid算子,这个属性是_saved_result
,因为sigmoid的梯度和计算结果有关。
大部分的pytorch算子都可以通过这种方式获得保存的具体变量内容,例如卷积算子保留了以下内容:_saved_bias_sym_sizes_opt/_saved_dilation/_saved_groups/_saved_input/_saved_output_padding/_saved_padding/_saved_stride/_saved_transposed/_saved_weight.
其中大部分都是卷积的配置(例如padding大小、stride大小等内容),真正对显存占用影响最大的就是_saved_input
和_saved_weight
。
附上这部分代码,感兴趣的读者可以用它来分析pytorch自带算子的具体计算机制。
var = z
names =[k for k in dir(var.grad_fn) if k.startswith('_saved')]
for k in names:
v = getattr(var.grad_fn, k)
if isinstance(v, torch.Tensor):
print(k, v.shape)
else:
print(k, v)
原文链接:
https://zhuanlan.zhihu.com/p/641894014
进NLP群—>加入NLP交流群
微信扫码关注该文公众号作者