清华大学揭示BF16 FlashAttention训练崩溃之谜:极简修复方案助力稳定训练

清华研究揭示BF16 FlashAttention训练崩溃机制:相似低秩结构放大舍入误差,导致loss爆炸。提出极简修复方案,稳定训练。

原文标题:为什么BF16的FlashAttention会把训练「炸掉」?清华首次给出机制解释,用极简改动稳住训练

原文作者:机器之心

冷月清谈:

本文总结了清华大学的一篇论文,该论文深入研究了在使用 FlashAttention 和 BF16 等低精度训练时,模型训练容易崩溃的问题。研究表明,这并非随机 bug,而是在特定条件下,FlashAttention 会触发有方向的数值偏置,并借助注意力机制中涌现的相似低秩更新方向被持续放大,最终导致 loss 爆炸。论文通过实验复现了该问题,并定位到 FlashAttention 反传过程中的一个关键项。通过分析,作者发现BF16的舍入误差是罪魁祸首,当一行 score 出现多个最大值时,会触发系统性偏差,导致误差累积。针对此问题,论文提出了一个极简的修复方案:动态调整 safe softmax 的行移位常数,避免出现多个最大值,从而切断误差链。实验证明,该方案能显著稳定训练过程。这项研究不仅解决了一个实际问题,更重要的是,它提供了一种诊断和解决低精度训练中数值不稳定问题的范式。

怜星夜思:

1、论文中提到的“相似低秩结构”具体是指什么?除了注意力机制,还有哪些模型结构容易出现这种现象?
2、论文中提到,一些稳定化技巧(如 QK normalization、Gated Attention)可能通过 “打散结构相似性” 来阻止误差同向累积。这些技巧是如何起作用的?有没有其他有效的稳定训练技巧?
3、这篇论文主要关注的是 FlashAttention 在 BF16 下的问题,那么 FlashAttention 在其他精度下,或者其他 attention 机制在 BF16 下,是否也会出现类似问题?这种研究对于未来大模型训练有哪些启示?

原文内容


一句话总结:社区里困扰了多年的一个 “玄学” 现象终于被拆解清楚了:在 BF16 等低精度训练里,FlashAttention 不是随机出 bug,而是会在特定条件下触发有方向的数值偏置,借助注意力中涌现的相似低秩更新方向被持续放大,最终把权重谱范数和激活推到失控,导致 loss 突然爆炸。论文还给出一个几乎不改模型、只在 safe softmax 里做的极小修改,实测能显著稳定训练。


因果链总览(论文 Figure 1)



  • 标题:Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention

  • 作者:邱海权,姚权铭

  • 机构:清华大学 电子工程系

  • 投稿:ICLR 2026 Oral

  • 关键词:低精度训练,BF16,FlashAttention,数值稳定性,舍入误差(rounding error),低秩表示(low-rank)

  • 论文链接:https://arxiv.org/abs/2510.04212

  • 代码链接:https://github.com/ucker/why-low-precision-training-fails


背景:低精度训练越来越 “刚需”,但注意力比你想的更敏感


大模型训练的现实是:显存和吞吐决定一切。工业界普遍在混合精度里使用 BF16/FP16,甚至把 FFN 推到 FP8,以换取更高的训练效率。但工程实践同样残酷:越接近 “极限精度”,训练越容易出现难以解释的不稳定。


Flash Attention 是长上下文训练的关键加速组件,几乎成了标配。问题在于,社区长期存在一个可复现却难以解释的失败案例:


  • 用 FlashAttention + BF16 训练 GPT-2,一开始正常收敛,但在几千 step 之后突然 loss 爆炸。

  • 你可以通过回退到标准注意力、或把关键计算提高到 FP32 来 “救火”,但代价是吞吐和显存优势没了。


这类问题被报告了多年(相关 issue 在多个开源项目里反复出现),却一直缺少一条能 “从数值误差一路解释到 loss 爆炸” 的机制链。


论文做了什么:把 “锅” 一步步缩小到 FlashAttention 反传里的一个 图片


作者的做法很工程,且足够 “可复现”:


1. 严格复现失败:GPT-2(12 层、12 头、d=768、context=1024),OpenWebText 预训练;并且通过记录并重放相同的数据 batch 序列,把数据顺序带来的随机性排除掉。

2. 定位到层、头:用谱范数(spectral norm)等指标快速缩小范围,发现异常主要来自某一层注意力模块,甚至少数几个 attention head。

3. 定位到一个关键中间量:FlashAttention 反向传播为效率会计算图片


论文发现:只要让这里用到的 图片 走一条更 “高精度 / 等价但不同数值路径” 的计算方式,训练就能恢复稳定。换言之,训练炸掉的导火索并不是 “整个低精度训练”,而是非常具体的:低精度下图片的数值误差进入 图片,污染后续梯度。


机制解释 1:相似低秩结构,让误差变成 “持续推力” 而不是噪声


定位到图片之后,关键问题变成:为什么一个看似很小的数值误差,能在训练中被放大到灾难?


论文把高精度(hp)与低精度(lp)的梯度差写成一种很直观的形式:梯度误差与 图片 成正比,并被注意力里的某些项调制。进一步分解后,误差更新可以近似看作许多 rank-1 项的叠加;更关键的是,作者在实证中观察到:


  • 不同 token、不同训练 step 下,相关的矩阵结构出现了强相似性,可以抽象为一个共同的低秩方向 R。

  • 如果 图片 的系数在统计上还出现偏置(不是围绕 0 对称波动),那么误差不会抵消,而会沿着 R持续累积


结果就是:权重更新被 “带偏”,谱范数和激活异常增长,最终把训练推到 loss 爆炸。



低秩结构相似性与偏置累积(论文 Figure 4/5)


机制解释 2:偏置从哪来?safe softmax + BF16 舍入误差里藏着一个 “离散触发器”


第二条链更 “反直觉”,但也更关键:为什么图片会偏向同一方向?


作者把问题追到了 FlashAttention 前向里的未归一化输出:


  •  图片(safe softmax 的常见写法)

  • 图片

  • 图片


论文的关键观察是:图片在 BF16 下会出现系统性偏差,并且偏差的触发条件很具体:


触发条件:一行 score 出现 “多个相同最大值”,图片里会出现多个精确的 1


当 S 的某一行里,有不止一个位置等于行最大值时,这些位置在指数里都会变成 图片,也就是:图片 中会出现多个精确的 1(不是接近 1,而是 float 表示上真的等于 1)。


这看上去像个细节,但它会把后续 图片 的点积推到一个危险区间。


偏置来源:当 图片且某些维度的 V 以负数为主时,BF16 加法会系统性 “越加越负”


在某些特征维度上,V [:, i] 的分布可能以负数为主。此时当图片,乘积项就是 V [t,i] 本身(负的 BF16 数)。多个负数在 BF16 的加法舍入中会更容易触发尾数溢出、右移与 sticky bit 相关的舍入行为,导致误差贡献出现不对称,表现为:


  • 图片 相对 图片 更倾向于 “偏负”

  • 如果上游梯度 图片 在对应维度也倾向为负,那么在 图片里就会形成偏正的误差项

  • 偏正的 图片 再去驱动前一节提到的相似低秩方向 R,就形成 “越训越偏” 的闭环


当 图片 出现时,图片 的误差会发生明显 “负跳变”(论文 Figure 6)


极简修复:让 图片永远严格小于 1


既然问题的离散触发器是 图片 中出现精确的 1,作者给出的修复思路非常直接:


  • 检测一行 S 中最大值是否出现多次

  • 一旦出现 “重复最大值”,就动态调整 safe softmax 的行移位常数 m,让最大位置的指数也变成严格小于 1


论文给出的实现(概念上)如下:


