Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[6/7] SFTDataset: revamp instruct/chat #1286

Merged
merged 13 commits into from
Aug 26, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Aug 7, 2024

Context

Repurposes the old instruct_dataset and chat_dataset builders for config friendly builders of SFTDataset.

  • instruct_dataset
    • creates SFTDataset with InputOutputToMessages as default message transform, since most instruct datasets follow this format
  • chat_dataset
    • creates SFTDataset with ShareGPTToMessages or JSONToMessages, depending on selected conversation style, as default message transform, since most chat datasets follow this format

Also add a new decorator function deprecated, the grim reaper coming to announce the expiration of your favorite classes. The following are currently on the chopping block:

  • ChatDataset (replaced by SFTDataset)
  • InstructDataset (replaced by SFTDataset)
  • get_openai_messages (replaced by JSONToMessages)
  • get_sharegpt_messages (replaced by ShareGPTToMessages)

Rest In Power to the following:

  • Llama2ChatFormat (replaced by Llama2ChatTemplate)
  • MistralChatFormat (replaced by MistralChatTemplate)
  • ChatMLFormat (replaced by ChatMLTemplate)

AlpacaInstructTemplate will be removed by #1284 and StackExchangedPairedTemplate removed by #1276.

Other changes:

  • Centralize ASSETS in one location and refactor

Test plan

Updated tests for instruct_dataset and chat_dataset, these now load in a tiny json dataset with the expected format.

Removed tests for chat formats.

unit test for deprecated. you can also see how the logs look when you run the tests locally:

tests/torchtune/data/test_instruct_templates.py:43
  /data/users/rafiayub/torchtune-rafiayub/tests/torchtune/data/test_instruct_templates.py:43: FutureWarning: AlpacaInstructTemplate is deprecated and will be removed in future versions. 
    template = AlpacaInstructTemplate()

tests/torchtune/data/test_converters.py::TestShareGPTToLlama2Messages::test_conversion
  /data/users/rafiayub/torchtune-rafiayub/tests/torchtune/data/test_converters.py:35: FutureWarning: get_sharegpt_messages is deprecated and will be removed in future versions. Please use `torchtune.data.ShareGPTToMessages` with `torchtune.datasets.SFTDataset` instead.
    converted_messages = get_sharegpt_messages(self.samples)

tests/torchtune/data/test_converters.py::TestOpenAIToLlama2Messages::test_conversion_conversations_key
  /data/users/rafiayub/torchtune-rafiayub/tests/torchtune/data/test_converters.py:81: FutureWarning: get_openai_messages is deprecated and will be removed in future versions. Please use `torchtune.data.JSONToMessages` with `torchtune.datasets.SFTDataset` instead.
    converted_messages_1 = get_openai_messages(self.samples_1)

Copy link

pytorch-bot bot commented Aug 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1286

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1ea12f4 with merge base 3e29e6b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 7, 2024
@@ -25,3 +29,34 @@ def get_logger(level: Optional[str] = None) -> logging.Logger:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger


