Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _text_completion.py to support packed mode #1061

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions torchtune/datasets/_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data import truncate
from torchtune.datasets._packed import PackedDataset
from torchtune.modules.tokenizers import Tokenizer


Expand Down Expand Up @@ -68,6 +69,7 @@ def text_completion_dataset(
source: str,
column: Optional[str] = None,
max_seq_len: Optional[int] = None,
packed: bool = False,
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
"""
Expand All @@ -84,6 +86,7 @@ def text_completion_dataset(
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.

Examples:
Expand All @@ -94,6 +97,7 @@ def text_completion_dataset(
... column="text",
... max_seq_len=2096,
... data_dir="realnewslike",
... packed=True,
andyl98 marked this conversation as resolved.
Show resolved Hide resolved
... )

This can also be accomplished via the yaml config::
Expand All @@ -104,14 +108,17 @@ def text_completion_dataset(
column: text
max_seq_len: 2096
data_dir: realnewslike
packed: True
andyl98 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
TextCompletionDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset`
TextCompletionDataset or PackedDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset`
or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`
"""
return TextCompletionDataset(
ds = TextCompletionDataset(
tokenizer=tokenizer,
source=source,
column=column,
max_seq_len=max_seq_len,
**load_dataset_kwargs,
)
return PackedDataset(ds, max_seq_len=max_seq_len) if packed else ds