Skip to content

Commit

Permalink
Allow passing the token_ids as instruction_template in DataCollatorFo…
Browse files Browse the repository at this point in the history
…rCompletionOnlyLM (#749)

* Update utils.py

* correctly assign instruction_template in DataCollatorForCompletionOnlyLM

* correctly use instruction_token_ids in DataCollatorForCompletionOnlyLM

* DataCollatorForCompletionOnlyLM: fix instruction_template / response_template type check: handle cases where instruction_template is None

* make precommit

* Test DataCollatorForCompletionOnlyLM with pre-tokenized instruction_template
  • Loading branch information
devxpy authored Sep 26, 2023
1 parent 92b03f5 commit d608fea
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
15 changes: 15 additions & 0 deletions tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,22 @@

class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
def test_data_collator_finds_response_template_llama2_tokenizer(self):
# this should ideally be tested with meta-llama/Llama-2-7b-hf
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.instruction = """### System: You are a helpful assistant.
### User: How much is 2+2?
### Assistant: 2+2 equals 4"""
self.instruction_template = "\n### User:"
self.response_template = "\n### Assistant:"

# GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25]
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901]
self.tokenized_instruction_w_context = self.tokenizer.encode(
self.instruction_template, add_special_tokens=False
)[2:]

# GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25]
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901]
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:]
Expand All @@ -42,6 +50,13 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self):
self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer)
self.collator.torch_call([self.tokenized_instruction])

# Test for PR #749
# Pass already tokenized (w context) instruction and response both so token_ids are like in the instruction + response
self.collator = DataCollatorForCompletionOnlyLM(
self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer
)
self.collator.torch_call([self.tokenized_instruction])

def test_data_collator_handling_of_long_sequences(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.instruction = """### System: You are a helpful assistant.
Expand Down
22 changes: 16 additions & 6 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,31 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def __init__(
self,
response_template: Union[str, List[int]],
instruction_template: Optional[str] = None,
instruction_template: Union[str, List[int]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)

self.instruction_template = instruction_template
if isinstance(instruction_template, str):
# The user provides a string, must tokenize
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.instruction_token_ids = instruction_template

self.response_template = response_template
self.ignore_index = ignore_index
if type(response_template) == list:
if isinstance(response_template, str):
# The user provides a string, must tokenize
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.response_token_ids = response_template
else:
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)

self.ignore_index = ignore_index

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
Expand Down Expand Up @@ -142,7 +152,7 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
)
batch["labels"][i, :] = self.ignore_index

human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
human_token_ids = self.instruction_token_ids
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
Expand Down

0 comments on commit d608fea

Please sign in to comment.