Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add support of multiple datasets in config #889

Merged
merged 25 commits into from
May 3, 2024

Conversation

EvilFreelancer
Copy link
Contributor

@EvilFreelancer EvilFreelancer commented Apr 27, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

I've added the ability to use multiple sources for any types of datasets.

After the merger of this PR, users of TorchTune will be able to pass multiple datasets in different formats. For example, it will be possible to mix chat and instruct datasets using different templates, splits, etc.

Example of the new version of the config:

dataset:
  - _component_: torchtune.datasets.instruct_dataset
    source: tatsu-lab/alpaca
    template: AlpacaInstructTemplate
    split: train
    train_on_input: True
  - _component_: torchtune.datasets.chat_dataset
    source: Open-Orca/SlimOrca-Dedup
    conversation_style: sharegpt
    chat_format: Llama2ChatFormat
    max_seq_len: 1024
    split: train
seed: null
shuffle: True

For backward compatibility, users can continue using the previous format of the dataset field if they do not wish to use multiple datasets:

dataset:
  _component_: torchtune.datasets.instruct_dataset
  source: tatsu-lab/alpaca
  template: AlpacaInstructTemplate
  split: train
  train_on_input: True

To run unit tests:

pytest ./tests/torchtune/datasets/test_multi_dataset.py

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented Apr 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/889

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f7a3f95 with merge base aa65012 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @EvilFreelancer!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@EvilFreelancer
Copy link
Contributor Author

Hi! I've signed CLA, how to rerun failed check?

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 27, 2024
@RdoubleA
Copy link
Contributor

Thanks for this awesome PR @EvilFreelancer! Handling multiple datasets has been a north star for us, as most mature data pipelines for fine-tuning models typically incorporate multiple data sources, so I appreciate you adding this.

The only one problem is: dataset must have simmilar formats.

This is my main concern - the concatenated datasets must have the same columns AND the same instruct template / chat format. Also, all the keyword arguments will need to be shared (max_seq_len, train_on_input). This is quite restrictive and will require users to do a lot of offline preprocessing work. Ideally, using multiple datasets should be flexible enough that each dataset can have different columns and different template yet we're able to coalesce these together. In the end, all the data gets tokenized as Messages, so I do think this is possible.

What are your thoughts on letting InstructDataset and ChatDataset handle single data sources only, and creating some container class that can hold multiple InstructDatasets or ChatDatasets? Since the dataset classes return the token IDs ready to be used by the model, the container class can focus on handling the logic of sampling from the list of datasets.

class MultiDataset(torch.utils.data.Dataset):
    def __init__(self, datasets: List[Dataset]):
        self.datasets = datasets

    def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
        # Figure out how to sample/interleave multiple datasets here

Then, in the config, you could specify multiple datasets like this (or something similar):

dataset:
  - _component_: torchtune.datasets.instruct_dataset
    source: tatsu-lab/alpaca
    ...
  - _component_: torchtune.datasets.instruct_dataset
    source: vicgalle/alpaca-gpt4
    ...

And the recipe can instantiate a MultiDataset if the dataset param is a list. This way, each dataset can keep their individual parameters.

Ideally, we can make use of concatenate_datasets and interleave_datasets from HF but since we have preprocessing logic in our dataset classes that are specific to each data source, we might have to create similar logic ourselves.

@EvilFreelancer
Copy link
Contributor Author

Hi @RdoubleA, thank you for your response!

Hm, the MultiDataset class and the possibility to pass an array of datasets through the dataset parameter - it's a great idea, and it's much simpler for end-users to understand how to use this feature than dealing with datasets in similar formats.

I have a couple of spare days, so I may start working on implementing this solution, as it is critical for the project I am working on to have the ability to train on various combinations of datasets.

@EvilFreelancer
Copy link
Contributor Author

Hi! I've implemented logic of MultiDataset and enabled it in traning recipes.

I've also removed original multi-source logic and cleaned tests.

How to use:

dataset:
  - _component_: torchtune.datasets.instruct_dataset
    source: tatsu-lab/alpaca
    template: AlpacaInstructTemplate
    split: train
    train_on_input: True
  - _component_: torchtune.datasets.instruct_dataset
    source: vicgalle/alpaca-gpt4
    template: AlpacaInstructTemplate
    split: train
    train_on_input: True
  - _component_: torchtune.datasets.chat_dataset
    source: Open-Orca/SlimOrca-Dedup
    conversation_style: sharegpt
    chat_format: Llama2ChatFormat
    max_seq_len: 1024
    split: train
  - _component_: torchtune.datasets.chat_dataset
    source: ajibawa-2023/Code-290k-ShareGPT
    conversation_style: sharegpt
    chat_format: Llama2ChatFormat
    max_seq_len: 1024
    split: train
seed: null
shuffle: True

By the way, this gave an interesting side effect: using this logic, you can freely mix datasets of any format.

In the example that I gave, dataset formats instrcut and chat are mixed, it seems to me that this will be a very interesting and convenient feature of TorchTune.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the simplicity of the MultiDataset class you implemented, very easy to understand! There's some considerations around how we instantiate each individual dataset we should figure out.

