现代深度学习不确定性建模:从Masksembles到ZigZag

这篇论文介绍了深度学习中不确定性建模的新方法,包括Masksembles、ZigZag和IT3框架,旨在提高模型在现实场景中的可靠性和鲁棒性。

原文标题:【EPFL博士论文】现代深度学习中的不确定性建模

原文作者:数据派THU

冷月清谈:

本文总结了一篇关于现代深度学习中不确定性建模的博士论文。该论文着重解决深度学习模型在评估预测可信度方面的不足,以及由此带来的高风险问题。当前深度学习方法面临计算复杂性的挑战,尤其是在训练和推理阶段。为了降低训练成本,研究者提出了 Masksembles 方法,该方法通过训练单个模型在推理时实现集成效果,提高了计算效率并在 MC-Dropout 与深度集成方法之间实现了无缝插值。此外,论文还引入了神经网络在不确定性估计中的幂等性属性,提出了 ZigZag 方法,该方法通过训练神经网络在有无附加预测信息的情况下输出一致的结果,并以其差异度量不确定性,实现了业界领先的不确定性估计效果。在此基础上,进一步提出了幂等测试时训练(IT3)框架,旨在应对分布偏移问题。最后,论文还提出了一种针对迭代结构的不确定性估计方法,通过分析连续输出的收敛速率来量化不确定性。该方法在贝叶斯优化和分布外检测等任务中表现出色。

怜星夜思:

1、深度集成(Deep Ensembling)虽然效果好,但是计算负担大,除了文章中提到的Masksembles方法,还有没有其他降低深度集成计算成本的有效方法?
2、文章中提到的幂等性在不确定性估计中起到了什么作用?为什么训练神经网络在有无附加预测信息的情况下输出一致的结果可以度量不确定性?
3、文章提出的IT3框架,如何利用ZigZag提供的不确定性得分作为测试阶段的训练损失来提升模型性能?测试时训练(Test-Time Training)在实际应用中会遇到什么问题?

原文内容

来源:专知
本文约1000字,建议阅读5分钟
我们引入神经网络确定估计中的属性基于提出一种采样方法 ZigZag [DDLF24]。



论文中,我们聚焦现代深度学习中的一个基本挑战:确定估计。尽管深度神经网络多个关键领域取得显著成功——机器技术、大型语言模型先进信息检索系统——它们评估预测可信度方面能力仍然有限。随着这些系统日益应用风险现实场景中,缺口带来重大挑战。随着机器学习依赖不断增强,能够适应确定性、具有可靠性模型需求增长。尽管确定估计重要性日益凸显,深度学习中的实际应用面临挑战,包括扩展性、效率以及适应性。

我们首先着重解决当前深度学习方法中的一个核心问题:训练推理过程中的计算复杂性。目前深度学习受欢迎、有效确定估计方法之一——深度集成(Deep Ensembling)[LPB17]——训练推理阶段存在显著计算负担,使很多应用变得不切实际。为了解决训练阶段复杂问题,我们提出了 Masksembles方法,方法训练一个模型,却能推理实现集成效果。策略显著降低训练成本,同时保持确定估计质量。Masksembles 提高计算效率,在 MC-Dropout [GG16] 深度集成方法之间实现无缝值,融合两者优势。我们一个合成人群计数实验验证方法有效性,场景中,训练合成数据模型常常难以适应真实图像转移问题。通过使用 Masksembles,我们一个结合合成图像真实图像训练流程,基于确定引导标签方法 [LDF22] 实现强健适应能力,保持推理开销同时,超越当前先进方法。

此外,我们引入神经网络确定估计中的属性基于提出一种采样方法 ZigZag [DDLF24],方法具有效率高、通用特点,实现业界领先确定估计效果。ZigZag 通过训练神经网络附加预测信息情况输出一致结果,以其差异度量确定性。方法性能可与深度集成方法媲美,计算效率显著更高。在此基础上,我们进一步提出了 测试训练(Idempotent Test-Time Training, IT3) [DSO+24] 框架,一个领域无关方法,应对分布偏移问题。IT3 利用 ZigZag 提供确定得分作为测试阶段训练损失,推理过程中将模型表示训练分布齐,从而提升性能。框架适用多种任务,无缝集成任何模型架构中,包括 MLP、CNN 和 GNN,一点当前测试训练方法具备的。

