TST训练法:不改模型架构,也能让大模型预训练提速

TST通过训练阶段叠加token,在不改模型架构和推理方式下提升预训练效率。

原文标题:低成本提效,解锁大模型预训练的全新提速思路

原文作者:数据派THU

冷月清谈:

Nous Research 提出的 Token Superposition Training(TST)尝试在不改变模型架构、分词器、优化器、并行策略和推理方式的前提下,提高大模型预训练效率。它将训练分为两个阶段:前半程把连续多个 token 的 embedding 平均成一个“token bag”表示,并让模型预测下一组 token;后半程再切回标准自回归逐 token 预测。这样模型在相同每步计算量下能接触更多原始文本。实验覆盖 270M、600M、3B 稠密模型和 10B-A1B MoE,其中 10B-A1B 在不到 40% 训练时间内达到基线最终 loss,对应同等 loss 下最高约 2.5 倍提速。研究还发现,bag size、训练切换比例、MCE 损失设计以及输入输出表示对齐都会影响效果。TST 的价值在于低侵入、易迁移,但它更适合数据充足、算力紧张的预训练场景,对长上下文能力等下游影响仍需进一步验证。

怜星夜思:

1、TST这种“不改架构,只改训练过程”的方法,会不会比改模型结构更适合工业落地?
2、把多个token平均成一个表示,会不会损失语序信息?为什么实验里反而还能提速并降低loss?
3、TST的提速来自“看了更多数据”,那如果高质量数据不够,它还会有效吗?
4、TST对长上下文能力可能有帮助吗,还是只是训练loss更好看?

原文内容

图片
本文约2000字,建议阅读5分钟
本文介绍了 TST 训练法,零改模型架构,大幅提升大模型预训练效率。


不改模型架构和推理方式,只在预训练前半程调整 token 表示和预测目标,就让 10B-A1B MoE 跑出同等 loss 下最高 2.5 倍提速。


标准 LLM 预训练里,每个训练 step 通常只处理一段给定长度的 token 序列。


想在同样算力下让模型接触更多文本,常见办法是换分词器、改 attention 结构、上 MoE,或者额外加入多 token 预测头。


但这些做法往往会改变最终模型本身,让训练效率和推理结构绑在一起,后续很难判断收益到底来自训练吞吐、架构变化,还是额外预测目标。


那么,有没有可能在一行模型架构代码都不改的前提下,只在训练阶段提高预训练吞吐?


Nous Research 近期提出了 token 叠加训练,也就是 Token Superposition Training(TST)。


不改模型架构、并行策略、优化器、分词器和训练数据,只在预训练前半程把连续 token 的 embedding 临时平均成一个新表示,并让模型预测下一组 token,后半程再切回标准逐 token 预测训练。


 10B-A1B MoE 实验中,TST 用不到 40% 的训练时间达到基线模型的最终损失水平,对应该设置下同等损失约 2.5 倍的预训练提速。


〓 10B MoE 模型在叠加阶段切换至恢复阶段后的 loss 表现


训练结束后,模型仍然是标准自回归结构,推理方式不变,也不需要额外模块。

论文标题:

Efficient Pre-Training with Token Superposition

论文链接:

https://arxiv.org/abs/2605.06546


1、纯拼吞吐量:TST 的极简训练流


TST 最容易和 MTP、SuperBPE 混在一起看,但它们改的不是同一个地方。


从宏观架构来看,多 token 预测(MTP)通过增加预测头来提供更密的局部监督,但它并不会提高每单位 FLOPs 处理的 token 数,还会带来额外参数。


〓 TST与标准下一Token预测、多Token预测及SuperBPE的架构对比


TST 的思路则完全侧重于训练吞吐量。整个训练过程分为叠加阶段(Superposition Phase)与恢复阶段(Recovery Phase)。


在叠加阶段,模型以一组连续 token 的叠加表示作为输入,并预测下一组 token。


达到预设训练比例后,再切回标准逐 token 预测。这个切换并不是完全平滑的,loss 会短暂上升,随后进入恢复阶段。


2、输入端:物理层面的张量折叠


在叠加阶段的输入端,长度为 L 的连续序列会被切成若干个不重叠的 token bag,每个 bag 包含 s 个 token。


在 embedding 层,这 s 个 Token 的向量表示会被取平均,形成一个新的 latent s-token 表示。


〓 通过 reshape 将序列切成 token bag,再用 mean 得到叠加后的 token 表示


经过这一步,模型实际处理的 latent 序列长度变为原来的 1/s。


为了保持与 baseline 相同的每步 FLOPs,TST 在这一阶段把原始输入长度扩大 s 倍。这样一来,模型在相同每步计算量下,可以接触到更多原始 token。


3、输出端:预测下一组 Token


输出端也同步进行调整,模型不再预测单个 next token,而是预测下一组 token。


为了让一个预测位置同时对应 s 个有效 label,作者把标准 CE 换成了多热交叉熵损失(MCE)。


MCE 可以理解为把目标概率均分给下一组 token 中的 s 个 label。其完整的数学展开式如下:



其中   对应 token bag 的大小,也就是  。由于   是常数项,不影响梯度,训练时可以直接去掉。