rm = rowmax (S)
rs = rowsum (S == rm)  # 最大值出现次数
if rs > 1 and rm > 0: m = β * rm   (β > 1)
if rs > 1 and rm < 0: m = 0
else:                 m = rm
Pbar = exp (S - m)     # 从而 max (Pbar) < 1


这一步在精确算术下不改变注意力结果(softmax 对 “整行减常数” 不敏感),但在有限精度下能避免 图片 触发后续 BF16 累加的偏置舍入,从根上切断误差链。


实验结果:稳定训练不再 “突然炸”


论文在 BF16 设置下验证了上述分析与修复:


  • GPT-2S:使用修改后的 FlashAttention,在 AdamW 与 Muon 两种优化器下,都能稳定训练到 600K steps

  • GPT-2M:同样能在 AdamW 下稳定训练(论文展示到 100K steps)

  • 论文还提到该现象与结论在多种硬件上保持一致(包括 A100、RTX 4090、Ascend 910B)


验证集 loss 曲线对比(论文 Figure 7)


更重要的启示:别把低精度误差当成 “零均值噪声”


这篇论文的价值不只在 “修了一个 bug”,更在于给出了一个可迁移的诊断范式:


  • 数值误差未必是随机噪声。在特定分布与离散事件(如重复最大值、概率精确为 1)下,舍入误差可能形成系统性偏置。

  • 模型结构会放大偏置。注意力里涌现的相似低秩更新方向,让偏置误差更容易 “同向叠加”。

  • 经验修复为什么有效也能被解释:论文讨论了 attention sinks 与多最大值的关系,并给出了一个数值层面的连接;同时也指出一些稳定化技巧(如 QK normalization、Gated Attention)可能通过 “打散结构相似性” 来阻止误差同向累积。


作者介绍


邱海权是清华大学在读博士研究生,研究方向涵盖机器学习理论、表示学习与大模型机制分析。他的研究围绕模型表达能力、结构归纳偏置以及参数空间几何与优化动力学之间的内在联系展开,关注模型在不同结构约束与训练条件下的泛化行为与可组合性问题。整体上,他强调以可分析的理论框架刻画模型的能力边界与机制来源,从结构与原理层面理解深度模型为何有效、何时失效。


姚权铭,清华大学电子工程系副教授。长期致力于数据高效学习与智能体系统研究,在少样本学习、图学习、知识图谱与生物医药智能等方向取得系统性成果。发表 Nature 子刊、TPAMI、JMLR、ICML、NeurIPS、ICLR 等论文 130 余篇,被引 1.4 万余次。代表性工作包括抗噪学习算法 Co-teaching、小样本学习综述、自动化图学习方法及新药物相互作用预测模型。现任 TPAMI、TMLR 编委及 Neural Networks 资深编委,多次担任 ICML、NeurIPS、ICLR 领域主席,入选 IEEE Computing Top 30、IET Fellow 等。


© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:liyazhou@jiqizhixin.com


我认为具有一定的借鉴意义。文章的核心在于揭示了低精度训练中,数值误差可能被特定模型结构放大的机制。这种机制不一定只存在于FlashAttention中,其他Attention机制或模型结构也可能存在类似的风险。因此,文章的分析方法,即通过定位关键中间量、分析误差来源、以及寻找误差放大机制,可以推广到其他场景。

我认为可以考虑使用混合精度训练,即在模型的不同部分使用不同的精度。例如,可以将对数值精度要求较高的部分(如softmax)使用FP32精度,而将其他部分使用BF16精度。这样既能保证训练的稳定性,又能提高训练效率。此外,还可以尝试使用一些更先进的attention机制,例如Longformer或者Big Bird,这些模型通过稀疏attention来降低计算复杂度,同时也能提高模型的表达能力。

