trl+DPO算法
- 开源代码
- 2025-09-09 10:24:01

一、定义
1.数据集格式 2.损失函数 3. 模型训练demo 4. 模型加载与合并
二、实现1.数据集格式 需要的字段:prompt 、chosn、rejected 对应trl 数据处理方法
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") def _prepare_dataset( self, dataset: Union[Dataset, IterableDataset], processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], args: DPOConfig, dataset_name: str, ) -> Union[Dataset, IterableDataset]: # Build the kwargs for the `map` function map_kwargs = {"writer_batch_size": 10} if isinstance(dataset, Dataset): # IterableDataset does not support num_proc map_kwargs["num_proc"] = args.dataset_num_proc with PartialState().local_main_process_first(): # Extract prompt if needed if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" dataset = dataset.map(maybe_extract_prompt, **map_kwargs) # Apply the chat template if needed if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" dataset = dataset.map( maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs ) # Tokenize the dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" dataset = dataset.map( self.tokenize_row if not self.is_vision_model else self.process_row, remove_columns=["prompt", "chosen", "rejected"], fn_kwargs={ "processing_class": processing_class, "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) "add_special_tokens": False, }, **map_kwargs, ) return dataset @staticmethod def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): """ Tokenize a row of the dataset. Args: features (`dict[str, str]`): Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. Returns: `dict[str, list[int]]`: Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and `"rejected_input_ids". Example: ```python >>> from transformers import GPT2Tokenizer >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} >>> DPOTrainer.tokenize_row( ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False ... ) {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} ``` """ tokenizer = processing_class # the processing class is a tokenizer prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] # Add special tokens (typically for encoder-decoder models) if add_special_tokens: if tokenizer.bos_token_id is not None: prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids if tokenizer.eos_token_id is not None: prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] # Truncate prompt and completion sequences if max_prompt_length is not None: prompt_input_ids = prompt_input_ids[-max_prompt_length:] if max_completion_length is not None: chosen_input_ids = chosen_input_ids[:max_completion_length] rejected_input_ids = rejected_input_ids[:max_completion_length] return { "prompt_input_ids": prompt_input_ids, "chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids, } 损失函数 def get_batch_loss_metrics( self, model, batch: dict[str, Union[list, torch.LongTensor]], train_eval: Literal["train", "eval"] = "train", ): """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} #模型输出 model_output = self.concatenated_forward(model, batch) # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: ref_chosen_logps = batch["ref_chosen_logps"] ref_rejected_logps = batch["ref_rejected_logps"] else: #参考模型仍然是自身 ref_chosen_logps, ref_rejected_logps = self pute_ref_log_probs(batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps ) 其中损失函数 通过模型输出chosen_logps、rejected_logps、ref_chosen_logps, ref_rejected_logps 求损失。 其中通过模型生成多条输出,进行取样活动。 output["chosen_logps"] = all_logps[:num_examples] output["rejected_logps"] = all_logps[num_examples:] def dpo_loss( self, chosen_logps: torch.FloatTensor, rejected_logps: torch.FloatTensor, ref_chosen_logps: torch.FloatTensor, ref_rejected_logps: torch.FloatTensor, ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) logratios = chosen_logps - rejected_logps ref_logratios = ref_chosen_logps - ref_rejected_logps logits = logratios - ref_logratios losses = ( -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) loss 反向传播进行优化。3.训练demo
python test.py --model_name_or_path E:\Qwen2.5-0.5B-Instruct --dataset_name E:\trl-libultrafeedback_binarized --learning_rate 5.0e-6 \ --num_train_epochs 1 --output_dir Qwen2-0.5B-DPO --per_device_train_batch_size 2 --gradient_accumulation_steps 2 --gradient_checkpointing 1 \ --num_train_epochs 1 --logging_steps 2 --use_peft 1 --lora_r 32 --lora_alpha 16参数含义: rewards/chosen: 与参考模型相比,chosen 差别 rewards/rejected 与参考模型相比,rejected 差别 rewards/accuracies :正向准确率 (chosen_rewards > rejected_rewards) 代码:
# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http:// .apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ # Full training python trl/scripts/dpo.py \ --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --learning_rate 5.0e-7 \ --num_train_epochs 1 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 8 \ --gradient_checkpointing \ --logging_steps 25 \ --eval_strategy steps \ --eval_steps 50 \ --output_dir Qwen2-0.5B-DPO \ --no_remove_unused_columns # LoRA: python trl/scripts/dpo.py \ --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --learning_rate 5.0e-6 \ --num_train_epochs 1 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 8 \ --gradient_checkpointing \ --logging_steps 25 \ --eval_strategy steps \ --eval_steps 50 \ --output_dir Qwen2-0.5B-DPO \ --no_remove_unused_columns \ --use_peft \ --lora_r 32 \ --lora_alpha 16 """ import argparse import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import ( DPOConfig, DPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, get_quantization_config, ) from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE def main(script_args, training_args, model_args): ################ # Model & Tokenizer ################### torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) peft_config = get_peft_config(model_args) if peft_config is None: ref_model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) else: ref_model = None tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE if script_args.ignore_bias_buffers: # torch distributed hack model._ddp_params_and_buffers_to_ignore = [ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool ] ################ # Dataset ################ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ########## # Training ################ trainer = DPOTrainer( model, ref_model, args=training_args, train_dataset=dataset[script_args.dataset_train_split].select(range(20)), eval_dataset=dataset[script_args.dataset_test_split].select(range(20)) if training_args.eval_strategy != "no" else None, processing_class=tokenizer, peft_config=peft_config, ) trainer.train() if training_args.eval_strategy != "no": metrics = trainer.evaluate() trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) def make_parser(subparsers: argparse._SubParsersAction = None): dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) if subparsers is not None: parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) else: parser = TrlParser(dataclass_types) return parser if __name__ == "__main__": parser = make_parser() script_args, training_args, model_args = parser.parse_args_and_config() main(script_args, training_args, model_args)4.模型加载与合并
from transformers import AutoTokenizer model = AutoPeftModelForCausalLM.from_pretrained( "Qwen2-0.5B-DPO", # YOUR MODEL YOU USED FOR TRAINING torch_dtype="auto", device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("Qwen2-0.5B-DPO") model = model.merge_and_unload() model.save_pretrained("merged-model") tokenizer.save_pretrained("merged-model")上一篇
JUC并发-4.wait和notify以及Atomic原理
下一篇
宏任务和微任务