清华大学揭示BF16 FlashAttention训练崩溃之谜:极简修复方案助力稳定训练

清华大学揭示BF16 FlashAttention训练崩溃机制:低精度下数值偏置被放大导致loss爆炸。提出极简修复方案,稳定训练。

原文标题:为什么BF16的FlashAttention会把训练「炸掉」?清华首次给出机制解释,用极简改动稳住训练

原文作者:机器之心

冷月清谈:

清华大学的研究揭示了长期困扰社区的难题:在使用BF16等低精度训练时,FlashAttention并非随机出错,而是在特定条件下触发有方向的数值偏置。这种偏置借助注意力机制中相似的低秩更新方向被持续放大,最终导致权重谱范数和激活值失控,从而引发loss突然爆炸。研究团队通过严格的实验复现并逐步定位问题,最终发现FlashAttention反向传播中的一个特定项是罪魁祸首。他们进一步揭示了导致偏差产生的两个关键机制:相似低秩结构使误差累积而非抵消,以及safe softmax和BF16舍入误差共同作用下产生“离散触发器”。基于此,论文提出了一种几乎不改变模型结构的极简修复方案,即在safe softmax中动态调整行移位常数,以避免出现多个最大值的情况,从而有效稳定训练。该研究不仅解决了FlashAttention在低精度训练中的问题,更提供了一种诊断和解决类似数值稳定性问题的通用范式。

怜星夜思:

1、论文中提到的“相似低秩结构”是如何让误差变成“持续推力”而不是噪声的?除了文中提到的注意力机制,还有哪些模型结构可能存在类似的现象?
2、safe softmax中“多个相同最大值”是如何触发BF16加法系统性误差的?这个现象在其他低精度计算中是否也存在?
3、论文提出的极简修复方案,通过动态调整safe softmax的行移位常数来避免出现多个最大值。除了这种方法,还有没有其他思路可以解决这个问题?例如,从优化器的角度入手,或者从数据预处理的角度入手?

原文内容


一句话总结:社区里困扰了多年的一个 “玄学” 现象终于被拆解清楚了:在 BF16 等低精度训练里,FlashAttention 不是随机出 bug,而是会在特定条件下触发有方向的数值偏置,借助注意力中涌现的相似低秩更新方向被持续放大,最终把权重谱范数和激活推到失控,导致 loss 突然爆炸。论文还给出一个几乎不改模型、只在 safe softmax 里做的极小修改,实测能显著稳定训练。


因果链总览(论文 Figure 1)



  • 标题:Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention

  • 作者:邱海权,姚权铭

  • 机构:清华大学 电子工程系

  • 投稿:ICLR 2026 Oral

  • 关键词:低精度训练,BF16,FlashAttention,数值稳定性,舍入误差(rounding error),低秩表示(low-rank)

  • 论文链接:https://arxiv.org/abs/2510.04212

  • 代码链接:https://github.com/ucker/why-low-precision-training-fails


背景:低精度训练越来越 “刚需”,但注意力比你想的更敏感


大模型训练的现实是:显存和吞吐决定一切。工业界普遍在混合精度里使用 BF16/FP16,甚至把 FFN 推到 FP8,以换取更高的训练效率。但工程实践同样残酷:越接近 “极限精度”,训练越容易出现难以解释的不稳定。


Flash Attention 是长上下文训练的关键加速组件,几乎成了标配。问题在于,社区长期存在一个可复现却难以解释的失败案例:


  • 用 FlashAttention + BF16 训练 GPT-2,一开始正常收敛,但在几千 step 之后突然 loss 爆炸。

  • 你可以通过回退到标准注意力、或把关键计算提高到 FP32 来 “救火”,但代价是吞吐和显存优势没了。


这类问题被报告了多年(相关 issue 在多个开源项目里反复出现),却一直缺少一条能 “从数值误差一路解释到 loss 爆炸” 的机制链。


论文做了什么:把 “锅” 一步步缩小到 FlashAttention 反传里的一个 图片


作者的做法很工程,且足够 “可复现”:


1. 严格复现失败:GPT-2(12 层、12 头、d=768、context=1024),OpenWebText 预训练;并且通过记录并重放相同的数据 batch 序列,把数据顺序带来的随机性排除掉。

2. 定位到层、头:用谱范数(spectral norm)等指标快速缩小范围,发现异常主要来自某一层注意力模块,甚至少数几个 attention head。

3. 定位到一个关键中间量:FlashAttention 反向传播为效率会计算图片


论文发现:只要让这里用到的 图片 走一条更 “高精度 / 等价但不同数值路径” 的计算方式,训练就能恢复稳定。换言之,训练炸掉的导火索并不是 “整个低精度训练”,而是非常具体的:低精度下图片的数值误差进入 图片,污染后续梯度。


机制解释 1:相似低秩结构,让误差变成 “持续推力” 而不是噪声