好问题!这“相似低秩结构”听起来就挺抽象的。我觉得它可能跟Transformer模型本身的学习方式有关,模型倾向于捕捉数据中的共性模式。如果数据集中存在某些重复出现的模式(例如,语言中的常用短语),那么attention机制可能会学习到一组通用的“基向量”,然后用这些基向量的线性组合来表示不同的输入。这就自然形成了低秩结构。当然,具体的数据集和模型结构也会影响这种结构的形成,例如,数据集的统计特性、模型的层数、attention头的数量等等。

这是一个trade-off问题,稳定性和表达能力往往不可兼得。QK normalization和Gated Attention这些技巧,本质上都是在模型中引入额外的约束,限制模型的自由度,从而提高训练的稳定性。但是,这种约束也可能会降低模型的表达能力,影响模型的性能。我觉得,更好的方法可能是从优化算法入手,例如使用AdamW等自适应优化器,或者采用梯度裁剪等技术,来控制训练过程中的梯度大小,从而避免loss爆炸。

与其说是修复方案的潜力,不如说是提供了一种排查问题的思路。以后遇到类似“玄学”问题,可以考虑是不是因为特定数值在低精度计算中引入了偏差。比如,ReLU激活函数在输入为0时导数为0,这可能会导致某些神经元“死亡”。虽然ReLU很简单,但是类似的简单操作也可能在特定情况下引发问题。

从线性代数的角度看,任何矩阵都可以分解为一系列秩-1矩阵的和。而attention机制中的query和key的交互,本质上就是矩阵乘法。如果query和key的分布比较集中,那么得到的attention矩阵可能就会呈现出低秩性。所以,我认为这是数据、模型和attention机制共同作用的结果。

我从理论角度补充一下,深度学习模型实际上是在学习数据的内在结构,而真实世界的数据往往具有一定的低秩性。例如,图像数据通常包含大量的冗余信息,视频数据在时间上具有连续性,这些都会导致模型学到的特征具有低秩结构。这种低秩结构有助于模型压缩和加速推理,但也可能限制模型的表达能力。如何在模型设计中更好地利用这种低秩结构,是一个值得深入研究的问题。

这让我想到了知识蒸馏!知识蒸馏就是把一个大模型(teacher model)的知识迁移到一个小模型(student model)上。如果 teacher model 中存在相似低秩结构,那么 student model 也很可能会学到类似的结构。这可以帮助 student model 更好地泛化,但也可能让 student model 继承 teacher model 的一些偏差。所以,在知识蒸馏中,如何选择合适的 teacher model 以及如何设计蒸馏策略,都非常重要。

论文中作者已经强调了,在精确算术下,safe softmax的修改不改变注意力结果。但在有限精度下,它能避免误差累积。我认为这种修改对模型精度的影响应该是很小的,甚至可能因为训练更稳定而提升精度。除了修改safe softmax,或许可以尝试一些梯度裁剪策略,或者使用一些对数值稳定性更友好的激活函数。另外,混合精度训练中,合理分配不同模块的精度也很重要。

这个让我想到了信号处理里的量化误差。把连续的信号变成离散的数字,总会引入误差。而且这种误差往往不是随机的,而是具有一定的分布规律。在深度学习里,低精度计算就像是一种特殊的量化,同样会引入非随机的误差。所以,我们需要更加小心地处理这些误差,避免它们对训练产生不利影响。

我觉得优化器可能是一个方向。比如,可以使用一些自适应的优化器,如AdamW,来缓解梯度爆炸的问题。或者,可以尝试一些二阶优化方法,利用Hessian矩阵的信息来更准确地更新权重。当然,这些方法可能会增加计算成本,需要在实际应用中权衡。

文章里解释得很清楚了,关键在于BF16的舍入误差。当一行score出现多个最大值时,softmax后的结果会出现多个精确的1,这时如果V矩阵的某些维度以负数为主,那么BF16加法就更容易出现“越加越负”的情况。这种现象应该不仅仅存在于BF16中,其他低精度计算也可能存在类似的舍入误差问题,只是具体的触发条件和误差累积方式可能有所不同。感觉可以研究一下INT8或者FP8是否存在类似问题。

