清华大学揭示BF16 FlashAttention训练崩溃之谜:极简修复方案助力稳定训练

这个问题很有意思!感觉Transformer里各种Normalization层,比如LayerNorm,也挺容易出问题的。Normalization本质上是对数据分布做调整,如果低精度下统计量(均值、方差)计算有偏差,或者Normalization本身引入了额外的舍入误差,都可能对后续的梯度产生意想不到的影响。诊断的话,我觉得可以像这篇论文一样,先定位到具体的层或者模块,然后仔细分析中间变量的数值分布,看看有没有出现异常的偏置。

感觉这个现象挺常见的,尤其是在上下文比较长的时候。因为模型倾向于给一些“重要”的token分配比较高的attention score,而这些重要token可能不止一个。从数据层面来说,可以尝试做一些数据增强,比如随机mask掉一些token,或者用同义词替换,来增加数据的多样性,避免模型过度依赖某些特定的token组合。模型设计层面,可以考虑用sparse attention或者longformer之类的结构,来降低attention矩阵的秩,减少出现多个相同最大值的概率。

我最近在做相似文本检索,发现余弦相似度也经常出现多个最大值的情况。这可能和embedding的维度有关,如果embedding维度太低,不同文本的表示就容易挤在一起。所以,我觉得增加模型的capacity,比如增加embedding的维度或者增加模型的层数,也可能有帮助。

从原理上讲,safe softmax 通过引入一个小于1的缩放因子,避免了softmax输出出现绝对的1,从而减轻了后续计算中BF16舍入误差带来的累积效应。这就像在电路中增加了一个限幅器,防止信号过强导致失真。

个人认为,更通用的思路是提升硬件本身的计算精度。例如,有些新型加速器开始支持更灵活的混合精度计算,可以根据不同的计算需求动态调整精度,从而在保证效率的同时,提高数值稳定性。此外,还可以探索一些新的数值表示方法,例如 posit,它在小数值和大数值上都具有更高的精度,有望在低精度训练中发挥更大的作用。

我感觉这种方法比较巧妙,因为它只是在数值计算层面做了一些微调,并没有改变模型的结构和参数。如果调整的幅度比较小,应该不会对模型的表达能力产生太大的影响。不过,在实际应用中,最好设置一个合适的阈值,避免过度调整行移位常数,导致模型性能下降。

我觉得这反映了我们对大模型内部机制的理解还远远不够深入。很多时候,我们只能凭借经验和直觉来解决问题,缺少一种系统性的方法论。未来,我们需要加强对模型内部表示、梯度流动等方面的研究,建立更加完善的理论框架,才能更好地诊断和解决这类“玄学”问题。

这个问题有点硬核啊!safe softmax 这个东西,我理解是为了防止overflow。但是没想到在 BF16 下,好心办坏事,反而成了bug的诱因。这告诉我们,工程上的trick,还是要知其所以然,不然就容易踩坑。话说回来,这个结论的得出,感觉还是实验+经验主义。更深入的理论分析,可能要请教数学大佬了。

同意楼上的观点,低秩结构相似性本质上是信息冗余的一种表现。除了RNN和CNN,我觉得像自编码器(Autoencoder)这种试图学习输入数据低维表示的模型,以及生成对抗网络(GAN)中的生成器和判别器,也可能因为学习到的低秩表示而放大误差。此外,一些模型压缩技术,比如知识蒸馏,如果学生模型过度拟合教师模型的低秩表示,也可能导致误差累积。关键在于找到误差放大的“共振点”。

我觉得这个修复方法有点像“打补丁”,虽然有效,但不够优雅。一个更通用的思路是,从优化算法入手,例如使用二阶优化算法(如AdamW)或者一些自适应学习率调整策略,来更好地处理低精度训练中的梯度噪声。另外,也有研究表明,使用不同的初始化方法可以改善低精度训练的稳定性。总之,这是一个系统性的问题,需要综合考虑模型结构、优化算法和硬件特性。

理论上,这个方法在精确算术下不改变结果,但在有限精度下,微小的改变确实可能对模型的表达能力产生影响。不过,文章的实验结果表明,这种影响是可接受的,至少在GPT-2上是这样。更通用的解决方案可能包括:1)使用更高精度的浮点数进行计算(但这会降低训练效率);2)引入更强的正则化,防止权重过度集中;3)探索更鲁棒的注意力机制,例如使用不同的归一化方法或者引入随机性。

从信息论的角度来看,所谓“系统性偏置”,本质上是模型学习到了数据中隐藏的 correlation。要缓解这种偏置,就需要在训练过程中引入更多的“信息熵”。比如,可以使用 dropout、mixup 等技术,增加模型学习的随机性。另外,还可以尝试使用对比学习等方法,让模型学习到更加普适的特征表示。

换个角度想,这篇文章其实是在告诉我们,不要想当然地认为低精度训练带来的误差就是随机噪声。在某些特定情况下,这些误差可能会被放大,导致严重的后果。所以,在使用低精度训练时,我们需要更加谨慎,要对模型进行充分的测试和验证,确保模型的稳定性和可靠性。同时也提醒我们,在优化模型性能的同时,也要关注模型的鲁棒性,避免出现“一招鲜,吃遍天”的情况。

从信息论的角度来看,这些稳定化技巧可能是在增加模型学习过程中的熵。QK normalization 和 Gated Attention 相当于在 attention 分布中引入了更多的随机性,降低了信息集中度,从而减弱了误差的累积效应。其他一些dropout、label smoothing等技巧,本质上也是在增加模型的熵,提高模型的鲁棒性。但是,熵也不是越高越好,过高的熵可能会导致模型学习效率降低。所以,如何在稳定性和学习效率之间取得平衡,是一个需要权衡的问题。

QK normalization,简单来说就是对Query和Key进行归一化,这样可以控制attention score的分布,避免出现某些score过大,从而降低低秩结构中相似性的影响。而Gated Attention则通过引入门控机制,动态地调整attention的权重,从而打散结构相似性,阻止误差累积。其他有效的稳定训练技巧还有很多,比如梯度裁剪、权重衰减、Layer Normalization等等。梯度裁剪可以限制梯度的最大值,避免梯度爆炸;权重衰减可以防止权重过大,提高模型的泛化能力;Layer Normalization可以加速训练,提高模型的鲁棒性。总的来说,稳定训练是一个系统性的问题,需要综合考虑模型结构、优化算法、数据预处理等多个方面。

我觉得这项研究最大的价值在于,它提供了一种诊断和解决低精度训练中数值不稳定问题的思路。以前我们遇到这类问题,往往只能靠经验或者“炼丹”,很难找到根本原因。而这篇论文通过严谨的实验和分析,一步步地缩小问题范围,最终找到了问题的根源,并提出了有效的解决方案。这种研究方法对于我们未来解决大模型训练中的其他问题,也具有重要的借鉴意义。

楼上说的有道理,我补充一点。这个“相似低秩结构”本质上是模型中的一些参数,它们的变化会高度相关。这种相关性可能源于模型的固有结构,也可能源于训练数据。比如,在处理语言数据时,某些词语或短语可能会频繁共现,导致模型在学习这些词语或短语的表示时,参数更新方向趋于一致。除了CNN和RNN,一些图神经网络也可能存在类似的问题,因为图结构本身就可能引入节点之间的相关性。我觉得理解这一点,对于我们设计更鲁棒的模型结构很有帮助。

我理解的“打散结构相似性” 就是减少模型在更新时的“共振”现象。你想啊,如果大家都朝着一个方向使劲,一旦方向错了,那就会错的离谱。QK normalization 像是给高亢的声音加了个均衡器,让大家的音量更平均,Gated Attention则像是引入了“调音师”,时不时调整一下各个声部的音量,避免共振。除了这些,还有一些“玄学”操作,比如warm up,让模型在初期慢慢适应数据,避免一下子被“冲垮”。