A10单卡24G复现DeepSeek R1强化学习模型训练过程

使用A10单卡24G显存复现DeepSeek R1强化学习模型训练,分析训练过程和奖励函数,验证强化学习效果。

原文标题:使用A10单卡24G复现DeepSeek R1强化学习过程

原文作者:阿里云开发者

冷月清谈:

本文介绍了如何使用单张A10 24G显存的GPU复现DeepSeek R1强化学习模型的训练过程。DeepSeek 包含三个主要模型:DeepSeek-R1-Zero、DeepSeek-R1 和 DeepSeek-R1-Distill*。其中,DeepSeek-R1-Zero 是在 DeepSeek-V3 基础上进行强化学习得到的,是 DeepSeek 的核心模型之一。

复现过程基于Qwen2.5-0.5B-Instruct 模型,采用了 Hugging Face 提供的 GRPO 训练方案和 gsm8k 数据集。文中详细描述了环境配置、依赖安装、训练代码以及五个关键的奖励函数:correctness_reward_func(正确性奖励)、int_reward_func(整数检测奖励)、strict_format_reward_func(严格格式奖励)、soft_format_reward_func(宽松格式奖励)和 xmlcount_reward_func(XML 结构评分)。

训练过程中使用了梯度累积、限制完成长度、定期保存检查点等策略。通过分析训练日志,特别是准确性奖励和格式奖励的变化趋势,可以了解模型的学习过程。最后,通过对比微调前后模型的推理结果,验证了强化学习的有效性。实验结果表明,即使是小模型和少量数据,也能通过此流程体验强化学习,并理解 DeepSeek 的方案设计思路,例如冷启动问题的处理。

怜星夜思:

1、文章中提到了DeepSeek-R1为了解决冷启动问题使用了SFT数据,那么除了SFT还有什么其他方法可以用来解决强化学习中的冷启动问题呢?
2、文章使用了gsm8k数据集,如果换成其他的数学推理数据集,例如MAWPS、MathQA等,训练效果会有怎样的变化?
3、文章中提到了DeepSeek模型的参数量高达6710亿,如此大的模型在实际应用中会面临哪些挑战?

原文内容

阿里妹导读


本文描述DeepSeek的三个模型的学习过程,其中DeepSeek-R1-Zero模型所涉及的强化学习算法,是DeepSeek最核心的部分之一会重点展示。

一、背景

随着DeepSeek的火爆使用,其背后的训练技术也值得深入学习,整体DeepSeek相关的训练过程如下图所示。





其中主要涉及以下三个模型,其中DeepSeek-R1-Zero模型所涉及的强化学习算法,是DeepSeek最核心的部分之一,本次我们主要重现的也是这个部分。


1. DeepSeek-R1-Zero

是在基础模型DeepSeek-V3上进行强化学习(RL)后得到了DeepSeek-R1-Zero模型。该模型学会了如何推理、创建思维链序列,并具备自我验证和反思等能力。尽管DeepSeek-R1-Zero的学习能力令人惊叹,但它存在语言混合、可读性差等严重问题。

2. DeepSeek-R1

首先使用数千个思维链(CoT)序列示例形式的冷启动数据,在DeepSeek-V3上进行监督微调(SFT),目的是为强化学习创建一个更稳定的起点,解决DeepSeek-R1-Zero存在的问题。接着进行强化学习,并设置奖励机制,以促进语言一致性,增强在科学、编码和数学等任务上的推理能力。然后,再次进行监督微调,这次加入了非推理重点的训练示例,帮助模型保留写作、角色扮演等更多通用能力。最后,再次进行强化学习,以更好地符合人类偏好。最终得到了一个拥有6710亿参数的高性能模型。

3. DeepSeek-R1-Distill*

他们基于Qwen和Llama架构,对参数在15亿 - 700亿之间的较小模型进行微调,得到了一组更轻量、更高效且推理能力更强的模型。这极大地提高了开发人员的可及性,因为许多提炼后的模型可以在他们的设备上快速运行。

二、方案

1. 环境信息

强化学习(TRL):主要采用了huggingface提供的grpo_trainer方案(参考链接:https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb)

数据集:主要通过数据集gsm8k进行训练

GPU: 单张A10,显存24G

模型:Qwen2.5-0.5B-Instruct

2. 依赖安装

# 基于目前最新的vllm 0.7.2进行验证
pip install vllm -U

基于目前最新的trl 0.15.1进行验证

pip install trl -U

3. 训练

import re

import torch
from modelscope import AutoTokenizer, AutoModelForCausalLM
from modelscope.msdatasets import MsDataset
from trl import GRPOConfig, GRPOTrainer
SYSTEM_PROMPT = “”"
You need to answer in XML format, include <reasoning> and <answer>, respond in the following format:
<reasoning>

</reasoning>
<answer>

</answer>
“”"
XML_COT_FORMAT = “”"
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
“”"
def extract_xml_answer(text: str) -> str:
   answer = text.split(“<answer>”)[-1]
   answer = answer.split(“</answer>”)[0]
   return answer.strip()
def extract_hash_answer(text: str) -> str | None:
   if “####” not in text:
       return None
   return text.split(“####”)[1].strip()
def get_gsm8k_questions(split=“train”) -> MsDataset:
   data = MsDataset.load(‘modelscope/gsm8k’, subset_name=‘main’, split=split)
   data = data.map(lambda x: {
       ‘prompt’: [
           {‘role’: ‘system’, ‘content’: SYSTEM_PROMPT},
           {‘role’: ‘user’, ‘content’: x[‘question’]}
       ],
       ‘answer’: extract_hash_answer(x[‘answer’])
   })
   return data
dataset = get_gsm8k_questions()

Reward functions

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
   responses = [completion[0][‘content’] for completion in completions]
   q = prompts[0][-1][‘content’]
   extracted_responses = [extract_xml_answer(r) for r in responses]
   print(‘-’ * 20, f"Question:\n{q}“, f”\nAnswer:\n{answer[0]}“, f”\nResponse:\n{responses[0]}“,
         f”\nExtracted:\n{extracted_responses[0]}")
   return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
   responses = [completion[0][‘content’] for completion in completions]
   extracted_responses = [extract_xml_answer(r) for r in responses]
   return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:

    pattern = r"\n<reasoning>\n.?\n</reasoning>\n<answer>\n.?\n</answer>\n$"

    responses = [completion[0][“content”] for completion in completions]

    matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]

    return [0.5 if match else 0.0 for match in matches]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
   pattern = r"<reasoning>\n.?\n</reasoning>\n<answer>\n.?\n</answer>"
   responses = [completion[0][“content”] for completion in completions]
   # 新增调试日志
   matches =
   for idx, r in enumerate(responses):
       print(f"\n— Processing response {idx} —“)
       print(“Raw content:”, repr(r))  # 使用 repr() 显示转义字符
       match = re.fullmatch(pattern, r, re.DOTALL)
       print(“Match result:”, bool(match))
       matches.append(match)
   return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
   pattern = r”<reasoning>.?</reasoning>\s<answer>.*?</answer>"
   responses = [completion[0][“content”] for completion in completions]
   matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]
   return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
   count = 0.0
   if text.count(“<reasoning>\n”) == 1:
       count += 0.125
   if text.count(“\n</reasoning>\n”) == 1:
       count += 0.125
   if text.count(“\n<answer>\n”) == 1:
       count += 0.125
       count -= len(text.split(“\n</answer>\n”)[-1]) * 0.001
   if text.count(“\n</answer>”) == 1:
       count += 0.125
       count -= (len(text.split(“\n</answer>”)[-1]) - 1) * 0.001
   return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
   contents = [completion[0][“content”] for completion in completions]
   return [count_xml(c) for c in contents]
model_name = “Qwen/Qwen2.5-0.5B-Instruct”
output_dir = “outputs/Qwen-0.5B-GRPO”
run_name = “Qwen-0.5B-GRPO-gsm8k”
training_args = GRPOConfig(
   output_dir=output_dir,
   run_name=run_name,
   learning_rate=5e-6,
   adam_beta1=0.9,
   adam_beta2=0.99,
   weight_decay=0.1,
   warmup_ratio=0.1,
   lr_scheduler_type=‘cosine’,
   logging_steps=1,
   bf16=True,
   per_device_train_batch_size=8,
   gradient_accumulation_steps=4,
   num_generations=8,
   max_prompt_length=256,
   max_completion_length=200,
   num_train_epochs=1,
   save_steps=100,
   max_grad_norm=0.1,
   log_on_each_node=False,
   use_vllm=True,
   vllm_gpu_memory_utilization=.2,
   vllm_device=“cuda:0”,
   report_to=“none”
)
model = AutoModelForCausalLM.from_pretrained(
   model_name,
   torch_dtype=torch.bfloat16,
   device_map=None
).to(“cuda”)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
trainer = GRPOTrainer(
   model=model,
   processing_class=tokenizer,
   reward_funcs=[
       xmlcount_reward_func,
       soft_format_reward_func,
       strict_format_reward_func,
       int_reward_func,
       correctness_reward_func],
   args=training_args,
   train_dataset=dataset,
)
trainer.train()

4. reward_funcs(奖励函数)

如上面代码所示,主要涉及以下5个奖励函数

4.1.  correctness_reward_func(正确性奖励函数)

检查模型的输出是否与参考答案 (answer) 完全匹配,匹配则奖励 2.0,否则 0.0。

4.2. int_reward_func(整数检测奖励函数)

检查模型输出是否是纯数字(整数),是则奖励 0.5,否则 0.0。

4.3. strict_format_reward_func(严格格式奖励函数)

严格格式奖励,必须完全匹配 <reasoning>...</reasoning><answer>...</answer>,包括其中的换行符,都必须满足格式,如果符合格式的奖励 0.5,否则 0.0

4.4. soft_format_reward_func(宽松格式奖励函数)

允许更灵活的格式,只要包含 <reasoning>...</reasoning><answer>...</answer>,即奖励 0.5,对比严格模式更加宽松

4.5. count_xml,xmlcount_reward_func(XML 结构评分函数)

计算模型输出 XML 结构的完整度,并给予相应奖励。奖励规则:

检查 XML 结构完整度:

每个正确的标签匹配增加 0.125 奖励:

<reasoning>\\n:+0.125

</reasoning>\\n:+0.125

<answer>\\n:+0.125

</answer>:+0.125

考虑额外文本的惩罚:

如果 </answer> 后面有多余的内容,则减少奖励 0.001 × 额外字符数

5. 训练参数

核心参数说明如下:

1.gradient_accumulation_steps=4:每进行4次的前向传播和反向传播后,才会执行一次权重更新;

2.max_completion_length=200: 表示限制模型返回最大长度200;

3.save_steps=100:表示每运行100步才保存一次checkpoint;

gsm8k数据集一共接近8000条数据,每4次会更新一次,则需要更新2000次,每100步保存一次,则需要生成20个checkpoint。

三、过程日志分析

1. 日志分析

通过python train.py > train.log运行代码,通过tail -f train.log进行实时日志查看,最后整体效果如下图所示,最后有效数据1868个,运行时间是2:25:25。





2. 训练数据分析

GRPO Trainer会记录很多训练过程中的指标,主要包括在:

  • completion_length:完成时长;
  • reward/{reward_func_name}:每个 reward 函数计算的奖励;
  • reward:平均奖励;
  • reward_std :奖励组内的平均标准差
  • kl : 根据完成次数计算的模型和参考模型之间的平均 KL 散度。

其中我们主要关注以下两个奖励指标:

  • 准确性奖励:基于响应的正确性(对应correctness_reward_func)
  • 格式奖励:确保响应符合结构指南(对应strict_format_reward_func和soft_format_reward_func)

2.1. 准确性奖励





2.2. 格式奖励









四、推理验证

1. 微调前的模型

格式和答案都不对,而且不稳定:





2. 微调后的模型

格式和答案都满足要求:





五、思考

通过对比微调前后的模型,虽然我们这次使用的是一个0.5B的小模型,数据量也不大,但是还是可以通过这个流程,体验强化学习的整个流程,对我们理解强化学习还是很有好处的。并且从整个实验中,也理解了DeepSeek整个方案设计的原因,其中以下几个点印象深刻。

1. 训练数据分析

通过对训练后的奖励函数数据进行分析发现,其中模型的格式奖励函数strict_format_reward_func和soft_format_reward_func,都是在训练到固定步数左右的时候,得分开始突然上升,然后后续就逐渐稳定,如下图所示。可以看到,宽松校验在500步的时候已经基本稳定到0.5的分数,而由于严格模式对格式更加严格,所以严格模式在1000步的时候才到稳定。通过这样的数据,可以指导我们下一步进行实验数据调整,从而获取最佳的checkponit模型进行导出。





2. 冷启动的问题

我们可以看到模型在早期训练的时候,效果很差,模型基本都是在瞎试。所以为了加快训练,deepseek加入了SFT的数据解决冷启动的问题,如下面的截图所示。通过R1-Zero生成SFT的数据,解决了R1的冷启动问题。






Lindorm泛时序数据一站式解决方案


随着业务增长带来的数据量激增,如何高效地获取和分析这些数据成为业务洞察和决策的关键挑战,Lindorm作为阿里云自研的云原生多模数据库,具备低成本存储、弹性高可用的能力,提供一站式的分析与洞察。    


点击阅读原文查看详情。


我觉得换数据集的话,模型的泛化能力可能会受到影响。gsm8k 的数据分布和 MAWPS、MathQA 肯定不一样,如果只在一个数据集上训练,模型可能在另一个数据集上表现不佳。

关于“文章中提到了DeepSeek模型的参数量高达6710亿,如此大的模型在实际应用中会面临哪些挑战?”这个问题,我想说,首先就是计算资源的消耗。这么大的模型需要大量的计算资源来进行训练和推理,这对于普通用户来说是一个很大的门槛。

这个问题问得好!除了SFT,模仿学习也是一种不错的选择。通过模仿专家策略的 demonstrations,可以让 agent 初期就获得一个比较合理的策略,避免一开始就胡乱探索。

关于冷启动,我觉得引入一些先验知识也是很有帮助的。比如,可以根据任务的特性设计一些规则,或者利用一些已有的知识库来指导 agent 的行为。这样可以有效地减少 agent 早期探索的盲目性。

对于“文章使用了gsm8k数据集,如果换成其他的数学推理数据集,例如MAWPS、MathQA等,训练效果会有怎样的变化?”这个问题,我个人认为,不同的数据集具有不同的特点和难度,因此训练效果可能会有差异。例如,MAWPS 数据集侧重于简单的算术运算,而 MathQA 则包含更复杂的数学问题。因此,如果使用不同的数据集,模型的性能可能会发生变化。

针对“文章中提到了DeepSeek-R1为了解决冷启动问题使用了SFT数据,那么除了SFT还有什么其他方法可以用来解决强化学习中的冷启动问题呢?”这个问题,我想补充一点,除了SFT,预训练也是一种常见的冷启动解决方案。通过在大量无标签数据上进行预训练,模型可以学习到一定的语言模式和知识,从而在后续的强化学习过程中更快地达到较好的性能。

除了计算资源,还有推理速度的问题。大模型的推理速度通常比较慢,这在一些实时性要求较高的应用场景中是不可接受的。

此外,大模型的部署和维护也是一个挑战。如何将如此大的模型部署到生产环境中,并保证其稳定性和可靠性,是一个需要认真考虑的问题。

如果要换数据集,可能需要对模型的结构和参数进行调整,才能达到最佳的训练效果。不同数据集的规模、难度、数据分布等都有差异,需要针对具体情况进行优化。