最后,我们提出一种针对迭代结构确定估计方法 [DOL+24],通过分析连续输出速率量化确定性。方法实现当前领先估计质量,能够有效支持化,训练分布之外空间进行高效探索(例如空气动力形状化),同时图像中的道路检测任务实现高效分布检测。

关键确定估计,概率模,异常性,分布化,主动学习,





关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU


从工程角度来说,可以考虑用分布式计算框架来并行化深度集成的训练和推理过程。把任务拆解成小块,分发到不同的机器上同时处理,最后再汇总结果。比如用TensorFlow的分布式策略或者PyTorch的torch.distributed,都能实现并行计算。当然,这需要一定的集群资源和配置能力。

我最近在研究知识蒸馏,感觉它和Masksembles思路有点像,都是想用更少的资源达到类似集成的效果。另外,还可以考虑使用一些更高效的硬件加速,比如GPU或者TPU,也能在一定程度上缓解计算压力。从算法层面,也可以尝试一些更轻量级的集成策略,比如随机森林或者梯度提升树,虽然可能精度上不如深度集成,但是计算效率会高很多。

深度集成的计算成本确实是个问题。除了Masksembles这种单模型模拟集成的方法,还可以考虑模型蒸馏。用一个更大的集成模型“教”一个小模型,让小模型也能达到接近集成的效果,但推理速度更快。另外,像一些剪枝或量化的模型压缩技术,也能在一定程度上降低集成模型中单个模型的计算量,从而变相降低整体成本。不过,具体效果还得看任务和数据。

幂等性在这里的关键我认为是提供了一个“参照系”。如果模型在有额外信息和没额外信息的情况下,输出的结果应该是一致(或者非常接近)的,那么就可以认为模型对当前的预测是“自信”的。如果输出结果差异很大,就说明模型对这个预测没有把握,即存在不确定性。这个差异的大小,就反映了不确定性的程度。

我认为测试时训练最大的问题在于数据污染和过拟合。如果在测试阶段引入了错误的标签或者噪声数据,模型很容易被误导,导致性能下降。另外,如果模型过度关注测试集中的特定样本,可能会丧失对未知数据的泛化能力。因此,在实际应用中,需要谨慎控制测试时训练的学习率和迭代次数,并采取一些正则化措施来防止过拟合。

从工程角度来看,测试时训练的部署也是一个挑战。它需要在推理过程中持续进行模型更新,对系统的实时性和稳定性提出了更高的要求。比如,需要考虑如何高效地管理和存储测试数据,如何保证模型更新的原子性和一致性,以及如何在资源有限的边缘设备上实现测试时训练。这些都需要精心的设计和优化。

IT3框架挺有意思的,它利用ZigZag给出的不确定性得分,将测试阶段的数据也加入到训练循环中,相当于让模型在“实战”中不断调整自己。如果ZigZag认为某个测试样本的不确定性很高,IT3就会加大对这个样本的训练力度,让模型努力去适应它,从而提高整体的泛化能力。这有点像“哪里不会学哪里”的感觉。

幂等性其实是假设了一种“理想状态”,即模型已经完全掌握了数据中的固有规律,不论你给它额外的信息,它都应该能做出一样的判断。 但现实是模型总有缺陷,对某些信息不够敏感或者过度敏感。当给它额外信息时,如果模型没能保持输出一致,就表明模型对这部分数据的理解还不够到位,不确定性就体现在这种不一致性上。我觉得这有点像控制变量法,通过引入变量来观察模型是否稳定。

从数学角度看,幂等性可以理解为一种约束条件。通过施加这个约束,我们可以迫使模型学习到更加鲁棒的特征表示。如果模型在受到扰动(即引入额外信息)的情况下,输出仍然能够保持不变,说明模型已经学到了数据中的本质信息,而不是过度依赖于输入中的噪声。这种鲁棒性就对应于较低的不确定性。