Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/n] Merged fine-tuning dataset: grammar + samsum #1234

Merged
merged 22 commits into from
Aug 5, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Jul 27, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

As discussed in the RFC in #1186, we will merged instruct and chat datasets to the following unified pipeline that can better support multimodal:

  1. message_transform to create List[Message] from dataset with full flexibility on columns, ad-hoc modifications, etc. For multimodal, additionally images are loaded from the path
  2. prompt_template as a optional way to add structured text around specific roles in the list of messages
  3. model_transform that takes the list of messages and tokenizes it. For multimodal, it will additionally apply model-specific image transforms to the images associated with the sample

For ease of review, we will stage this as multiple moderate-sized PRs. This PR creates the unified dataset class, and refactors grammar and samsum to start off with. As a result, a few key changes were made:

  • New FinetuneDataset (not married to the name, best I can think of) class with the unified pipeline with associated unit test
  • Add ToInputOutputMessages (open to better names) which provides a generic message transform that takes input column -> user message, output column -> assistant message
  • Refactor grammar and samsum datasets. They now create a message transform using ToInputOutputMessages, add a default prompt template that can be changed, and use the new FinetuneDataset class
  • Move Message from _types.py to _messages.py. This file will now contain everything related to Message, including generic message transforms (like the ones in _converters.py which will eventually migrate here)
  • Make all tokenizers also double as Transforms for use in the model_transform argument
  • New unified PromptTemplate interface that merges functionality of instructTemplate and ChatFormat. Refactored grammar and summarize templates to use a common CustomPromptTemplate class that takes in any template and formats a list of messages accordingly. It also covers the ability to specify a custom template from configs, which we were missing. More will be discussed in an upcoming tutorial update.

Test plan

  • update live docs and check rendering
  • unit tests for FinetuneDataset, grammar, samsum
  • unit tests for ToInputOutputMessages
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • compare loss curves of grammar, samsum to original versions on main
    image
image

Copy link

pytorch-bot bot commented Jul 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1234

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 75db622 with merge base 8519c35 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 27, 2024
@RdoubleA RdoubleA changed the title [WIP][1/n] Merged fine-tuning dataset, grammar + samsum [1/n] Merged fine-tuning dataset: grammar + samsum Jul 30, 2024
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are my deprecation warnings? :)

@RdoubleA
Copy link
Contributor Author

Where are my deprecation warnings? :)

will be added in an upcoming PR... there's a lot stacked on this one :)

@codecov-commenter
Copy link

codecov-commenter commented Aug 1, 2024

Codecov Report

Attention: Patch coverage is 87.04319% with 39 lines in your changes missing coverage. Please review.

Project coverage is 69.60%. Comparing base (bc6b7e9) to head (9822cdd).
Report is 3 commits behind head on main.

Files Patch % Lines
torchtune/models/llama3/_tokenizer.py 50.00% 7 Missing ⚠️
torchtune/models/phi3/_tokenizer.py 50.00% 7 Missing ⚠️
torchtune/models/gemma/_tokenizer.py 50.00% 5 Missing ⚠️
torchtune/models/llama2/_tokenizer.py 50.00% 5 Missing ⚠️
torchtune/models/mistral/_tokenizer.py 50.00% 5 Missing ⚠️
torchtune/models/qwen2/_tokenizer.py 60.00% 2 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 1 Missing ⚠️
... and 4 more
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1234       +/-   ##
===========================================
+ Coverage   27.41%   69.60%   +42.19%     
===========================================
  Files         233      238        +5     
  Lines       10591    10771      +180     
===========================================
+ Hits         2903     7497     +4594     
+ Misses       7688     3274     -4414     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first pass, not thinking too hard about the design. The main thing that caught my attention was trying to understand what exactly FinetuneDataset is. The first line makes it sound like its a Base or a General purpose dataset that other datasets would use, but Chat and Instruct dont, so it is not immediately clear to me when to use one or the other. I hope it makes sense.

torchtune/data/_prompt_templates.py Outdated Show resolved Hide resolved
torchtune/data/_prompt_templates.py Show resolved Hide resolved
torchtune/data/_prompt_templates.py Show resolved Hide resolved
torchtune/datasets/_finetune.py Outdated Show resolved Hide resolved

