主页 > 手机  > 

【复现DeepSeek-R1之OpenR1实战】系列6:GRPO源码逐行深度解析(上)

【复现DeepSeek-R1之OpenR1实战】系列6:GRPO源码逐行深度解析(上)

目录 4 GRPO源码分析4.1 数据类 `GRPOScriptArguments`4.2 系统提示字符串 `SYSTEM_PROMPT`4.3 奖励函数4.3.1 accuracy_reward函数4.3.2 verify函数4.3.3 format_reward函数 4.4 将数据集格式化为对话形式4.5 初始化GRPO Trainer


【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上) 【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)

4 GRPO源码分析

前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。

4.1 数据类 GRPOScriptArguments

该类使用了 Python 的 dataclass 装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments 类。

reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为 ["accuracy", "format"]。这些奖励函数可能是用于评估模型性能的不同标准。

reward_funcs: list[str] = field( default_factory=lambda: ["accuracy", "format"], metadata={ "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'" }, )

cosine_min_value_wrong 和 cosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.0 和 -0.5。

cosine_min_value_correct 和 cosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.5 和 1.0。

cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为 1000。

repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为 3。

repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为 -1.0。

每个字段都使用了 field() 函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。

4.2 系统提示字符串 SYSTEM_PROMPT SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., " "<think> reasoning process here </think><answer> answer here </answer>" )

字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think> 和 <answer> 标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。

4.3 奖励函数

奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。

# Get reward functions REWARD_FUNCS_REGISTRY = { "accuracy": accuracy_reward, "format": format_reward, "reasoning_steps": reasoning_steps_reward, "cosine": get_cosine_scaled_reward( min_value_wrong=script_args.cosine_min_value_wrong, max_value_wrong=script_args.cosine_max_value_wrong, min_value_correct=script_args.cosine_min_value_correct, max_value_correct=script_args.cosine_max_value_correct, max_len=script_args.cosine_max_len, ), "repetition_penalty": get_repetition_penalty_reward( ngram_size=script_args.repetition_n_grams, max_penalty=script_args.repetition_max_penalty, ), "length": len_reward, } reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。

注册表定义 accuracy: 使用 accuracy_reward 函数评估模型输出的准确性。format: 使用 format_reward 函数评估模型输出的格式。reasoning_steps: 使用 reasoning_steps_reward 函数评估模型输出的推理步骤。cosine: 使用 get_cosine_scaled_reward 函数计算余弦相似度奖励,参数包括: min_value_wrong: 错误情况下的最小值。max_value_wrong: 错误情况下的最大值。min_value_correct: 正确情况下的最小值。max_value_correct: 正确情况下的最大值。max_len: 最大长度。 repetition_penalty: 使用 get_repetition_penalty_reward 函数计算重复惩罚奖励,参数包括: ngram_size: n-gram 的大小。max_penalty: 最大惩罚值。 length: 使用 len_reward 函数评估模型输出的长度。 动态生成奖励函数列表 reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs] 根据 script_args.reward_funcs 中指定的奖励函数名称,从 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数,并生成一个列表 reward_funcs。 4.3.1 accuracy_reward函数

该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。

def accuracy_reward(completions, solution, **kwargs): """Reward function that checks if the completion is the same as the ground truth.""" contents = [completion[0]["content"] for completion in completions] rewards = [] for content, sol in zip(contents, solution): gold_parsed = parse( sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) answer_parsed = parse( content, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, equations=True, boxed="all", units=True, ), # Ensures that boxed is tried first boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode="first_match", ) # Reward 1 if the content is the same as the ground truth, 0 otherwise reward = float(verify(answer_parsed, gold_parsed)) else: # If the gold solution is not parseable, we reward 1 to skip this example reward = 1.0 print("Failed to parse gold solution: ", sol) rewards.append(reward) return rewards completions (list): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。solution (list): 真实答案的列表。kwargs: 其他可选参数(在本函数中未使用)。

提取补全内容

contents = [completion[0]["content"] for completion in completions] 从 completions 列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的 contents 列表。

初始化奖励列表

rewards = []

遍历每个补全和对应的真实答案

for content, sol in zip(contents, solution): gold_parsed = parse( sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) 使用 zip 函数将 contents 和 solution 配对。对于每一对补全内容和真实答案,首先解析真实答案 sol,使用 parse 函数提取其中的信息。

处理解析结果

if len(gold_parsed) != 0: answer_parsed = parse( content, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, equations=True, boxed="all", units=True, ), # Ensures that boxed is tried first boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode="first_match", ) 如果解析得到的真实答案 gold_parsed 非空,则继续解析生成的补全内容 content。使用 LatexExtractionConfig 和 NormalizationConfig 进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。

计算奖励

reward = float(verify(answer_parsed, gold_parsed)) 使用 verify 函数比较生成的补全解析结果和真实答案的解析结果。如果两者匹配,则返回 1.0,否则返回 0.0。

处理无法解析的情况

else: reward = 1.0 print("Failed to parse gold solution: ", sol) 如果真实答案无法解析,则默认给予奖励 1.0 并打印警告信息。

添加奖励到列表

rewards.append(reward)

返回所有奖励

return rewards 4.3.2 verify函数

该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。

def verify( gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, float_rounding: int=6, numeric_precision: int=15, strict: bool=True, timeout_seconds: int=3 ) -> bool: gold: 参考或正确的表达式,可以是单个 SymPy 表达式(Basic 或 MatrixBase)、字符串或这些类型的列表。target: 需要验证的表达式,类型同 gold。float_rounding: 浮点数舍入的小数位数,默认为 6。numeric_precision: 数值比较时考虑的小数位数,默认为 15。strict: 是否启用严格比较模式,默认为 True。 在严格模式下:变量很重要,集合不可与元组比较。在非严格模式下:变量按位置匹配,集合可与元组比较。 timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。

定义内部比较函数 compare_single_extraction

@timeout(timeout_seconds=timeout_seconds) def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool: ... 使用装饰器 @timeout 设置超时保护,默认超时时间为 timeout_seconds。比较两个表达式: 如果两者都是 SymPy 表达式(Basic 或 MatrixBase),则调用 sympy_expr_eq 进行比较。如果两者都是字符串,则进行简单的字符串比较。

定义包装函数 compare_single_extraction_wrapper

def compare_single_extraction_wrapper(g, t): try: return compare_single_extraction(g, t) except Exception as e: logger.exception(f"Error comparing {g} and {t}") return False 包装 compare_single_extraction,捕获并记录任何异常,返回 False 以避免程序中断。

处理输入列表

if not isinstance(gold, list): gold = [gold] if not isinstance(target, list): target = [target] 如果 gold 或 target 不是列表,则将其转换为单元素列表,以便统一处理。

组合所有可能的比较

return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target)) 使用 itertools.product 生成所有可能的 gold 和 target 组合。对每个组合调用 compare_single_extraction_wrapper,如果任意一对匹配成功,则返回 True。 4.3.3 format_reward函数

函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think> 和 <answer> 标签,并且这两个标签的内容是非空的。

def format_reward(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] return [1.0 if match else 0.0 for match in matches] completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键 "content",其值是需要检查的文本。kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。

正则表达式模式定义:

pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$" 这个正则表达式用于匹配字符串是否以 <think> 开始,紧接着是任意字符(非贪婪匹配),然后是 </think>,接着可能有任意数量的空白字符(包括换行符),最后是以 <answer> 开始并以 </answer> 结束。.*? 是非贪婪匹配,确保尽可能少地匹配字符。\s* 匹配零个或多个空白字符(包括换行符)。re.DOTALL | re.MULTILINE 标志允许点号 . 匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。

提取完成内容:

completion_contents = [completion[0]["content"] for completion in completions] 这里通过列表推导式从 completions 列表中提取每个完成对象的第一个元素的 "content" 字段,形成一个新的列表 completion_contents。

匹配正则表达式:

matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] 使用 re.match 函数对 completion_contents 中的每个内容应用正则表达式模式。matches 列表将包含 re.Match 对象(如果匹配成功)或 None(如果匹配失败)。

生成奖励分数:

return [1.0 if match else 0.0 for match in matches] 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即 match 不是 None),则返回 1.0;否则返回 0.0。

示例代码:

completions = [ [{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}], [{"content": "<think>This is reasoning.</think>"}], [{"content": "<answer>This is answer.</answer>"}], [{"content": "This does not match."}] ] reward_scores = format_reward(completions) print(reward_scores) # 输出: [1.0, 0.0, 0.0, 0.0]

在这个例子中:

第一个完成内容完全匹配正则表达式,因此得分为 1.0。后三个完成内容不符合要求,因此得分均为 0.0。 4.4 将数据集格式化为对话形式 # Format into conversation def make_conversation(example): return { "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": example["problem"]}, ], } dataset = dataset.map(make_conversation) for split in dataset: if "messages" in dataset[split].column_names: dataset[split] = dataset[split].remove_columns("messages")

将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages)。

输入:example 是一个字典,包含单个数据样本的信息,其中 "problem" 键对应的值是用户的问题或任务描述。输出:返回一个新的字典,包含一个 "prompt" 键,其值是一个对话列表: 第一条消息是系统消息,内容由 SYSTEM_PROMPT 定义。第二条消息是用户消息,内容是 example["problem"]。 dataset.map(make_conversation):使用 map 方法将 make_conversation 函数应用到数据集的每个示例上,生成新的对话格式。移除多余列:遍历数据集的每个拆分(split),如果存在 "messages" 列,则将其移除。 4.5 初始化GRPO Trainer trainer = GRPOTrainer( model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), callbacks=get_callbacks(training_args, model_args), )

篇幅有限,训练部分的代码我们放到下一篇博文详细解读!

标签:

【复现DeepSeek-R1之OpenR1实战】系列6:GRPO源码逐行深度解析(上)由讯客互联手机栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“【复现DeepSeek-R1之OpenR1实战】系列6:GRPO源码逐行深度解析(上)