©PaperWeekly 原创 · 作者 | 苏剑林 在生成扩散模型的发展史上,DDIM 和同期 Song Yang 的扩散 SDE 都称得上是里程碑式的工作,因为它们建立起了扩散模型与随机微分方程(SDE)、常微分方程(ODE)这两个数学领域的紧密联系,从而允许我们可以利用 SDE、ODE 已有的各种数学工具来对分析、求解和拓展扩散模型,比如后续大量的加速采样工作都以此为基础,可以说这打开了生成扩散模型的一个全新视角。
本文我们聚焦于 ODE。在本系列的(六) 、(十二) 、(十四) 、(十五) 等文章中,我们已经推导过 ODE 与扩散模型的联系,本文则对扩散 ODE 的采样加速做简单介绍,并重点介绍一种巧妙地利用“中值定理”思想的新颖采样加速方案“AMED”。
欧拉方法
正如前面所说,我们已经有多篇文章推导过扩散模型与 ODE 的联系,所以这里不重复介绍,而是直接将扩散 ODE 的采样定义为如下 ODE 的求解:
其中 ,初值条件是 ,要返回的结果是 。原则上我们并不关心 时的中间值 ,只需要最终的 。为了数值求解,我们还需要选定节 点 , 常见的选择是 其中 。该形式来自《Elucidating the Design Space of Diffusion-Based Generative Models》 [1] (EDM),AMED 也沿用了该方案,个人认为节点的选择不算关键要素,因此本文对此不做深究。 这通常也直接称为 DDIM 方法,因为是 DDIM 首先注意到它的采样过程对应于 ODE 的欧拉法,继而反推出对应的 ODE。
高阶方法 从数值求解的角度来看,欧拉方法属于一阶近似,特点是简单快捷,缺点是精度差,所以步长不能太小,这意味着单纯利用欧拉法不大可能明显降低采样步数并且保证采样质量。因此,后续的采样加速工作都应用了更高阶的方法。 比如,直觉上差 分 应 该更接近中间点的导数而不是边界的导数,所以右端也换成 和 的平均应该会有更高的精度: 然而,右端出现了 ,而我们要做的就是计算 ,所以这样的等式并不能直接用来迭代,为此,我们用欧拉方法“预估”一下 ,然后替换掉上式中的 : 这就是 EDM 所用的“Heun 方法 [2] ”,是一种二阶方法。这样每步迭代需要算两次 ,但精度明显提高,因此可以明显减少迭代步数,总的计算成本是降低的。 二阶方法还有很多变体,比如式 (5) 的右端我们可以直接换成中间点 的函数值,这得到 中间点也有不同的求法,除了代数平均 外,也可以考虑几何平均 事实上,式 (9) 就是 DPM-Solver-2 [3 ] 的一个特例。 除了二阶方法外,ODE 的求解还有不少更高阶的方法,如"Runge-Kutta 方法 [4 ] ”、“线性多步法 [5 ] ”等。然而,不管是二阶方法还是高阶方法,虽然都能一定程度上加速扩散 ODE 的采样,但由于这些都是“通法”,没有针对扩散模型的背景和形式进行定制,因此很难将采样过程的计算步数降到极致(个位数)。
中值定理 至此,本文的主角 AMED 登场了,其论文《Fast ODE-based Sampling for Diffusion Models in Around 5 Steps》[6] 前两天才放到 Arxiv,可谓“新鲜滚热辣”。AMED 并非像传统的 ODE 求解器那样一味提高理论精度,而是巧妙地类比了“中值定理”,并加上非常小的蒸馏成本,为扩散 ODE 定制 了高速的求解器。 ▲ 几种扩散ODE-Solver示意图
首先,我们对方程 (1) 两端积分,那么可以写出精确的等式: 如果 只是一维的标量函数,那么由“积分中值定理 [7 ] ”我们可以知道存在点 ,使得 很遗憾,中值定理对一般的向量函数并不成立。不过,在 不太大以及一定的假设之下,我们依然可以类比地写出近似 其中 是训练参数, 是 U-Net 模型 的中间特征。最后,为了求解参数 ,我们采用蒸馏的思想,预先用步数更多的求解器求出精度更高的轨迹点对 ( ),然后最小化估计误差。这就是论文中的 AMED-Solver( A pproximate ME an- D irection Solver),它具备常规 ODE-Solver 的形式,但又需要额外的蒸馏成本,然而这点蒸馏成本相比其他蒸馏加速方法又几乎可以忽略不计,所以笔者将它理解为“定制”求解器。 定制一词非常关键,扩散 ODE 的采样加速研究由来已久,在众多研究人员的贡献加成下,非训练的求解器大概已经走了非常远,但依然未能将采样步数降到极致,除非未来我们对扩散模型的理论理解有进一步的突破,否则笔者不认为非训练的求解器还有显著的提升空间。因此,AMED 这种带有少量训练成本的加速度,既是“剑走偏锋”、“另辟蹊径”,也是“应运而生”、“顺理成章”。
实验结果 在看实验结果之前,我们首先了解一个名为“NFE”的概念,全称是“Number of Function Evaluations”,说白了就是模型 的执行次数,它跟计算量直接挂钩。 比如,一阶方法每步迭代的 NFE 是 1,因为只需要执行一次 ,而二阶方法每一步迭代的 NFE 是 2,AMED-Solver 的 计算量很小,可以忽略不计,所以 AMED-Solver 每一步的 NFE 也算是 2。为了实现公平的比较,需要保持整个采样过程中总的 NFE 不变,来对比不同 Solver 的效果。 基本的实验结果是原论文的 Table 2:
▲ AMED的实验结果(Table 2)
这个表格有几个值得特别留意的地方。第一,在 NFE 不超过 5 时,二阶的 DPM-Solver、EDM 效果还不如一阶的 DDIM,这是因为 Solver 的误差不仅跟阶次有关,还跟步长 有关,大致上的关系就是 ,其中 m 就是“阶”,在总 NFE 较小时,高阶方法只能取较大的步长,所以实际精度反而更差,从而效果不佳。 第二,同样是二阶方法的 SMED-Solver,在小 NFE 时效果取得了全面 SOTA,这充分体现了“定制”的重要性;第三,这里的“AMED-Plugin”是原论文提出的将 AMED 的思想作为其他 ODESolver 的“插件”的用法,细节更加复杂一些,但取得了更好的效果。 可能有读者会疑问:既然二阶方法每一步迭代都需要 2 个 NFE,那么表格中怎么会出现奇数的 NFE?其实,这是因为作者用到了一个名为“AFS(Analytical First Step)”的技巧来减少了 1 个 NFE。 该技巧出自《Genie: Higher-order denoising diffusion solvers》[8 ] ,具体是指在扩散模型背景下我们发现 与 非常接近(不同的扩散模型表现可能不大一样,但核心思想都是第一步可以直接解析求解),于是在采样的第一步直接用 替代 ,这就省了一个 NFE。论文附录的 Table 8、Table 9、Table 10 也更详尽地评估了 AFS 对效果的影响,有兴趣的读者可以自行分析。 最后,由于 AMED 使用了蒸馏的方法来训练 ,那么也许会有读者想知道它跟其他蒸馏加速的方案的效果差异,不过很遗憾,论文没有提供相关对比。 为此我也邮件咨询过作者,作者表示 AMED 的蒸馏成本是极低的,CIFAR10 只需要在单张 A100 上训练不到 20 分钟,256 大小的图片也只需要在 4 张 A100 上训练几个小时,而相比之下其他蒸馏加速的思路需要的时间是数天甚至数十天,因此作者将 AMED 视为 Solver 的工作而不是蒸馏的工作。不过作者也表示,后面有机会也尽可能补上跟蒸馏工作的对比。 假设分析 前面在讨论中值定理到向量函数的推广时,我们提到“一定的假设之下”,那么这里的假设是什么呢?是否真的成立呢? 不难举出反例证明,即便是二维函数积分中值定理都不恒成立,换言之积分中值定理只在一维函数上成立,这意味着如果高维函数成立积分中值定理,那么该函数所描述的空间轨迹只能是一条直线,也就是说采样过程中所有的 构成一条直线。 这个假设自然非常强,实际上几乎不可能成立,但也侧面告诉我们,要想积分中值定理在高维空间尽可能成立,那么采样轨迹要保持在一个尽可能低维的子空间中。 为了验证这一点,论文作者加大了采样步数得到了较为精确的采样轨迹,然后对轨迹做主成分分析,结果如下图所示: 主成分分析的结果显示,只保留 top1 的主成分,就可以保留轨迹的大部分精度,而同时保留前两个主成本,那么后面的误差几乎可以忽略了,这告诉我们采样轨迹几乎都集中在一个二维子平面上,甚至非常接近这个子平面上的的一个直线,于是在 并不是特别大的时候,扩散模型的高维空间的积分中值定理也近似成立。 这个结果可能会让人比较意外,但事后来看其实也能解释:在《生成扩散模型漫谈:构建ODE的一般步骤(下)》 我们介绍了先指定 到 的“伪轨迹”,然后再构建对应的扩散 ODE 的一般步骤,而实际应用中,我们所构建的“伪轨迹”都是 与 的线性插值(关于 t 可能是非线性的,关于 和 则是线性的),于是构建的“伪轨迹”都是直线,这会进一步鼓励真实的扩散轨迹是一条直线,这就解释了主成分分析的结果。
文章小结 本文简单回顾了扩散 ODE 的采样加速方法,并重点介绍了前两天刚发布的一个名为“AMED”的新颖加速采样方案,该 Solver 类比了积分中值定理来构建迭代格式,以极低的蒸馏成本提高了 Solver 在低 NFE 时的表现。
[1] https://arxiv.org/abs/2206.00364
[2] https://en.wikipedia.org/wiki/Heun%27s_method
[3] https://arxiv.org/abs/2206.00927
[4] https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods
[5] https://en.wikipedia.org/wiki/Linear_multistep_method
[6] https://arxiv.org/abs/2312.00094
[7] https://en.wikipedia.org/wiki/Mean_value_theorem#Mean_value_theorems_for_definite_integrals
[8] https://arxiv.org/abs/2210.05475
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读 ,也可以是学术热点剖析 、科研心得 或竞赛经验讲解 等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品 ,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬 ,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱: [email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02 )快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」 也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」 订阅我们的专栏吧