定位到图片之后,关键问题变成:为什么一个看似很小的数值误差,能在训练中被放大到灾难?


论文把高精度(hp)与低精度(lp)的梯度差写成一种很直观的形式:梯度误差与 图片 成正比,并被注意力里的某些项调制。进一步分解后,误差更新可以近似看作许多 rank-1 项的叠加;更关键的是,作者在实证中观察到:


  • 不同 token、不同训练 step 下,相关的矩阵结构出现了强相似性,可以抽象为一个共同的低秩方向 R。

  • 如果 图片 的系数在统计上还出现偏置(不是围绕 0 对称波动),那么误差不会抵消,而会沿着 R持续累积


结果就是:权重更新被 “带偏”,谱范数和激活异常增长,最终把训练推到 loss 爆炸。



低秩结构相似性与偏置累积(论文 Figure 4/5)


机制解释 2:偏置从哪来?safe softmax + BF16 舍入误差里藏着一个 “离散触发器”


第二条链更 “反直觉”,但也更关键:为什么图片会偏向同一方向?


作者把问题追到了 FlashAttention 前向里的未归一化输出:


  •  图片(safe softmax 的常见写法)

  • 图片

  • 图片


论文的关键观察是:图片在 BF16 下会出现系统性偏差,并且偏差的触发条件很具体:


触发条件:一行 score 出现 “多个相同最大值”,图片里会出现多个精确的 1


当 S 的某一行里,有不止一个位置等于行最大值时,这些位置在指数里都会变成 图片,也就是:图片 中会出现多个精确的 1(不是接近 1,而是 float 表示上真的等于 1)。


这看上去像个细节,但它会把后续 图片 的点积推到一个危险区间。


偏置来源:当 图片且某些维度的 V 以负数为主时,BF16 加法会系统性 “越加越负”


在某些特征维度上,V [:, i] 的分布可能以负数为主。此时当图片,乘积项就是 V [t,i] 本身(负的 BF16 数)。多个负数在 BF16 的加法舍入中会更容易触发尾数溢出、右移与 sticky bit 相关的舍入行为,导致误差贡献出现不对称,表现为:


  • 图片 相对 图片 更倾向于 “偏负”

  • 如果上游梯度 图片 在对应维度也倾向为负,那么在 图片里就会形成偏正的误差项

  • 偏正的 图片 再去驱动前一节提到的相似低秩方向 R,就形成 “越训越偏” 的闭环


当 图片 出现时,图片 的误差会发生明显 “负跳变”(论文 Figure 6)


极简修复:让 图片永远严格小于 1


既然问题的离散触发器是 图片 中出现精确的 1,作者给出的修复思路非常直接:


  • 检测一行 S 中最大值是否出现多次

  • 一旦出现 “重复最大值”,就动态调整 safe softmax 的行移位常数 m,让最大位置的指数也变成严格小于 1


论文给出的实现(概念上)如下:


rm = rowmax (S)
rs = rowsum (S == rm)  # 最大值出现次数
if rs > 1 and rm > 0: m = β * rm   (β > 1)
if rs > 1 and rm < 0: m = 0
else:                 m = rm
Pbar = exp (S - m)     # 从而 max (Pbar) < 1


这一步在精确算术下不改变注意力结果(softmax 对 “整行减常数” 不敏感),但在有限精度下能避免 图片 触发后续 BF16 累加的偏置舍入,从根上切断误差链。


实验结果:稳定训练不再 “突然炸”


论文在 BF16 设置下验证了上述分析与修复:


  • GPT-2S:使用修改后的 FlashAttention,在 AdamW 与 Muon 两种优化器下,都能稳定训练到 600K steps

  • GPT-2M:同样能在 AdamW 下稳定训练(论文展示到 100K steps)

  • 论文还提到该现象与结论在多种硬件上保持一致(包括 A100、RTX 4090、Ascend 910B)


验证集 loss 曲线对比(论文 Figure 7)


更重要的启示:别把低精度误差当成 “零均值噪声”


这篇论文的价值不只在 “修了一个 bug”,更在于给出了一个可迁移的诊断范式:


  • 数值误差未必是随机噪声。在特定分布与离散事件(如重复最大值、概率精确为 1)下,舍入误差可能形成系统性偏置。

  • 模型结构会放大偏置。注意力里涌现的相似低秩更新方向,让偏置误差更容易 “同向叠加”。

  • 经验修复为什么有效也能被解释:论文讨论了 attention sinks 与多最大值的关系,并给出了一个数值层面的连接;同时也指出一些稳定化技巧(如 QK normalization、Gated Attention)可能通过 “打散结构相似性” 来阻止误差同向累积。


作者介绍


