-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Add support of multiple datasets in config #889
Feat: Add support of multiple datasets in config #889
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/889
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f7a3f95 with merge base aa65012 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @EvilFreelancer! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Hi! I've signed CLA, how to rerun failed check? |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Thanks for this awesome PR @EvilFreelancer! Handling multiple datasets has been a north star for us, as most mature data pipelines for fine-tuning models typically incorporate multiple data sources, so I appreciate you adding this.
This is my main concern - the concatenated datasets must have the same columns AND the same instruct template / chat format. Also, all the keyword arguments will need to be shared (max_seq_len, train_on_input). This is quite restrictive and will require users to do a lot of offline preprocessing work. Ideally, using multiple datasets should be flexible enough that each dataset can have different columns and different template yet we're able to coalesce these together. In the end, all the data gets tokenized as What are your thoughts on letting
Then, in the config, you could specify multiple datasets like this (or something similar):
And the recipe can instantiate a MultiDataset if the dataset param is a list. This way, each dataset can keep their individual parameters. Ideally, we can make use of |
Hi @RdoubleA, thank you for your response! Hm, the I have a couple of spare days, so I may start working on implementing this solution, as it is critical for the project I am working on to have the ability to train on various combinations of datasets. |
Hi! I've implemented logic of MultiDataset and enabled it in traning recipes. I've also removed original multi-source logic and cleaned tests. How to use: dataset:
- _component_: torchtune.datasets.instruct_dataset
source: tatsu-lab/alpaca
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.instruct_dataset
source: vicgalle/alpaca-gpt4
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.chat_dataset
source: Open-Orca/SlimOrca-Dedup
conversation_style: sharegpt
chat_format: Llama2ChatFormat
max_seq_len: 1024
split: train
- _component_: torchtune.datasets.chat_dataset
source: ajibawa-2023/Code-290k-ShareGPT
conversation_style: sharegpt
chat_format: Llama2ChatFormat
max_seq_len: 1024
split: train
seed: null
shuffle: True By the way, this gave an interesting side effect: using this logic, you can freely mix datasets of any format. In the example that I gave, dataset formats |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the simplicity of the MultiDataset
class you implemented, very easy to understand! There's some considerations around how we instantiate each individual dataset we should figure out.
recipes/full_finetune_distributed.py
Outdated
tokenizer=self._tokenizer, | ||
) | ||
|
||
if isinstance(cfg_dataset.get(0), DictConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can also just directly check if it's a ListConfig. If it's a single dataset then this might fail
if isinstance(cfg_dataset.get(0), DictConfig): | |
if isinstance(cfg_dataset, ListConfig): |
I also wonder if there's a better way to handle this so we don't have to repeat this if-else check across all recipes...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, i've replaced it with ListConfig
check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder if there's a better way to handle this so we don't have to repeat this if-else check across all recipes...
Yeah, this logic can be moved (for example) to config.instantiate
method, but I guess it will break single responsibility principle. So I suggest leaving it as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to make these changes to the test samples to use Dataset across all the individual dataset test files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to keep the datasets.Dataset
objects, since this is a common format and anyone can see that the data should come in this format, plus there is no need to perform any extra transformations, as was the case, for example in _chat
tests.
@RdoubleA hi! Thanks for your review, requested fixes added. By the way I've noticed one small thing, for example in UPD. Ah, because of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is turning into something really nice! Thanks for your patience working through this. Could you add a unit test for MultiDataset
on some toy data?
Also tagging @ebsmothers, @kartikayk to get their thoughts on how we should handle instantiating this in the recipes. I think we could just make every dataset a MultiDataset
, since you could just pass a list with a single dataset and it should work the same (correct me if I'm wrong here). That should simplify the logic in the recipes.
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
@RdoubleA hi! I've added fixes you mentioned, and couple simple tests on What do you think about: if not isinstance(cfg_dataset, ListConfig):
cfg_dataset = [cfg_dataset]
datasets = [config.instantiate(cfg_item, tokenizer=self._tokenizer) for cfg_item in cfg_dataset]
ds = utils.MultiDataset(datasets=datasets) instead of: if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenzier=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = utils.MultiDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) |
@EvilFreelancer Thanks for the updates! I'll take another pass soon. In the meantime, since this is a significant feature we are adding, I'd like to make sure this is rigorously tested. Do you mind confirming the following:
Also, I forgot to mention this earlier, but I think a better location for |
@EvilFreelancer Can you update the README with examples on how this will now look in the YAML config? |
torchtune/utils/multi_dataset.py
Outdated
# Calculate distribution of indexes in all datasets | ||
cumulative_index = 0 | ||
for idx, dataset in enumerate(datasets): | ||
next_cumulative_index = cumulative_index + len(dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RdoubleA Hi! I've made some fixes to the PR.
The reason I believed that the
The description of the PR was updated, and a note on how to run unit tests was added. pytest ./tests/torchtune/datasets/test_multi_dataset.py (venv) [pasha-pc] ~/Documents/Repository/nn-nlp/torchtune $ pytest ./tests/torchtune/datasets/test_multi_dataset.py
= test session starts =
platform linux -- Python 3.11.2, pytest-8.1.2, pluggy-1.5.0
rootdir: /home/pasha/Documents/Repository/nn-nlp/torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 3 items
tests/torchtune/datasets/test_multi_dataset.py ...
Unfortunately, my local server is equipped with only one GPU. As a result, I've created an instruction and enlisted the help of one of my subscribers who owns a multi-GPU server. A detailed update will be provided tomorrow. For now, I can only confirm that the training has started on the multi-GPU server and that the number of training steps is the same as from my side.
I've also created an instruction and trained Gemma 2b on a single device. Here is the WandB report about my attempt. GPU: 1x RTX 4090, CPU: 1x AMD 5950X. |
@joecummings hi!
I've added small section to dataset page in tutorial. |
@EvilFreelancer Thanks for sharing the test runs! Sorry for the delay in following up on this. An update from our side is that we've been having a lot of discussions around iterable datasets, which is something that we want to support more and design around moving forward. This would enable more flexibility around interleaving, weighted sampling, etc. for multiple datasets that won't fit in memory. That being said, I still think the Other than that, I have no major concerns for the rest of the changes. Let me know if this makes sense! |
torchtune/datasets/_multi.py
Outdated
log = utils.get_logger("DEBUG") | ||
|
||
|
||
class MultiDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comprehensive docstring here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comprehensive docstring added.
Hey @EvilFreelancer, are you still planning to make the changes? If not, I'm happy to do it for you and get this merged in. Let me know how you'd like to proceed. |
Hi @RdoubleA, I wanted to update you on the progress regarding the MultiDataset class. It has been renamed to ConcatDataset, along with corresponding updates to all related tests and imports in the recipes. I've also added a comprehensive docstring to the class to enhance clarity and usability. Apologies for the delay in implementing these changes, the last couple of days have been particularly hectic at work, and I couldn't address these tasks sooner. Thank you for your understanding. |
if isinstance(cfg_dataset, ListConfig): | ||
datasets = [ | ||
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) | ||
for single_cfg_dataset in cfg_dataset | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary for this PR, but it'd be nice to have a utility like
def instantiate_list_config(cfg: ListConfig, **kwargs) -> List[Any]:
return [
instantiate(element, kwargs) for element in cfg
]
to reduce copypasta across the recipes a bit. cc @RdoubleA to tackle in a follow-up
), | ||
} | ||
] | ||
load_dataset.return_value = Dataset.from_list( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the story with these Dataset.from_list
changes? I know it will give us the right return type (Dataset
instead of raw List
), anything besides that motivating the change? (I am fine with keeping them in, mainly asking out of curiosity)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically it's a more accurate return type of load dataset, I think it's ok to just leave these out and stick to primitives for simplicity, but I don't have a strong opinion here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've invested considerable time into understanding how to test my new dataset class. It was initially unclear regarding the required format and content of its elements. Hence, I suggest providing clear guidelines to save fellow programmers time, elucidating the expected format for dataset elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is very good feedback. I agree that the contracts of various dataset components are not always obvious and take some time to sort through. Aside from improving live docs and better code comments, I'm open to any suggestions you have on how to make this clearer based on your experience.
Co-authored-by: ebsmothers <ebs@meta.com>
@ebsmothers typo fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thank you for adding this!
Thank you for the meticulous code-review and excellent advice, I thoroughly enjoyed working on this project with all of you! |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
I've added the ability to use multiple sources for any types of datasets.
After the merger of this PR, users of TorchTune will be able to pass multiple datasets in different formats. For example, it will be possible to mix chat and instruct datasets using different templates, splits, etc.
Example of the new version of the config:
For backward compatibility, users can continue using the previous format of the
dataset
field if they do not wish to use multiple datasets:To run unit tests:
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test