论文中提到的“相似低秩结构”和“偏置累积”这两个机制非常重要。在设计模型结构时,应该尽量避免出现这种容易放大误差的结构。例如,可以尝试使用一些正交化的技术,降低参数之间的相关性。此外,还可以使用一些梯度累积的优化算法,例如LAMB,来缓解偏置累积的问题。

这个研究最大的启发是,不要把低精度误差简单地看作是随机噪声,要深入分析误差的来源和传播路径。很多时候,看似微小的误差,在特定条件下会被放大,最终导致训练失败。所以,在进行低精度训练时,需要更加细致地进行数值分析和调试。

我觉得这个结论应该具有一定的普适性。因为 FlashAttention 和 BF16 的组合在很多 Transformer 模型中都有应用。只要涉及到低精度计算和注意力机制,就可能存在类似的误差累积问题。当然,具体的影响程度可能取决于模型的结构、大小以及任务的特点。

从数学上来说,QK normalization可以降低注意力矩阵的秩,使得其奇异值分布更加平缓,从而降低低秩结构的影响。而Gated Attention则可以看作是一种集成学习方法,通过组合多个不同的注意力head,来提高模型的鲁棒性。

更进一步,这些技巧也可以看作是一种正则化方法,通过约束模型的参数空间,来提高模型的泛化能力。

这个问题问得好!除了论文中提到的方法,我想到一些其他的思路:

1. 增加噪声: 在softmax之前,可以给score矩阵S添加一些随机噪声,打破“重复最大值”的局面。当然,噪声的强度需要仔细调整,否则可能会影响模型的性能。

2. 正则化: 可以对注意力权重进行正则化,例如L1正则化或者L2正则化,鼓励权重更加稀疏,避免出现多个接近的最大值。

3. 改变激活函数: 可以尝试其他的激活函数,例如GELU或者ReLU,看看是否能够减少“重复最大值”的出现。

4. 量化感知训练: 在训练过程中,模拟低精度计算带来的影响,例如对权重和激活值进行量化,从而使模型更好地适应低精度环境。

这些方法我都还没有尝试过,只是抛砖引玉,希望大家可以一起讨论。

学术一点说,低秩结构指的是矩阵的秩远小于其维度。在Transformer中,注意力机制产生的矩阵往往具有这种特性,因为序列中存在大量的冗余信息。这意味着信息可以被压缩到少数几个维度上,但也意味着误差也会被集中到这些维度上,从而被放大。

更进一步,这种低秩性其实是深度学习模型的一种常见的现象。很多研究表明,深度神经网络的权重矩阵也具有低秩性。这可能是由于模型过度参数化导致的,但也可能是模型学习到的表示本身就具有低秩结构。理解这种低秩性对于我们理解深度学习模型的泛化能力和鲁棒性至关重要。

低秩结构可以理解为矩阵可以用几个主要的向量来近似表示,也就是说,矩阵的信息并没有分散在所有的维度上,而是集中在几个主要的维度上。在注意力机制中,这意味着不同的 token 在计算注意力权重时,它们的 query 向量和 key 向量的点积结果,在某些维度上是相似的,从而导致 attention 矩阵的秩较低。

注意力机制容易出现低秩结构的原因:

1. 语义相关性: 自然语言中,词语之间存在语义相关性。例如,“猫” 和 “动物” 在语义上是相关的,因此模型在计算它们之间的注意力权重时,可能会在某些维度上产生相似的结果。
2. 信息瓶颈: 注意力机制可以看作是一个信息瓶颈,它将输入序列的信息压缩到一个固定大小的向量中。为了有效地压缩信息,模型可能会学习到一些通用的模式,从而导致低秩结构。
3. 模型结构: Transformer 模型中的多头注意力机制,通过多个独立的注意力头来捕捉不同的信息。但是,如果这些注意力头学习到的信息是冗余的,那么也会导致低秩结构。