Skip to content

Commit 81b6abe

Browse files
committed
comments
1 parent d25f657 commit 81b6abe

File tree

3 files changed

+23
-14
lines changed

3 files changed

+23
-14
lines changed

docs/source/basics/packing.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
Sample packing
55
==============
66

7-
You can use sample packing with any of the single dataset builders by passing in
8-
:code:`packed=True`. This requires some pre-processing of the dataset which may
7+
Sample packing involves concatenating multiple samples from your dataset into a single sequence, upto a maximum
8+
sequence length. This requires some pre-processing of the dataset which may
99
slow down time-to-first-batch, but can introduce significant training speedups
10-
depending on the dataset.
10+
depending on the dataset. In torchtune, sample packing is done by iterating through your dataset and performing
11+
greedy packing upon dataset initialization. You can use sample packing with any of the single dataset builders by passing in
12+
:code:`packed=True`.
1113

1214
To set the max sequence length to pack to, make sure to define ``max_seq_len`` on your tokenizer.
1315

@@ -48,5 +50,5 @@ To set the max sequence length to pack to, make sure to define ``max_seq_len`` o
4850
torchtune will automatically handle document masking and relative position IDs when sample packing is enabled
4951
to prevent different irrelevant samples from cross-attending. This is done via PyTorch's `Flex Attention <https://pytorch.org/blog/flexattention/#document-maskingjagged-sequences>`_,
5052
which enables the use of flash attention with non-causal masks. If your hardware does not support Flex Attention
51-
(for CUDA devices, it must be Turing or above), standard SDPA with ememory-efficient attention will be used as a fallback,
53+
(for CUDA devices, it must be Turing or above), standard SDPA with memory-efficient attention will be used as a fallback,
5254
while retaining the document masking and relative position IDs.

recipes/configs/generation.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ tokenizer:
3030
prompt_template: null
3131

3232
# Generation arguments; defaults taken from gpt-fast
33-
prompt: "Tell me a joke?"
33+
prompt:
34+
system: null
35+
user: "Tell me a joke."
3436
max_new_tokens: 300
3537
temperature: 0.6 # 0.8 and 0.6 are popular values to try
3638
top_k: 300

recipes/generate.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch import nn
1414

1515
from torchtune import config, generation, training, utils
16-
from torchtune.data import Message
16+
from torchtune.data import Message, Role
1717
from torchtune.training import FullModelTorchTuneCheckpointer
1818

1919
logger = utils.get_logger("DEBUG")
@@ -99,17 +99,22 @@ def _setup_model(
9999

100100
def convert_prompt_to_tokens(
101101
self,
102-
prompt: str,
102+
prompt: Dict[Role, str],
103103
) -> List[int]:
104104
"""
105-
Convert the prompt string to a user message and tokenize using the prompt template
106-
defined on the tokenizer.
105+
Convert the prompt string to a user message with optional system messages
106+
and tokenize using the prompt template defined on the tokenizer.
107107
"""
108-
messages = [
109-
Message(role="user", content=prompt),
110-
# Empty assistant message to kick-start generation
111-
Message(role="assistant", content=""),
112-
]
108+
messages = []
109+
if "system" in prompt and prompt["system"] is not None:
110+
messages.append(Message(role="system", content=prompt["system"]))
111+
messages.extend(
112+
[
113+
Message(role="user", content=prompt["user"]),
114+
# Empty assistant message to kick-start generation
115+
Message(role="assistant", content=""),
116+
]
117+
)
113118
return self._tokenizer({"messages": messages}, inference=True)["tokens"]
114119

115120
@torch.inference_mode()

0 commit comments

Comments
 (0)