WSDM 2023 | 学习蒸馏图神经网络
图神经网络 (GNNs) 能够有效地获取图的拓扑和属性信息,在许多领域得到了广泛的研究。近年来,为提高 GNN 的效率和有效性,为 GNN 上配置知识蒸馏成为一种新趋势。然而,据我们所知,现有的应用于 GNN 的知识蒸馏方法都采用了预定义的蒸馏过程,这些过程由几个超参数控制,而不受蒸馏模型性能的监督。蒸馏和评价之间的这种隔离会导致次优结果。
在这项工作中,我们旨在提出一个通用的知识蒸馏框架,可以应用于任何预先训练的 GNN 模型,以进一步提高它们的性能。为了解决分离问题,我们提出了参数化和学习适合蒸馏 GNN 的蒸馏过程。
具体地说,我们没有像以前的大多数工作那样引入一个统一的温度超参数,我们将学习节点特定的蒸馏温度,以获得更好的蒸馏模型性能。我们首先通过一个关于节点邻域编码和预测分布的函数将每个节点的温度参数化,然后设计了一种新的迭代学习过程来进行模型蒸馏和温度学习。我们还引入了我们的方法的一个可扩展的变体来加速模型训练。
论文链接:
简介
图神经网络 (GNNs) 已经成为最先进的图上的半监督学习技术,并在过去的五年中受到了广泛的关注。数以百计的图神经网络模型已经被提出并成功地应用于各种领域,如计算机视觉、自然语言处理和数据挖掘。近年来,在图神经网络中加入知识蒸馏来达到更好的效率或效果是一种新趋势。
在知识蒸馏中,学生模型通过训练来模仿预先训练的教师模型的软预测来学习知识。从效率的角度来看,知识蒸馏可以将深层的图卷积神经网络(GCN)模型(教师)压缩为浅层模型(学生),从而实现更快的推理。从有效性的角度来看,知识蒸馏可以提取图神经网络模型(教师)的知识,并将其注入到设计良好的非图神经网络模型(学生)中,从而利用更多的先验知识,得到更准确的预测结果。
除了教师和学生的选择,蒸馏过程决定了教师和学生模型的软预测在损失函数中如何匹配,也对蒸馏后的学生对下游任务的预测表现至关重要。例如,全局超参数“温度”在知识蒸馏中被广泛采用,它软化了教师模型和学生模型的预测,以促进知识转移。
然而,据我们所知,应用于图神经网络的现有知识蒸馏方法都采用了预先定义的蒸馏过程,即只有超参数而没有任何可学习的参数。换句话说,蒸馏过程是启发式或经验式设计的,没有任何来自蒸馏学生的监督,这将分离蒸馏与评价,从而导致次优结果。针对现有的图上知识蒸馏方法的上述缺点,本文提出了一种参数化蒸馏过程的框架。
在本工作中,我们的目标是提出一个通用的知识蒸馏框架,可以应用于任何预训练过的图神经网络模型,以进一步提高其性能。注意,我们关注的是蒸馏过程的研究,而不是学生模型的选择,因此,就像 BAN 建议的那样,简单地让一个学生模型拥有与其老师相同的神经结构。为了克服蒸馏和评估之间的隔离问题,我们没有将全局温度作为超参数引入,而是创新性地提出通过蒸馏 GNN 学生的表现来学习特定节点的温度。
本工作的主要思想是为图上的每个节点学到一个特定的温度。我们通过一个关于节点邻域编码和节点预测分布的函数来参数化每个节点的温度。由于传统知识蒸馏框架存在隔离问题,经过蒸馏的学生的性能对节点温度的偏导数不存在,这使得温度参数化中的参数学习有着一定的困难。
2.1 节点分类
2.2 图神经网络
2.3 知识蒸馏
我们没有引入全局温度作为超参数,而是创新地提出学习特定节点的温度,以获得更好的蒸馏性能。我们将首先介绍如何在温度参数化中引入可学习参数,然后设计一种基于迭代学习过程的参数训练新算法。
3.1 参数化温度
直接为每个节点指定一个自由参数作为节点特定温度将导致严重的过拟合问题。因此,我们假设具有相似编码和邻域预测的节点应该具有相似的蒸馏温度。
在实际应用中,每个节点 𝑣的温度可以通过一个函数来参数化,该函数需要用到以下特征:1)学生的 logit 向量,它直接表征了学生模型当前的预测状态;2)Logits 向量的 L2 范数,由于 softmax 函数中的指数算子,较大的范数通常表示较硬的预测分布;3)中心节点邻居的预测熵,描述了节点邻居的标签多样性。
3.2 迭代学习过程
评估蒸馏的学生和监督温度的损失是:
实验
在本节中,我们对五个基准数据集进行了实验,以回答以下研究问题 (RQs):
RQ1:我们的 LTD 蒸馏出的 GNN 学生是否优于其他知识蒸馏框架蒸馏出的学生?与其他蒸馏框架相比,我们模型的效率如何?
RQ2:我们的模型在不同的环境下(如消融研究,GNN 教师/学生的不同组合)表现如何?
RQ3:我们可以从学习的 LTD 参数(即特定节点的温度)中观察到什么模式?
4.1 主实验
4.2 消融实验
4.3 学习到的温度分析
我们分析了在 5 × 5 = 25 GNN 数据集组合中学习到的节点特定温度,并提出了以下基于 GAT 的案例研究,以说明 LTD 如何帮助学习更好的蒸馏。
1. 首先,我们计算了随机初始化温度和学习温度之间的 Pearson 相关系数,以证明训练过程后节点温度发生了显著变化,并且真正具有节点特异性。
2. 我们观察到“令人困惑”的类(即与其他类混合)中的节点往往具有更高的温度。例如,在下图中我们将 GAT 老师学习到的节点嵌入可视化,并使用节点颜色来表示它们的标签。
3. 我们观察到具有较小的 L2 范数的节点倾向于具有较高的温度。
4. 请注意,我们允许负节点温度,这将完全颠倒先训练的教师的预测。我们观察到具有负学习温度的节点很可能被 GNN 老师错误地预测。
结论
在本文中,我们提出了一种新的知识蒸馏框架 LTD,可以应用于任何预训练的 GNN 模型,以进一步提高其预测性能。我们没有像以前的大多数工作那样引入一个全局温度超参数,而是创新地提出通过蒸馏学生的表现来学习节点特定的蒸馏温度。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
微信扫码关注该文公众号作者