DeepSeek-V3核心架构解析:轻量级模型MLA与MoE实现指南

深入解析Mini DeepSeek-V3模型MLA与MoE核心架构,详释高效训练与推理实现。

原文标题:原创 | 一文读懂DeepSeek-V3(上)

原文作者:数据派THU

冷月清谈:

本文深入剖析了在有限算力下构建Mini DeepSeek-V3模型架构的核心实现细节。文章重点介绍了多头低秩注意力(MLA)在查询、键值、矩阵吸收及输出计算等前向传播关键步骤,揭示其如何通过优化提升计算效率。接着,详细阐述了混合专家模型(MoE)的组成,包括MLP、专家和门控网络的协同工作。尤其重要的是,文中深入解析了无辅助损失的负载均衡策略以及序列级辅助损失如何在MoE部分实现,以确保专家利用率的均衡,并有效提升模型训练与推理的效率。文章为理解DeepSeek-V3的底层机制提供了详尽的代码级实现思路。

怜星夜思:

1、文章里提到DeepSeek-V3的MLA通过低秩压缩和权重吸收来提高效率,感觉挺巧妙的。但在实际搞大模型的时候,除了DeepSeek-V3这种思路,你觉得还有哪些attention机制的优化,能显著提升推理速度或者减少显存占用,特别是面对超长上下文或者大Batch Size的时候?大家有没有在项目里尝试过什么黑科技?
2、MoE模型里面专家负载均衡一直是个老大难问题,DeepSeek-V3这篇提了无辅助损失和序列级辅助损失两种策略。抛开这些具体实现,你觉得未来MoE模型的负载均衡还有哪些更具创新性或者更普适性的优化方向?比如,能不能根据不同任务动态调整专家分配,或者有啥能确保稀疏激活下性能最大化的新招?
3、文章教我们怎么用小算力实现Mini DeepSeek-V3,对于我们这些资源有限的个人开发者太有用了!除了文章里说的缩小模型规模,大家在尝试复现或者学习这种复杂大模型的时候,还有哪些特别实用的小技巧或者工具,能帮助我们更好地在自己的笔记本或者单张显卡上跑起来并真正理解它们?求分享实用经验!

原文内容

作者:王坤擎
本文约6500字,建议阅读10+分钟

本文主要介绍如何利用较小算力资源,实现一个Mini-DeepSeek-V3模型架构,需要对DeepSeek-V3相关理论具有一定了解。

 

完整的DeepSeek-V3涉及到庞大的工程优化和算力资源,本文旨在基于较小的算力,从MLADeepSeekMoEMTP、无辅助损失的负载均衡策略、序列级辅助损失等核心架构来实现一个Mini-DeepSeek-V3的训练和推理Demo。由于模型规模较小,采用单卡或DDP的方式训练,且不考虑使用YaRN来扩展上下文长度,使用原始RoPE


一、MLA实现


为了解释清楚MLA的具体实现,我们将代码拆解成小块逐个讲解。首先,给出MLA的整体实现框架:



(一)定义和初始化


创建MLA类,所需要传入的参数如下:



下面的表格更加直观的列出上述参数和论文中变量的对应关系:



捋清楚这些变量的含义非常重要,结合MLA的理论,就基本可以构想出后续forward方法的实现。


(二)MLA的前向传播


首先给出完整代码:



forward方法接收的参数包括输入序列x、当前推理步骤的起始位置start_pos、复数RoPE矩阵freqs_cis和掩码mask。接下来,分块来一步一步看forward代码,核心代码可分为五部分。


1. query部分



1行:对原始向量依次进行下投影、RMSNorm和上投影+解耦,对应的公式为:



输出变量q实际上是:


2行:划分注意力头数,即:


图片

其中,

是最终执行注意力计算的每个query头,且此时解耦的部分还未加入RoPE


3行:将

切分为不需携带位置编码的q_nope和需要携带位置编码的q_pe两部分。


4行:为需要携带位置编码的q_pe应用RoPE位置编码,其形状不改变。

至此,完成了下图中红框所示的部分:



2. key/value部分



1行:对原始向量下投影+解耦,即:



输出变量kv实际上是:


2行:将

切分为潜在向量kv和需要携带位置编码的k_pe两部分。


3行:首先为k_pe添加头数维度,从而适应apply_rotary_emb函数,然后为其应用RoPE位置编码,形状不改变。


至此,完成了下图中红框所示的部分:


3.第一次矩阵吸收



1行:获取self.wkv_b的权重矩阵,在nn.Linear中,权重矩阵的形状为(out_features, in_features),因此wkv_b的形状为(n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank)


2行:进一步将wkv_b变为(n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank)。我们知道,wkv_b实际上是

,这一行的目的是为了方便后续切分出,它们分别需要吸收到中。


3行:在query部分的代码实现中,我们得到了q_nope: (batch_size, seq_len, n_heads, qk_nope_head_dim),它实际上是



因此有:



其中,对于第i个注意力头,

的分块矩阵。同理,正常来说,在key/value部分,本应计算出k_nope,类似的应该有:



当我们聚焦于第i个注意力头的计算时,假设位置tquery对位置jkey做点积计算,会有:



其中,图片是每个注意力头

吸收了得到的新矩阵。因此,对所有注意力头而言,有新的,那么在矩阵吸收之后,新的q_nope变为图片,即新q_nope的形状变为:(batch_size, seq_len, n_heads, kv_lora_rank)


因此,矩阵吸收后,在key/value部分中是不用计算出k_nope的,直接使用潜在向量kv即可。上述第3行代码是通过爱因斯坦求和约定:einsum来实现矩阵吸收这一过程的。


至此,完成了下图中红框所示的部分:

4.注意力实现



我们从推理部分的计算开始看起:


6行:为kv应用RMSNorm后,缓存至kv_cache

7行:在key/value部分中,为了给k_pe添加位置编码,为其添加了head维度,形状变为(batch_size, seq_len, 1, qk_rope_head_dim),因此这里将head维度去除后,缓存至pe_cache

89行:计算注意力分数,这里没有将无位置信息的nope部分和有位置信息的pe部分拼接起来再计算注意力,而是分别计算nope部分和pe部分的点积,然后相加,并乘以注意力计算的缩放因子。这是为了避免不必要的数据移动和冗余计算,从而提高计算效率。

1011行:应用掩码mask

12行:对最后一个维度应用softmax,将点积转换为权重。

训练阶段(第1-4行)与推理阶段类似,只是不用进行缓存操作。


至此,完成了下图中红框所示的部分:


5.计算输出+第二次矩阵吸收



同样从推理阶段开始看(训练阶段同样只是不用执行缓存操作):


4行:在计算输出时,会进行

的矩阵吸收。仍聚焦于第i个注意力头,对位置t的输出,有:



其中,

表示在第i个注意力头中,当前位置t对位置j计算得到的注意力权重。为每个注意力头的输出向量。因此,最终的输出为:



其中,各变量的形状为:



4行代码首先计算了scoreskv_cache相乘

5行:从wkv_b中切分出wkv_b[:, :self.v_head_dim],即形状为(n_heads, v_head_dim, kv_lora_rank)的部分,它代表了每个头的

6:将上述结果进行输出维度转换,即(batch_size, seq_len, n_heads, v_head_dim)转换为(batch_size, seq_len, dim)


至此,完成了下图中红框所示的部分,MLA部分代码完成:


 

二、MoE实现


MoE部分主要包括四个类,分别是MLPExpertGateMoE无辅助损失的负载均衡策略序列级辅助损失均在此部分实现,由于源代码未开源训练部分,这两块由本人根据论文理解实现,仅供参考


(一)MLPExpert


DeepSeek-V3源码中,MLPExpert类的结构是完全一致的,只是做了用处的区分。原DeepSeek-V3中,前3层是Dense Layer,这是因为前面几层的负载均衡收敛较慢,MLP用于构建Dense Layer的前馈网络。此外,MLP也用于实例化共享专家。而Expert则专门用于实例化路由专家。这里只列举MLP的代码如下:



其中,F.silu函数是β=1时的SwiGLUSwiGLU结合了SwishGLU两者的特点。


1. Swish


Swish是一个非线性激活函数,定义如下:



其中,β为可学习参数。Swish可以比ReLU激活函数更好,因为它在0附近提供了更平滑的转换,这可以带来更好的优化。下图为不同β值对应的Swish激活函数图像:



2. GLU


GLUGated Linear Unit)定义为两个线性变换的分量积,其中一个线性变换由sigmoid激活。它其实不算是一种激活函数,而是一种神经网络层。它是一个线性变换后面接门控机制的结构。其中门控机制是一个sigmoid函数用来控制信息能够通过多少,定义如下:



LLM中常用的SwiGLU其实就是采用Swish作为激活函数的GLU变体:


图片

使用SwiGLU函数构造一个前馈网络,不使用偏置项,有:


图片

其结构如下图所示:



(二)Gate


Gate主要用于动态路由,其代码如下:



Gate类中,我们的目标是返回当前token选中专家的门控权重,和选中专家对应的索引,从而进行下一步计算。同时,在选中专家的过程中,我们会应用无辅助损失的负载均衡策略,即为亲和度得分添加一个根据过往专家负载情况来更新的偏置bias。此外,原文中除了使用无辅助损失的负载均衡策略,还使用了节点路由限制,一方面是为了保证不同节点的负载均衡,另一方面也是为了节省通信开销。由于本文的模型规模较小,因此所有专家都在一个GPU上。但仍可以通过类似的思想来实现专家选择上的负载均衡。


我们首先列出Gate运算的大致过程:


  • 对专家进行分组n_groups个组;

  • 每个组计算2个最大亲和度得分之和,其中,亲和度得分可以使用bias来调整;

  • 根据上述结果,选出得分最大的topk_groups个组;

  • 从上述topk_groups个组的所有专家中选出topk个专家,也就是最终需要激活的专家。


需要注意的是,Gate中的无辅助损失的负载均衡策略属于应用部分,即只负责给亲和度得分加入biasbias更新部分的逻辑我们在MoE中实现。bias使用nn.Parameter初始化为0,因此它会作为模型参数的一部分,但是它是不需要梯度的,因为bias的更新逻辑实际上是根据过往的专家负载情况来动态更新的,而不是通过loss


我们主要看前向传播部分的代码,可以大致分为三个部分


1.计算亲和度分数


forward的输入张量形状为(batch_size * seq_len, dim),在输入前已经在外部调整好了形状,后续代码中我们会看到。


图片


为第ttoken的输入,这两行代码对应的公式为:



其中:

-

token与第i个专家之间的亲和度得分,即某个token被分配给某个专家的概率或权重;

-

:第i个路由专家的质心向量,用于衡量token和专家的匹配程度。

可见,这里的质心向量实际上就是初始化的self.weight,它会在训练中学习到。scores的形状为(batch_size * seq_len, n_routed_experts)


图片


这里对求得的scores进行了两个赋值。第一个赋值是因为在后续的MoE中,实现序列级辅助损失时,需要用到token对专家的原始得分,因此需要保存下来,以便后续使用。第二个赋值是为了避免后续对scores的原地操作,从而导致梯度回传时出现问题。


图片


这里对原始分数加上了bias,从而能够影响后续对专家的选择。


2.专家分组


图片

若对专家进行了分组,则将原始得分的形状由(batch_size * seq_len, n_routed_experts)变为(batch_size * seq_len, n_groups, n_routed_experts_per_group)



如果没有应用无辅助损失的负载均衡策略,就选取每一组最大的得分作为这组的得分,如果使用了无辅助损失的负载均衡策略,就选组每一组top 2的得分之和作为这一组的得分。


图片


topk函数会返回一个元组,即(values, indicies),这里从所有组的得分中,选出topk_groups个组的索引,后续将从这几个组的所有专家中,选出最终的topk个专家。


图片

首先创建一个形状为(x.size(0), self.n_groups),即(batch_size, n_groups)的全True(全1)的mask张量。然后使用scatter_()函数将选中的组标记为Falsescatter_()的作用是:


- index 指定的位置,将 value 的值填充到目标张量。

-沿dim 维度 进行填充(例如 dim=0 按行,dim=1 按列)。

这样,mask对选中的组为False,对未选中的组为True


图片


最后,首先将mask增加最后一个维度,变为(batch_size * seq_len, n_groups, 1),以适应分数张量(batch_size * seq_len, n_groups, n_routed_experts_per_group),将对应maskTrue的,也就是未选中的组的所有专家得分置为负无穷,这样就只保留了选中组的所有专家的得分,并展平为(batch_size * seq_len, n_routed_experts)


3.计算权重和索引


图片


从选中组的所有专家中,选出topk个专家的索引,这就是最终确定需要激活的专家。



gather()函数用于按照指定索引index和维度dim提取数据。提取出的数据形状为(batch_size * seq_len, topk),将这topk个专家得分进行归一化,并进行缩放(缩放self.rout_scale默认为1,根据需要调整),就得到了每个专家的权重。最后将权重和选中专家索引返回,用于下一步计算。


(三)MoE


MoE基于上述类构造,并加入了无辅助损失负载均衡策略的bias更新逻辑和序列级辅助损失逻辑,完整代码如下:



MoE的前向传播部分同样分割为个部分 


1.变量准备



这几行代码均用于准备或初始化后面需要用到的变量,其中x被重新划分为形状(batch_size * seq_len, dim),然后输入到Gate中,获取到的weightsindices形状均为(batch_size * seq_len, topk)global_counts用于记录每个批次里全局的专家激活次数情况,这里的“全局”意思是,如果使用DDP训练,记录的是所有GPU上专家激活次数的总和。


图片

bincount()函数用于计算非负整数张量中每个值的出现次数。indices.flatten()将每个token激活的专家索引由(batch_size * seq_len, topk)展平为(batch_size * seq_len * topk),参数minlength指定了输出张量的最小长度,使在当前的indices.flatten()中,某些专家索引可能一次都没有出现,设置minlength=self.n_routed_experts可以确保输出的counts张量长度一定等于总的专家数量。如果某个专家的索引i(其中i < self.n_routed_experts)在输入中没有出现,那么输出counts张量中对应位置counts[i]的值将是 0。综上,counts保存了一个batch里每个专家对应的激活次数。


2.无辅助损失负载均衡策略


无辅助损失负载均衡策略bias更新的过程如下图所示,该过程只在训练时使用:



在经过Gate后,本batch的专家负载情况就确定了,因此能够根据本轮的负载情况调整bias的值,从而使下一个batch的负载情况更加均衡。如果当前处于DDP训练环境,那么每个GPU的专家负载情况是不同的,那么每个GPU分别更新bias的值也会不同,因此要基于所有GPU的专家激活情况来统一确定如何更新bias。首先,将本GPU的负载计数counts拷贝给global_counts,如果当前处于DDP训练环境,就收集所有GPUglobal_counts,得到全局负载情况。最后计算全局所有专家的平均激活情况avg_counts



如果当前是DDP训练中的主进程,或者当前是使用单卡进行训练,那么计算平均负载情况与每个专家实际激活的差值,并基于此差值和偏置更新速度来计算新的bias值。最后,如果是DDP,就将这个新的bias广播给所有的GPU,这样就确保了每个GPU的模型参数更新是一致的。以下计算流程是源于论文Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts中的算法:



3.序列级别的辅助损失


无辅助损失的负载均衡策略主要关注全局的专家负载均衡,确保在整个batch级别上,专家的负载相对均衡。但在序列级别(Sequence-Wise)上,仍然可能出现负载不均衡的情况。例如,一个输入序列中的多个token可能会集中分配给某些专家,导致这些专家在单个序列内负载过高。序列级别的辅助损失只在训练时计算。


首先回顾序列级辅助损失的计算公式:



其中:



首先将Gate的原始得分从(batch_size * seq_len, n_routed_experts)变为(batch_size, seq_len, n_routed_experts),此即原始的si,t,沿着n_routed_experts的方向归一化,形成s'i,t。而后沿着token的方向求平均,得到Pi
,形状为(batch_size, n_routed_experts),含义为第i个专家在一个序列中每个token上的平均归一化亲和度得分。



现在来计算fi,即第i个专家在一个序列中每个token上的平均激活次数。indices的初始形状为(batch_size * seq_len, topk),表示一个batch中每个token激活了哪些专家。现在我们需要计算的是在一个batch的每个序列中,每个专家被哪些token激活,可以使用one-hot编码来实现这一过程。


- 上述代码第1行:首先将indices形状变为(batch_size, seq_len * topk),而后使用one-hot编码,类别数为n_routed_experts,得到形状为(batch_size, seq_len * topk, n_routed_experts)ont-hot编码。


-上述代码第2行:沿着seq_len * topk维度相加后,可求出每个专家被多少个token激活,得到形状(batch_size, n_routed_experts),得到了

图片


- 上述代码第3行:乘以

系数,得到第i个专家在一个序列中每个token上的平均激活次数。


最后,根据

求得当前层MoE所计算出的序列级辅助损失。直观理解上, Pi可由调整模型权重来改变,而fi是由Pi导致的客观结果,专家的得分大,自然被激活的次数就多。因此,若在一个序列中,各个token最终激活专家 的频率很大,那么该专家的得分就应该减小,反之亦然,从而鼓励每个序列上的专家负载变得均衡。


4.计算专家输出



遍历counts,如果第i个元素不为0,说明第i个路由专家被激活了。indices的形状为(batch_size * seq_len, topk)torch.where(indices == i)用于找到激活了第i个专家的tokenidx代表行索引(即第几个token),top代表列索引(即该tokentop几选择),idxtop的类型为torch.Tensorx的形状为(batch_size * seq_len, dim),将x中索引为idxtoken输入到它激活的专家expert中,同时乘以其对应的权重weights[idx, top, None],将其赋值给前面初始化的y。遍历完counts之后,y中只保留的激活专家的输出值,未激活的则为0


图片


z计算出共享专家的输出,而后将共享专家和路由专家相加,并转换为原始的(batch_size, seq_len, dim)返回。此外,还返回了序列级辅助损失和全局专家负载情况,前者用于后续收集各层的总loss,最终用于梯度计算,后者用于输出到模型外部,记录分析每层专家的负载情况。


结语


项目具体的训练代码等在此不做介绍了,可在https://github.com/WKQ9411/Mini-LLM查看所有代码。最终效果如下:


演示视频

https://github.com/user-attachments/assets/af546e22-5c8a-4524-9bad-746909ed49d5

参考链接

1. [DeepSeek-V3 technical report]( http://arxiv.org/abs/2412.19437)

2. [DeepSeek-V2: A strong, economical, and efficient mixture-of-experts language model]( http://arxiv.org/abs/2405.04434)

3. [DeepSeekMoE: Towards ultimate expert specialization in mixture-of-experts language models]( http://arxiv.org/abs/2401.06066)

4. [Auxiliary-loss-free load balancing strategy for mixture-of-experts]( http://arxiv.org/abs/2408.15664)

5. [优雅地实现多头自注意力——使用einsum(爱因斯坦求和)进行矩阵运算](https://www.cnblogs.com/qftie/p/16245124.html)

6. [DeepSeek-V3 MLA优化全攻略:从低秩压缩到权重吸收,揭秘高性能推理的优化之道](https://zhuanlan.zhihu.com/p/25449691772)

7. [全网最细!DeepSeekMTPToken预测:从算法原理到代码实现](https://www.bilibili.com/video/BV1QEwReKEHg/?spm_id_from=333.1391.0.0&vd_source=8d0e80baab699baab100ac9fdf2c4028)

8. [deepseek技术解读(2)-MTPMulti-Token Prediction)的前世今生](https://zhuanlan.zhihu.com/p/18056041194)




编辑:黄继彦





欢迎大家扫码加入粉丝群




图片


欢迎在评论区留言与本文作者互动交流!



作者简介

王坤擎,国防科技大学智能科学学院| 控制科学与工程(认知科学创新实验室) 硕士。对大模型和智能交互技术抱有浓厚兴趣和持续学习的热情,致力于实现更高效、自然的智能交互喜欢深入研究探索新鲜技术,记录并分享收获与心得。

数据派研究部介绍




数据派研究部成立于2017年初,以兴趣为核心划分多个组别,各组既遵循研究部整体的知识分享实践项目规划,又各具特色:


算法模型组:积极组队参加kaggle等比赛,原创手把手教系列文章;

调研分析组:通过专访等方式调研大数据的应用,探索数据产品之美;

系统平台组:追踪大数据&人工智能系统平台技术前沿,对话专家;

自然语言处理组:重于实践,积极参加比赛及策划各类文本分析项目;

制造业大数据组:秉工业强国之梦,产学研政结合,挖掘数据价值;

数据可视化组:将信息与艺术融合,探索数据之美,学用可视化讲故事;

网络爬虫组:爬取网络信息,配合其他各组开发创意项目。


点击文末“阅读原文”,报名数据派研究部志愿者,总有一组适合你~



转载须知


如需转载,请在开篇显著位置注明作者和出处(转自:数据派THUID:DatapiTHU),并在文章结尾放置数据派醒目二维码。有原创标识文章,请发送【文章名称-待授权公众号名称及ID】至联系邮箱,申请白名单授权并按要求编辑。

未经许可的转载以及改编者,我们将依法追究其法律责任。




关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

点击“阅读原文”拥抱组织