From 18347701f213f269ae86678c40f178f5386a90ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 28 Oct 2024 10:58:24 +0000 Subject: [PATCH] truncation left for reward tokenizer --- examples/scripts/dpo_online.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 8164eda6e7..58d7c4a2c0 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -96,6 +96,8 @@ reward_tokenizer = AutoTokenizer.from_pretrained( training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, + truncation=True, + truncation_side="left", # since we judge the completion, truncating left is more appropriate ) else: reward_model = None @@ -131,11 +133,14 @@ reward_processing_class=reward_tokenizer, peft_config=get_peft_config(model_config), ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature - ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() # Save and push to hub