def deprecated(msg: str = "") -> Callable[[T], T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid blowing up people's logs, is it possible to make sure that:

  1. This only logs once (not every time it is hit)
  2. This only logs on rank0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice-to-have: include an explicit version number when we will remove support for the API

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is kinda orthogonal, but I don't like that our utils directory == trainer utils. Because it takes deps on pretty much every other directory, which means we may run into circular import issues if we try to use this in our lower-level components. I would think about whether there's a place further upstream in our dependency graph we can put this (if only we didn't already commandeer the utils folder name with our furthest downstream components..)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. the lru cache ensures it's logged only once per class, also verified it in the unit test
  2. good point, will add this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a place further upstream in our dependency graph we can put this

not entirely sure what this would practically mean, do you mean a different file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a separate directory. Because otherwise anywhere we add from utils.file_name import SomeClass will go through utils/__init__.py, right? And that will import all our APIs with all their upstream dependencies in data, datasets, modules, models. So we'll get stuck with a circular dependency unless we move this out of utils entirely

@@ -25,3 +29,34 @@ def get_logger(level: Optional[str] = None) -> logging.Logger:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger


def deprecated(msg: str = "") -> Callable[[T], T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice-to-have: include an explicit version number when we will remove support for the API

@@ -25,3 +29,34 @@ def get_logger(level: Optional[str] = None) -> logging.Logger:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger


def deprecated(msg: str = "") -> Callable[[T], T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is kinda orthogonal, but I don't like that our utils directory == trainer utils. Because it takes deps on pretty much every other directory, which means we may run into circular import issues if we try to use this in our lower-level components. I would think about whether there's a place further upstream in our dependency graph we can put this (if only we didn't already commandeer the utils folder name with our furthest downstream components..)

@@ -37,6 +39,7 @@ def format(
pass


@deprecated()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no message for this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re our discussion on the alpaca PR, the template will just be absorbed into the message transform. this class will be removed anyway instead of deprecated

torchtune/data/_converters.py Outdated Show resolved Hide resolved


@deprecated(msg="Please use `torchtune.datasets.SFTDataset` for custom chat data.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @ebsmothers already raised this, but worth pointing the user to a concrete example of a replacement?
If not, this says "chat" instead of instruct btw

@RdoubleA RdoubleA changed the title [6/7] SFTDataset: deprecate instruct/chat [6/7] SFTDataset: revamp instruct/chat Aug 20, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the PR summary? I think it's not actually clearly emphasizing the latest set of changes you've made

torchtune/utils/logging.py Outdated Show resolved Hide resolved
torchtune/datasets/_stack_exchange_paired.py Outdated Show resolved Hide resolved
@@ -36,7 +36,7 @@ def test_label_no_masking(self, load_dataset, tokenizer):
]
)

grammar_ds = grammar_dataset(model_transform=tokenizer, train_on_input=True)
grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't we just change this lol. Not opposed to changing it back, just curious why we changed our mind here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm honestly on the fence here, but discussed with @pbontrager and agreed that we'll take a look at these holistically at the end and make a call once we work on multimodal recipes. For now I've been following text datasets = tokenizer, SFTDataset and multimodal = model transform

from torchtune.data._utils import deprecated


def test_deprecated():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put in a test class?

@@ -73,3 +77,34 @@ def validate_messages(
f"System message at index {i} in messages, but system messages must come first"
)
last_turn = message.role


def deprecated(msg: str = "") -> Callable[[T], T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I will never stop harping on this, but this time I come with an actual proposal.

How about a utils/_internal directory with its own __init__.py file. Then we just do from torchtune.utils._internal import deprecated, pretty sure any such usage will not trigger the imports in the parent directory's __init__.py. Admittedly a little bit confusing but better than what we have now and I think it should work. Please let me know if I am fundamentally misunderstanding how Python imports work (always a possibility)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, any import to torchtune.utils requires hitting the __init__.py file, so this does not work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh should've known it was too good to be true..

torchtune/datasets/_grammar.py Show resolved Hide resolved

with pytest.warns(
FutureWarning,
match="DummyClass is deprecated and will be removed in future versions. Please use `TotallyAwesomeClass` instead.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nit but it's a bit awkward to me that we split the log message across the utility and the arg we pass. Like your passing msg="Please use TotallyAwesomeClass instead" assumes that you know exactly what the first half of the message is, which is somewhat annoying to go check every time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I could just pass in the full message everytime but that gets tedious. I supposed I could make it a required argument and do that for simplicity.

Comment on lines +237 to +239
message_transform = InputOutputToMessages(
train_on_input=train_on_input, column_map=column_map
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we discuss renaming this or did I imagine that? Also I'm a bit out of the loop on #1366 but why do we use a single static system prompt in InputOutputToMessages now? Is that sufficient? (Sorry if you guys already hashed this all out)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't recall that discussion but @joecummings was interested in potentially renaming this.

As for the system prompt, yes a single prompt is sufficient to add a system message for every sample conversation.

torchtune/data/_prompt_templates.py Show resolved Hide resolved
tests/torchtune/datasets/test_instruct_dataset.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @RdoubleA. Thanks for this. Only 14.285714285% of the work left!

@RdoubleA RdoubleA merged commit 7e084d9 into pytorch:main Aug 26, 2024
20 checks passed
@RdoubleA RdoubleA deleted the deprecate_instruct_chat branch August 26, 2024 19:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants