Skip to content

Correct masking when the same roles are present in adjacent messages in DataCollatorForCompletionOnlyLM #1994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lsy641 opened this issue Aug 30, 2024 · 5 comments
Labels
🗃️ data Related to data ✨ enhancement New feature or request

Comments

@lsy641
Copy link

lsy641 commented Aug 30, 2024

Feature request

In the function torch_call of DataCollatorForCompletionOnlyLM, the suggested new feature can support correct masking on user requests even if the user and assistant messages are not present alternately.

The current version requires an assistant message must follow a user message, and a user message follows an assistant message.
Two adjacent messages with the same role will cause wrong masking, as the current codes haven't considered that a large start variable can be paired with a small end variable when two roles don't take turn by turn :

 for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
      # Make pytorch loss function ignore all non response tokens

      if idx != 0:
          batch["labels"][i, start:end] = self.ignore_index
      else:
          batch["labels"][i, :end] = self.ignore_index

  if len(response_token_ids_idxs) < len(human_token_ids_idxs):
      batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index

Using two pointers can solve the issue, below is an example solution:

#build test cases for response_token_ids_idxs and human_token_ids_idxs
response_token_ids_idxs = [1, 4, 6, 7, 8, 9, 15, 36, 57, 88, 89, 200]
human_token_ids_idxs= [2, 5, 12, 13, 56, 66, 90, 199, 201, 202]
pointer_human = 0
pointer_response = 0
mask_start = -1
mask_end = -1

while pointer_response <= len(response_token_ids_idxs) - 1 and pointer_human <= len(human_token_ids_idxs) - 1:
    if mask_start == -1:
        mask_start = 0 if response_token_ids_idxs[0] != 0 else human_token_ids_idxs[pointer_human]
    if mask_end == -1:
        mask_end = response_token_ids_idxs[0]
    if response_token_ids_idxs[pointer_response] > human_token_ids_idxs[pointer_human]:
        if mask_end < mask_start:
            mask_end = response_token_ids_idxs[pointer_response]
        pointer_human += 1      
    elif response_token_ids_idxs[pointer_response] < human_token_ids_idxs[pointer_human]:
        if mask_start < mask_end:
            print(mask_start, "~", mask_end) #will substitute this line with batch["labels"][i, mask_start:mask_end] = self.ignore_index when pulling a request
            mask_start = human_token_ids_idxs[pointer_human]
        pointer_response += 1
    else:
        raise Exception("response_token_id and human_token_id could not be the same. Please check your response and human template ids")
if pointer_human < len(human_token_ids_idxs) - 1:
    while human_token_ids_idxs[pointer_human] < mask_end:
        pointer_human += 1
    if pointer_human <= len(human_token_ids_idxs) - 1:
        print(human_token_ids_idxs[pointer_human], "~", "end") #will substitute this line with batch["labels"][i, mask_start:mask_end] = self.ignore_index when pulling a request

This code can be tested to output:

### output
0 ~ 1
2 ~ 4
5 ~ 6
12 ~ 15
56 ~ 57
66 ~ 88
90 ~ 200
201 ~ end

Motivation

Support flexible and correct masking strategies for DataCollatorForCompletionOnlyLM, especially allowing masking for continuous messages from the same role.

Your contribution

I submit a PR: #2000

@qgallouedec
Copy link
Member

The current version requires an assistant message must follow a user message, and a user message follows an assistant message.

I'm not sure why we would want to have a dataset in which the role is not interleaved. Moreover, some chat templates explicitly assume that messages are an interleaving of user and assistant messages.
Do you have an example?

@lsy641
Copy link
Author

lsy641 commented Sep 16, 2024

We encountered this problem because we wanted to fine-tune models on real-human conversations. In natural conversation, it is common to see an utterance followed by another from the same speaker in a context where that pause happens., for example, in counseling conversations.
I think not all researchers aim to build LLM as an AI assistant, which if so I agree only needs interleaved role-play.

@lsy641
Copy link
Author

lsy641 commented Sep 16, 2024

And another scenario is NPC dialogue in the game. When a player doesn't interrupt, the LLM could keep saying with the same role @qgallouedec

@qgallouedec
Copy link
Member

Thank you very much for the clarification. We are currently working on a new dataset format that could be linked (for different motivation though). See #2148

@qgallouedec qgallouedec added ✨ enhancement New feature or request 🗃️ data Related to data labels Oct 20, 2024
@lsy641
Copy link
Author

lsy641 commented Oct 31, 2024

Thank you!

Kirili4ik added a commit to Kirili4ik/trl that referenced this issue Apr 3, 2025
…ForCompletionOnlyLM (huggingface#3223)

Refactors the masking logic in `DataCollatorForCompletionOnlyLM` to correctly handle conversations with multiple instruction roles (e.g., user, tool) and consecutive assistant turns, enabling its use for more complex dialogue formats like agent trajectories.

Previously, the collator assumed a strict alternation of a single instruction template and a response template (e.g., User -> Assistant). This failed for:
1.  Datasets with multiple instruction roles (e.g., user prompts and tool calls).
2.  Sequences with consecutive assistant messages (e.g., Assistant -> Assistant).

This commit addresses these limitations:
- Updates `__init__` to accept a list of strings or pre-tokenized IDs for `instruction_template`, allowing multiple distinct instruction roles.
- Rewrites the core masking logic in `torch_call`:
    - It now identifies all occurrences of response and all specified instruction templates.
    - For each assistant response, it unmasks tokens from the end of its template up to the beginning of the *next* instruction template or the sequence end.
    - Correctly handles consecutive assistant turns by masking the template tokens of subsequent responses while unmasking their content.
- Adds comprehensive unit tests (`test_masking_*`) covering multi-role scenarios, consecutive assistant messages, left-padding, and initialization with tokenized templates.

This allows `DataCollatorForCompletionOnlyLM` to process conversational data commonly found in ChatML formats and agent fine-tuning datasets.

Related: huggingface#1994, huggingface#2545
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🗃️ data Related to data ✨ enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants