©PaperWeekly 原创 · 作者 | 张一帆
Domain Adaptation(DA: 域自适应),Domain Generalization(DG: 域泛化)一直以来都是各大顶会的热门研究方向。DA 假设我们有有一个带标签的训练集(源域),这时候我们想让模型在另一个数据集上同样表现很好(目标域),利用目标域的无标签数据,提升模型在域间的适应能力是 DA 所强调的。以此为基础,DG 进一步弱化了假设,我们只有多个源域的数据,根本不知道目标域是什么,这个时候如何提升模型泛化性呢?传统 DG 方法就是在源域 finetune 预训练模型,然后部署时不经过任何调整,核心在于如何利用多个源域带来的丰富信息。然而一些文献表明,在不利用目标域信息的情况下实现很难实现泛化到任意分布这一目标。为了解决这一问题,测试时间自适应(TTA)方法被提出并得到了广泛研究,然而
1. 现有的 TTA 方法在推理阶段需要离线目标数据或更复杂的优化过程,如下图所示,各种 TTA 的方法,要么需要根据测试样本重新训练模型,要么需要更新模型的部分参数,或者需要额外的分支。2. 绝大多数方法没有一个理论上的验证甚至是直觉。本文介绍我们发表于 ICML 2023 的文章《AdaNPC: Exploring Non-Parametric Classifier for Test-Time Adaptation》,感谢来自北大,Meta,阿里达摩院,普林斯顿的合作者们。
本文提出了非参数化测试时间自适应的方法,不需要任何的梯度更新。在此基础上,我们从理论上验证了该框架的有效性,说明了通过引入测试样本信息,我们能够取得更好的泛化效果。据我们所知,这也是第一篇对 TTA 进行理论分析的工作。https://arxiv.org/abs/2304.12566代码链接:
https://github.com/yfzhang114/AdaNPC
在最近的研究中,人们发现在没有在推理期间利用目标样本的情况下,使模型对任何未知分布具有鲁棒性几乎是不可能的。测试时自适应(TTA)方法近期受到了广泛关注,以利用具有计算可行性约束的目标样本。然而,当前的 TTA 方法存在几个缺点。
1. 计算开销:现有的TTA方法需要批处理目标数据进行梯度更新和/或一个额外的模型进行微调,这在目标样本以在线方式一个接一个到达时是不可接受的。
2. 灾难性遗忘和模型扩展性弱:现有的 TTA 方法需要对训练模型进行更改。模型会逐渐失去对训练域的预测能力,这表明一些知识损失是不可避免的。这个问题在连续推理一系列领域时尤为重要。以 Rotated MNIST 数据集为例,我们使用最新的 TTA 方法 T3A 和 Tent 依次对 进行测试时自适应。在上图中,我们观察到所有现有方法对 的泛化能力即使在前四个域进行适应后仍然很差。我们还总结了不断使用的模型在源域 的表现,如下图所示,随着模型的 adaptation,其在源域的性能显著下降。也就是说,当前的 TTA 方法不能适应一系列在线域,很容易忘记历史知识。Method
为此,我们提出了一种名为 AdaNPC 的非参数适应方法。如下图所示1. 训练阶段,我们依然可以使用 ERM,CoRAL 等算法进行训练,我们的目标是获得更好的 representation,因此我们的框架和目前绝大多数 DG 方法都是正交的,他们学到的 representation 越好,我们 TTA 的效果也会越好。2. 测试阶段,我们只需要将模型最后的 Linear 分类器替换成 KNN。在 test time 的时候,我们将所有 training sample 的特征存入一个 memory bank 中,分类器每次 infer K 个最像的样本然后根据他们的标签生成最终的结果。
3. 模型自适应:这一步更加简单,我们只需要将目标域样本的特征和 pseudo label 存入 memory,就完成了整个 TTA 的流程。这一过程不涉及任何梯度反传,模型优化等,因此更为简单和高效。
除此之外,我们还介绍了一些可用的,但不是关键设计的 trick:
1. KNN loss:在训练阶段,我们默认采用使用线性分类器的 ERM,这就造成了 train-test 的 mismatch,因此我们也可以在训练阶段就使用 KNN 分类器,然后使用如下的损失函数来训练整个网络。通过这种配置,具有相同标签的特征被放大,而具有不同标签的特征被推开。与交叉熵损失相比,这种训练范式可以产生更好的表示,我们在原文中验证过这个观点。但是下式的优化是非平凡的,因为 随着模型参数 以不可微的方式变化。本文采用 EM 算法对其进行近似求解。我们只是定期更新 ,并在剩余的时间内保持它们的固定,这样我们就可以很容易地在 PyTorch 或 TensorFlow 中应用标准优化方案。当然,在源域的训练并不是我们的重点,在实验部分,我们表明我们甚至不需要在源域中微调预训练模型,AdaNPC 仍然可以获得良好的泛化性能。2. BN retraining:非参数分类器的性能高度依赖于模型表示,为了获得更强大的表示并保持 AdaNPC 的简单性,我们可以选择在分类器之前添加一个 BN 层。然后在评估时,通过最小化预测熵,只重新训练 BN 层参数。
理论分析
这里我们要做的事情有两个:1. 我们从理论上验证使用 KNN 作为分类器可以显式地减少域散度。2. 加入目标域样本,即非参数化的 Test-Time 自适应,将进一步减少目标域的期望误差。本文不会跳入证明细节,只是简单的提供基本的 intuition 和最终的结果,用到的假设涉及到对分布 Density,测度空间,函数平滑性等内容的一些常用假设,基本上可以算作常用假设。3.1 非参数化(KNN)分类器能够显式减小domain divergence对于传统的域泛化误差,我们知道它大概由三部分组成,即给定 hypothesis ,目标域的期望误差误差 由三个项组成、最小组合误差 、源域误差 和一个常数项的阶数 所限定。第一项与假设空间有关,一般认为是常数项,第三项即目标域和源域的 divergence,这里刻画为一阶的 wasserstein distance。当切换到非参数分类器时,在下面的定理中,我们将 替换为在 中显式衰减的量。直观地说,因为我们使用 KNN,因此,我们对每一个 target sample 的 decision,只取决于和他最近的那一群 neighbor,而不是传统的 linear classifier 那样,权重取决于整个源域训练数据。因此对每个目标域样本,我们的 domain divergence 实际上变成了他在源域的 neighbor set 和它本身,这就大大减小了 divergence。正式的说,通过构造 ,我们只保留源域与目标域有足够相似度的样本,这自然会缩短 到 的距离。我们还想强调该命题应该暗示非参数分类器可能能够从大型预训练源数据集中获得更多好处,我们在实验中验证了这一观点。最后,为什么是根据 显式衰减的量,我们可以考虑,随着源域样本数据变多 ,样本空间被铺满,那么对于任意一个目标域样本,他的邻居们都和他无限近,divergence 自然变成了 0。关于 和 影响的详细讨论见附录。当然这个 bound 是基于传统做法推出来的,因此有很多项阻碍了我们进一步研究非参数分类器的好处,比如 。我们在下一节中给出一个 bound,他只取决于我们模型参数的选择,而与假设空间本身没有关系,或者说这个关系被隐含在了假设中。3.2 AdaNPC 通过引入目标域无标签样本进一步减小目标域期望损失
在本节中,我们开发了基于 covariate-shift 和 posterior-shift 设置下的目标域 excess error 的上届,这进一步阐明了影响我们算法性能的所有因素和使用在线目标数据的好处。首先,所谓的 excess error,实际上就是给定分类器和贝叶斯最优分类器的 error 之差(二分类)。
有了这个定义我们就可以得到如下结论(下面的Proposition 2.),
1. bound受 的影响,类似于命题 1 的讨论。不同的是,只有设置 时,当 ,我们有高概率可以得到,covariate-shift 设置下的 excess error 会降为 。2. 命题 2 显示了对 选择的权衡。虽然一个小的 减少了 domain divergence 或表示相似性 ,但众所周知,小的 K 会导致 KNN 模型将变得过于具体,无法很好地泛化。3. 当回归函数不同时(可以看作条件分布 不同),bound 中引入了一个额外的项,即自适应差 ,用来度量两个回归函数的差。这种差距可以用现有的方法来估计和缩小。最后,本文提出的 AdaNPC 是一种特殊的 Test-Time 自适应方法,它可以利用在线目标样本来提高预测泛化。接下来,我们从理论上验证,通过将在线目标样本纳入 KNN 存储库,可以进一步减小 excess error。我们想要强调的是,这个错误边界比不更新内存库的情况更加 tight,也就是说,通过将新的 test instances 引入 memory,我们得到了更好的 error bound。
4.1 AdaNPC在域泛化,鲁棒性等benchmark上都取得SOTA
这里的 AdaNPC 即使用了 BN retraining 的策略。除此之外,我们发现,当 batch size 非常小时,现有方法(基于梯度更新或者其他参数更新方法)往往会产生负面影响,因为单个样本的梯度噪声非常高,这不利于模型优化。
然而,应该强调的是,批数据(batch of online data)不符合在线学习的设置,在线学习需要按需推理而不是等待一批一批的数据传入,或者当推理发生在边缘设备(如手机)上时,没有机会进行批处理。因此,AdaNPC 这种对批量大小不敏感的 TTA 方法对当前的研究领域具有重要价值。4.2 AdaNPC克服了灾难性遗忘,有很强的知识可扩展性
在本文中我们介绍一个新的 setting,即 Successive adaptation: 在域 0 上训练的模型将适应一系列的域,即通过运行 TTA 算法适应 ,我们会得到模型i,然后模型 i 将在域 i+1 上进行调整和评估,同时我们也评估了每个适应后的模型在源域的表现。以 Rotated MNIST 数据集为例,下图显示,最新的 TTA 方法,即 T3A 和 Tent,在进行测试时间适应的情况下,其性能略高于甚至低于 ERM 基线。相反,AdaNPC 记住了所有 adaptation 过程中的信息,因此取得了惊人的效果。也正是因为这个原因,AdaNPC 在源域的效果不会随着自适应的进行而变差,这是相对于现有算法的另一个突出优势。4.3 AdaNPC:不需要任何源域训练的域泛化
下图显示了直接在目标域上评估预训练模型的结果,而没有在源域进行任何微调,对于 AdaNPC,只是将源域特征进行了存储。在 PACS 数据集上(下图a)和 Rotated MNIST 数据集上(下图b)使用 MLP 分类器的平均泛化性能低于 ,即使使用强大的主干(ViT-L16)。相反,使用 KNN 分类器可以达到平均泛化精度 。如今,由于预训练模型的规模不断增长,微调通常在计算上是昂贵的。AdaNPC 要求的不是基于梯度的更新,而是外部大容量存储来存储用于图像分类的知识,例如图像特征图,这为利用预训练的知识提供了一个新的有前途的方向。此外,随着源域实例数量的增加,下图(c)表明 AdaNPC 获得了更好的性能,这验证了我们的理论结果。4.4 AdaNPC有较强的可解释性,允许引入专家信息
下图显示了 KNN 分类器如何使用源域的知识。决策过程将不再是一个黑匣子。例如,下图(b)中的长颈鹿被分类为低置信度,因为它最近的邻居是大多数具有相似姿势的人或狗。也就是说,encoder representation 忽略了一些重要特征,例如面部形状。然而,这些特征很容易被人类识别;因此,当我们得到低置信度预测时,AdaNPC 允许我们手动删除一些明显错误的邻居。在这种情况下,我们的分类结果将更加准确和自信,这对于高风险任务来说很有希望结合专家知识以获得更好的分类结果。
该论文提出了一种新的域泛化测试时间自适应方法,AdaNPC,它引入了一个非参数分类器,即 KNN 分类器,用于预测和自适应。与需要模型更新且容易忘记先前知识的当前领域泛化或测试时间自适应方法不同,所提出的方法是无参数的并且可以记住所有知识,使得AdaNPC适用于实际设置,特别是模型需要适应一系列域的时候。
我们推导出协变量偏移和后验概率偏移设置下的误差界限,其中 AdaNPC 理论上显示能够减少看不见域的目标误差。此外,AdaNPC 具有更快的收敛性、更好的可解释性和强大的知识可扩展性。更重要的是,AdaNPC 无需对源域进行任何微调即可实现高泛化精度,这为利用规模不断增长的预训练模型提供了一个有前途的方向。
一个可能的 concern 是 AdaNPC 需要进行 dense vector searching,以及需要存储大量的源域特征,对显存/内存有较高要求。我们在文中实验部分对这些 concern 进行了解答,目前的搜索速度和内存要求和现有算法至少是可以相比较或者更快的。但是我们在未来工作中也会考虑如何更加简化 memory 的构造,加速整个框架的推理时间。
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