ToST:基于统计学的线性注意力机制,革新Transformer效率

ToST提出基于统计学的线性注意力机制,显著提升Transformer效率,在多个领域的任务中取得出色性能。

原文标题:首个基于统计学的线性注意力机制ToST,高分拿下ICLR Spotlight

原文作者:机器之心

冷月清谈:

加州大学伯克利分校等机构的研究者提出了一种名为Token Statistics Transformer (ToST) 的新型Transformer架构,其核心在于一种线性时间复杂度的注意力机制——Token Statistics Self-Attention (TSSA)。与传统Transformer的二次方复杂度相比,ToST通过对序列特征进行统计建模,避免了token两两相似度的计算,极大提高了效率。

ToST基于变分编码率缩减 (VRR) 框架,并通过对最大编码率缩减 (MCR²) 目标进行变分展开和梯度下降推导出TSSA。TSSA利用token特征的统计量构建注意力,而非传统的两两相似性计算。这使得ToST在计算和内存复杂度上都实现了线性扩展,尤其在长序列任务中优势显著。

实验结果表明,ToST在自然语言处理和计算机视觉等多个领域的任务中均取得了与传统Transformer相当的性能,同时显著降低了计算资源消耗。在ImageNet-1k等视觉数据集以及长序列任务基准测试中,ToST展现了出色的性能和效率。此外,ToST还增强了模型的可解释性,其注意力操作基于统计量的低秩投影,更易于理解和分析。

怜星夜思:

1、ToST的线性复杂度注意力机制如何应用于实际的工程项目中?有哪些需要注意的点?
2、相比于其他的线性注意力机制,例如Performer,Linear Transformer等,ToST的优势和劣势分别是什么?
3、ToST如何处理不同模态的数据,例如文本、图像、音频等?未来在多模态学习方面有哪些潜在的应用?

原文内容

图片

AIxiv专栏是机器之心发布学术、技术内容的栏目。过去数年,机器之心AIxiv专栏接收报道了2000多篇内容,覆盖全球各大高校与企业的顶级实验室,有效促进了学术交流与传播。如果您有优秀的工作想要分享,欢迎投稿或者联系报道。投稿邮箱:[email protected][email protected]


本文第一作者为加州大学伯克利分校三年级博士生吴梓阳,导师为马毅教授。吴的主要研究方向为表征学习与多模态学习。该工作由多所学校与机构的研究者共同完成,包括加州大学伯克利分校、宾夕法尼亚大学、密歇根大学、清华大学、忆生科技、香港大学、约翰·霍普金斯大学等。据悉,马毅教授已受邀在今年四月的ICLR大会上就和此项成果相关的一系列白盒神经网络相关工作,进行为时一小时的主题报告(Keynote)。


Transformer 架构在过去几年中通过注意力机制在多个领域(如计算机视觉、自然语言处理和长序列任务)中取得了非凡的成就。然而,其核心组件「自注意力机制」 的计算复杂度随输入 token 数量呈二次方增长,导致资源消耗巨大,难以扩展到更长的序列或更大的模型。


Token Statistics Transformer (ToST) 提出了一种新的注意力机制,它的时间复杂度是线性的。通过对序列特征的统计建模,ToST 提高了序列处理任务中的效率。文章探讨了基于变分编码率缩减(Variational Rate Reduction, VRR)的框架,并通过实验验证了其在不同任务中的性能,通过革新传统注意力机制,解决了这些长期困扰 Transformer 架构的效率瓶颈。


ToST 也作为 Spotlight 论文,入选了 ICLR 2025 大会。



  • 论文标题:Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction
  • 论文地址:https://arxiv.org/abs/2412.17810
  • 项目主页:https://robinwu218.github.io/ToST/
  • 目前该工作已开源:https://github.com/RobinWu218/ToST


研究背景与动机


一直以来,自注意力机制依赖于对输入 token 两两相似性的计算,这一过程虽然有效,但其资源开销显著;尤其当输入 token 数量极大时,传统注意力机制(如 Transformer 中的全局注意力)在计算复杂度和内存使用上的瓶颈问题愈发显著。


为了应对这一挑战,本文提出了一种基于统计学特征的注意力机制:Token Statistics Self-Attention (TSSA)。它通过避免两两相似性的计算,仅依赖于 token 特征的统计量,显著降低了计算复杂度。

 

Token Statistics Transformer (ToST) 的架构。Token Statistics Self-Attention (TSSA) 运算符通过对投影后的 token 进行行标量化变换,从而实现了线性复杂度。


核心方法


ToST 的核心方法是通过特定的概率分布函数对输入序列进行建模,减少冗余信息并提取关键特征。具体包括:


1. 统计特征提取:对序列中的每个 token 提取其统计特征。

2. 变分编码率缩减:利用 VRR 框架对特征进行压缩,减少信息冗余。

