蚂蚁GCA:ICML 2025 论文揭示千倍长度泛化的长文本理解新方案

蚂蚁提出GCA注意力机制,实现千倍长度泛化,16M上下文精准理解。端到端学习检索相关片段,显著降低显存开销,为长文本处理带来新思路。

原文标题:ICML 2025 | 千倍长度泛化!蚂蚁新注意力机制GCA实现16M长上下文精准理解

原文作者:机器之心

冷月清谈:

蚂蚁技术团队提出了一种名为GCA(Grouped Cross Attention)的新型注意力机制,旨在解决长文本建模中的挑战。GCA通过端到端学习如何从上文检索并挑选最相关片段,实现了超长序列的高性能处理与泛化能力。该方法模仿人类记忆的特性,将不常用的上文信息卸载到CPU/磁盘,只在需要时动态加载,显著降低了显存开销。实验结果表明,整合GCA的模型在长文本数据集上表现出更优的perplexity,并实现了1000倍以上的长度泛化能力,在16K上下文预训练的模型可在16M长上下文密钥检索中实现100%准确率。GCA的核心优势在于其端到端的可训练性,通过两阶段注意力机制,使得每个chunk的检索分能参与到自回归预测中。模型架构结合了GCA和滑动窗口注意力,前者负责长程信息检索,后者负责整合短程信息。实验还表明,GCA学到的不仅是字面相似性,更包含了语义乃至逻辑相关性,例如在arXiv-math数据集上,模型会检索下文生成中可能会用到的引理及变量声明。

怜星夜思:

1、GCA通过卸载不常用的KV cache到CPU来降低显存占用,这种CPU-GPU数据交换会不会成为性能瓶颈?实际应用中如何平衡显存占用和推理速度?
2、GCA的核心在于可导的检索模块,可以端到端学习。那么,GCA的检索模块具体是如何实现的?它与传统的RAG方法中的检索模块有哪些关键区别,使得GCA能够实现更好的性能?
3、论文提到GCA在16K上下文预训练的模型可在16M长上下文密钥检索实现100%准确率,这个结果非常惊艳。但是,实际应用中,我们真的需要如此长的上下文吗?GCA主要适用于哪些具体的应用场景?

原文内容


该工作第一作者为蚂蚁技术研究院副研究员胡翔,蚂蚁技术研究院高级研究员武威为通讯作者。


在大语言模型如火如荼的当下,长文本建模仍然是一个极具挑战的问题。纠其根源,一方面在于主流 LLMs 的架构 Transformers 中平方复杂度及随序列长度线性增长的推理阶段显存开销;另一方面在于 full-attention 有限的外推能力,难以泛化到远超预训练阶段长度的输入。


而高效处理长上下文能力,除了简单的工业界降本增效的需求外,还涉及通用人工智能 (AGI) 的核心问题:具有永久记忆的智能体。如果将人类从出生开始接收到的信息视作长上下文,人类拥有记忆无非是访问这些上下文。因此记忆可以看作是超长上下文访问能力,而拥有与用户所有对话记忆的智能体,很可能为大语言模型公司构建数据护城河 (事实上,OpenAI 已经开放了类似能力)。


近日,蚂蚁的研究团队为这个问题带来了一个新思路。就像人类开卷考试只会挑和当前问题相关的关键页作为参考,语言模型也可以只关注与当前上下文相关的过去片段。以此为出发点,他们提出一种基于因果检索的注意力机制 GCA (Grouped Cross Attention),完全端到端地学习如何从上文检索并挑选最相关片段,从而实现超长序列高性能处理与泛化能力。人类记忆的另一个特性是大部分时候记忆处于沉睡状态,相关记忆片段只会在激活时进入意识。类似地,GCA 通过将上文信息卸载到 CPU / 磁盘,只在需要的时候动态加载需要的片段到 GPU 的方式,大幅降低了长文本处理的显存开销。


目前,GCA 的 Triton kernel 实现已全部开源,相关论文已被 ICML 2025 接收。


  • 论文标题:Efficient Length-Generalizable Attention via Causal Retrieval for Long-Context Language Modeling

  • 论文地址:https://arxiv.org/abs/2410.01651

  • GitHub 主页:https://github.com/ant-research/long-context-modeling