邱海权是清华大学在读博士研究生,研究方向涵盖机器学习理论、表示学习与大模型机制分析。他的研究围绕模型表达能力、结构归纳偏置以及参数空间几何与优化动力学之间的内在联系展开,关注模型在不同结构约束与训练条件下的泛化行为与可组合性问题。整体上,他强调以可分析的理论框架刻画模型的能力边界与机制来源,从结构与原理层面理解深度模型为何有效、何时失效。


姚权铭,清华大学电子工程系副教授。长期致力于数据高效学习与智能体系统研究,在少样本学习、图学习、知识图谱与生物医药智能等方向取得系统性成果。发表 Nature 子刊、TPAMI、JMLR、ICML、NeurIPS、ICLR 等论文 130 余篇,被引 1.4 万余次。代表性工作包括抗噪学习算法 Co-teaching、小样本学习综述、自动化图学习方法及新药物相互作用预测模型。现任 TPAMI、TMLR 编委及 Neural Networks 资深编委,多次担任 ICML、NeurIPS、ICLR 领域主席,入选 IEEE Computing Top 30、IET Fellow 等。


© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:liyazhou@jiqizhixin.com


我认为具有一定的借鉴意义。文章的核心在于揭示了低精度训练中,数值误差可能被特定模型结构放大的机制。这种机制不一定只存在于FlashAttention中,其他Attention机制或模型结构也可能存在类似的风险。因此,文章的分析方法,即通过定位关键中间量、分析误差来源、以及寻找误差放大机制,可以推广到其他场景。

我认为可以考虑使用混合精度训练,即在模型的不同部分使用不同的精度。例如,可以将对数值精度要求较高的部分(如softmax)使用FP32精度,而将其他部分使用BF16精度。这样既能保证训练的稳定性,又能提高训练效率。此外,还可以尝试使用一些更先进的attention机制,例如Longformer或者Big Bird,这些模型通过稀疏attention来降低计算复杂度,同时也能提高模型的表达能力。

好问题!这“相似低秩结构”听起来就挺抽象的。我觉得它可能跟Transformer模型本身的学习方式有关,模型倾向于捕捉数据中的共性模式。如果数据集中存在某些重复出现的模式(例如,语言中的常用短语),那么attention机制可能会学习到一组通用的“基向量”,然后用这些基向量的线性组合来表示不同的输入。这就自然形成了低秩结构。当然,具体的数据集和模型结构也会影响这种结构的形成,例如,数据集的统计特性、模型的层数、attention头的数量等等。

这是一个trade-off问题,稳定性和表达能力往往不可兼得。QK normalization和Gated Attention这些技巧,本质上都是在模型中引入额外的约束,限制模型的自由度,从而提高训练的稳定性。但是,这种约束也可能会降低模型的表达能力,影响模型的性能。我觉得,更好的方法可能是从优化算法入手,例如使用AdamW等自适应优化器,或者采用梯度裁剪等技术,来控制训练过程中的梯度大小,从而避免loss爆炸。

与其说是修复方案的潜力,不如说是提供了一种排查问题的思路。以后遇到类似“玄学”问题,可以考虑是不是因为特定数值在低精度计算中引入了偏差。比如,ReLU激活函数在输入为0时导数为0,这可能会导致某些神经元“死亡”。虽然ReLU很简单,但是类似的简单操作也可能在特定情况下引发问题。

从线性代数的角度看,任何矩阵都可以分解为一系列秩-1矩阵的和。而attention机制中的query和key的交互,本质上就是矩阵乘法。如果query和key的分布比较集中,那么得到的attention矩阵可能就会呈现出低秩性。所以,我认为这是数据、模型和attention机制共同作用的结果。

我从理论角度补充一下,深度学习模型实际上是在学习数据的内在结构,而真实世界的数据往往具有一定的低秩性。例如,图像数据通常包含大量的冗余信息,视频数据在时间上具有连续性,这些都会导致模型学到的特征具有低秩结构。这种低秩结构有助于模型压缩和加速推理,但也可能限制模型的表达能力。如何在模型设计中更好地利用这种低秩结构,是一个值得深入研究的问题。

这让我想到了知识蒸馏!知识蒸馏就是把一个大模型(teacher model)的知识迁移到一个小模型(student model)上。如果 teacher model 中存在相似低秩结构,那么 student model 也很可能会学到类似的结构。这可以帮助 student model 更好地泛化,但也可能让 student model 继承 teacher model 的一些偏差。所以,在知识蒸馏中,如何选择合适的 teacher model 以及如何设计蒸馏策略,都非常重要。

论文中作者已经强调了,在精确算术下,safe softmax的修改不改变注意力结果。但在有限精度下,它能避免误差累积。我认为这种修改对模型精度的影响应该是很小的,甚至可能因为训练更稳定而提升精度。除了修改safe softmax,或许可以尝试一些梯度裁剪策略,或者使用一些对数值稳定性更友好的激活函数。另外,混合精度训练中,合理分配不同模块的精度也很重要。