3. 线性复杂度实现:通过一系列优化,其计算复杂度从 O (n²) 降低为 O (n)。

 

ToST 的方法概述。在 CRATE 的理论基础上,ToST 通过几何空间的结构化特征实现 token 分组和映射。


网络架构的推导


该团队通过扩展先前的 CRATE 工作推导出网络架构。CRATE 显示,一种 Transformer 风格的架构可以通过 "白盒" 架构设计自然生成,其中网络的每一层都旨在实现最大编码率缩减目标 (MCR²) 的增量优化步骤。


具体来说,该团队推导了 MCR² 目标的一个新颖的变分形式,并表明通过对该变分目标进行展开梯度下降所得到的架构会引入一种新的注意力模块,称为 Token Statistics Self-Attention (TSSA)。TSSA 拥有线性的计算和内存复杂度,并从根本上不同于典型的注意力架构,其后者通过计算 token 之间的两两相似性来实现。


关键公式 MCR² 目标函数定义


技术细节


1. 线性时间注意力机制:Token Statistics Self-Attention (TSSA)


通过白盒设计方法(algorithmic unrolling),TSSA 从最大编码率减少(Maximal Coding Rate Reduction, MCR² )的变分形式中推导而来。


传统 Transformer 依赖于 pairwise 相似度计算,而 TSSA 则基于 token 特征的统计量构建注意力机制,其计算复杂度从 O (n²) 降低为 O (n),内存占用同样显著减少。


2. 创新性的网络结构:Token Statistics Transformer (ToST)


ToST 通过将 TSSA 替代标准的自注意力模块,不仅实现了显著的效率提升,还增强了模型的可解释性。


传统模型不同,ToST 架构中的注意力操作基于统计量的低秩投影,通过减少不必要的计算路径,大幅优化了资源使用。


3. 理论支撑与数学推导


基于 MCR² 的变分形式,提出了一种新颖的压缩项公式,可对大型矩阵进行有效的特征提取。


通过设计数据相关的低秩投影,TSSA 在保留关键信息的同时,消除了冗余方向。


实验验证与性能分析


实验覆盖了自然言语处理(NLP)、计算机视觉(CV)等多个领域的任务,包括文本分类、机器翻译、图像识别等。结果表明,ToST 在保证模型性能的同时,大幅降低了计算资源消耗。


1. 计算和内存的线性复杂度分析


实验结果显示,与现有的注意力机制相比,TSSA 的时间和内存复杂度更低。具体而言,TSSA 的复杂度为 O (pn),显著优于传统 Transformer 的 O (n²)。

ToST 在计算时间和内存使用上均随序列长度实现线性扩展,使其显著优于标准 Transformer 的效率。如下:


复杂度分析对比

 

在 GPU 上评估的速度和内存使用对比


2. 视觉任务性能分析


在 ImageNet-1k 等主流视觉数据集上的实验表明,ToST 的性能可与传统 Transformer 架构(如 ViT 和 XCiT)相媲美,同时显著减少了模型参数量和计算开销。


迁移学习实验中,ToST 在 CIFAR、Oxford Flowers 等数据集上的表现进一步验证了其在多种视觉任务中的适应性。


结果展示了与传统 Transformer 相当的性能,同时在计算效率上显著更高。

 

3. 长序列任务和语言建模


  • 长序列任务


在长序列任务基准测试(如 Long-Range Arena)中,ToST 展现出优异的长距离建模能力,其性能超越了现有 Transformer 变体。


  • 语言建模


ToST 可以扩展并适用于多种任务场景,包括因果语言建模。针对语言建模,ToST 采用了一种因果版本的 TSSA,在多个数据集上实现了高效的预测能力。此外,即使在参数规模扩大的情况下,ToST 依然保持了优异的时间和内存效率。


NLP 任务中的表现


4. 有原理支持的模型设计


由于 ToST 是通过展开从学习目标中推导出来的,我们可以以有原理支持的方式逐层分析学习到的模型行为。


ToST 模型不同层次的 TSSA 输出的变分压缩项


5. 学习表示的可解释性分析


ToST 通过统计量驱动的注意力机制,使每一层的注意力操作更加透明,便于解释和分析。其分组机制展现了 token 特征在低维空间中的聚类效果,直观反映了模型的决策过程。


ToST 在无需复杂的自监督训练的情况下,自然生成了可解释的注意力模式。


倒数第二个全局类注意力层中最后一个头部的 [CLS] token 注意力图的比较


在 TSSA 层中,可视化估计的隶属矩阵 Π 的每一行(经过重塑后)


可能对未来产生的影响


1. 大模型的高效化


随着语言模型、生成模型和多模态模型规模的持续扩展,计算效率成为核心瓶颈。ToST 展示的统计量驱动注意力机制,为实现线性复杂度的大模型提供了可能性。


