Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masking strategies to message transforms #2284

Merged

Conversation

supreethmanyam
Copy link
Contributor

@supreethmanyam supreethmanyam commented Jan 20, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
#2261

Changelog

What are the changes made in this PR?

  • Added mask_messages utility in torchtune/data/_messages.py to mask each message given a list of messages based on user provided masking_strategy
  • Fixes in-built message transforms to always mask system message
  • Added tests for calling message transforms with masking strategies

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Jan 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2284

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 806f437 with merge base d5d12fe (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 20, 2025
@supreethmanyam supreethmanyam force-pushed the add-chat-loss-masking-strategies branch from db9a565 to 255bdcd Compare January 20, 2025 23:14
@supreethmanyam supreethmanyam force-pushed the add-chat-loss-masking-strategies branch from 255bdcd to ac6272f Compare January 20, 2025 23:22
@supreethmanyam
Copy link
Contributor Author

supreethmanyam commented Jan 21, 2025

@RdoubleA Added mask_messages utility in torchtune/data/_messages.py along with relevant unit tests.
I will update the documentation once the changes look good. Please let me know.

@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
Copy link
Collaborator

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, really like the tests 👌 left a couple of comments

@@ -176,12 +185,24 @@ class InputOutputToMessages(Transform):

def __init__(
self,
train_on_input: bool = False,
masking_strategy: Optional[str] = "train_on_all",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the default value for train_on_input was False, this would correspond to train_on_assistant. Let's keep that same default behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated default strategy

@@ -305,11 +332,23 @@ class ChosenRejectedToMessages(Transform):

def __init__(
self,
train_on_input: bool = False,
masking_strategy: Optional[str] = "train_on_all",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated default strategy


Raises:
ValueError: If ``column_map`` is provided and ``conversations`` not in ``column_map``.
"""

def __init__(
self,
train_on_input: bool = False,
masking_strategy: Optional[str] = "train_on_all",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated default strategy

"""

def __init__(
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
self,
masking_strategy: Optional[str] = "train_on_all",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contrary to the other comment, this default value should remain train_on_all since train_on_input was defaulted to True. so let's keep this

ValueError: If the masking strategy is not one of the supported strategies:
`train_on_all`, `train_on_assistant`, `train_on_last`.
"""
if masking_strategy not in ["train_on_all", "train_on_assistant", "train_on_last"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion that I didn't mention in the issue is that we could make masking_strategy an Enum so that it can be type checked more easily. users can still specify as a string in the config and in the message transforms above, but here in mask_messages you could convert to the Enum, and the ValueError will automatically be captured.

Copy link
Contributor Author

@supreethmanyam supreethmanyam Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted masking_strategy to Enum and raises ValueError automatically if doesn't match with mapping

@supreethmanyam
Copy link
Contributor Author

@RdoubleA

In addition to converting masking_strategy to Enum and fixing default strategy in built-in transforms, I made following small changes:

  1. I modified structure of column_map usage in InputOutputToMessages and AlpacaToMessages to make it consistent across transforms
  2. I realized there were no unit tests for AlpacaToMessages, I added a few tests.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @supreethmanyam, I think this PR fell through the cracks a bit. I have a couple comments but after that I think it looks good. cc @RdoubleA to confirm

@@ -177,31 +193,42 @@ class InputOutputToMessages(Transform):

def __init__(
self,
train_on_input: bool = False,
masking_strategy: Optional[str] = "train_on_assistant",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we switch the order of these arguments? If someone is passing InputOutputToMessages(False) this will break them, right? (Since we unfortunately don't enforce keyword-only args on these APIs). Personally I would leave train_on_input in the first position and put masking_strategy last

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, let's keep the positions of the arguments the same until we deprecate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! my bad!

I rearranged the order of arguments. I moved train_on_input to first as it was before and moved masking_strategy to the last. Updated docstrings accordingly.

Comment on lines +302 to +304
("train_on_all", [[False, False], [False, False]]),
("train_on_assistant", [[True, False], [True, False]]),
("train_on_last", [[True, False], [True, False]]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should explicitly test a case where the results of train_on_assistant and train_on_last differ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added test cases using multi turn samples where results train_on_assistant and train_on_last differ where applicable (ChosenRejectedToMessages, ShareGPTToMessages, OpenAIToMessages)

@@ -479,6 +609,200 @@ def test_call_image_messages(self, mock_load_image):
mock_load_image.assert_called_once_with("https://example.com")


class TestAlpacaToMessages:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this and the existing TestAlpacaToMessages here? They look fairly similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know we had TestAlpacaToMessages in test_alpaca_dataset.py and added one in test_messages.py.

Both are similar and the one in test_messages.py runs more tests. Please let me know if I can remove TestAlpacaToMessages from here

return {"messages": messages}


def validate_messages(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar q here: don't we have an existing validate_messages function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was just moved down here I believe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved validate_messages down to keep all the message transforms at one place.

):
self.train_on_input = train_on_input
self.column_map = column_map
if train_on_input is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we had a deprecated param utility, but it looks like it never got reviewed: #2321. should make a note there to also replace this code if this PR merges first @ebsmothers

@supreethmanyam
Copy link
Contributor Author

supreethmanyam commented Mar 12, 2025

@RdoubleA @ebsmothers
I have made all the changes and updated the documentation. Requesting another review.

Current references to train_on_input in the docs comes from dataset modules, hence I haven't made any changes yet. Please let me know if I should create another issue for dataset module changes. I will be happy to make changes to dataset modules to reflect using masking_strategy instead of train_on_input and update docs accordingly.

@ebsmothers
Copy link
Contributor

Thanks @supreethmanyam for your patience on this one. Since we have not been great about timely reviews here, I took the liberty of pushing some changes directly to the PR (hope that's alright). Mainly I changed the deprecation logic slightly: when train_on_input and masking_strategy are both passed, we should still use train_on_input but warn the user (otherwise it's possible for silent behavior changes without the user actually migrating). While doing this, I noticed that our SlimOrca test case seems incorrect -- it appears as though we were never masking the system message (see here). cc @RdoubleA to sanity check me here. After this (and green CI) I think this should be good to land. Thanks again for your patience!

@RdoubleA
Copy link
Collaborator

That's a tough one... it does look like system is being masked, but if train_on_input is true it will not be masked, same as the user message. So yeah, big oof

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate you adding this @supreethmanyam!

@ebsmothers ebsmothers merged commit 32d195c into pytorch:main Mar 26, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants