|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from typing import Any, Callable, Dict, List, Mapping, Optional, Union |
| 7 | +from typing import Any, Dict, Optional, Union |
8 | 8 |
|
9 | | -import numpy as np |
10 | | - |
11 | | -from datasets import load_dataset |
12 | | -from torch.utils.data import Dataset |
13 | | -from torchtune.data._chat_formats import ChatFormat |
14 | | -from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX |
15 | | -from torchtune.data._messages import ( |
16 | | - Message, |
17 | | - OpenAIToMessages, |
18 | | - ShareGPTToMessages, |
19 | | - validate_messages, |
20 | | -) |
| 9 | +from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages |
21 | 10 | from torchtune.datasets._packed import PackedDataset |
22 | 11 | from torchtune.datasets._sft import SFTDataset |
23 | 12 | from torchtune.modules.tokenizers import ModelTokenizer |
24 | | -from torchtune.utils._logging import deprecated |
25 | | - |
26 | | - |
27 | | -@deprecated(msg="Please use `torchtune.datasets.SFTDataset` for custom chat data.") |
28 | | -class ChatDataset(Dataset): |
29 | | - """ |
30 | | - Note: |
31 | | - This class is deprecated and will be removed in a future release. Please use |
32 | | - :class:`~torchtune.datasets.SFTDataset` or :func:`~torchtune.datasets.chat_dataset` |
33 | | - for custom chat data. |
34 | | -
|
35 | | - Class that supports any custom dataset with multiturn conversations. |
36 | | -
|
37 | | - The general flow from loading a sample to tokenized prompt is: |
38 | | - load sample -> apply transform -> foreach turn{format into template -> tokenize} |
39 | | -
|
40 | | - Use ``convert_to_messages`` to prepare your dataset into the Llama2 chat format |
41 | | - and roles:: |
42 | | -
|
43 | | - [ |
44 | | - Message( |
45 | | - role=<system|user|assistant>, |
46 | | - content=<message>, |
47 | | - ), |
48 | | - ... |
49 | | - ] |
50 | | -
|
51 | | - This class supports multi-turn conversations. If a tokenizer sample with multiple |
52 | | - turns does not fit within ``max_seq_len`` then it is truncated. |
53 | | -
|
54 | | - Args: |
55 | | - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. |
56 | | - source (str): path to dataset repository on Hugging Face. For local datasets, |
57 | | - define source as the data file type (e.g. "json", "csv", "text") and pass |
58 | | - in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` |
59 | | - (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) |
60 | | - for more details. |
61 | | - convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample |
62 | | - and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys |
63 | | - chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual |
64 | | - messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not |
65 | | - as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed, |
66 | | - unless you want to structure messages in a particular way for inference. |
67 | | - max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. |
68 | | - train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. |
69 | | - **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, |
70 | | - such as ``data_files`` or ``split``. |
71 | | - """ |
72 | | - |
73 | | - def __init__( |
74 | | - self, |
75 | | - *, |
76 | | - tokenizer: ModelTokenizer, |
77 | | - source: str, |
78 | | - convert_to_messages: Callable[[Mapping[str, Any]], List[Message]], |
79 | | - chat_format: Optional[ChatFormat] = None, |
80 | | - max_seq_len: int, |
81 | | - train_on_input: bool = False, |
82 | | - **load_dataset_kwargs: Dict[str, Any], |
83 | | - ) -> None: |
84 | | - |
85 | | - self._tokenizer = tokenizer |
86 | | - self._data = load_dataset(source, **load_dataset_kwargs) |
87 | | - self._convert_to_messages = convert_to_messages |
88 | | - self.chat_format = chat_format |
89 | | - self.max_seq_len = max_seq_len |
90 | | - self.train_on_input = train_on_input |
91 | | - |
92 | | - def __len__(self): |
93 | | - return len(self._data) |
94 | | - |
95 | | - def __getitem__(self, index: int) -> Dict[str, List[int]]: |
96 | | - sample = self._data[index] |
97 | | - return self._prepare_sample(sample) |
98 | | - |
99 | | - def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: |
100 | | - messages = self._convert_to_messages(sample, self.train_on_input) |
101 | | - if self.chat_format is not None: |
102 | | - messages = self.chat_format.format(messages) |
103 | | - validate_messages(messages) |
104 | | - tokens, mask = self._tokenizer.tokenize_messages( |
105 | | - messages, |
106 | | - ) |
107 | | - # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens |
108 | | - labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens)) |
109 | | - assert len(tokens) == len(labels) |
110 | | - |
111 | | - return {"tokens": tokens, "labels": labels} |
112 | 13 |
|
113 | 14 |
|
114 | 15 | def chat_dataset( |
|
0 commit comments