去掉常数项后,训练中使用的 MCE 可以写成 bag 内多个标准 CE 的平均值:



这个简化形式大大降低了技术迁移成本。实际实现时,可以直接复用现有预训练库中的 fused CE kernel,对 bag 内 label 分别计算 CE 后求平均,不需要额外写 CUDA kernel。


〓 简化后的 MCE 可以直接复用标准 CE 实现


4、2.5倍提速怎么来的?


研究团队在 TorchTitan 框架下结合 FSDP 并行策略,对 TST 进行了多尺度的规模化验证,覆盖了 270M、600M、3B 稠密模型以及 10B 混合专家模型



〓 TST 在各模型规模下的预训练表现及下游评测全景数据


3B 实验给了三个比较角度:

  • 在同等计算量下,TST 的最终训练损失更低;

  • 在同等 loss 下,TST 需要的训练时间更短;

  • 在同等数据量下,TST 每个原始 Token 获得的计算更少,表现反而弱于 baseline。


〓 3B 稠密模型在同等算力、同等 loss 及同等数据量三种约束条件下的训练 loss 曲线


超参数上,step ratio r 在 0.2 到 0.4 之间相对稳定,bag size s 则呈现明显的 U 型趋势。



〓 在 270M 与 600M 规模下,最终 loss 随包大小 s 呈现 U 型曲线,且最优区间随模型增大而右移


随着模型参数量的增加,最优的包大小区间整体向更大的数值发生偏移。在 10B-A1B 的大规模验证中,作者采用了 s=16,最终 loss 从 baseline 的 2.252 降到 2.236。


作者也尝试过一些替代设计,但效果并不如默认设置。


例如,BCE、hinge loss 等替代 bag loss 明显弱于 MCE;官方 blog 还提到,一些试图补回 bag 内位置信息的设计也没有带来稳定收益。


这些结果至少说明,强行恢复 bag 内顺序并不是 TST 收益的关键。


此外,在较大的 bag size 下,均匀加权未必是最稳的选择。


〓 DCLM 数据集中 token间互信息随距离呈现幂律衰减


通过对 DCLM 数据集进行分析,团队发现自然语言中 token 间的互信息随距离呈现出明显的幂律衰减规律。引入随距离衰减的非均匀加权多热交叉熵(Weighted MCE),能够在此类配置下获得更低的最终损失。


5、底层表示绝对共享


消融实验进一步拆解了输入与输出叠加的独立贡献。结果表明,单独应用输入叠加(压缩粒度)或输出叠加(改造梯度信号)均能带来超越基线的增益,而两者的结合并未产生干扰,证明了这两种机制在底层是完全正交的。


〓 输入叠加与输出叠加机制各自独立生效,结合使用能取得最大综合收益


不少多阶段训练方法在切换目标时,会引入 adapter 或额外 alignment phase 来缓解表示不匹配。TST 不引入额外 adapter,关键在于其在两个阶段中保持了完全相同的底层表示。


如果在恢复阶段开始时,随机重新初始化模型的输入 embedding 层和输出 LM Head,TST 在前期积累的所有优化红利将彻底清零,最终的损失值会高过从零开始的基线模型。



〓 在恢复阶段重置输入输出表示层会导致 TST 前期积累的收益完全消失,证明了表示对齐的必要性


这说明,跨阶段保持输入输出表示对齐,很可能是 TST 收益能够延续到标准训练阶段的重要条件。


6、结语


TST 本质上是用更多数据消耗,换同等计算下更低的训练 loss。在算力资源紧缺且数据仍充足的预训练设置下,这类低侵入训练方法很有吸引力。


由于第一阶段会把 s 个 token 折成一个 latent 位置,模型在相同 latent 序列长度下对应的是更长的原始文本跨度。


这可能减少长文档在训练中被截断或切分的情况,但论文没有评测最终长上下文能力,因此只能视为一个后续问题,而不是已经验证的结论。


编辑:于腾凯

校对:李享沣



关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

回答“TST在数据不够时是否有效”:文章里其实已经暗示了限制,它本质是用更多数据消耗换同等算力下更低loss。如果高质量语料不足,重复数据太多,继续增加token接触量可能收益会明显下降,甚至加剧过拟合或数据污染。

2 个赞

关于“TST是否提升长上下文能力”,目前不能直接下结论。它在叠加阶段让同样latent长度覆盖更长原始文本,理论上可能减少文档截断,但论文没有系统评测长上下文任务,所以只能算一个值得验证的假设。

1 个赞

我理解这里的关键不是让模型在bag内部精确建模顺序,而是提升单位算力下的数据覆盖率。自然语言里很多统计规律不一定依赖每个token的严格顺序,早期预训练可能更需要大规模分布信息,所以平均表示未必马上致命。

1 个赞

针对“平均token会不会损失语序”这个问题,肯定会损失一部分局部顺序信息,但TST不是全程这么训。它只在前半段用叠加,让模型先以更低计算成本看更多文本,后半段再恢复标准token级训练,相当于先粗读大量材料,再精读校准。

3 个赞

“loss好看”和“真会长上下文”之间差着一个评测集。很多方法训练曲线很漂亮,一到needle-in-a-haystack或者长文档问答就露馅。所以这块我等后续实验,不急着吹。

2 个赞