ICLR 2023 | 扩散生成模型新方法:极度简化,一步生成
Diffusion Generative Models(扩散式生成模型)已经在各种生成式建模任务中大放异彩,但是,其复杂的数学推导却常常让大家望而却步,缓慢的生成速度也极大地阻碍了研究的快速迭代和高效部署。研究过 DDPM 的同学可能见到过这种画风的变分法(Variational Inference)推导(截取自 What are Diffusion Models):
总体上推导的难度和对数学的要求还是比较高的。在连续时间的形式下,还需要随机微分方程(Stochastic Differential Equation(SDE))的知识,有不低的入门门槛。除此以外,扩散式生成模型的一个众所周知的老大难问题就是生成速度慢:生成一张图需要模拟一整个基于复杂的深度模型的扩散过程。缓慢的生成速度是阻碍这些模型更广泛的普及的一个主要瓶颈。
Rectified Flow,一个“简简单单走直线”生成模型,是我们对这些挑战的一个回答:极度简单,一步生成。我们的方法有以下要点:
(1)我们无需一般扩散模型复杂的推导,代之以一个简单的“沿直线生成”的思想。算法理解上不需要变分法或随机微分方程等基础知识。我们的方法是基于一个简单的常微分方程(ODE),通过构造一个“尽量走直线”的连续运动系统来产生想要的数据分布。
(2)“尽量走直线”的目的是让我们模型实现快速生成。通过一个叫“reflow”的方法,我们可以实现梦想中的“一步生成”:只需一步计算就直接产生高质量的结果,而不需要调用计算量大的数值求解器来迭代式地模拟整个扩散过程。
(3)通常的扩散模型是把高斯白噪声转换成想要的数据(比如图片)。我们的方法可以把任何一种数据或噪声(比如猫脸照片)转换成另外一种数据(比如人脸照片)。所以我们的方法不仅可以做生成模型,还可以应用于很多更广泛的迁移学习(比如 domain transfer)任务上。
有兴趣的同学可以参见我们的论文(Arxiv 或 OpenReview,以及和最优传输(optimal transport)相关的深入理论 Arxiv)。代码,示例 Colab Notebook 和预训练模型已经开源在 github。一个英文版简介在这里。欢迎大家使用和交流!
问题-传输映射(将一个分布搬运到另一个分布)
我们先定义好要解决的问题。无论是从噪声生成图片(generative modeling),还是将人脸转化为猫脸(domain transfer),都可以这样概括成将一个分布转化成另一个分布的问题:
走直线,走得快
除了希望 ,我们还希望这个连续运动系统能够在计算机里快速地模拟出来。注意到,在实际计算过程中,上面的连续系统通常是用 Euler 法(或其变种)在离散化的时间上近似:,
这里 是一个步长参数。我们需要适当的选择 来平衡速度和精度: 需要足够小来保证近似的精度,但同时小的 意味着我们从 到 要跑很多步,速度就慢。
那么问题来了,什么样的系统能最快地用 Euler 法来模拟呢?也就是说,什么样的体系能允许我们在用较大的步长 的同时还能得到很好的精度呢?
答案是“走直线”。如下图所示,如果粒子的运动轨迹是弯曲的,我们需要很细的离散化来得到很好的结果。如果粒子的轨迹是直线,那么即使我们取最大的步长(),只用一步走到 时刻,还是能得到正确的结果!
所以,我们希望我们学习出来的速度模型 既能保证 ,又能给出尽量直的轨迹。怎么同时实现这两个目的在数学上是一个非常不简单(non-trivial)的问题,涉及最优传输(optimal transport)的一些深刻理论。但是我们发现其实可以用一个非常简单的方法来解决这个问题。
▲ 蓝色:真实 ODE 轨迹;绿色:Euler 法得到的离散轨迹。左:弯曲的运动轨迹需要较小的步长来离散化才能得到较小误差,所以需要更多的步数;右:笔直的运动轨迹甚至可以在计算机里用一步进行完美的模拟。
Rectified Flow-基于直线ODE学习生成模型
假设我们有从两个分布中的采样 ,(比如 是从 里出来的随机噪声, 是一个随机的数据(服从 ))。我们把 和 用一个线性插值连接起来,得到
这里 和 是随机,或者说,以任意方式配对的。你也许觉得 和 应该用一种有意义的方式配对好,这样能够得到更好的效果。我们先忽略这个问题,待会回来解决它。
现在,如果我们拿 对时间 求导,我们其实已经可以得到一个能够将数据从 传输到 的“ODE”了,
但是,这个“ODE”并不实用而且很奇怪,所以要打个引号:它不是一个“因果”(causal),或者“可前向模拟”(forward simulatable)的系统,因为要计算 在 时刻的速度 需要提前(在 时)知道 ODE 轨迹的终点 。如果我们都已经知道 了,那其实也就没有必要模拟 ODE 了。
那么我们能不能学习 ,使得我们想要的“可前向模拟”的 ODE 能尽可能逼近刚才这个“不可前向模拟”的过程呢?最简单的方法就是优化 来最小化这两个系统的速度函数(分别是 和 )之间的平方误差:
这是一个标准的优化任务。我们可以将 设置成一个神经网络,并用随机梯度下降或者 Adam 来优化,进而得到我们的可模拟 ODE 模型。
这就是我们的基本方法。数学上,我们可以证明这样学出来的 确实可以保证生成想要的分布 。对数学感兴趣的同学可以看一看论文里的理论推导。下面我们只用这个图来给一些直观的解释。
图(a):在我们用直线连接 和 时,有些线会在中间的地方相交,这是导致 非因果的原因(在交叉点, 既可以沿蓝线走,也可以沿绿线走,因此粒子不知该向岔路的哪边走)。
图(b):我们学习出的 ODE 因为必须是因果的,所以不能出现道路相交的情况,它会在原来相交的地方把道路交换成不交叉的形式。这样,我们学习出来的 ODE 仍然保留了原来的基本路径,但是做了一个重组来避免相交的情况。这样的结果是,图(a)和图(b)里的系统在每个时刻 的边际分布是一样的,即使总体的路径不一样。
我们的方法起名为 Rectified Flow。这里 rectified 是“拉直”,“规整”的意思。我们这个框架其实也可以用来推导和解释其他的扩散模型(如 DDPM)。我们论文里有详细说明,这里就不赘述了。我们现在的算法版本应该是在已知的算法空间里最简单的选项了。我们提供了 Colab Notebook 来帮助大家通过实践来理解这个过程。
Reflow-拉直轨迹,一步生成
Reflow与Distillation
给定一个配对 ,要想实现一步生成,也就是 , 我们好像也可以通过优化下面的平方误差来直接"蒸馏(distillation)"出一个一步模型:
这个目标函数和上面的 Reflow 的目标函数很像,只是把所有的时间 都设成 了。
尽管如此,Distillation 和 Reflow 是有本质的区别的。Distillation 试图一五一十地复现 配对的关系。但是,如果 的配对是随机的,Distillation最多只能得到 在给定 时的条件平均,也就是 ,并不能成功地完全匹配 。即使 有确定的一一对应关系,他们的配对关系也可能很复杂,导致直接蒸馏很困难。
▲ 图中,每个红点代表一次两随机的直线交叉的事件。随着 reflow,交叉的概率逐渐降低,对应的 ODE 的轨迹也越来越直。
理论保证
Rectified Flow 不仅简洁,而且在理论上也有很好的性质。我们在此给出一些理论保证的非正式表述,如果大家对理论部分感兴趣,欢迎大家阅读我们文章的细节。
1.边际分布不变:当 取得最优值时,对任意时间 ,我们有 和 的分布相等。因为 ,因此 确实可以将 转移到 。
实验结果-Rectified Flow能做到什么?
▲ CIFAR-10实验结果
使用 Runge Kutta-45 求解器,1-Rectified Flow 在 CIFAR10 上得到 IS=9.6, FID=2.58,recall=0.57,基本与之前的 VP SDE/sub-VP SDE [2] 相同,但是平均只需要 127 步进行模拟。
Reflow 可以使 ODE 轨迹变直,因此2-Rectified Flow 和 3-Rectified Flow 在仅用一步(N=1)时也可以有效的生成图片(FID=12.21/8.15)。
Reflow 可以降低传输损失,因此在进行蒸馏时会得到更好的表现。用 2-Rectified Flow + 蒸馏,我们在仅用一步生成时得到了 FID=4.85,远超之前最好的仅基于蒸馏/基于 GAN loss 的快速扩散式生成模型(当用一步采样时 FID=8.91)。同时,比起 GAN,Rectified Flow + 蒸馏有更好的多样性(recall>0.5)。
我们的方法也可以用于高清图片生成或无监督图像转换。
▲ 1-rectified flow: 256分辨率图像生成
▲ 1-rectified flow: 256分辨率无监督图像转换
同期相关工作
有意思的是,今年 ICLR 在 openreview 上出现了好几篇投稿论文提出了类似的想法。
这些工作都或多或少地提出了用拟合插值过程来构建生成式 ODE 模型的方法。除此之外,我们的工作还阐明了这个路径相交重组的直观解释和最优传输的内在联系,提出了 Reflow 算法,实现了一步生成,建立了比较完善的理论基础。大家不约而同地在一个地方发力,说明这个方法的出现是有很大的必然性的。因为它的简单形式和很好的效果,相信以后有很大的潜力。
如有任何问题,欢迎留言或者发邮件!
主要论文:
参考文献
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
微信扫码关注该文公众号作者