class FinetuneDataset(Dataset):
"""
Dataset class for creating instruct, chat, tool, or multimodal datasets for fine-tuning.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to use it in ChatDataset? At first, when reading this, I was thinking: "Ok, this is a base class or a general purpose dataset that other datasets will generally use", but then I checked ChatDataset and it is not there. Same for InstructDataset.

So, if the answer is no, I would be a bit confused about this description/location/naming.

If the answer is yes, maybe BaseDataset/GeneralDataset could be candidates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This will replace both Instruct and ChatDataset. We are essentially merging the two, while also adding support for multimodal.

I haven't thought of a better name than FinetuneDataset. Maybe TuneDataset or TokenizedDataset?

torchtune/datasets/_finetune.py Outdated Show resolved Hide resolved
torchtune/datasets/_finetune.py Outdated Show resolved Hide resolved
pass


class CustomPromptTemplate(PromptTemplate):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of renaming this to PromptTemplate and call the other thing PromptTemplateInterface.

The "Custom" part of prompt template feels redundant. Like what's the difference between a CustomPromptTemplate and a regular PromptTemplate? I see here by looking at the code that it's b/c one is an interface, but that's not clear from the names. This should be evident before having to go to the docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also use the Interface naming for protocols with our recipes so it makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense to me, although if a user needs a bit more custom behavior that's not offered by CustomPromptTemplate, I wouldn't want them to accidentally inherit from PromptTemplate instead of PromptTemplateInterface.

from torchtune.modules.transforms import Transform


class FinetuneDataset(Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually only for SFT? If so, what makes it SFT specific?

And if it is, then should we rename to SFTDataset????

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is SFT specific in the sense that it isn't Preference or TextCompletion. i.e., it covers the previous instruct, chat, and eventual multimodal datasets but it does not cover 1) chosen/rejected messaged in Preference, 2) using tokenizer.encode directly instead of tokenizer.tokenize_messages in TextCompletion

I'm cool with SFTDataset tho

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get some other opinions.

@felipemello1 @SalmanMohammadi @pbontrager @ebsmothers

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

milord? you asked for me?

SupervisedDataset? TaskDataset? SupervisedTaskDataset? STDataset?

otherwise +1 for something slightly more descriptive than FinetuneDataset , SFTDataset is OK too

All datasets are formatted into :class:`~torchtune.data.Message`s because for
fine-tuning, datasets can be considered as "conversations" with the model,
or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to
a :class:`~torchtune.data.Role`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't rendering correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh no

multimodal datasets requires processing the images in a way specific to the vision
encoder being used by the model and is agnostic to the specific dataset.

Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:class:~torchtune.modules.tokenizers.ModelTokenizer is not rendering correctly :(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we don't actually generate the doc for that anywhere perhaps?

Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`s
can be treated as a ``model_transform`` since it uses the model-specific tokenizer to
transform the list of messages outputted from the ``message_transform`` into tokens
used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not rendered.

- Task-specific templates to gear models for a particular task that it will expect after training
- Model-specific templates that are required whenever the model is prompted, such as the [INST]
tags in Llama2 and in Mistral
- Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not rendered

Args:
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See Hugging Face's ``load_dataset``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a hyperlink?

filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hyperlink

self,
*,
source: str,
message_transform: Transform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, Transform is not being picked up. Is it included in the docs somewhere?

I'm guessing no...

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am the King of Nits.

+ [10, 1, 6, -1]
]
ds = SFTDataset(
source="iam/agoofy/goober",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks world-class - docs are really clear and easy to understand. I'm excited for this to land 💯

@RdoubleA RdoubleA merged commit 167bb01 into pytorch:main Aug 5, 2024
29 checks passed
@RdoubleA RdoubleA deleted the merged_dataset_1 branch August 5, 2024 23:48
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't hate me for only getting around to reviewing this after you landed it. Anyways it looks great

@@ -130,113 +130,6 @@ def format(
return prompt


class GrammarErrorCorrectionTemplate(InstructTemplate):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍

@@ -8,6 +8,7 @@
from torchtune.datasets._chat import chat_dataset, ChatDataset
from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset
from torchtune.datasets._concat import ConcatDataset
from torchtune.datasets._finetune import SFTDataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be that guy but imo it's a bit confusing that our canonical dataset does not match the pattern of all our other datasets (class name is the capitalized version of the filename). I'm good with the name SFTDataset, but then maybe we should rename the file too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I missed that. will include that change in the subsequent PRs

mask = truncate(mask, max_seq_len, True)
if self.max_seq_len:
tokens = truncate(tokens, self.max_seq_len, self.eos_id)
mask = truncate(mask, self.max_seq_len, True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one doesn't need to add __call__/ Transform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good call out

tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`.
model_transform (Transform): model specific transform to convert a list of messages
output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will always be a :class:~torchtune.modules.tokenizers.ModelTokenizer

Just curious, why do we not type it as such then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly to be consistent with SFTDataset and the upcoming multimodal dataset builders. I'd be open to still calling it model_transform but typing it ModelTokenizer, but I don't know if we'd want to go all the way and just call this tokenizer

Default is False.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input"
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default "input" and "output" column names.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should give an example in the docstring

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants