【复现DeepSeek-R1之OpenR1实战】系列3:跑通GRPO!
- 电脑硬件
- 2025-09-06 06:09:01

目录 1 配置环境2 训练2.1 命令和配置参数2.2 num_generations2.2.1 参数定义2.2.2 参数含义2.2.3 示例2.2.4 使用场景2.2.5 示例代码 2.3 显存占用和耗时 1 配置环境
关于环境配置,可以参考这篇博文:【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)
关于flash-attention依赖库的安装问题,运行以下命令,等待一小时左右,依赖库就安装成功了:
pip install flash-attn --no-cache-dir 2 训练 2.1 命令和配置参数训练的命令如下,和SFT差不多:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \ --num_processes=7 src/open_r1/grpo.py \ --config /nfs/ofs-902-1/fusion/zhongyudong/open-r1/recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml我们需要修改config配置文件:recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml,主要是将和Huggingface的链接关掉,修改模型路径、数据集路径、GPU个数(num_processes=GPU个数-1,因为vLLM使用了一张卡)。
在训练过程中,我发现torch的DDP不稳定,容易接收不到Worker的信号导致训练失败,所以保存策略改成了每步都保存(save_strategy: “steps”)。
完整的配置如下:
# Model arguments model_name_or_path: /nfs/ofs-902-1/pnc/huggingface_hub/Qwen/Qwen2.5-1.5B-Instruct # model_revision: main torch_dtype: bfloat16 attn_implementation: flash_attention_2 # Data training arguments dataset_name: /nfs/ofs-902-1/fusion/zhongyudong/open-r1/datas/NuminaMath-TIR/data dataset_configs: - all # Num processes is less by 1 as vLLM is using 1 GPU num_processes: 7 # GRPO trainer config bf16: true use_vllm: true vllm_device: auto vllm_gpu_memory_utilization: 0.7 do_eval: true eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 16 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false # hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO # hub_strategy: every_save learning_rate: 2.0e-05 log_level: info logging_steps: 5 logging_strategy: steps lr_scheduler_type: cosine max_prompt_length: 512 max_completion_length: 1024 max_steps: -1 num_generations: 7 num_train_epochs: 1 output_dir: /nfs/ofs-902-1/fusion/zhongyudong/open-r1/outputs/Qwen2.5-1.5B-Open-R1-GRPO overwrite_output_dir: true per_device_eval_batch_size: 32 per_device_train_batch_size: 16 push_to_hub: false # report_to: # - wandb save_strategy: "steps" seed: 42 warmup_ratio: 0.1重点解释一下num_generations这个参数,主要是控制每个提示(Prompt)生成的样本数量。
2.2 num_generations参数 num_generations 用于指定每个提示(prompt)生成的样本数量。这个参数在生成模型中非常常见,特别是在文本生成、对话系统或其他需要从模型中采样多个输出的任务中。以下是对该参数及其使用场景的详细解释:
2.2.1 参数定义 num_generations (`int` or `None`, *optional*, defaults to `8`): Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value. 类型: 可以是整数(int)或 None。默认值: 默认为 8。可选性: 是一个可选参数。 2.2.2 参数含义生成数量:
num_generations 指定了对于每一个输入的提示(prompt),模型将生成多少个不同的输出样本。例如,如果你设置 num_generations=3,那么对于每一个输入提示,模型会生成3个不同的输出。全局批处理大小的约束:
全局批处理大小(global batch size)是指所有进程和设备上批处理大小的总和,通常计算为 num_processes * per_device_batch_size。这个全局批处理大小必须能够被 num_generations 整除。也就是说,global_batch_size % num_generations == 0 必须成立。这个约束确保了在分布式训练或多设备环境中,每个设备上的生成任务可以均匀分配。 2.2.3 示例假设你有以下配置:
per_device_batch_size = 4num_processes = 2(即你在使用两个GPU或其他并行计算单元)num_generations = 8在这种情况下:
全局批处理大小为 global_batch_size = num_processes * per_device_batch_size = 2 * 4 = 8因为 global_batch_size 等于 num_generations,所以条件满足。如果我们将 num_generations 改为 6,则:
全局批处理大小仍然是 8,但 8 % 6 != 0,这会导致错误,因为无法均匀分配生成任务。 2.2.4 使用场景 文本生成在文本生成任务中,你可能希望从一个提示生成多个不同的输出,以便选择最好的结果或者展示多样性。例如,在故事生成或对话系统中,生成多个候选答案可以让用户有更多的选择。
对话系统在对话系统中,生成多个回复可以帮助系统提供更丰富的互动体验。通过生成多个回复,系统可以选择最合适的回答,或者让用户选择他们喜欢的回答。
多模态生成在多模态生成任务(如图像字幕生成、视频描述等)中,生成多个输出可以提高生成内容的多样性和准确性。
2.2.5 示例代码以下是一个简单的示例,展示了如何使用 num_generations 参数:
from transformers import AutoModelForCausalLM, AutoTokenizer # 加载预训练模型和分词器 model_name = "your-pretrained-model" model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # 定义输入提示 prompt = "Once upon a time" # 将提示编码为模型输入格式 input_ids = tokenizer(prompt, return_tensors="pt").input_ids # 设置生成参数 num_generations = 5 # 每个提示生成5个样本 # 生成多个样本 outputs = model.generate( input_ids, num_return_sequences=num_generations, # 设置num_generations max_length=50, do_sample=True ) # 解码生成的样本 for i, output in enumerate(outputs): print(f"Generated text {i+1}:") print(tokenizer.decode(output, skip_special_tokens=True)) print()在这个例子中,num_return_sequences 参数对应于 num_generations,它指定了要生成的样本数量。
2.3 显存占用和耗时8卡H20,每张卡占用40~50G。
要跑将15个多小时。
【复现DeepSeek-R1之OpenR1实战】系列3:跑通GRPO!由讯客互联电脑硬件栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“【复现DeepSeek-R1之OpenR1实战】系列3:跑通GRPO!”