2. 推动 Transformer 的普适化应用


高效的注意力机制使得 ToST 能够更广泛地应用于资源受限场景,如边缘计算、实时系统、嵌入式设备等。这为人工智能技术从中心化计算向分布式、边缘化方向的发展奠定了基础。


3. 多模态融合的可能性


ToST 的低复杂度机制为处理多模态长序列任务提供了新的技术框架,使未来多模态大模型在生成、分析和交互中的效率显著提升。


4. 促进跨学科应用


ToST 对数学理论与工程实现的有机结合,不仅在传统 AI 任务中表现突出,还可能推动其在新兴领域(如量子计算、生物信息学和材料设计)中的应用。


Token Statistics Transformer (ToST) 重塑了注意力机制,它不需要计算 token 之间的两两交互,而是基于投影后 token 特征的二阶矩统计量构建,其基于数据压缩和表示学习的理论原则目标,为 Transformer 的发展开辟了新路径。其基于统计特性的低复杂度设计,不仅优化了现有架构的性能,还为未来大模型的高效化、多模态融合和跨学科应用提供了启示。



© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:[email protected]


Performer和Linear Transformer我记得主要用核方法来近似注意力矩阵,而ToST则是通过统计特征来构建注意力。相比之下,ToST的方法可能对输入数据的分布更加敏感,如果数据分布比较复杂,可能会影响ToST的性能。而Performer和Linear Transformer的泛化能力可能更好一些,对不同类型的数据集的适应性更强。

关于ToST处理多模态数据的问题,我觉得文章中提到的“统计特征”是很关键的一点。因为不同的模态数据的数据表示方式差异很大,例如文本数据通常用词向量表示,图像数据用像素值或特征图表示,音频数据用波形或频谱表示。ToST如果要处理多模态数据,就需要找到一种方法,能够将不同模态的数据转换成统一的统计特征表示,这样才能应用TSSA进行注意力计算。这是一个很有挑战性的问题。

从实际应用的角度来看,我觉得ToST的优势在于它的计算效率更高,内存占用更少。这对于处理长序列数据非常重要,可以有效地减少训练时间和计算成本。但是ToST的劣势在于它的理论相对比较复杂,理解和实现起来可能比Performer和Linear Transformer更难一些,对开发人员的要求更高。所以选择哪种方法,需要根据具体的应用场景和技术水平来决定。

ToST应用到工程项目中,我觉得除了硬件适配,还要考虑模型的规模和数据集的特点。ToST虽然复杂度低,但也需要根据实际任务的需求选择合适的模型大小。如果数据集规模很大,特征维度很高,那么即使是线性复杂度,也可能需要大量的计算资源。所以在实际应用中,需要根据数据集的特点,选择合适的模型规模和参数配置,并在训练过程中监控模型的性能和资源消耗,以便进行及时的调整。

关于ToST在线性复杂度注意力机制在实际工程项目中的应用,我感觉首先要考虑的是硬件的适配性。因为ToST的优势在于线性复杂度,所以理论上它对硬件的要求更低,可以部署在一些资源受限的设备上,例如移动端、嵌入式设备等等。但是实际操作中,需要针对具体的硬件平台进行优化,例如针对不同的GPU架构、CPU指令集进行代码的适配和优化,才能最大限度地发挥ToST的性能优势。

除了硬件和模型规模,我觉得软件生态也挺重要的。目前深度学习框架对ToST的支持程度如何?有没有相关的库和工具可以方便地使用?如果需要自己从头实现,开发成本和维护成本会有多高?这些都是需要在实际工程项目中考虑的问题。一个良好的软件生态可以大大降低开发和部署的难度,加快项目的进度。

从更长远的角度来看,我觉得ToST甚至可以应用于一些更复杂的跨模态任务,例如将文本描述转换成图像、将语音转换成动画等等。这些任务都需要模型能够理解不同模态数据之间的关联性,并进行跨模态的转换和生成。ToST的线性复杂度注意力机制,可以为这些任务提供高效的计算支持,从而推动多模态学习的发展。

如果ToST能够有效地处理多模态数据,那么在多模态学习方面会有很大的应用潜力。例如,可以将ToST应用于视频理解、图像 captioning、语音识别等等任务。通过对不同模态数据的联合建模,可以提高模型的理解和生成能力。比如在视频理解中,可以结合视频的图像信息和音频信息,更全面地理解视频内容。

说到ToST和其它线性注意力机制的比较,我觉得一个很重要的点是ToST的理论基础。ToST是基于变分编码率缩减(VRR)框架和最大编码率减少(MCR²)目标推导出来的,这使得ToST的设计更加有理有据,而不是像一些其他的线性注意力机制那样,仅仅是基于经验的改进。这种理论上的支撑,使得ToST的性能更加稳定,也更具可解释性。