diff --git a/tests/test_data_collator_completion_only.py b/tests/test_data_collator_completion_only.py index 96813b5a82..c895a616e1 100644 --- a/tests/test_data_collator_completion_only.py +++ b/tests/test_data_collator_completion_only.py @@ -21,14 +21,22 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase): def test_data_collator_finds_response_template_llama2_tokenizer(self): + # this should ideally be tested with meta-llama/Llama-2-7b-hf self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") self.instruction = """### System: You are a helpful assistant. ### User: How much is 2+2? ### Assistant: 2+2 equals 4""" + self.instruction_template = "\n### User:" self.response_template = "\n### Assistant:" + # GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25] + # Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901] + self.tokenized_instruction_w_context = self.tokenizer.encode( + self.instruction_template, add_special_tokens=False + )[2:] + # GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25] # Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901] self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:] @@ -42,6 +50,13 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self): self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer) self.collator.torch_call([self.tokenized_instruction]) + # Test for PR #749 + # Pass already tokenized (w context) instruction and response both so token_ids are like in the instruction + response + self.collator = DataCollatorForCompletionOnlyLM( + self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer + ) + self.collator.torch_call([self.tokenized_instruction]) + def test_data_collator_handling_of_long_sequences(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") self.instruction = """### System: You are a helpful assistant. diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 451552f14b..fa58fd3374 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -75,21 +75,31 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): def __init__( self, response_template: Union[str, List[int]], - instruction_template: Optional[str] = None, + instruction_template: Union[str, List[int]] = None, *args, mlm: bool = False, ignore_index: int = -100, **kwargs, ): super().__init__(*args, mlm=mlm, **kwargs) + self.instruction_template = instruction_template + if isinstance(instruction_template, str): + # The user provides a string, must tokenize + self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.instruction_token_ids = instruction_template + self.response_template = response_template - self.ignore_index = ignore_index - if type(response_template) == list: + if isinstance(response_template, str): + # The user provides a string, must tokenize + self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) + else: # The user already provides the token ids self.response_token_ids = response_template - else: - self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) + + self.ignore_index = ignore_index def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: batch = super().torch_call(examples) @@ -142,7 +152,7 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D ) batch["labels"][i, :] = self.ignore_index - human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) + human_token_ids = self.instruction_token_ids for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: # find the indexes of the start of a human answer. if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():