Skip to content

Commit

Permalink
Better wikitext (#1258)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Aug 5, 2024
1 parent 8519c35 commit 5019074
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 25 deletions.
2 changes: 1 addition & 1 deletion tests/torchtune/datasets/test_wikitext_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len):
# Sample data from wikitext dataset
load_dataset.return_value = [
{
"text": "Bart , like the rest of his family , has yellow skin . "
"page": "Bart , like the rest of his family , has yellow skin . "
"Bart usually wears a red T @-@ shirt , blue shorts and blue trainers . "
"When the Simpson family goes to church in the episodes , or to school "
"events or shows , Bart wears a blue suit with a white shirt , a purple "
Expand Down
37 changes: 27 additions & 10 deletions torchtune/datasets/_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, Union

from datasets import load_dataset
from torch.utils.data import Dataset
Expand All @@ -26,8 +26,9 @@ class TextCompletionDataset(Dataset):
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
for more details.
column (str): name of column in the sample that contains the text data. This is typically required
for Hugging Face datasets or tabular data. For local datasets with a single column, use the default "text",
which is what is assigned by Hugging Face datasets when loaded into memory. Default is "text".
for Hugging Face datasets or tabular data. For local datasets with a single column
(e.g. unstructured txt files), use the default "text" which is used by Hugging Face datasets
when loaded into memory. Default is "text".
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
Expand Down Expand Up @@ -75,12 +76,13 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
def text_completion_dataset(
tokenizer: ModelTokenizer,
source: str,
column: Optional[str] = None,
column: str = "text",
max_seq_len: Optional[int] = None,
add_eos: bool = True,
packed: bool = False,
split_across_pack: bool = True,
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
) -> Union[TextCompletionDataset, PackedDataset]:
"""
Build a configurable dataset from a freeform, unstructured text corpus similar
to datasets used in pre-training. This method should be
Expand All @@ -89,15 +91,25 @@ def text_completion_dataset(
Args:
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See Hugging Face's ``load_dataset``
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
column (Optional[str]): name of column in the sample that contains the text data. This is typically required
for Hugging Face datasets or tabular data, but can be omitted for local datasets. Default is None.
for more details.
column (str): name of column in the sample that contains the text data. This is typically required
for Hugging Face datasets or tabular data. For local datasets with a single column
(e.g. unstructured txt files), use the default "text" which is used by Hugging Face datasets
when loaded into memory. Default is "text".
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
add_eos (bool): Whether to add an EOS token to the end of the sequence. Default is True.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``,
split the sample into the next pack, or move it entirely to the beginning of the next pack.
For pre-training, typically this is set to True for general text completion. For
fine-tuning, typically this is set to False to avoid truncating sentences in instruct
tuning. This argument is ignored if ``packed=False``. Default is True.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Examples:
Expand All @@ -122,7 +134,7 @@ def text_completion_dataset(
packed: False
Returns:
TextCompletionDataset or PackedDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset`
Union[TextCompletionDataset, PackedDataset]: the configured :class:`~torchtune.datasets.TextCompletionDataset`
or :class:`~torchtune.datasets.PackedDataset` if ``packed=True``
"""
ds = TextCompletionDataset(
Expand All @@ -134,7 +146,12 @@ def text_completion_dataset(
**load_dataset_kwargs,
)
return (
PackedDataset(ds, max_seq_len=max_seq_len, padding_idx=tokenizer.pad_id)
PackedDataset(
ds,
max_seq_len=max_seq_len,
padding_idx=tokenizer.pad_id,
split_across_pack=split_across_pack,
)
if packed
else ds
)
39 changes: 26 additions & 13 deletions torchtune/datasets/_wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,59 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from torchtune.datasets._text_completion import TextCompletionDataset
from torchtune.datasets._packed import PackedDataset

from torchtune.datasets._text_completion import (
text_completion_dataset,
TextCompletionDataset,
)

from torchtune.modules.tokenizers import ModelTokenizer


def wikitext_dataset(
tokenizer: ModelTokenizer,
source: str = "wikitext",
subset: str = "wikitext-103-raw-v1",
source: str = "EleutherAI/wikitext_document_level",
subset: str = "wikitext-103-v1",
max_seq_len: Optional[int] = None,
packed: bool = False,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
) -> Union[TextCompletionDataset, PackedDataset]:
"""
Support for family of datasets similar to `wikitext <https://huggingface.co/datasets/wikitext>`_,
an unstructured text corpus consisting of articles from Wikipedia.
Support for family of datasets similar to `wikitext
<https://huggingface.co/datasets/EleutherAI/wikitext_document_level>`_,
an unstructured text corpus consisting of fulls articles from Wikipedia.
Args:
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See Hugging Face's ``load_dataset``
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
subset (str): name of subset of data to use, see the `wikitext page <https://huggingface.co/datasets/wikitext#data-fields>`_
for available subsets.
for more details.
subset (str): name of subset of data to use, see the `wikitext page
<https://huggingface.co/datasets/EleutherAI/wikitext_document_level#data-instances>`_
for available subsets. Default is ``"wikitext-103-v1"``.
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Returns:
TextCompletionDataset: the configured TextCompletionDataset
Union[TextCompletionDataset, PackedDataset]: the configured :class:`~torchtune.datasets.TextCompletionDataset`
or :class:`~torchtune.datasets.PackedDataset` if ``packed=True``
"""

return TextCompletionDataset(
return text_completion_dataset(
tokenizer=tokenizer,
source=source,
column="text",
column="page",
max_seq_len=max_seq_len,
name=subset,
split=split,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/utils/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def padded_collate(
(0, labels_seq_len - input_ids_seq_len),
value=padding_idx,
)
return {"tokens": input_ids, "labels": labels}
return {"tokens": input_ids.long(), "labels": labels.long()}


def padded_collate_dpo(
Expand Down

0 comments on commit 5019074

Please sign in to comment.