ICLR 2023 | DIFFormer: 扩散过程启发的Transformer
机器之心专栏
本⽂介绍⼀项近期的研究⼯作,试图建⽴能量约束扩散微分⽅程与神经⽹络架构的联系,从而原创性的提出了物理启发下的 Transformer,称作 DIFFormer。作为⼀种通⽤的可以灵活⾼效的学习样本间隐含依赖关系的编码器架构,DIFFormer 在各类任务上都展现了强大潜⼒。这项工作已被 ICLR 2023 接收,并在⾸轮评审就收到了四位审稿⼈给出的 10/8/8/6 评分(最终均分排名位于前 0.5%)。
论⽂地址:https://arxiv.org/pdf/2301.09474.pdf 项⽬地址:https://github.com/qitianwu/DIFFormer
如果是⼀个的单位矩阵:(1)式中每个样本的表征计算只取决于⾃⼰(与其他样本独⽴),此时给出的是 Multi-Layer Perceptron (MLP) 的更新公式,即每个样本被单独输⼊进 encoder 计算表征; 如果在固定位置存在⾮零值(如输⼊图中存在连边的位置):(1)式中每个样本的表征更新会依赖于图中相邻的其他节点,此时给出的是 Graph Neural Networks (GNN) 的更新公式,其中是传播矩阵(propagation matrix),例如图卷积⽹络(GCN)模型采⽤归⼀化后的邻接矩阵; 如果在所有位置都允许有⾮零值,且每层的都可以发⽣变化:(1)式中每个样本的表征更新会依赖于其他所有节点,且每次更新两两节点间的影响也会适应性的变化,此时 (1) 式给出的是 Transformer 结构的更新公式,表示第层的 attention 矩阵。
我们这⾥引⼊⼀个能量函数,来刻画每时每刻由系统中所有节点表征所定义的内在⼀致性,通过能量的最⼩化来引导扩散过程中节点信号的演 变⽅向。具体的,对于样本表征,其对应的能量定义为:
基于此,我们考虑⼀种带能量约束的扩散过程,每⼀步的扩散率被定义为⼀个待优化的隐变量,我们希望它给出的每⼀步的节点表征都能够使得系统整体的能量下降。带能量约束的扩散过程可以被形式化的描述为:
DIFFormer-s:采⽤简单的 dot-product 来衡量相似性,作为 attention function(这⾥使⽤ L2 normalization 将输⼊向量限制在 [-1,1] 之间从⽽保证得到的注意⼒权重⾮负):
DIFFormer-a:在计算相似度时引⼊⾮线性,从⽽提升模型学习复杂结构的表达能⼒:
我们可以把代⼊更新单个样本的聚合公式,然后通过矩阵乘法结合律交换矩阵运算的顺序(这⾥假设):
更进⼀步的,我们可以引⼊更多设计来提升模型的适⽤性和灵活度。上述的模型主要考虑了样本间的 all-pair attention。对于输⼊数据本身就含有样本间图结构的情况,我们可以加⼊现有图神经⽹络(GNN)中常⽤的传播矩阵(propagation matrix)来融合已知的图结构信息,从⽽定义每层的样本表征更新如下
此时输⼊数据是⼀张图,图中的每个节点是⼀个样本(包含特征和标签),⽬标是利⽤节点特征和图结构来预测节点的标签。我们⾸先考虑⼩规模图 的实验,此时可以将⼀整图输⼊ DIFFormer。相⽐于同类模型例如 GNN,DIFFormer 的优势在于可以不受限于输⼊图,学习未被观测到的连边关系,从⽽更好的捕捉⻓距离依赖和潜在关系。下图展示了与 SOTA ⽅法的对⽐结果。
第⼆个场景我们考虑⼀般的分类问题,输⼊是⼀些独⽴的样本(如图⽚、⽂本),样本间没有已观测到的依赖关系。此时尽管没有输⼊图结构, DIFFormer 仍然可以学习隐含在数据中的样本依赖关系。对于对⽐⽅法 GCN/GAT,由于依赖于输⼊图,我们这⾥使⽤ K 近邻⼈⼯构造⼀个样本间的图结构。
进⼀步的,我们考虑时空预测任务,此时模型需要根据历史的观测图⽚段(包含上⼀时刻节点标签和图结构)来预测下⼀时刻的节点标签。这⾥我们横向对⽐ 了 DIFFormer-s/DIFFormer-a 在使⽤输⼊图和不使⽤输⼊图(w/o g)时的性能,发现在不少情况下不使⽤输⼊图模型反⽽能给出的较⾼预测精度。这也说明了在这类任务中,给定的观测图结构可能是不可靠的,⽽ DIFFormer 则可以通过从数据中学习依赖关系得到更有⽤的结构信息。
从设计思想上看:模型结构从能量下降扩散过程的⻆度导出,相⽐于直接的启发式设计更加具有理论依据; 从模型实现上看:在保留了学习每层所有节点全局 all-pair attention 的表达能⼒的同时,DIFFormer-s 只需要复杂度来更新个节点的表征,同时兼容 mini-batch training,可以有效扩展到⼤规模数据集。
建模含有观测结构的数据,得到节点表征(简⾔之就是使⽤ GNN 的场景):输⼊是⼀张图包含了互连的节点,需要计算图中节点的表征。这是⼀个相对已被⼴泛研究的领域,DIFFormer 的优势在于可以挖掘未被观测的隐式结构(如图中的缺失边、⻓距离依赖等),以及在低标签率的情况下提升精度。 建模不含观测结构但样本间存在隐式依赖的数据(如⼀般的分类 / 回归问题):数据集包含⼀系列独⽴样本,样本间的依赖关系未知。此时 DIFFormer 可⽤于学习样本间的隐式依赖关系,利⽤全局信息来计算每个样本的表征。这是⼀个较少被研究的领域,传统⽅法的主要 bottleneck 是在⼩数据集上容易过拟合(由于考虑了样本依赖模型过于复杂),⼤数据集上⼜⽆法有效扩展(学习任意两两样本的关系带来了平⽅复杂度)。DIFFormer 的优势在于简单的模型结构有效避免了过拟合问题,⽽且保证了相对于样本数量的复杂度可以有效扩展到⼤规模数据集。 作为⼀般的即插即⽤式 encoder,解决各式各样的下游任务(如⽣成 / 预测 / 决策问题)。此时 DIFFormer 可以直接⽤于⼤框架下的某个部件,得到输⼊数据的隐空间表征,⽤于下游任务。相⽐于其他 encoder (如 MLP/GNN/Transformer),DIFFormer 的优势在于可以⾼效的计算全局 attention,同时具有⼀定的理论基础(能量下降扩散过程的观点)。
© THE END
转载请联系本公众号获得授权
投稿或寻求报道:[email protected]
微信扫码关注该文公众号作者
戳这里提交新闻线索和高质量文章给我们。
来源: qq
点击查看作者最近其他文章