提升AI泛化能力:阿姆斯特丹大学博士论文探索测试时学习

阿姆斯特丹大学博士论文聚焦AI泛化能力,提出多种测试时学习方法,提升模型在未知环境下的表现,并展望了测试时自适应的未来。

原文标题:【阿姆斯特丹博士论文】在测试时学习泛化

原文作者:数据派THU

冷月清谈:

这篇阿姆斯特丹大学的博士论文深入探讨了人工智能领域中至关重要的泛化能力问题,特别是“测试时泛化”。论文着重研究如何在训练阶段无法获取测试数据的情况下,提升模型在实际测试环境中的表现。论文主要从四个方面展开研究:**训练阶段的泛化模型学习**,利用贝叶斯神经网络中的不变性学习,训练更具泛化能力的模型;**测试阶段的泛化模型学习**,在无标签和额外测试信息的情况下,运用元学习和变分推断技术,使模型能够自适应每个测试样本;**测试阶段的泛化样本学习**, 通过能量模型将测试样本适配到训练分布,避免直接调整模型从而防止灾难性遗忘;**测试阶段的泛化提示学习**,针对多模态基础模型,设计新的prompt-learning框架,实现针对任意分布偏移的测试任务专属提示生成和在线提示更新的动态测试时调整。论文还系统回顾了测试时自适应(TTA)领域的研究进展,并对测试时泛化的未来进行了展望。

怜星夜思:

1、论文中提到的“灾难性遗忘”具体是指什么?又有哪些常见的解决策略呢?
2、论文提到了“提示学习(prompt-learning)”,这个概念在多模态基础模型中如何应用?
3、论文中提到的“测试时自适应(Test-Time Adaptation, TTA)”有哪些典型的应用场景?在实际应用中可能会遇到哪些挑战?

原文内容

来源:专知
本文约1000字,建议阅读5分钟
本论文聚焦于提升泛化能力这一关键问题,尤其是在测试时泛化,即在训练阶段无法访问测试数据的前提下提高模型在测试阶段的表现。

泛化能力,即将从已见上下文中学习到的知识有效应用于陌生情境的能力,是人类智能的重要特征,但对当前的人工智能系统而言仍是一项重大挑战。传统的机器学习算法通常依赖于训练数据与测试数据来自相同分布的假设,因此在面临分布偏移时,其性能往往显著下降。本论文聚焦于提升泛化能力这一关键问题,尤其是在测试时泛化,即在训练阶段无法访问测试数据的前提下提高模型在测试阶段的表现。

本论文的结构如下:


  1. 训练阶段的泛化模型学习:通过贝叶斯神经网络中的不变性学习实现更具泛化能力的模型训练;

  2. 测试阶段的泛化模型学习:在无标签和无额外测试信息的情况下,利用元学习和变分推断技术,使模型能对每个测试样本直接进行自适应;

  3. 测试阶段的泛化样本学习:采用能量模型将测试样本适配至训练分布,以避免调整模型本身,从而规避灾难性遗忘问题;

  4. 测试阶段的泛化提示学习(prompt-learning):面向多模态基础模型,设计新颖的提示学习框架,涵盖针对任意类型分布偏移的测试任务专属提示生成,以及用于在线提示更新的动态测试时提示调整方法。


每一章均提出了创新方法,详细介绍了方法论与实验结果,展示了在测试阶段提升泛化能力的全面路径。


最后,论文进一步探讨了测试时泛化的历史与未来,并系统回顾了测试时自适应(Test-Time Adaptation, TTA)领域的研究进展,为测试时泛化的发展提供了全面总结与未来展望。


https://hdl.handle.net/11245.1/a165fad4-684a-4767-9e55-1caa83e59f59





关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU


灾难性遗忘 (Catastrophic Forgetting) 就是神经网络在学习新知识的时候,会把之前学到的东西给忘了。这在持续学习 (Continual Learning) 领域是个大问题。想象一下,你训练一个模型识别猫,然后再让它识别狗,结果它连猫都不认识了,这就是灾难性遗忘。

解决策略有很多,我稍微列几个我知道的:

* EWC (Elastic Weight Consolidation): 这个方法会找出对于旧任务比较重要的权重,然后用正则化的方式限制这些权重的改变。
* iCaRL (Incremental Classifier and Representation Learning): 这个方法会保存一些旧任务的样本,然后用新旧样本一起来训练,避免模型忘记旧知识。
* GEM (Gradient Episodic Memory): 这个方法会保存一些旧任务的梯度,然后用这些梯度来约束新任务的训练方向,避免模型偏离旧知识太远。

每个方法都有自己的优缺点,具体用哪个取决于你的应用场景。

提示学习是一种通过设计合适的“提示”(prompt),引导预训练模型完成特定任务的方法。在多模态基础模型中,提示学习可以用于将不同模态的信息(例如文本、图像)统一到同一个语义空间中,从而实现跨模态的理解和生成。举个例子,我们想要让模型根据一张图片生成描述文本,可以设计一个提示“这张图片描述的是[MASK]”,然后让模型根据图片内容填充[MASK]部分。

提示学习的关键在于如何设计有效的提示,这需要对预训练模型的特性和任务需求有深入的理解。例如,我们可以使用可学习的向量作为提示,通过训练来优化提示的内容和形式,从而更好地引导模型完成任务。

提示学习,我的理解就是给模型一点“暗示”,让它更容易理解你的意图。

在多模态模型里,这个“暗示”可以很灵活。比如,你想让模型看图说话,你可以给它一个提示:“图里有什么?”模型就会根据图片内容回答。

或者你想让模型根据文字生成图片,你可以给它一个提示:“画一只戴帽子的猫”,模型就会生成相应的图片。

总之,提示学习就是通过巧妙地设计提示,让模型更好地发挥它的能力。

测试时自适应是一种在模型部署后,根据测试数据动态调整模型参数的技术。它在实际应用中有很多场景,例如:

* 图像识别:当模型部署到新的环境中,例如光照条件变化、摄像头参数不同等,TTA可以帮助模型适应这些变化,提高识别精度。
* 自然语言处理:当模型处理不同领域的文本时,例如从新闻领域到医疗领域,TTA可以帮助模型适应新的领域知识和语言风格。
* 医疗诊断:在医疗领域,不同医院的设备参数、扫描协议可能存在差异,TTA可以帮助模型适应这些差异,提高诊断准确率。

在实际应用中,TTA可能会遇到以下挑战:

* 计算资源:TTA需要在测试时进行模型调整,这会消耗一定的计算资源。
* 模型稳定性:不合理的调整可能会导致模型性能下降,甚至崩溃。
* 隐私保护:TTA可能需要访问用户的隐私数据,如何保护用户隐私是一个重要的挑战。

TTA 的应用场景可多了,我觉得最有潜力的几个:

* 自动驾驶: 路况 постоянно меняющиеся, 环境也复杂, TTA 可以帮助模型快速适应新的场景,提高安全性。
* 智能监控: 摄像头的位置、光照条件 постоянно меняющиеся, TTA 可以帮助模型适应这些变化,提高监控效果。
* 推荐系统: 用户的兴趣 постоянно меняющиеся, TTA 可以帮助模型快速适应用户的变化,提高推荐精度。

挑战也是有的,主要是:

* 效率问题: TTA 需要实时调整模型,计算量比较大,需要高性能的硬件支持。
* 鲁棒性问题: 如果测试数据质量不好,可能会导致模型调整出错,反而降低性能。
* 泛化性问题: TTA 只能适应测试数据中的变化,对于没有见过的变化可能无能为力。

楼上两位大佬讲的都很专业!我来个通俗点的。

灾难性遗忘就像是金鱼的记忆… 刚学的东西,转身就忘了!模型也一样,学了新的数据,旧的知识就被覆盖了。

缓解方法也挺多的,我理解的“重放”就像是“温故而知新”,时不时拿旧知识出来复习一下,加深印象。

还有一种“打补丁”的策略,就是给模型打个补丁,专门用来记住旧知识,这样新知识来了也不会把旧知识冲掉。

“灾难性遗忘”是指模型在学习新任务时,会迅速忘记之前已经学会的知识。这就像你学会了开车,结果学了骑自行车后,把开车的技能忘得一干二净。常见的解决策略包括:

* 正则化方法:通过添加正则化项,限制模型参数的变化,从而保护旧知识。
* 重放(Replay):保存一部分旧数据,与新数据混合训练,让模型回忆旧知识。
* 参数隔离:为每个任务分配独立的参数空间,避免任务间的相互干扰。

这篇论文里提到的能量模型适配测试样本到训练分布,某种程度上也可以理解为一种避免灾难性遗忘的策略,因为它不需要直接修改模型参数。

Prompt-learning 在多模态模型里,就相当于给模型一个“引导语”,让它知道你想让它做什么。比如,你给模型一张猫的图片,然后给它一个 prompt: “这是一张猫的[MASK]照片”,模型就会自动填空,生成 “这是一张猫的可爱照片”。关键在于, prompt 的设计要能够激发模型内部已经学到的知识,让它能够更好地完成任务。

在多模态场景下, prompt 可以是文本、图像、音频等等。比如说,你想让模型根据一段音乐生成一段描述文字,你可以给它一个 prompt : “这段音乐听起来[MASK]”,模型就会根据音乐的节奏、旋律等信息来填空。

这种方法的优点在于,你可以通过修改 prompt 来控制模型的输出,而不需要重新训练模型。这对于快速适应不同的任务非常有用。

TTA 感觉很像“临阵磨枪”,就是模型在真正用的时候,再根据实际情况微调一下。

应用场景嘛,我觉得像推荐系统就挺适合的。用户的喜好一直在变,TTA 可以让推荐系统更快地适应用户的口味。

挑战的话,主要还是怕“磨错枪”。万一测试数据有问题,把模型带偏了,那就得不偿失了。