tokenizer=self._tokenizer,
)

if isinstance(cfg_dataset.get(0), DictConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also just directly check if it's a ListConfig. If it's a single dataset then this might fail

Suggested change
if isinstance(cfg_dataset.get(0), DictConfig):
if isinstance(cfg_dataset, ListConfig):

I also wonder if there's a better way to handle this so we don't have to repeat this if-else check across all recipes...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, i've replaced it with ListConfig check.

Copy link
Contributor Author

@EvilFreelancer EvilFreelancer Apr 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if there's a better way to handle this so we don't have to repeat this if-else check across all recipes...

Yeah, this logic can be moved (for example) to config.instantiate method, but I guess it will break single responsibility principle. So I suggest leaving it as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to make these changes to the test samples to use Dataset across all the individual dataset test files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to keep the datasets.Dataset objects, since this is a common format and anyone can see that the data should come in this format, plus there is no need to perform any extra transformations, as was the case, for example in _chat tests.

torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
@EvilFreelancer
Copy link
Contributor Author

EvilFreelancer commented Apr 28, 2024

@RdoubleA hi! Thanks for your review, requested fixes added.

By the way I've noticed one small thing, for example in vicgalle/alpaca-gpt4 dataset have 52k rows, why TorchTune shows me 26k items on training stage? Maybe it's because train/val splits?

UPD. Ah, because of batch_size: 2, got it.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is turning into something really nice! Thanks for your patience working through this. Could you add a unit test for MultiDataset on some toy data?

Also tagging @ebsmothers, @kartikayk to get their thoughts on how we should handle instantiating this in the recipes. I think we could just make every dataset a MultiDataset, since you could just pass a list with a single dataset and it should work the same (correct me if I'm wrong here). That should simplify the logic in the recipes.

torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
@EvilFreelancer
Copy link
Contributor Author

EvilFreelancer commented Apr 29, 2024

@RdoubleA hi! I've added fixes you mentioned, and couple simple tests on MultiDataset class.

What do you think about:

if not isinstance(cfg_dataset, ListConfig):
    cfg_dataset = [cfg_dataset]
datasets = [config.instantiate(cfg_item, tokenizer=self._tokenizer) for cfg_item in cfg_dataset]
ds = utils.MultiDataset(datasets=datasets)

instead of:

if isinstance(cfg_dataset, ListConfig):
    datasets = [
        config.instantiate(single_cfg_dataset, tokenzier=self._tokenizer)
        for single_cfg_dataset in cfg_dataset
    ]
    ds = utils.MultiDataset(datasets=datasets)
else:
    ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)

@RdoubleA
Copy link
Contributor

RdoubleA commented Apr 29, 2024

@EvilFreelancer Thanks for the updates! I'll take another pass soon.

In the meantime, since this is a significant feature we are adding, I'd like to make sure this is rigorously tested. Do you mind confirming the following:

  • The unit test for MultiDataset passes (please add the test command to the PR summary)
  • Run one of the distributed recipes using MultiDataset and confirm that loss curves and tokens/sec look reasonable (you can use WandBLogger for easy visualization). Main thing I want to confirm here is that we can still sample from multiple datasets correctly in a distributed environment
  • Run one of the single device recipes using MultiDataset for a similar reason above

Also, I forgot to mention this earlier, but I think a better location for MultiDataset would be in torchtune/datasets instead of torchtune/utils, what do you think?

@joecummings
Copy link
Contributor

@EvilFreelancer Can you update the README with examples on how this will now look in the YAML config?

torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
# Calculate distribution of indexes in all datasets
cumulative_index = 0
for idx, dataset in enumerate(datasets):
next_cumulative_index = cumulative_index + len(dataset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for this PR, but worth thinking about how we're gonna do this for iterable datasets @RdoubleA @SLR722

torchtune/utils/multi_dataset.py Outdated Show resolved Hide resolved
@EvilFreelancer
Copy link
Contributor Author

EvilFreelancer commented Apr 30, 2024

@RdoubleA Hi! I've made some fixes to the PR.

Also, I forgot to mention this earlier, but I think a better location for MultiDataset would be in torchtune/datasets instead of torchtune/utils, what do you think?

The reason I believed that the torchtune/utils namespace was more suitable for this class is because the MultiDataset does not represent an actual dataset, such as Alpaca or OpenOrca, instead, it serves merely as a wrapper over several datasets. However, I agree that it is most logical to move this class to the torchtune/datasets namespace. (code refactored)

The unit test for MultiDataset passes

The description of the PR was updated, and a note on how to run unit tests was added.

pytest ./tests/torchtune/datasets/test_multi_dataset.py
(venv) [pasha-pc] ~/Documents/Repository/nn-nlp/torchtune $ pytest ./tests/torchtune/datasets/test_multi_dataset.py
= test session starts =
platform linux -- Python 3.11.2, pytest-8.1.2, pluggy-1.5.0
rootdir: /home/pasha/Documents/Repository/nn-nlp/torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 3 items

tests/torchtune/datasets/test_multi_dataset.py ...

Run one of the distributed recipes using MultiDataset and confirm that loss curves and tokens/sec look reasonable (you can use WandBLogger for easy visualization). Main thing I want to confirm here is that we can still sample from multiple datasets correctly in a distributed environment

Unfortunately, my local server is equipped with only one GPU. As a result, I've created an instruction and enlisted the help of one of my subscribers who owns a multi-GPU server. A detailed update will be provided tomorrow. For now, I can only confirm that the training has started on the multi-GPU server and that the number of training steps is the same as from my side.

Run one of the single device recipes using MultiDataset for a similar reason above

I've also created an instruction and trained Gemma 2b on a single device. Here is the WandB report about my attempt.

GPU: 1x RTX 4090, CPU: 1x AMD 5950X.

@EvilFreelancer
Copy link
Contributor Author

@joecummings hi!

Can you update the README with examples on how this will now look in the YAML config?

I've added small section to dataset page in tutorial.

@EvilFreelancer
Copy link
Contributor Author

EvilFreelancer commented May 1, 2024

The tests for Multi GPU Gemma 2B training are available here.

UPD. Second training on the same hardware without using MultiDataset class on a single Aplaca dataset.

As you can see tokens_per_seconds distribution is almost the same.

GPU: 2x RTX 4090, CPU: 1x Xeon Gold 6336Y.

@RdoubleA
Copy link
Contributor

RdoubleA commented May 1, 2024

@EvilFreelancer Thanks for sharing the test runs! Sorry for the delay in following up on this. An update from our side is that we've been having a lot of discussions around iterable datasets, which is something that we want to support more and design around moving forward. This would enable more flexibility around interleaving, weighted sampling, etc. for multiple datasets that won't fit in memory.

That being said, I still think the MultiDataset here is valuable for datasets that fit in memory and can still leverage map-style functionality. This will unblock users that want to quickly concatenate multiple data sources until we add iterable datasets and a more powerful MultiDataset. So to accurately reflect the scope of this class, my suggestion is to rename this to ConcatDataset (you might've started with this name to begin with, I apologize) and make it clear in the docstrings that this is for map datasets that all fit in memory.

Other than that, I have no major concerns for the rest of the changes. Let me know if this makes sense!

log = utils.get_logger("DEBUG")


class MultiDataset(Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comprehensive docstring here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comprehensive docstring added.

@RdoubleA
Copy link
Contributor

RdoubleA commented May 3, 2024

Hey @EvilFreelancer, are you still planning to make the changes? If not, I'm happy to do it for you and get this merged in. Let me know how you'd like to proceed.

@EvilFreelancer
Copy link
Contributor Author

Hi @RdoubleA,

I wanted to update you on the progress regarding the MultiDataset class. It has been renamed to ConcatDataset, along with corresponding updates to all related tests and imports in the recipes. I've also added a comprehensive docstring to the class to enhance clarity and usability.

Apologies for the delay in implementing these changes, the last couple of days have been particularly hectic at work, and I couldn't address these tasks sooner. Thank you for your understanding.

docs/source/tutorials/datasets.rst Outdated Show resolved Hide resolved
Comment on lines +361 to +365
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary for this PR, but it'd be nice to have a utility like

def instantiate_list_config(cfg: ListConfig, **kwargs) -> List[Any]:
    return [
        instantiate(element, kwargs) for element in cfg
    ]

to reduce copypasta across the recipes a bit. cc @RdoubleA to tackle in a follow-up

),
}
]
load_dataset.return_value = Dataset.from_list(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the story with these Dataset.from_list changes? I know it will give us the right return type (Dataset instead of raw List), anything besides that motivating the change? (I am fine with keeping them in, mainly asking out of curiosity)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it's a more accurate return type of load dataset, I think it's ok to just leave these out and stick to primitives for simplicity, but I don't have a strong opinion here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've invested considerable time into understanding how to test my new dataset class. It was initially unclear regarding the required format and content of its elements. Hence, I suggest providing clear guidelines to save fellow programmers time, elucidating the expected format for dataset elements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is very good feedback. I agree that the contracts of various dataset components are not always obvious and take some time to sort through. Aside from improving live docs and better code comments, I'm open to any suggestions you have on how to make this clearer based on your experience.

@ebsmothers
Copy link
Contributor

Ah crap one more thing: my suggestion screwed up the rendering of the multi-dataset YAML. Sorry about that! I think you just need to add back a newline before the codeblock statement.

Screenshot 2024-05-03 at 2 35 18 PM

@EvilFreelancer
Copy link
Contributor Author

@ebsmothers typo fixed

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thank you for adding this!

@ebsmothers ebsmothers merged commit d36e818 into pytorch:main May 3, 2024
29 checks passed
@EvilFreelancer
Copy link
Contributor Author

Thank you for the meticulous code-review and excellent advice, I thoroughly enjoyed working on this project with all of you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants