-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use any reward model for online methods #2276
Changes from 15 commits
571b6f3
370e010
5e7a495
c4455b1
12893e2
6e8ca96
86fd762
1dc15d3
34e0eaf
9255c34
6ee647b
52808af
a2192ee
bbcc129
68bd4b2
1834770
a9d8b23
414b90b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,8 +93,13 @@ | |
trust_remote_code=model_config.trust_remote_code, | ||
**model_kwargs, | ||
) | ||
reward_tokenizer = AutoTokenizer.from_pretrained( | ||
training_args.reward_model_path, | ||
trust_remote_code=model_config.trust_remote_code, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok for truncation, and truncation side. Not sure what's the best way to let the user set the max_length. Ok for doing this in a follow-up PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to follow up in separate PR - it should generally be safe for RMs that define the max length implicitly in their config anyway |
||
) | ||
else: | ||
reward_model = None | ||
reward_tokenizer = None | ||
|
||
if training_args.judge is not None: | ||
judge_cls = JUDGES[training_args.judge] | ||
|
@@ -123,6 +128,7 @@ | |
train_dataset=dataset[script_args.dataset_train_split], | ||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, | ||
processing_class=tokenizer, | ||
reward_processing_class=reward_tokenizer, | ||
peft_config=get_peft_config(model_config), | ||
) | ||
generation_config = GenerationConfig( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,7 @@ | |
from transformers.training_args import OptimizerNames | ||
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging | ||
|
||
from ..data_utils import is_conversational, maybe_apply_chat_template | ||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template | ||
from ..models import create_reference_model | ||
from ..models.utils import unwrap_model_for_generation | ||
from .judges import BasePairwiseJudge | ||
|
@@ -137,6 +137,7 @@ def __init__( | |
processing_class: Optional[ | ||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] | ||
] = None, | ||
reward_processing_class: Optional[PreTrainedTokenizerBase] = None, | ||
peft_config: Optional[Dict] = None, | ||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, | ||
callbacks: Optional[List[TrainerCallback]] = None, | ||
|
@@ -161,6 +162,7 @@ def __init__( | |
raise ValueError("Either `reward_model` or `judge` must be provided.") | ||
|
||
self.reward_model = reward_model | ||
self.reward_processing_class = reward_processing_class | ||
self.judge = judge | ||
|
||
if args.missing_eos_penalty is not None and judge is not None: | ||
|
@@ -428,18 +430,23 @@ def training_step( | |
ref_logprobs = torch.take_along_dim(ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1) | ||
del ref_output, ref_logits, ref_all_logprobs # free memory | ||
|
||
# Get the reward from the reward model or judge: | ||
if self.judge is not None: | ||
completions = self.processing_class.batch_decode( | ||
prompt_completion_ids[:, context_length:], skip_special_tokens=True | ||
) | ||
completions = [completion.strip() for completion in completions] # remove the leading space | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need strip |
||
# Decode the completions, and format them if the input is conversational | ||
device = prompt_completion_ids.device | ||
completions_ids = prompt_completion_ids[:, context_length:] | ||
completions = self.processing_class.batch_decode(completions_ids, skip_special_tokens=True) | ||
if is_conversational({"prompt": prompts[0]}): | ||
completions = [[{"role": "assistant", "content": completion}] for completion in completions] | ||
|
||
# Get the reward from the reward model or judge | ||
if self.judge is not None: | ||
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not | ||
lewtun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge | ||
# independent of the model's chat template, we use the raw conversation data, and apply our own chat | ||
# template to it. | ||
if is_conversational({"prompt": prompts[0]}): | ||
completions = [[{"role": "assistant", "content": completion}] for completion in completions] | ||
environment = jinja2.Environment() | ||
template = environment.from_string(SIMPLE_CHAT_TEMPLATE) | ||
prompts = [template.render(messages=message) for message in prompts] | ||
prompts = [template.render(messages=prompt) for prompt in prompts] | ||
completions = [template.render(messages=completion) for completion in completions] | ||
|
||
ranks_of_first_completion = self.judge.judge( | ||
|
@@ -449,24 +456,47 @@ def training_step( | |
# convert ranks to a True/False mask: | ||
# when rank == 0, it means the first completion is the best | ||
# when rank == 1, it means the second completion is the best | ||
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=prompt_completion_ids.device) | ||
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device) | ||
else: | ||
_, scores, _ = get_reward( | ||
self.reward_model, prompt_completion_ids, self.processing_class.pad_token_id, context_length | ||
) | ||
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the | ||
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class. | ||
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we do this? Is it to align a prompt with chosen/rejected? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point, we have: prompts = ["What color is the sky?", "What's the capital of France?"]
completions = ["Blue", "Lyon", "Green", "Paris"] and later, we need to concat the prompts and the completions to compute the reward prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
_, scores, _ = get_reward(
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
) so we need to repeat the prompt. |
||
if is_conversational({"prompt": prompts[0]}): | ||
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)] | ||
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples] | ||
prompts = [example["prompt"] for example in examples] | ||
completions = [example["completion"] for example in examples] | ||
|
||
# Tokenize the prompts | ||
prompts_ids = self.reward_processing_class( | ||
prompts, padding=True, return_tensors="pt", padding_side="left" | ||
)["input_ids"].to(device) | ||
context_length = prompts_ids.shape[1] | ||
|
||
# Tokenize the completions | ||
completions_ids = self.reward_processing_class( | ||
completions, padding=True, return_tensors="pt", padding_side="right" | ||
)["input_ids"].to(device) | ||
|
||
# Concatenate the prompts and completions and get the reward | ||
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1) | ||
with torch.inference_mode(): | ||
_, scores, _ = get_reward( | ||
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length | ||
) | ||
|
||
# Filter completion. Ensure that the sample contains stop_token_id | ||
# Completions not passing that filter will receive a lower score. | ||
if self.args.missing_eos_penalty is not None: | ||
scores[~contain_eos_token] -= self.args.missing_eos_penalty | ||
# Filter completion. Ensure that the sample contains stop_token_id | ||
# Completions not passing that filter will receive a lower score. | ||
if self.args.missing_eos_penalty is not None: | ||
scores[~contain_eos_token] -= self.args.missing_eos_penalty | ||
|
||
# Split the scores in 2 (the prompts of the first half are the same as the second half) | ||
first_half, second_half = scores.split(num_examples) | ||
|
||
# Get the indices of the chosen and rejected examples | ||
mask = first_half >= second_half | ||
|
||
num_examples_range = torch.arange(num_examples, device=prompt_completion_ids.device) | ||
num_examples_range = torch.arange(num_examples, device=device) | ||
chosen_indices = num_examples_range + (~mask * num_examples) | ||
rejected_indices = num_examples_range + (mask * num_examples) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the reason to use a processing class in case we want to support other modalities beyond text?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly. And since the
tokenizer
is now calledprocessing_class
within trainers, I'd recommend always aligning with it (even if only the textual modality is supported). Unless you have a good reason not to.