From 1300c12ac8fe65c09fff0b0a7470abfe9e4ad297 Mon Sep 17 00:00:00 2001 From: Andy Liu Date: Wed, 5 Jun 2024 17:53:52 -0700 Subject: [PATCH 1/4] Update _text_completion.py --- torchtune/datasets/_text_completion.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchtune/datasets/_text_completion.py b/torchtune/datasets/_text_completion.py index 02ce933482..d118d36f33 100644 --- a/torchtune/datasets/_text_completion.py +++ b/torchtune/datasets/_text_completion.py @@ -9,6 +9,7 @@ from datasets import load_dataset from torch.utils.data import Dataset from torchtune.data import truncate +from torchtune.datasets._packed import PackedDataset from torchtune.modules.tokenizers import Tokenizer @@ -26,6 +27,7 @@ class TextCompletionDataset(Dataset): max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. """ @@ -68,6 +70,7 @@ def text_completion_dataset( source: str, column: Optional[str] = None, max_seq_len: Optional[int] = None, + packed: bool = False, **load_dataset_kwargs: Dict[str, Any], ) -> TextCompletionDataset: """ @@ -94,6 +97,7 @@ def text_completion_dataset( ... column="text", ... max_seq_len=2096, ... data_dir="realnewslike", + ... packed=True, ... ) This can also be accomplished via the yaml config:: @@ -104,14 +108,17 @@ def text_completion_dataset( column: text max_seq_len: 2096 data_dir: realnewslike + packed: True Returns: - TextCompletionDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset` + TextCompletionDataset or PackedDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset` + or :class:`~torchtune.datasets.PackedDataset` if ``packed=True` """ - return TextCompletionDataset( + ds = TextCompletionDataset( tokenizer=tokenizer, source=source, column=column, max_seq_len=max_seq_len, **load_dataset_kwargs, ) + return PackedDataset(ds, max_seq_len=max_seq_len) if packed else ds From 79f6d019f768ac2492b8ad9855e6cb4e0dc7f3da Mon Sep 17 00:00:00 2001 From: Andy Liu Date: Wed, 5 Jun 2024 20:13:05 -0700 Subject: [PATCH 2/4] docstring bug fix --- torchtune/datasets/_text_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/datasets/_text_completion.py b/torchtune/datasets/_text_completion.py index d118d36f33..96c0028cfb 100644 --- a/torchtune/datasets/_text_completion.py +++ b/torchtune/datasets/_text_completion.py @@ -27,7 +27,6 @@ class TextCompletionDataset(Dataset): max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. - packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. """ @@ -87,6 +86,7 @@ def text_completion_dataset( max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. + packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. Examples: From 231b68b368f26a70e8269ffa28e414f707e04d88 Mon Sep 17 00:00:00 2001 From: Andy Liu Date: Mon, 10 Jun 2024 19:48:48 -0700 Subject: [PATCH 3/4] Update _text_completion.py --- torchtune/datasets/_text_completion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/datasets/_text_completion.py b/torchtune/datasets/_text_completion.py index 96c0028cfb..7e24f72444 100644 --- a/torchtune/datasets/_text_completion.py +++ b/torchtune/datasets/_text_completion.py @@ -97,7 +97,7 @@ def text_completion_dataset( ... column="text", ... max_seq_len=2096, ... data_dir="realnewslike", - ... packed=True, + ... packed=False, ... ) This can also be accomplished via the yaml config:: @@ -108,7 +108,7 @@ def text_completion_dataset( column: text max_seq_len: 2096 data_dir: realnewslike - packed: True + packed: False Returns: TextCompletionDataset or PackedDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset` From 6631b0747e54cb45875e5504a7d353361bcee16a Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 10 Jun 2024 22:25:11 -0700 Subject: [PATCH 4/4] missing backtick --- torchtune/datasets/_text_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/datasets/_text_completion.py b/torchtune/datasets/_text_completion.py index 7e24f72444..394451859b 100644 --- a/torchtune/datasets/_text_completion.py +++ b/torchtune/datasets/_text_completion.py @@ -112,7 +112,7 @@ def text_completion_dataset( Returns: TextCompletionDataset or PackedDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset` - or :class:`~torchtune.datasets.PackedDataset` if ``packed=True` + or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` """ ds = TextCompletionDataset( tokenizer=tokenizer,