实验结果也令人振奋:整合 GCA 的模型不仅在长文本数据集上展现了更优的 perplexity,更展现了 1000 倍以上的长度泛化能力,在 16K 上下文预训练的模型可在 16M 长上下文密钥检索 (passkey retrieval) 实现 100% 准确率,并在更复杂的多跳检索任务持续展现了超强外推能力。此外长度泛化与检索能力效果拔群,基于 GCA 的模型训练开销随序列长度几乎呈线性关系,并且推理的显存开销接近常数,同时基本持平 Transformers 推理速度。


值得一提的是,本工作 24 年 10 月在 arXiv 公开后,国产之光 DeepSeek 在 25 年初公开了 NSA,两者思路都是通过挑选过去 chunk 并 attention 的方式实现性能优化。但各有侧重,GCA 核心亮点在于超长的长度泛化,NSA 通过巧妙的 kernel 设计实现了逐 token 的稀疏 attention。受 NSA 的启发,GCA 的后继工作 HSA (https://arxiv.org/abs/2504.16795) 结合了两者的优点进行了融合。


长文本处理难点及现有方案的局限性


近年来,有不少工作讨论 Transformers (TRMs) 架构如何高效处理长文本。因为基于全量上文 attention 的 TRMs 有一个很显著的局限:输入长度超过预训练长度一定程度后,perplexity 会飙升,无法生成正常文本。如果只是解决正常生成的问题,一个最简单的思路是滑动窗口注意力,即每个 token 仅关注最邻近的 N 个 token 即可。这种方式可以保证 LLMs 持续生成,但它牺牲了长程信息获取能力。


另一种思路是认为 attention 窗口扩大到预训练长度范围之外后会导致原本的 attention 权重分布发生变化,因此通过调整 softmax 温度的方式进行长度泛化。但这类方法经实验验证往往泛化的倍率也有限。


因此,attention 长度泛化的难点在于处理超长序列的同时,能够真正有效利用上文中的信息。


GCA: 基于端到端因果检索的注意力机制


现有一些工作通过检索增强 (RAG) 的思路来进行长文本建模,其基本思路是将文本分段,譬如每 64 个 token 为一个 chunk;每生成一个 chunk 后,模型根据当前上文信息检索历史 chunk 来辅助下一个 chunk 的生成。理想情况下,只要能检索到对下文生成最有帮助的 chunk,再通过 cross-attention 机制从相关 chunk 收集信息即可。但通常检索模块是单独训练的,只能检索到相似内容,无法保证挑选对下文生成最有帮助的 chunk。


和已有工作相比,GCA 的一个显著优势是能够与自回归语言模型联合预训练,从而实现端到端学习。


上图对比了 GCA 与传统检索方式的运作区别。传统方式中 (a), 检索模块检索并返回相关 chunk,但检索分只用于挑选 chunk 完全不参与 forward 运算,因此无法获得梯度,无法学习。GCA 的核心创新在于通过一种两阶段的注意力机制,使得每个 chunk 的检索分能参与到自回归预测中,如图中(b)所示。


1. 分组注意力机制


不同于 (a) 中直接将 chunk 拼接在一起进行 attention, GCA 分别对每个 chunk 进行 attention (分组 attention),从各个 chunk 收集 token 粒度的信息并整合,作为每个 chunk 整体的信息。


2. Chunk-level 信息融合


GCA 将每个 chunk 的检索相关分通过 softmax 得到一个概率分布,将其作为权重对第一步所有 chunk 的表征进行加权求和,融合所有 chunk 信息用于下一个 token 预测。在反向传播过程中,更有助于预测下文的 chunk 将被分配更大的权重,从而实现检索模块的端到端学习。


模型整体架构是通过 GCA 与 sliding window attention 结合实现长上下文建模;前者负责长程信息检索,后者负责整合短程信息。为了进一步提升 GCA 性能,降低显存开销,研究团队将整个 GCA 封装成由 Triton 实现的 kernel,方便未来工作可以直接复用。


实验结果


在语言模型,长程检索等任务上的实验表明:


1. 基于 GCA 的 128M 的模型在大海捞针任务即可超越大部分主流 7B 模型,达成 1000 倍外推,实现 16M 上下文的完美大海捞针


在该实验中,所有模型都仅在不超过 16K 的上下文进行预训练,baseline 囊括了包含 sliding window attention 等主流注意力机制。基于 GCA 的模型无论在简单大海捞针,还是更复杂的变量追踪任务,都保持了稳定的外推能力。


注意到几乎所有 baseline 在上下文长度超过 64K 后几乎都归零,这些不同模型存在不同原因。划窗注意力因为只能看最邻近的 token,无法实现长程信息获取;基于循环结构的由于所有上下文信息都被压缩在一个固定维度的表征,必然存在信息损失的问题;基于单独训练检索器的模型 (RPTContriever) 的结果进一步验证了检索模型未必能检索到对下文有帮助的上文。


这一结果经验性地为可长度泛化的注意力机制提供了一个成功的概念原型。同时证明可泛化的长程信息获取能力取决于注意力机制原理上的改进,与参数量的提升无关。


在摘要及 RULER 榜单的效果


2. 预训练高效,推理时显存开销接近常数:GCA 是一种 sparse attention,其 attention 的视野域保持常数,因此在 batch size 一定的情况下,训练开销几乎与序列长度呈线性。由于 GCA 在生成阶段将所有上文的 KV cache 都卸载到 CPU,每次检索的时候才把相关 chunk 的 kv cache 载入 GPU,因此超长上文也不会有 KV cache 显存爆炸的问题。而 GPU-CPU 的交换控制在每 64 个 token 一次,因此对推理速度影响非常小,从而实现接近常数的显存开销,但仍保持高效的推理速度及长程信息获取能力。


训练时间及 ppl 随序列长度的变化


推理速度与显存开销相比基线 (基于划窗注意力的 Transformers) 的倍率关系(越低越好)


相同条件不同模型各个参数规模下的训练吞吐量,相比划窗注意力有额外 20% 的开销,但带来超长程信息获取的能力


3. 在 arXiv-math 上的数据分析发现,通过 GCA,语言模型会根据当前上下文,检索下文生成中可能会用到的引理及变量声明。这说明 GCA 学到的不仅仅是字面相似性,更包含了语义乃至逻辑相关性。


黑体是当前 chunk,红色,蓝色,黄色,分别代表 top3 相关 chunk、


结语


本工作提出一种可以长度泛化的稀疏注意力机制 GCA, 其核心在于可导的检索模块,可以有效处理 1000 倍于预训练长度的文本,首次实现在 16M 长度完美的大海捞针。虽然当前实验的模型规模较小,但期望该工作可以为机器如何实现永久记忆提供新的研究思路。


© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:[email protected]

16M上下文的100%准确率确实让人眼前一亮!虽然日常对话可能用不到这么长的上下文,但在某些特定领域,超长上下文处理能力至关重要。

我认为GCA在以下场景具有很大的应用潜力:

1. 法律/金融文档分析:分析大量法规、合同、财务报表,从中提取关键信息,进行风险评估和合规检查。
2. 医学报告解读:整合患者的病历、影像资料、基因组数据等,辅助医生进行诊断和制定治疗方案。
3. 科学研究:分析大量的实验数据、文献资料,发现新的科学规律。
4. 软件开发:理解大型代码库,进行代码重构、bug修复和安全漏洞检测。
5. 历史研究:分析大量的历史文献、档案资料,还原历史事件的真相。

这些场景都需要模型能够理解和处理海量的信息,并从中提取出关键的线索。GCA的超长上下文处理能力,能够帮助人们更好地应对这些挑战。

我理解GCA的这个设计就像是给LLM加了个“外接硬盘”,虽然容量大了,但读写速度肯定不如直接用GPU上的显存。不过,好在GCA的精髓在于“只读需要的”,所以只要检索算法足够精准,每次交换的数据量并不大,对整体性能的影响应该能控制在一个可以接受的范围内。

当然,如果CPU和GPU之间的带宽成了瓶颈,那就得考虑升级硬件了,或者看看有没有更高级的压缩算法,能在不损失太多信息的前提下,减少数据传输量。说不定以后会有专门的“LLM外接硬盘”接口呢,专门优化CPU-GPU之间的通信!

我来泼一盆冷水。 16M上下文听起来很美好,但实际应用中,信息的质量往往参差不齐。如果模型处理了大量的噪音数据,反而会影响性能。

所以,我认为在追求长上下文的同时,更重要的是提高数据的质量和信噪比。可以考虑使用一些数据清洗和过滤技术,去除冗余和无关的信息。

另外,GCA的检索模块也需要进一步优化,提高检索的准确性和效率。如果检索到了错误的信息,长上下文反而会成为累赘。

这个问题问到了点子上!CPU-GPU数据交换确实是潜在的瓶颈。个人觉得,GCA这种方案更适合对延迟不敏感,但对成本和显存有极致要求的场景,比如离线文档处理、长文本分析等。对于在线交互式应用,可能还需要更激进的优化手段。

补充一点,除了你提到的策略,还可以考虑用NUMA架构的服务器,让CPU访问GPU附近的内存,减少跨节点的数据传输延迟。这属于硬件层面的优化了。

这是一个很关键的问题!虽然GCA通过CPU-GPU卸载KV cache降低了显存占用,但频繁的数据交换确实可能带来性能损耗。我认为可以从以下几个方面进行优化:

1. 更智能的卸载策略:目前是每64个token交换一次,可以尝试动态调整这个频率,比如根据上下文信息熵或attention score,预测哪些chunk更可能被用到,并优先保留在GPU中。
2. 异步数据传输:利用CUDA的异步传输能力,在GPU计算的同时进行CPU-GPU的数据传输,隐藏一部分传输延迟。
3. 混合精度训练与推理:使用FP16甚至INT8来表示KV cache,进一步降低显存占用和传输带宽需求。

总的来说,需要在实际应用中根据具体场景和硬件配置,找到一个显存占用和推理速度之间的最佳平衡点。

GCA的检索模块是其能够实现端到端学习的关键。与传统RAG方法相比,GCA的检索模块并非独立训练的,而是与自回归语言模型联合预训练。这意味着检索模块的优化目标直接指向语言模型的生成效果,而非仅仅是检索相似的内容。

具体来说,GCA通过两阶段的注意力机制,使得每个chunk的检索得分能够参与到自回归预测中。在反向传播过程中,更有助于预测下文的chunk将被分配更大的权重,从而实现检索模块的端到端学习。这种端到端的学习方式使得GCA能够学习到字面相似性之外的语义甚至逻辑相关性,从而挑选出对下文生成最有帮助的chunk。

而传统的RAG方法中,检索模块通常是单独训练的,只能检索到相似内容,无法保证挑选对下文生成最有帮助的chunk。这种分离式的训练方式导致检索模块的优化目标与语言模型的生成效果存在gap,限制了整体性能的提升。

谢邀,这个问题我来试着答一下。GCA的检索模块之所以work,是因为它把“检索”这个动作也变成了模型的一部分,可以一起训练。你可以把它想象成一个“智能推荐系统”,这个系统不仅知道你搜了啥,还知道这些东西对你完成任务有没有帮助。

传统的RAG就像是“关键词搜索”,只能找到字面相似的东西,但GCA能理解你的意图,找到真正有用的信息。这种区别就像是“精准导航”和“瞎猫碰死耗子”,效果当然不一样啦!

当然,这种端到端的方式也带来了更高的训练成本,需要更多的数据和更强的计算能力,但为了更好的性能,我认为是值得的。

我觉得这个问题问得好!有时候我们可能会陷入“技术炫技”的误区,追求极致的指标,而忽略了实际需求。16M上下文固然厉害,但关键是whether we really need it.

除了楼上提到的专业领域,我认为GCA在以下场景也可能有用武之地:

* 游戏AI:为游戏角色提供更丰富的背景故事和记忆,使其行为更加真实和可信。
* 虚拟助手:记住用户所有的对话历史,提供更个性化的服务。
* 教育:为学生提供个性化的学习内容和辅导,根据学生的学习历史和知识掌握程度进行调整。

但是,这些场景对成本和延迟也有很高的要求。所以,GCA需要在这些方面进行进一步的优化,才能真正落地应用。

我从GCA论文里找到了更详细的描述:GCA的检索模块实际上是通过一个chunk-level的softmax来实现的。每个chunk经过分组注意力后,会得到一个表征向量,然后通过softmax计算每个chunk的“相关性得分”。这个得分会被用来加权融合所有chunk的信息,用于下一个token的预测。

关键在于,这个softmax的输出是参与到forward过程的,因此可以获得梯度,从而实现端到端学习。而传统的RAG方法,检索模块的输出只是作为输入,不参与梯度计算,无法实现端到端优化。

所以,GCA的创新之处在于把“检索”这个离散的操作,转化成了一个可微的过程,从而可以利用深度学习的强大能力进行优化。