MIT研究:Transformer解决经验贝叶斯问题,速度提升100倍

MIT团队利用Transformer解决经验贝叶斯问题,速度比传统方法提升百倍,性能更优,展现良好泛化能力。

原文标题:MIT三人团队:用Transformer解决经验贝叶斯问题,比经典方法快100倍

原文作者:机器之心

冷月清谈:

MIT 的研究人员利用 Transformer 成功解决了经验贝叶斯(EB)均值估计问题,速度比经典的非参数最大似然 (NPMLE) 方法快了近100倍。他们发现 Transformer 能够有效地学习可交换数据,并自然地表现出收缩效应,这对于 EB 估计至关重要。
研究人员使用了泊松-EB 任务进行实验,该任务涉及从未知先验分布中采样,然后根据该先验生成泊松分布的样本。目标是根据观察到的样本估计原始的先验值。
实验结果表明,即使参数规模很小的 Transformer 也能在这个任务上表现出色,并且在合成数据和真实数据集上都优于 NPMLE。此外,Transformer 还展现了良好的长度泛化能力,即使测试序列长度是训练长度的 4 倍,也能保持较低的后悔值。
研究人员还通过理论分析和线性探针技术研究了 Transformer 的工作机制,发现它并非简单地学习 Robbins 估计器或 NPMLE 的特征,而是学习了贝叶斯估计器的本质。

怜星夜思:

1、文章提到Transformer具有长度泛化能力,那么在其他类型的统计问题中,这种泛化能力是否仍然有效?
2、Transformer 在 EB 问题上的成功是否意味着它可以替代传统的统计方法?
3、文章中提到的“收缩效应”是什么?它对 EB 问题的求解有什么影响?

原文内容

机器之心报道
机器之心编辑部

Transformer 很成功,更一般而言,我们甚至可以将(仅编码器)Transformer 视为学习可交换数据的通用引擎。由于大多数经典的统计学任务都是基于独立同分布(iid)采用假设构建的,因此很自然可以尝试将 Transformer 用于它们。

针对经典统计问题训练 Transformer 的好处有两个:

  • 可以得到更好的估计器;

  • 可以在一个有比 NLP 更加容易和更好理解的统计结构的领域中阐释 Transformer 的工作机制。


近日,MIT 的三位研究者 Anzo Teh、Mark Jabbour 和 Yury Polyanskiy 宣称找到了一个可以满足这种需求 「可能存在的最简单的这类统计任务」,即 empirical Bayes (EB) mean estimation(经验贝叶斯均值估计)。


  • 论文标题:Solving Empirical Bayes via Transformers

  • 论文地址:https://arxiv.org/pdf/2502.09844


该团队表示:「我们认为 Transformer 适用于 EB,因为 EB 估计器会自然表现出收缩效应(即让均值估计偏向先验的最近模式),而 Transformer 也是如此,注意力机制会倾向于关注聚类 token。」对注意力机制的相关研究可参阅论文《The emergence of clusters in self-attention dynamics》。

此外,该团队还发现,EB 均值估计问题具有置换不变性,无需位置编码。

另一方面,人们非常需要这一问题的估计器,但麻烦的是最好的经典估计器(非参数最大似然 / NPMLE)也存在收敛速度缓慢的问题。

MIT 这个三人团队的研究表明 Transformer 不仅性能表现胜过 NPMLE,同时还能以其近 100 倍的速度运行!

总之,本文证明了即使对于经典的统计问题,Transformer 也提供了一种优秀的替代方案(在运行时间和性能方面)。对于简单的 1D 泊松 - EB 任务,本文还发现,即使是参数规模非常小的 Transformer(< 10 万参数)也能表现出色。

定义 EB 任务

泊松 - EB 任务:通过一个两步式过程以独立同分布(iid)方式生成 n 个样本 X_1, . . . , X_n.

第一步,从某个位于实数域 ℝ 的未知先验 π 采样 θ_1, . . . , θ_n。这里的 π 的作用是作为一个未曾见过的(非参数)隐变量,并且对其不做任何假设(设置没有连续性和平滑性假设)。

第二步,给定 θ_i,通过 X_i ∼ Poi (θ_i) 以 iid 方式有条件地对 X_i 进行采样。

这里的目标是根据看到的 X_1, . . . , X_n,通过image.png估计 θ_1, . . . , θ_n,以最小化期望的均方误差(MSE)image.png如果 π 是已知的,则这个最小化该 MSE 的贝叶斯估计器便是 θ 的后验均值,其形式如下:


其中 图片是 x 的后验密度。由于 π 是未知的,于是估计器 π 只能近似 图片这里该团队的做法是将估计器的质量量化为后悔值,定义成了图片多于图片的 MSE:


通过 Transformer 求解泊松 - EB

简单来说,该团队求解泊松 - EB 的方式如下:首先,生成合成数据并使用这些数据训练 Transformer;然后,冻结它们的权重并提供要估计的新数据。

该团队表示,这应该是首个使用神经网络模型来估计经验贝叶斯的研究工作。

理解 Transformer 是如何工作的

论文第四章试图解释 Transformer 是如何工作的,并从两个角度来实现这一目标。首先,他们建立了关于 Transformer 在解决经验贝叶斯任务中的表达能力的理论结果。其次,他们使用线性探针来研究 Transformer 的预测机制。

本文从 clipped Robbins 估计器开始,其定义如下:
 
得出:transformer 可以学习到任意精度的 clipped Robbins 估计器。即:

image.png

类似地,本文证明了 transformer 还可以近似 NPMLE。即:


完整的证明过程在附录 B 中,论文正文只提供了一个大致的概述。

接下来,研究者探讨了 Transformer 模型是如何学习的。他们通过线性探针(linear probe)技术来研究 Transformer 学习机制。

这项研究的目的是要了解 Transformer 模型是否像 Robbins 估计或 NPMLE 那样工作。图 1 中的结果显示,Transformer 模型不仅仅是学习这些特征,而是在学习贝叶斯估计器图片是什么。


总结而言,本章证明了 Transformer 可以近似 Robbins 估计器和 NPMLE(非参数最大似然估计器)。

此外,本文还使用线性探针(linear probes)来证明,经过预训练的 Transformer 的工作方式与上述两种估计器不同。

合成数据实验与真实数据实验

表 1 为模型参数设置,本文选取了两个模型,并根据层数将它们命名为 T18 和 T24,两个模型都大约有 25.6k 个参数。此外,本文还定义了 T18r 和 T24r 两个模型。


在这个实验中,本文评估了 Transformer 适应不同序列长度的能力。图 2 报告了 4096 个先验的平均后悔值。


图 6 显示 transformer 的运行时间与 ERM 的运行时间相当。


合成实验的一个重要意义在于,Transformer 展示了长度泛化能力:即使在未见过的先验分布上,当测试序列长度达到训练长度的 4 倍时,它们仍能实现更低的后悔值。这一点尤为重要,因为多项研究表明 Transformer 在长度泛化方面的表现参差不齐 [ZAC+24, WJW+24, KPNR+24, AWA+22]。

最后,本文还在真实数据集上对这些 Transformer 模型进行了评估,以完成类似的预测任务,结果表明它们通常优于经典基线方法,并且在速度方面大幅领先。


从表 3 可以看出,在大多数数据集中,Transformer 比传统方法有显著的改进。


总之,本文证明了 Transformer 能够通过上下文学习(in-context learning)掌握 EB - 泊松问题。实验过程中,作者展示了随着序列长度的增加,Transformer 能够实现后悔值的下降。在真实数据集上,本文证明了这些预训练的 Transformer 在大多数情况下能够超越经典基线方法。

© THE END 
转载请联系本公众号获得授权
投稿或寻求报道:[email protected]


我认为可以从理论上分析Transformer的泛化能力边界,就像文章中分析其表达能力一样。这样可以更有针对性地判断其在不同问题中的适用性。

我觉得长期来看,Transformer这类基于深度学习的方法很有可能取代一部分传统统计方法,尤其是在处理大规模数据和复杂问题时。

不好说,泛化能力这东西很玄学,得看具体情况。说不定换个问题,Transformer就水土不服了。

这得看实际应用场景,如果对可解释性的要求很高,那么传统方法依然是首选。如果更注重效率和性能,Transformer则更具优势。

关于“收缩效应”,可以参考James-Stein estimator,这个估计器证明了即使对独立的正态分布均值进行估计,将估计值向均值的均值收缩也能改进精度。

“收缩效应”是指将估计值向先验分布的中心靠拢的趋势。在EB问题中,由于先验分布未知,收缩效应可以帮助我们更好地利用数据信息,提高估计精度。

可以理解为是一种正则化,防止过拟合。

我觉得这个问题的关键在于其他统计问题的数据结构是否和EB问题类似。如果数据的可交换性在其他问题中也成立,那么Transformer的长度泛化能力很可能依然有效。

替代不至于,传统方法自有其优势,比如理论基础更完善,可解释性更强。Transformer可以作为一种补充,在某些特定问题上提供更高效的解决方案。