突破算力瓶颈:上下文并行与Ring Attention助力大模型训练百万Token

探索大模型如何突破百万Token上下文限制:上下文并行与Ring Attention技术解析。

原文标题:大模型如何训练百万 Token 上下文:上下文并行与 Ring Attention

原文作者:数据派THU

冷月清谈:

本文深入探讨了大模型训练中如何突破上下文窗口限制的关键技术,特别是针对百万Token级别上下文的处理方案。主要介绍了上下文并行Ring Attention两种技术,它们通过将序列在多个GPU上切分,并优化GPU间的通信,有效降低了内存占用,解决了传统并行策略在处理超长上下文时遇到的瓶颈。文章还对比了序列并行与上下文并行,阐述了Zig-Zag Ring Attention在负载均衡方面的优势,并对训练百万Token上下文模型所需的硬件配置进行了建议。总而言之,这些技术的出现,使得处理大规模代码库、海量论文和长时间对话记录成为可能。

怜星夜思:

1、Ring Attention通过环形传输K/V,实现了计算和通信的重叠,但这种方式是否会引入新的延迟?如果环上的GPU数量过多,会不会导致等待时间过长,反而降低效率?
2、文章提到Zig-Zag Ring Attention解决了因果注意力机制下的负载不均衡问题,那么,在非因果注意力场景下(比如BERT),是否还需要Zig-Zag的策略?或者说,有没有其他更适合非因果场景的优化方案?
3、上下文并行虽然能扩展上下文长度,但同时也带来了更高的通信成本。除了采用更快的互连技术(如NVLink、InfiniBand)外,还有没有其他方法可以降低上下文并行带来的通信开销?例如,能不能通过一些压缩或量化的方法来减少传输的数据量?

原文内容

图片
来源:DeepHub IMBA
本文约2000字,建议阅读5分钟
上下文并行本质上是拿通信开销换内存空间,而网络带宽是最要命的瓶颈。


只用了几年时间,上下文窗口就从 4k  膨胀到 1000 万。Meta 发布的 Llama 4 Scout 的时候说这个模型支持 1000 万 Token,是 Llama 3 那 128k 的 78 倍。而Google Gemini 3 Pro 是 100 万,Claude 4 也桐乡市100万。

一次推理跑完整个代码库、几百篇论文、连续好几天的对话记录在技术上可行了,但问题是硬件跟不上。

405B 参数的模型,32 位精度下光权重就要 6.5TB 内存。再算上梯度、状态、激活值,后者还随上下文长度二次方增长。单台 NVIDIA HGX B300 配了 2.3TB HBM3e都不够。

这就逼着必须做多节点分布式训练和推理,几十上百块 NVIDIA Blackwell GPU 、NVLink 再加上 InfiniBand,就成了数据中心的标配。所以难点就变味了 GPU 之间的通信瓶颈。

并行化基础


模型或数据集超出单卡容量,就得上并行策略,但是每种策略本质上都是拿通信开销换内存空间。

数据并行是最直接的方案:整个模型复制到每张卡上,训练数据切开,每张卡跑不同的 batch跑完一步同步梯度。适合小模型,计算是瓶颈、内存不是问题的场景。

模型并行针对大模型:单卡装不下,就把模型拆开,不同的层放不同的卡上,按顺序跑。405B 这种规模只能这样,并且下游的卡得等上游算完中间是有空转的。

张量并行更极端:连单个矩阵乘法都塞不进一张卡。就需要把矩阵按行或按列切开,分到各卡上算,再通过 all-reduce 合起来。

但这些都有共同的局限。模型大、上下文又长到几百万 Token,张量并行也顶不住。因为注意力的二次方内存增长太凶,激活值直接占满显存。128k 上下文的激活值内存是 8k 的 16 倍,这个目前没办法,因为就是这么夸张。

上下文并行与序列并行


序列并行和上下文并行都是在设备间切序列来省内存,但切法不一样。

序列并行配合张量并行使用,只切那些非矩阵乘法的操作,比如层归一化、dropout。张量并行管不到的地方,序列并行接手,每张卡处理一部分激活值。两者配合能把序列撑长一些,但到 128k 以上还是会有问题,因为注意力的二次方增长是绕不过去。

上下文并行更彻底:整个序列在所有模块里都切开,包括注意力。每个操作拿到的都是分区后的序列。百万级上下文的训练就靠这个,把激活值的内存占用分摊到各卡上。

注意力一直是最麻烦的问题,因为模型的其他操作基本都是逐 Token 独立处理并行起来很自然。但注意力不行,每个 Token 都要"看"序列里所有其他 Token。序列切到多张卡上之后,GPU 1 的 Token 怎么看 GPU 2 的 Token?直接等数据传完再算,整个流水线就卡住了。

Ring Attention 就是来解决这个问题的,让多节点多卡的大模型训练和推理能在大规模数据中心里跑起来。

Zig Zag Ring Attention:通信和计算重叠


Ring Attention 把 GPU 组织成环形拓扑。每张卡的工作流程是这样的:持有序列中 Q、K、V 张量的一个分块;用本地的 K 和 V 给自己的 Q 分块算注意力;把 K 和 V 传给环里的下一张卡;从上一张卡接收 K 和 V;循环往复,直到所有 Q Token 都跟所有 K/V Token 算完注意力。

关键在于计算和通信是重叠的。GPU 1 拿着当前的 K/V 分块算注意力的时候,同时在从 GPU 0 接收下一批分块。通信延迟减少了,因为不用干等数据全到了再开算。

GPT 这类自回归模型有个额外的麻烦:Token 只能看前面的 Token不能看后面的。所以会导致负载不均衡有些卡会空转,Zig-Zag Ring Attention 解决这个问题的办法是交错分配,不是按顺序切块而是 GPU 0 拿 Token [0, 4, 8...],GPU 1 拿 [1, 5, 9...],以此类推。每张卡都拿到早期和晚期 Token 的混合,因果注意力计算时负载就均衡了环里不会有卡闲着。

但是代价是索引逻辑稍微复杂一点,不过大规模场景下性能收益很可观,因果掩码下也能做到接近满 GPU 利用率。

上下文并行与 Ring Attention 常见问题


上下文并行把输入序列切到多张 GPU 上,突破训练时的内存限制。跟张量并行、数据并行不同,它在所有模型模块里都切序列维度。单卡装不下的百万级 Token 上下文,只有靠这个才能训。

Ring Attention 把 GPU 排成环,每张卡一边算当前数据的注意力,一边把键值对往下传。通信和计算重叠,全对全的注意力计算不用等完整序列数据到齐,GPU 不会干等。

而序列并行只切非矩阵乘法操作(层归一化之类的),配合张量并行用。上下文并行在所有模块里都切序列,包括注意力。超过 128k Token 的上下文必须用后者,因为激活值内存二次方增长太猛了。

为什么 Zig-Zag Ring Attention 比标准 Ring Attention 更好?

Zig-Zag 用交错分配代替顺序分配,因果掩码计算时各卡负载更均衡。标准 Ring Attention 会让后面的卡等前面的分块,造成计算空闲。Zig-Zag 把早期和晚期 Token 均匀撒到各卡上,避免这个问题。

那么训练百万级 Token 上下文的模型需要什么硬件?

多节点 GPU 集群,配 HBM 内存,加高速互连——NVIDIA NVLink 1.8TB/s 或者 InfiniBand。405B 参数模型 32 位精度从头训练加推理,4 台 NVIDIA HGX B300 的机架部署是个不错的起点。

总结


上下文并行本质上是拿通信开销换内存空间,而网络带宽是最要命的瓶颈。Ring Attention 要在 GPU 之间不停交换键值对,传输时间一旦超过计算时间,各卡就会从"边算边传"退化成"等数据"。NVIDIA NVLink 1.8TB/s 加 InfiniBand 的高速互连,在多机架部署里不是可选项是必需品。互连带宽必须匹配 GPU 计算吞吐量,否则上下文并行的效果会大打折扣。

by Khang Pham

编辑:文婧



关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU


我觉得这个问题要辩证地看。一方面,对于参数量很小的模型,可能单卡GPU就足够了,并行化带来的收益可能还不如同步和通信的开销。另一方面,如果硬件资源比较紧张,或者需要加速推理过程,即使是小参数模型,也可以考虑使用并行策略。比如,可以使用TensorRT等工具对模型进行优化,然后利用多卡进行并发推理,从而提高吞吐量。总之,要根据实际情况进行权衡,选择最合适的策略。

这个问题问得好!Ring Attention虽然巧妙地重叠了计算和通信,但在GPU数量非常多的情况下,确实可能引入额外的延迟。想象一下,每个GPU都要等一圈才能拿到所有K/V的信息,如果这个环太大,通信时间就会变得很显著,反而不如直接传输快了。所以,实际应用中需要根据GPU的性能和网络带宽,仔细权衡环的大小,找到一个最佳平衡点。

我觉得可以参考一下联邦学习里的差分隐私技术,在保证一定隐私性的前提下,降低需要传输的数据量。 不过这块我了解的不是很多,不知道在上下文并行里是否适用,抛砖引玉了。

BERT这种非因果模型,本身就可以并行计算,不存在需要’看前面的token’才能计算的问题,所以感觉没必要用Zig-Zag了。 个人觉得可以考虑针对attention矩阵本身的稀疏性做优化,减少不必要的计算。现在很多attention加速的工作,比如Sparse Attention,Linformer,都是这个思路。

我不太同意楼上的观点。即使在非因果注意力场景下,Zig-Zag策略仍然可能带来一些好处。例如,它可以提高GPU的利用率,减少空闲时间。此外,Zig-Zag还可以降低通信延迟,因为每个GPU都可以更快地获取到所有需要的信息。当然,具体是否需要Zig-Zag策略,还需要根据具体的模型和数据进行评估。

楼上的分析很到位!我认为可以从两个方面来看待这个问题:一是优化通信本身的效率,比如采用更快的NVLink或InfiniBand互连;二是改进Ring Attention的算法,例如通过更智能的调度策略,减少GPU的等待时间。此外,还可以考虑将Ring Attention与其他并行策略结合使用,以进一步降低延迟。