Skip to content

feat(trainer): Support multi-role & consecutive turns in DataCollatorForCompletionOnlyLM (#3223) #3224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 331 additions & 0 deletions tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,124 @@
from trl import DataCollatorForCompletionOnlyLM


# Define samples globally for reuse
CHATML_SAMPLE_BASIC_MULTI_TURN = """<|im_start|>system
system prompt system ptompt system prompt
<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>""" # 4 assistant turns

CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN = """<|im_start|>system
system prompt system ptompt system prompt
<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>""" # 8 assistant turns

CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE = """<|im_start|>system
system prompt system ptompt system prompt
<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>user
U U U U U<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>tool
T T T T T<|im_end|>
<|im_start|>assistant
A A A A A<|im_end|>""" # 8 assistant turns

CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_SIMPLE = """<|im_start|>system
Prompt.
<|im_end|>
<|im_start|>user
User query.<|im_end|>
<|im_start|>assistant
Assistant response 1.<|im_end|>
<|im_start|>assistant
Assistant response 2.<|im_end|>
<|im_start|>user
Another user query.<|im_end|>
<|im_start|>assistant
Assistant response 3.<|im_end|>""" # 3 assistant turns total, 2 consecutive

# Expected decoded output for a single assistant turn based on the samples above
EXPECTED_DECODED_ASSISTANT_CHUNK = "A A A A A<|im_end|>\n"


class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
def test_data_collator_finds_response_template_llama2_tokenizer(self):
# this should ideally be tested with meta-llama/Llama-2-7b-hf
@@ -167,3 +285,216 @@ def test_data_collator_for_completion_only_lm(self):
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [[0, 6, 13]]) # idem
self.assertEqual(batch["max_length_k"], torch.tensor([7])) # max length in batch, here 7 (second sequence)
self.assertEqual(batch["max_length_q"], torch.tensor([7])) # idem

def test_masking_basic_multi_turn(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

instruction_template = "<|im_start|>user\n"
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

conversations = [
CHATML_SAMPLE_BASIC_MULTI_TURN,
CHATML_SAMPLE_BASIC_MULTI_TURN,
] # Batch of 2 identical samples
tokenized = tokenizer(conversations, add_special_tokens=False)

# Prepare input for collator in the typical dictionary format
batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected output: 4 assistant turns per sample
expected_decoded_output = EXPECTED_DECODED_ASSISTANT_CHUNK * 4

# Check labels for each sample in the batch
for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
# strip potential leading/trailing whitespace artefacts from decode
self.assertEqual(
decoded_text.strip(), expected_decoded_output.strip(), f"Mismatch in decoded labels for sample {i}"
)

def test_masking_multi_role_multi_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

# Use a list for multiple instruction templates
instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE]
tokenized = tokenizer(conversations, add_special_tokens=False)

batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected outputs based on the number of assistant turns
expected_outputs = [
EXPECTED_DECODED_ASSISTANT_CHUNK * 8, # CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN has 8 assistant turns
EXPECTED_DECODED_ASSISTANT_CHUNK
* 8, # CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE has 8 assistant turns
]

# Check labels for each sample in the batch
self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch")

for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(), expected_outputs[i].strip(), f"Mismatch in decoded labels for sample {i}"
)

def test_masking_consecutive_assistant(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

tokenized = tokenizer([CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_SIMPLE], add_special_tokens=False)
batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected: Only the content *after* the response_template should be unmasked for all assistant turns.
# The logic correctly handles consecutive turns by masking up to the *next* instruction or the end.
expected_decoded_output = (
"Assistant response 1.<|im_end|>\nAssistant response 2.<|im_end|>\nAssistant response 3.<|im_end|>\n"
)

valid_indices = collated_batch["labels"][0] != -100
valid_labels = collated_batch["labels"][0][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(),
expected_decoded_output.strip(),
"Mismatch in decoded labels for consecutive assistant test",
)

def test_masking_left_padding(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
# Explicitly set left padding
tokenizer.padding_side = "left"
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

instruction_template = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template = "<|im_start|>assistant\n"

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
mlm=False,
)

conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_BASIC_MULTI_TURN]
tokenized = tokenizer(conversations, add_special_tokens=False, padding=True, truncation=True, max_length=512)

batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected outputs based on the number of assistant turns in the specific samples used
expected_outputs = [
EXPECTED_DECODED_ASSISTANT_CHUNK * 8, # CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN
EXPECTED_DECODED_ASSISTANT_CHUNK * 4, # CHATML_SAMPLE_BASIC_MULTI_TURN
]

self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch")

for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(),
expected_outputs[i].strip(),
f"Mismatch in decoded labels for left padding, sample {i}",
)

def test_masking_tokenized_templates(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

# Pre-tokenize the templates
instruction_templates_str = ["<|im_start|>tool\n", "<|im_start|>user\n"]
response_template_str = "<|im_start|>assistant\n"

instruction_token_ids = [
tokenizer.encode(tmpl, add_special_tokens=False) for tmpl in instruction_templates_str
]
response_token_ids = tokenizer.encode(response_template_str, add_special_tokens=False)

data_collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_token_ids, # Pass List[List[int]]
response_template=response_token_ids, # Pass List[int]
tokenizer=tokenizer,
mlm=False,
)

conversations = [CHATML_SAMPLE_MULTI_ROLE_MULTI_TURN, CHATML_SAMPLE_CONSECUTIVE_ASSISTANT_MULTI_ROLE]
tokenized = tokenizer(conversations, add_special_tokens=False, padding=True, truncation=True, max_length=512)

batch_input = [
{"input_ids": tokenized.input_ids[i], "attention_mask": tokenized.attention_mask[i]}
for i in range(len(tokenized.input_ids))
]
collated_batch = data_collator(batch_input)

# Expected outputs based on the number of assistant turns
expected_outputs = [
EXPECTED_DECODED_ASSISTANT_CHUNK * 8,
EXPECTED_DECODED_ASSISTANT_CHUNK * 8,
]

self.assertEqual(len(collated_batch["labels"]), len(expected_outputs), "Batch size mismatch")

for i in range(len(collated_batch["labels"])):
valid_indices = collated_batch["labels"][i] != -100
valid_labels = collated_batch["labels"][i][valid_indices]
decoded_text = tokenizer.decode(valid_labels, skip_special_tokens=False)
self.assertEqual(
decoded_text.strip(),
expected_outputs[i].strip(),
f"Mismatch in decoded labels for tokenized templates, sample {i}",
)
128 changes: 96 additions & 32 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
@@ -77,19 +77,25 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
response_template (`Union[str, list[int]]`): the template form that indicates the start of the response, typically something like
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
differently if it does not have proper context.
instruction_template (`Union[str, list[int]]`): the template form that indicates the start of the human instruction, typically something like
'### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
mlm (`bool`, *optional*, defaults to `False`): Whether to use masked language modeling in the underlying
instruction_template (`Union[str, list[int], list[str]]`, *optional*, defaults to `None`):
The template form that indicates the start of the human instruction, typically something like
'### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids
or as a list of strings when multiple instruction templates need to be detected (useful for multi-turn conversations e.g. ["<system>", "<tool>", "<user>"]).
mlm (`bool`, *optional*, defaults to `False`):
Whether to use masked language modeling in the underlying
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
for flexibility and backwards-compatibility.
ignore_index (`int`, *optional*, defaults to `-100`):
The index to use to ignore the initial tokens with
padding_free (`bool`, *optional*, defaults to `False`):
Whether to use padding-free training. When set to True, padding tokens are removed and positional ids are
added to the inputs to enable proper attention.
"""

def __init__(
self,
response_template: Union[str, list[int]],
instruction_template: Optional[Union[str, list[int]]] = None,
instruction_template: Optional[Union[str, list[int], list[str]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
@@ -99,12 +105,27 @@ def __init__(
super().__init__(*args, mlm=mlm, **kwargs)

self.instruction_template = instruction_template
self.has_multiple_instruction_templates = False

if isinstance(instruction_template, str):
# The user provides a string, must tokenize
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
elif isinstance(instruction_template, list) and isinstance(instruction_template[0], str):
# The user provides a list of strings, must tokenize each template
self.instruction_token_ids = []
for template in self.instruction_template:
self.instruction_token_ids.append(self.tokenizer.encode(template, add_special_tokens=False))
self.has_multiple_instruction_templates = True
else:
# The user already provides the token ids
self.instruction_token_ids = instruction_template
# Check if it's a list of lists (multiple templates)
if (
isinstance(instruction_template, list)
and instruction_template
and isinstance(instruction_template[0], list)
):
self.has_multiple_instruction_templates = True

self.response_template = response_template
if isinstance(response_template, str):
@@ -129,6 +150,12 @@ def __init__(
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
batch = super().torch_call(examples)

sequence_lengths = (batch["input_ids"] != self.tokenizer.pad_token_id).sum(dim=1)
content_starts = (
batch["input_ids"].shape[1] - sequence_lengths
if self.tokenizer.padding_side == "left"
else torch.zeros_like(sequence_lengths)
)
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
@@ -157,57 +184,94 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d

else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
response_start_positions = []
instruction_start_positions = []

for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# find the indexes of the start of a response.
if (
self.response_token_ids
== batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
):
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
response_start_positions.append(assistant_idx + len(self.response_token_ids))

if len(response_token_ids_idxs) == 0:
if len(response_start_positions) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index

human_token_ids = self.instruction_token_ids
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
human_token_ids_idxs.append(human_idx)

if len(human_token_ids_idxs) == 0:
continue

# Find all instruction token positions
if self.has_multiple_instruction_templates:
# Handle multiple instruction templates
for instruction_token_ids in self.instruction_token_ids:
for instruction_idx in np.where(batch["labels"][i] == instruction_token_ids[0])[0]:
if (
instruction_token_ids
== batch["labels"][i][
instruction_idx : instruction_idx + len(instruction_token_ids)
].tolist()
):
instruction_start_positions.append(instruction_idx)
instruction_start_positions = sorted(instruction_start_positions)
else:
instruction_token_ids = self.instruction_token_ids
for instruction_idx in np.where(batch["labels"][i] == instruction_token_ids[0])[0]:
# find the indexes of the start of an instruction.
if (
instruction_token_ids
== batch["labels"][i][
instruction_idx : instruction_idx + len(instruction_token_ids)
].tolist()
):
instruction_start_positions.append(instruction_idx)

if len(instruction_start_positions) == 0:
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index

if (
len(human_token_ids_idxs) > 0
and len(response_token_ids_idxs) > 0
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
):
human_token_ids_idxs = [0] + human_token_ids_idxs

for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
if idx != 0:
batch["labels"][i, start:end] = self.ignore_index
continue

# Mask everything first and we will unmask step by step
batch["labels"][i, :] = self.ignore_index

# Unmask regions between each response and next instruction (or till end)
sequence_length = sequence_lengths[i].item()
content_start = content_starts[i].item()
last_processed_instruction_pos = -1
for response_pos in response_start_positions:
# Find the first instruction position that comes after this response
next_instruction_pos = None
for instruction_pos in instruction_start_positions:
if instruction_pos > response_pos:
next_instruction_pos = instruction_pos
break

# If no instruction position found after response, use sequence length from input_ids
if next_instruction_pos is None:
# Calculate actual sequence length using pad token positions
next_instruction_pos = content_start + sequence_length

# Handle consecutive responses
if response_pos > last_processed_instruction_pos:
# Unmask from response start to instruction start (or end); base case
batch["labels"][i, response_pos:next_instruction_pos] = batch["input_ids"][
i, response_pos:next_instruction_pos
]
last_processed_instruction_pos = next_instruction_pos
else:
batch["labels"][i, :end] = self.ignore_index

if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
# 2 reponses in a row so we unmask the special tokens for response in the middle
batch["labels"][i, response_pos - len(self.response_token_ids) : response_pos] = (
self.ignore_index
)

if self.padding_free:
# remove padding, `attention_mask` and add `position_ids`