TST通过训练阶段叠加token,在不改模型架构和推理方式下提升预训练效率。
原文标题:低成本提效,解锁大模型预训练的全新提速思路
原文作者:数据派THU
冷月清谈:
怜星夜思:
2、把多个token平均成一个表示,会不会损失语序信息?为什么实验里反而还能提速并降低loss?
3、TST的提速来自“看了更多数据”,那如果高质量数据不够,它还会有效吗?
4、TST对长上下文能力可能有帮助吗,还是只是训练loss更好看?
原文内容
本文约2000字,建议阅读5分钟本文介绍了 TST 训练法,零改模型架构,大幅提升大模型预训练效率。
不改模型架构和推理方式,只在预训练前半程调整 token 表示和预测目标,就让 10B-A1B MoE 跑出同等 loss 下最高 2.5 倍提速。
不改模型架构和推理方式,只在预训练前半程调整 token 表示和预测目标,就让 10B-A1B MoE 跑出同等 loss 下最高 2.5 倍提速。
标准 LLM 预训练里,每个训练 step 通常只处理一段给定长度的 token 序列。
想在同样算力下让模型接触更多文本,常见办法是换分词器、改 attention 结构、上 MoE,或者额外加入多 token 预测头。
但这些做法往往会改变最终模型本身,让训练效率和推理结构绑在一起,后续很难判断收益到底来自训练吞吐、架构变化,还是额外预测目标。
那么,有没有可能在一行模型架构代码都不改的前提下,只在训练阶段提高预训练吞吐?
Nous Research 近期提出了 token 叠加训练,也就是 Token Superposition Training(TST)。
它不改模型架构、并行策略、优化器、分词器和训练数据,只在预训练前半程把连续 token 的 embedding 临时平均成一个新表示,并让模型预测下一组 token,后半程再切回标准逐 token 预测训练。
在 10B-A1B MoE 实验中,TST 用不到 40% 的训练时间达到基线模型的最终损失水平,对应该设置下同等损失约 2.5 倍的预训练提速。
〓 10B MoE 模型在叠加阶段切换至恢复阶段后的 loss 表现
训练结束后,模型仍然是标准自回归结构,推理方式不变,也不需要额外模块。
论文标题:
Efficient Pre-Training with Token Superposition
论文链接:
https://arxiv.org/abs/2605.06546
1、纯拼吞吐量:TST 的极简训练流
TST 最容易和 MTP、SuperBPE 混在一起看,但它们改的不是同一个地方。
从宏观架构来看,多 token 预测(MTP)通过增加预测头来提供更密的局部监督,但它并不会提高每单位 FLOPs 处理的 token 数,还会带来额外参数。
〓 TST与标准下一Token预测、多Token预测及SuperBPE的架构对比
TST 的思路则完全侧重于训练吞吐量。整个训练过程分为叠加阶段(Superposition Phase)与恢复阶段(Recovery Phase)。
在叠加阶段,模型以一组连续 token 的叠加表示作为输入,并预测下一组 token。
达到预设训练比例后,再切回标准逐 token 预测。这个切换并不是完全平滑的,loss 会短暂上升,随后进入恢复阶段。
2、输入端:物理层面的张量折叠
在叠加阶段的输入端,长度为 L 的连续序列会被切成若干个不重叠的 token bag,每个 bag 包含 s 个 token。
在 embedding 层,这 s 个 Token 的向量表示会被取平均,形成一个新的 latent s-token 表示。
〓 通过 reshape 将序列切成 token bag,再用 mean 得到叠加后的 token 表示
经过这一步,模型实际处理的 latent 序列长度变为原来的 1/s。
为了保持与 baseline 相同的每步 FLOPs,TST 在这一阶段把原始输入长度扩大 s 倍。这样一来,模型在相同每步计算量下,可以接触到更多原始 token。
3、输出端:预测下一组 Token
输出端也同步进行调整,模型不再预测单个 next token,而是预测下一组 token。
为了让一个预测位置同时对应 s 个有效 label,作者把标准 CE 换成了多热交叉熵损失(MCE)。
MCE 可以理解为把目标概率均分给下一组 token 中的 s 个 label。其完整的数学展开式如下:
其中 对应 token bag 的大小,也就是 。由于 是常数项,不影响梯度,训练时可以直接去掉。
去掉常数项后,训练中使用的 MCE 可以写成 bag 内多个标准 CE 的平均值:
这个简化形式大大降低了技术迁移成本。实际实现时,可以直接复用现有预训练库中的 fused CE kernel,对 bag 内 label 分别计算 CE 后求平均,不需要额外写 CUDA kernel。
〓 简化后的 MCE 可以直接复用标准 CE 实现
4、2.5倍提速怎么来的?
研究团队在 TorchTitan 框架下结合 FSDP 并行策略,对 TST 进行了多尺度的规模化验证,覆盖了 270M、600M、3B 稠密模型以及 10B 混合专家模型。
〓 TST 在各模型规模下的预训练表现及下游评测全景数据
3B 实验给了三个比较角度:
-
在同等计算量下,TST 的最终训练损失更低;
-
在同等 loss 下,TST 需要的训练时间更短;
-
在同等数据量下,TST 每个原始 Token 获得的计算更少,表现反而弱于 baseline。
〓 3B 稠密模型在同等算力、同等 loss 及同等数据量三种约束条件下的训练 loss 曲线
超参数上,step ratio r 在 0.2 到 0.4 之间相对稳定,bag size s 则呈现明显的 U 型趋势。
〓 在 270M 与 600M 规模下,最终 loss 随包大小 s 呈现 U 型曲线,且最优区间随模型增大而右移
随着模型参数量的增加,最优的包大小区间整体向更大的数值发生偏移。在 10B-A1B 的大规模验证中,作者采用了 s=16,最终 loss 从 baseline 的 2.252 降到 2.236。
作者也尝试过一些替代设计,但效果并不如默认设置。
例如,BCE、hinge loss 等替代 bag loss 明显弱于 MCE;官方 blog 还提到,一些试图补回 bag 内位置信息的设计也没有带来稳定收益。
这些结果至少说明,强行恢复 bag 内顺序并不是 TST 收益的关键。
此外,在较大的 bag size 下,均匀加权未必是最稳的选择。
〓 DCLM 数据集中 token间互信息随距离呈现幂律衰减
通过对 DCLM 数据集进行分析,团队发现自然语言中 token 间的互信息随距离呈现出明显的幂律衰减规律。引入随距离衰减的非均匀加权多热交叉熵(Weighted MCE),能够在此类配置下获得更低的最终损失。
5、底层表示绝对共享
消融实验进一步拆解了输入与输出叠加的独立贡献。结果表明,单独应用输入叠加(压缩粒度)或输出叠加(改造梯度信号)均能带来超越基线的增益,而两者的结合并未产生干扰,证明了这两种机制在底层是完全正交的。
〓 输入叠加与输出叠加机制各自独立生效,结合使用能取得最大综合收益
不少多阶段训练方法在切换目标时,会引入 adapter 或额外 alignment phase 来缓解表示不匹配。TST 不引入额外 adapter,关键在于其在两个阶段中保持了完全相同的底层表示。
如果在恢复阶段开始时,随机重新初始化模型的输入 embedding 层和输出 LM Head,TST 在前期积累的所有优化红利将彻底清零,最终的损失值会高过从零开始的基线模型。
〓 在恢复阶段重置输入输出表示层会导致 TST 前期积累的收益完全消失,证明了表示对齐的必要性
这说明,跨阶段保持输入输出表示对齐,很可能是 TST 收益能够延续到标准训练阶段的重要条件。
6、结语
TST 本质上是用更多数据消耗,换同等计算下更低的训练 loss。在算力资源紧缺且数据仍充足的预训练设置下,这类低侵入训练方法很有吸引力。
由于第一阶段会把 s 个 token 折成一个 latent 位置,模型在相同 latent 序列长度下对应的是更长的原始文本跨度。
这可能减少长文档在训练中被截断或切分的情况,但论文没有评测最终长上下文能力,因此只能视为一个后续问题,而不是已经验证的结论。
编辑:于腾凯
校对:李享沣













