字节Seed团队提出PHD-Transformer:突破预训练长度扩展,解决KV缓存膨胀难题

字节Seed团队提出PHD-Transformer,通过创新KV缓存管理实现预训练长度扩展,解决KV缓存膨胀难题,显著提升模型性能和推理速度。

原文标题:字节Seed团队PHD-Transformer突破预训练长度扩展!破解KV缓存膨胀难题

原文作者:机器之心

冷月清谈:

本文介绍了字节跳动Seed团队提出的PHD-Transformer,一种在预训练阶段实现长度扩展的新方法。该方法通过重复输入token,并创新性地管理KV缓存,有效解决了传统方法中KV缓存膨胀导致的推理速度慢和内存占用高等问题。PHD-Transformer及其变体PHD-SWA和PHD-CSWA,在保持甚至降低KV缓存大小的同时,实现了显著的性能提升和推理加速。实验结果表明,该方法在多个基准测试中均取得了优异的成绩,为预训练长度扩展提供了一个高效可行的解决方案。

怜星夜思:

1、PHD-Transformer通过重复输入token实现长度扩展,这种方式会不会导致模型过拟合,尤其是在重复次数较多时?
2、论文中提到PHD-Transformer的核心在于原始token和隐藏解码token的解耦,那么这种解耦具体是如何影响模型的性能和效率的?
3、PHD-CSWA通过分块处理滑动窗口注意力来减少预填充时间,那么这个块大小的选择会对模型的性能有什么影响?

原文内容

机器之心报道

编辑:杜伟


最近,DeepSeek-R1 和 OpenAI o1/03 等推理大模型在后训练阶段探索了长度扩展(length scaling),通过强化学习(比如 PPO、GPRO)训练模型生成很长的推理链(CoT),并在奥数等高难度推理任务上取得了显著的效果提升。


受此启发,研究人员开始探索预训练阶段的长度扩展,已有方法包括在序列中插入文本、插入潜在向量(如 Coconut)、复用中间层隐藏状态(如 CoTFormer)以及将中间隐藏状态映射为概念(如 COCOMix)。不过,这些方法普遍存在问题,比如需要更大的 KV 缓存导致推理慢 / 占内存多。


本文中,来自 ByteDance Seed 团队的研究者提出了更简单的方法:直接重复输入 tokens(1/2/3/4 次),不做中间层处理。他们观察到了训练损失和模型性能随重复倍数扩展的趋势,如下图 1a 和 1b 所示。但是,直接重复 tokens 也带来了新问题,包括 KV 缓存规模线性增加,内存压力大;预填充时间超线性增加;解码延迟变长。这些都是实现预训练长度扩展需要重点解决的挑战。



  • 论文标题:Efficient Pretraining Length Scaling

  • arXiv 地址:https://arxiv.org/pdf/2504.14992


研究者提出了一种推理友好的新颖长度扩展方法,核心是 PHD-Transformer(Parallel Hidden Decoding Transformer),它保持了与原始 transformer 相同的 KV 缓存大小,同时实现有效的长度扩展。PHD-Transformer 通过创新的 KV 缓存管理策略实现了这些能力。


具体来讲,研究者将第一个 token 表示原始 token,将重复的 token 表示为解码 token。同时仅保留从原始 token 生成的 KV 缓存来用于长距离依赖建模,并在隐藏解码 token 用于下一个 token 预测之后丢弃它们的 KV 缓存。因此,PHD-Transformer 提供了与原始 transformer 相同的 KV 缓存,同时相较于简单的 token 重复实现了显著的推理加速(如图 1d 所示)。


另外,为了更好地保留隐藏解码 token 的 KV 缓存的性能优势,研究者引入了一种滑动窗口注意力 ——PHD-SWA,保持了这些 token 的局部滑动窗口缓存,在实现显著性能提升的同时,仅需要

的额外 KV 缓存内存。


研究者还注意到,在 PHD-SWA 中,隐藏解码 token 的 KV 缓存表现出了顺序依赖关系,这导致预填充时间呈线性增长。为了解决这个问题,研究者提出了逐块滑动窗口注意力 —— PHD-CSWA,从而限制了每个块内的顺序依赖关系。


因此,得益于只有最后一个块的预填充时间呈线性增长,PHD-CSWA 显著缩短了预填充时间(如图 1c 所示)。



方法概览


PHD 的架构下图 2 所示,与原始 Transformer 相比,PHD 保留了相同的模型架构,仅在输入序列和注意力矩阵的设计上有所不同。具体而言,他们仅允许原始 token

生成 KV 缓存,并且可以被所有 token 全局关注;同时隐藏状态的 KV 缓存在并行隐藏解码后会被立即丢弃。注意力矩阵的策略具体如下: 


研究者在推理过程中实现了与原始 Transformer 相同的 KV 缓存大小和内存访问模式。虽然需要 K 次 FLOP,但这些计算可以并行处理,从而在内存受限的推理场景中最大限度地降低延迟开销。该架构的核心优势在于原始 token 和隐藏解码 token 之间的解耦。在预填充期间,只有原始 token 需要计算。


这种设计确保预填充时间与原始 Transformer 相同,并且无论扩展因子 K 如何变化,预填充时间都保持不变。而对于损失计算,研究者仅使用 token 的最终副本进行下一个 token 的预测。总之,使用 token 的第一个副本进行 KV 缓存生成,使用 token 的最后一个副本进行下一个 token 的预测。



内核设计


M^ij_mn 的简单实现会导致注意力层计算量增加 K^2 倍,FFN 层计算量也增加 K 倍。然而,由于注意力是稀疏计算的,

的注意力可以大幅降低。因此,研究者将原始 token 和隐藏解码 token 分成两组,并将它们连接在一起。


下图 3 展示了 K = 3 的示例,可以得到一个包含 t 个原始 token 的序列和一个包含 2t 个隐藏解码序列的序列。通过重新排列 token 的位置,研究者将掩码注意力的位置保留在一个连续块中,从而优化了注意力计算,将注意力计算复杂度降低到



PHD-SWA 和 PHD-CSWA


与简单的 token 重复相比,PHD-Transformer 在保持原始 KV 缓存大小的同时实现了长度扩展。然而通过经验观察到,为隐藏解码 token 保留一些 KV 缓存可以带来显著的性能提升。因此,为了在保持效率的同时获得这些优势,研究者引入了 PHD-SWA,将滑动窗口注意力限制在 W 个先前的隐藏解码 token 上。


如下图 4 所示,PHD-SWA 的注意力模式将对原始 token 的全局访问与对 W 个最近隐藏解码 token 的局部访问相结合。这种改进的注意力机制实现了显著的性能提升,同时仅需要

的额外 KV 缓存内存。


虽然 PHD-SWA 滑动窗口方法提升了模型性能,但由于隐藏解码 token 的 KV 缓存中存在顺序依赖关系,它会产生 K 倍的预填充开销。为了解决这个问题,研究者引入了 PHD-CSWA,它可以在独立的块内处理注意力。 


如下图 4 所示,PHD-CSWA 将滑动窗口注意力限制在单个块内运行。这种架构创新将额外的预填充开销减少到最终块内的 K 次重复,而不是整个序列重复,这使得额外的计算成本几乎可以忽略不计,同时保留了局部注意力模式的优势。



实验结果


在实验中,研究者使用 OLMo2 作为代码库,并在 ARC、HellaSwag、PIQA、Winogrande、MMLU 和 CommonsenseQA 等公开基准测试集上进行了评估。


训练细节:研究者使用 1.2B 参数规模的模型,它是一个 16 层的密集模型。每个 token 的隐藏层维数设置为 2048,FFN 层的隐藏层大小设置为 16384。同时使用组查询注意力 (Group-Query Attention,GQA),它包含 32 个查询头和 8 个键 / 值头,每个头的隐藏层维数设置为 64。研究者使用 500B 个 token 训练该模型。


对于本文提出的 PHD 系列设置,研究者预训练了以下两种 PHD-CSWA 变体:


  • PHD-CSWA-2-16-32,其中训练 token 重复两次。保留一个包含 16 个 token 的局部窗口,并将块大小设置为 32 个 token。 

  • PHD-CSWA-3-16-32,其中训练 token 重复三次。局部窗口大小和块大小与 PHD-CSWA-2-16-32 的设置相同。


PHD-CSWA 在各个基准测试中均实现了持续的性能提升。下图 5 中展示了训练曲线,下表 1 中展示了主要结果。本文提出的 PHD-CSWA-2-16-32 在这些基准测试中平均实现了 1.5% 的准确率提升,训练损失降低了 0.025;而 PHD-CSWA-3-16-32 在这些基准测试中平均实现了 2.0% 的准确率提升,训练损失降低了 0.034。




研究者还分析了 PHD 和 PHD-SWA 的扩展性能,以分析扩展解码计算的性能。 训练细节:使用相同的 550M 模型配置,将窗口大小 W 设置为 16,并在 {2, 3, 5} 范围内改变扩展因子 K。对于局部窗口大小,研究者在所有实验中都将窗口大小设置为 16。


PHD-SWA 的性能在增加扩展因子时有效扩展。如下图 8 所示,使用固定窗口大小时,损失曲线和下游性能会随着 token 重复次数而有效扩展。通过将扩展因子设置为 5,可以实现接近 0.06 的损失降低,同时显著提升下游性能。


下表 2 中的定量结果表明,当扩展至 K = 5 时,所有基准测试的平均准确率提高了 1.8%,这证实了本文的方法在更激进的扩展方面仍然有效。




更多实验结果请参阅原论文。



© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:liyazhou@jiqizhixin.com

过拟合确实是个值得关注的点。从信息论的角度看,重复token并没有增加新的信息,反而可能让模型过度关注这些已有的信息,从而降低泛化能力。不过,如果将重复token看作是一种数据增强手段,可以提高模型的鲁棒性,对抗噪声数据。关键在于找到一个平衡点,避免过度重复导致过拟合。可能需要在验证集上仔细调参,找到最优的重复次数。

我觉得这个问题要辩证地看。重复token虽然看似增加了冗余信息,但同时也可能帮助模型更好地捕捉序列中的长程依赖关系。Transformer本身就是通过注意力机制来捕捉依赖关系的,重复token可以强化这种注意力,让模型更加关注重要的信息。当然,过拟合的风险确实存在,但可以通过一些技巧来缓解,比如作者提到的滑动窗口注意力,限制模型只关注局部信息,避免全局信息的过度重复。

块大小的选择肯定会影响模型性能。如果块太小,那滑动窗口的优势就无法充分发挥,模型可能无法捕捉到足够的局部信息;如果块太大,那预填充时间的优化效果就会受到影响,而且模型可能会变得更加复杂。所以,需要根据具体的任务和数据集来选择一个合适的块大小,找到一个性能和效率之间的平衡点。

解耦的最大好处应该是效率的提升。传统的Transformer需要对每个token都计算KV缓存,而PHD-Transformer只需要对原始token计算KV缓存,大大减少了计算量。同时,隐藏解码token的KV缓存在使用后立即丢弃,也降低了内存占用。这种设计思路非常巧妙,既保证了性能,又提高了效率,非常适合在资源受限的场景下使用。

从信息论的角度来看,块大小的选择实际上是在控制模型的信息处理粒度。更大的块大小对应着更粗粒度的信息处理,模型更关注全局信息;更小的块大小对应着更细粒度的信息处理,模型更关注局部细节。选择合适的块大小,可以让模型在不同的信息粒度上进行有效的学习,从而提高整体性能。

我理解的解耦更像是一种“分而治之”的策略。原始token负责“骨架”信息的提取,隐藏解码token负责“血肉”信息的填充。通过这种方式,模型可以更加高效地处理长序列,避免信息拥堵。想象一下,如果所有token都参与全局信息的计算,那计算量会非常庞大,而且容易造成信息混乱。而通过解耦,模型可以更加清晰地处理信息,提高效率。

我认为块大小的选择涉及到模型容量和计算效率的权衡。更大的块大小意味着更大的感受野,模型可以捕捉到更长距离的依赖关系,但也意味着更大的计算量和内存占用。更小的块大小则相反。所以,需要根据实际的硬件资源和任务需求来选择一个合适的块大小。

我感觉会有过拟合的风险,毕竟重复的token本质上是给模型提供了更多的“冗余”信息。虽然论文里说他们的方法能有效利用这些重复信息,但具体效果可能还要看数据集和任务的特性。如果数据集本身就比较小,或者任务比较简单,那过拟合的可能性就更高了。解决这个问题可能需要在训练过程中加入一些正则化的手段,比如dropout或者weight decay,来限制模型的复杂度。

我觉得解耦的关键在于让原始token负责全局信息的提取和传递,而隐藏解码token负责局部信息的增强。原始token生成KV缓存,捕捉长程依赖;隐藏解码token则通过滑动窗口注意力,关注局部细节。这种分工合作的方式,既保证了模型的全局视野,又提高了局部信息的处理能力,从而提升了整体性能。