-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[6/7] SFTDataset: revamp instruct/chat #1286
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1286
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1ea12f4 with merge base 3e29e6b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/utils/logging.py
Outdated
@@ -25,3 +29,34 @@ def get_logger(level: Optional[str] = None) -> logging.Logger: | |||
level = getattr(logging, level.upper()) | |||
logger.setLevel(level) | |||
return logger | |||
|
|||
|
|||
def deprecated(msg: str = "") -> Callable[[T], T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid blowing up people's logs, is it possible to make sure that:
- This only logs once (not every time it is hit)
- This only logs on rank0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice-to-have: include an explicit version number when we will remove support for the API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this is kinda orthogonal, but I don't like that our utils directory == trainer utils. Because it takes deps on pretty much every other directory, which means we may run into circular import issues if we try to use this in our lower-level components. I would think about whether there's a place further upstream in our dependency graph we can put this (if only we didn't already commandeer the utils folder name with our furthest downstream components..)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- the lru cache ensures it's logged only once per class, also verified it in the unit test
- good point, will add this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a place further upstream in our dependency graph we can put this
not entirely sure what this would practically mean, do you mean a different file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a separate directory. Because otherwise anywhere we add from utils.file_name import SomeClass
will go through utils/__init__.py
, right? And that will import all our APIs with all their upstream dependencies in data, datasets, modules, models. So we'll get stuck with a circular dependency unless we move this out of utils entirely
torchtune/utils/logging.py
Outdated
@@ -25,3 +29,34 @@ def get_logger(level: Optional[str] = None) -> logging.Logger: | |||
level = getattr(logging, level.upper()) | |||
logger.setLevel(level) | |||
return logger | |||
|
|||
|
|||
def deprecated(msg: str = "") -> Callable[[T], T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice-to-have: include an explicit version number when we will remove support for the API
torchtune/utils/logging.py
Outdated
@@ -25,3 +29,34 @@ def get_logger(level: Optional[str] = None) -> logging.Logger: | |||
level = getattr(logging, level.upper()) | |||
logger.setLevel(level) | |||
return logger | |||
|
|||
|
|||
def deprecated(msg: str = "") -> Callable[[T], T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this is kinda orthogonal, but I don't like that our utils directory == trainer utils. Because it takes deps on pretty much every other directory, which means we may run into circular import issues if we try to use this in our lower-level components. I would think about whether there's a place further upstream in our dependency graph we can put this (if only we didn't already commandeer the utils folder name with our furthest downstream components..)
@@ -37,6 +39,7 @@ def format( | |||
pass | |||
|
|||
|
|||
@deprecated() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why no message for this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re our discussion on the alpaca PR, the template will just be absorbed into the message transform. this class will be removed anyway instead of deprecated
torchtune/datasets/_instruct.py
Outdated
|
||
|
||
@deprecated(msg="Please use `torchtune.datasets.SFTDataset` for custom chat data.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @ebsmothers already raised this, but worth pointing the user to a concrete example of a replacement?
If not, this says "chat" instead of instruct btw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the PR summary? I think it's not actually clearly emphasizing the latest set of changes you've made
@@ -36,7 +36,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): | |||
] | |||
) | |||
|
|||
grammar_ds = grammar_dataset(model_transform=tokenizer, train_on_input=True) | |||
grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't we just change this lol. Not opposed to changing it back, just curious why we changed our mind here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm honestly on the fence here, but discussed with @pbontrager and agreed that we'll take a look at these holistically at the end and make a call once we work on multimodal recipes. For now I've been following text datasets = tokenizer, SFTDataset and multimodal = model transform
from torchtune.data._utils import deprecated | ||
|
||
|
||
def test_deprecated(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: put in a test class?
@@ -73,3 +77,34 @@ def validate_messages( | |||
f"System message at index {i} in messages, but system messages must come first" | |||
) | |||
last_turn = message.role | |||
|
|||
|
|||
def deprecated(msg: str = "") -> Callable[[T], T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I will never stop harping on this, but this time I come with an actual proposal.
How about a utils/_internal
directory with its own __init__.py
file. Then we just do from torchtune.utils._internal import deprecated
, pretty sure any such usage will not trigger the imports in the parent directory's __init__.py
. Admittedly a little bit confusing but better than what we have now and I think it should work. Please let me know if I am fundamentally misunderstanding how Python imports work (always a possibility)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can try this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately, any import to torchtune.utils
requires hitting the __init__.py
file, so this does not work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh should've known it was too good to be true..
|
||
with pytest.warns( | ||
FutureWarning, | ||
match="DummyClass is deprecated and will be removed in future versions. Please use `TotallyAwesomeClass` instead.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a nit but it's a bit awkward to me that we split the log message across the utility and the arg we pass. Like your passing msg="Please use TotallyAwesomeClass
instead" assumes that you know exactly what the first half of the message is, which is somewhat annoying to go check every time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I could just pass in the full message everytime but that gets tedious. I supposed I could make it a required argument and do that for simplicity.
message_transform = InputOutputToMessages( | ||
train_on_input=train_on_input, column_map=column_map | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we discuss renaming this or did I imagine that? Also I'm a bit out of the loop on #1366 but why do we use a single static system prompt in InputOutputToMessages
now? Is that sufficient? (Sorry if you guys already hashed this all out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't recall that discussion but @joecummings was interested in potentially renaming this.
As for the system prompt, yes a single prompt is sufficient to add a system message for every sample conversation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @RdoubleA. Thanks for this. Only 14.285714285% of the work left!
Context
Repurposes the old instruct_dataset and chat_dataset builders for config friendly builders of SFTDataset.
instruct_dataset
SFTDataset
withInputOutputToMessages
as default message transform, since most instruct datasets follow this formatchat_dataset
SFTDataset
withShareGPTToMessages
orJSONToMessages
, depending on selected conversation style, as default message transform, since most chat datasets follow this formatAlso add a new decorator function
deprecated
, the grim reaper coming to announce the expiration of your favorite classes. The following are currently on the chopping block:ChatDataset
(replaced bySFTDataset
)InstructDataset
(replaced bySFTDataset
)get_openai_messages
(replaced byJSONToMessages
)get_sharegpt_messages
(replaced byShareGPTToMessages
)Rest In Power to the following:
Llama2ChatFormat
(replaced byLlama2ChatTemplate
)MistralChatFormat
(replaced byMistralChatTemplate
)ChatMLFormat
(replaced byChatMLTemplate
)AlpacaInstructTemplate
will be removed by #1284 andStackExchangedPairedTemplate
removed by #1276.Other changes:
ASSETS
in one location and refactorTest plan
Updated tests for
instruct_dataset
andchat_dataset
, these now load in a tiny json dataset with the expected format.Removed tests for chat formats.
unit test for
deprecated
. you can also see how the logs look when you run the tests locally: