Skip to content

Commit 986f90c

Browse files
authored
Merge branch 'main' into hsdp
2 parents 0be173f + 952078e commit 986f90c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+13336
-254
lines changed

README.md

Lines changed: 87 additions & 78 deletions
Large diffs are not rendered by default.

docs/source/api_ref_modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ model specific tokenizers.
5050

5151
transforms.tokenizers.SentencePieceBaseTokenizer
5252
transforms.tokenizers.TikTokenBaseTokenizer
53+
transforms.tokenizers.HuggingFaceBaseTokenizer
5354
transforms.tokenizers.ModelTokenizer
5455
transforms.tokenizers.BaseTokenizer
5556

docs/source/api_ref_training.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Various logging utilities.
102102
metric_logging.TensorBoardLogger
103103
metric_logging.StdoutLogger
104104
metric_logging.DiskLogger
105+
metric_logging.MLFlowLogger
105106

106107
.. _perf_profiling_label:
107108

docs/source/basics/model_transforms.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ These are intended to be drop-in replacements for tokenizers in multimodal datas
3232
Message(
3333
role="user",
3434
content=[
35-
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
36-
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
35+
{"type": "image", "content": Image.new(mode="RGB", size=(560, 560))},
36+
{"type": "image", "content": Image.new(mode="RGB", size=(560, 560))},
3737
{"type": "text", "content": "What is common in these two images?"},
3838
],
3939
),

docs/source/basics/multimodal_datasets.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ in the text, ``"<image>"`` for where to place the image tokens. This will get re
4545
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
4646
from torchtune.datasets.multimodal import multimodal_chat_dataset
4747
48-
model_transform = Llama3VisionTransform(
48+
model_transform = llama3_2_vision_transform(
4949
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
5050
prompt_template="torchtune.data.QuestionAnswerTemplate",
5151
max_seq_len=8192,

docs/source/basics/tokenizers.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,30 @@ to do the actual encoding and decoding.
222222
print(sp_tokenizer.encode(text))
223223
# [1, 6312, 28709, 1526, 2]
224224
225+
.. _hf_tokenizers:
226+
227+
Using Hugging Face tokenizers
228+
-----------------------------
229+
230+
Sometimes tokenizers hosted on Hugging Face do not contain files compatible with one of torchtune's
231+
existing tokenizer classes. In this case, we provide :class:`~torchtune.modules.transforms.tokenizers.HuggingFaceBaseTokenizer`
232+
to parse the Hugging Face ``tokenizer.json`` file and define the correct ``encode`` and ``decode`` methods to
233+
match torchtune's other :class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer` classes. You should also pass the path to
234+
either ``tokenizer_config.json`` or ``generation_config.json``, which will allow torchtune to infer BOS and EOS tokens.
235+
Continuing with the Mistral example:
236+
237+
.. code-block:: python
238+
239+
hf_tokenizer = HuggingFaceBaseTokenizer(
240+
tokenizer_json_path="/tmp/Mistral-7B-v0.1/tokenizer.json",
241+
tokenizer_config_json_path="/tmp/Mistral-7B-v0.1/tokenizer_config.json",
242+
)
243+
244+
text = "hello world"
245+
246+
print(hf_tokenizer.encode(text))
247+
# [1, 6312, 28709, 1526, 2]
248+
225249
.. _model_tokenizers:
226250

227251
Model tokenizers

docs/source/deep_dives/checkpointer.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ For more details about each file, please check the End-to-End tutorial mentioned
293293
│ ├── adapter_model.pt
294294
│ ├── adapter_model.safetensors
295295
│ ├── config.json
296-
│ ├── ft-model-00001-of-00002.safetensors
297-
│ ├── ft-model-00002-of-00002.safetensors
296+
│ ├── model-00001-of-00002.safetensors
297+
│ ├── model-00002-of-00002.safetensors
298298
│ ├── generation_config.json
299299
│ ├── LICENSE.txt
300300
│ ├── model.safetensors.index.json
@@ -313,8 +313,8 @@ For more details about each file, please check the End-to-End tutorial mentioned
313313
│ ├── adapter_model.pt
314314
│ ├── adapter_model.safetensors
315315
│ ├── config.json
316-
│ ├── ft-model-00001-of-00002.safetensors
317-
│ ├── ft-model-00002-of-00002.safetensors
316+
│ ├── model-00001-of-00002.safetensors
317+
│ ├── model-00002-of-00002.safetensors
318318
│ ├── generation_config.json
319319
│ ├── LICENSE.txt
320320
│ ├── model.safetensors.index.json
@@ -394,7 +394,7 @@ you'll need to **update** the following fields in your configs:
394394

395395
**resume_from_checkpoint**: Set it to True;
396396

397-
**checkpoint_files**: change the path to ``epoch_{YOUR_EPOCH}/ft-model={}-of-{}.safetensors``;
397+
**checkpoint_files**: change the path to ``epoch_{YOUR_EPOCH}/model-{}-of-{}.safetensors``;
398398

399399
Notice that we do **not** change our checkpoint_dir or output_dir. Since we are resuming from checkpoint, we know
400400
to look for it in the output_dir.
@@ -405,8 +405,8 @@ to look for it in the output_dir.
405405
# checkpoint files. Note that you will need to update this
406406
# section of the config with the intermediate checkpoint files
407407
checkpoint_files: [
408-
epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
409-
epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
408+
epoch_{YOUR_EPOCH}/model-00001-of-00002.safetensors,
409+
epoch_{YOUR_EPOCH}/model-00001-of-00002.safetensors,
410410
]
411411
412412
# set to True if restarting training

docs/source/tutorials/e2e_flow.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ There are 3 types of folders:
142142
│ ├── adapter_model.pt
143143
│ ├── adapter_model.safetensors
144144
│ ├── config.json
145-
│ ├── ft-model-00001-of-00002.safetensors
146-
│ ├── ft-model-00002-of-00002.safetensors
145+
│ ├── model-00001-of-00002.safetensors
146+
│ ├── model-00002-of-00002.safetensors
147147
│ ├── generation_config.json
148148
│ ├── LICENSE.txt
149149
│ ├── model.safetensors.index.json
@@ -168,7 +168,7 @@ There are 3 types of folders:
168168
Let's understand the files:
169169

170170
- ``adapter_model.safetensors`` and ``adapter_model.pt`` are your LoRA trained adapter weights. We save a duplicated .pt version of it to facilitate resuming from checkpoint.
171-
- ``ft-model-{}-of-{}.safetensors`` are your trained full model weights (not adapters). When LoRA finetuning, these are only present if we set ``save_adapter_weights_only=False``. In that case, we merge the merged base model with trained adapters, making inference easier.
171+
- ``model-{}-of-{}.safetensors`` are your trained full model weights (not adapters). When LoRA finetuning, these are only present if we set ``save_adapter_weights_only=False``. In that case, we merge the merged base model with trained adapters, making inference easier.
172172
- ``adapter_config.json`` is used by Huggingface PEFT when loading an adapter (more on that later);
173173
- ``model.safetensors.index.json`` is used by Hugging Face ``from_pretrained()`` when loading the model weights (more on that later)
174174
- All other files were originally in the checkpoint_dir. They are automatically copied during training. Files over 100MiB and ending on .safetensors, .pth, .pt, .bin are ignored, making it lightweight.
@@ -223,8 +223,8 @@ Notice that we are using the merged weights, and not the LoRA adapters.
223223
_component_: torchtune.training.FullModelHFCheckpointer
224224
checkpoint_dir: ${output_dir}
225225
checkpoint_files: [
226-
ft-model-00001-of-00002.safetensors,
227-
ft-model-00002-of-00002.safetensors,
226+
model-00001-of-00002.safetensors,
227+
model-00002-of-00002.safetensors,
228228
]
229229
output_dir: ${output_dir}
230230
model_type: LLAMA3_2
@@ -299,8 +299,8 @@ Let's modify ``custom_generation_config.yaml`` to include the following changes.
299299
_component_: torchtune.training.FullModelHFCheckpointer
300300
checkpoint_dir: ${checkpoint_dir}
301301
checkpoint_files: [
302-
ft-model-00001-of-00002.safetensors,
303-
ft-model-00002-of-00002.safetensors,
302+
model-00001-of-00002.safetensors,
303+
model-00002-of-00002.safetensors,
304304
]
305305
output_dir: ${output_dir}
306306
model_type: LLAMA3_2

docs/source/tutorials/qat_finetune.rst

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,47 @@ Between these two steps, training can proceed exactly as before.
6464
Applying QAT to Llama3 models
6565
-----------------------------
6666

67-
We can easily apply the above QAT transformations to Llama3 in torchtune for fine-tuning:
67+
We can easily apply the above QAT transformations to Llama3 for fine-tuning,
68+
leveraging the APIs in torchao as follows:
6869

6970
.. code-block:: python
7071
71-
from torchtune.training.quantization import Int8DynActInt4WeightQATQuantizer
72+
import copy
73+
import torch
74+
from torchao.quantization import quantize_
75+
from torchao.quantization.qat import (
76+
FakeQuantizeConfig,
77+
IntXQuantizationAwareTrainingConfig,
78+
)
7279
from torchtune.models.llama3 import llama3_8b
7380
7481
model = llama3_8b()
82+
original_model = copy.deepcopy(model)
83+
84+
# Config for int8 dynamic asymmetric per token activations +
85+
# int4 symmetric per group weights, only for linear layers
86+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
87+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
88+
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
7589
76-
# Quantizer for int8 dynamic per token activations +
77-
# int4 grouped per channel weights, only for linear layers
78-
quantizer = Int8DynActInt4WeightQATQuantizer()
90+
# Prepare the model for quantization-aware fine-tuning.
91+
#
92+
# This step inserts "fake quantize" ops that simulate
93+
# quantization numerics during fine-tuning without
94+
# actually casting the activations/weights to lower-bit
95+
# dtypes like in "real" quantization.
96+
quantize_(model, qat_config)
7997
80-
# Insert "fake quantize" operations into linear layers.
81-
# These operations simulate quantization numerics during
82-
# fine-tuning without performing any dtype casting
83-
prepared_model = quantizer.prepare(model)
98+
prepared_model = model
8499
85-
If we print the model we’ll see that all linear layers have been swapped with
86-
:code:`Int8DynActInt4WeightQATLinear`, which simulates the numerics of int8
87-
dynamic per token activations + int4 grouped per channel weights. Now the model
88-
is ready for fine-tuning.
100+
The model is now ready for QAT fine-tuning! If we print the model we’ll see that
101+
all linear layers have been swapped with :code:`FakeQuantizedLinear`, which simulates
102+
the numerics of int8 dynamic asymmetric per token activations + int4 symmetric
103+
per group weights:
89104

90105
.. code-block:: bash
91106
92-
>>> print(model.layers[0].attn)
107+
>>> original_model.layers[0].attn
93108
MultiHeadAttention(
94109
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
95110
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
@@ -98,37 +113,71 @@ is ready for fine-tuning.
98113
(pos_embeddings): RotaryPositionalEmbeddings()
99114
)
100115
101-
>>> print(prepared_model.layers[0].attn)
116+
.. code-block:: bash
117+
118+
>>> prepared_model.layers[0].attn
102119
MultiHeadAttention(
103-
(q_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
104-
(k_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
105-
(v_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
106-
(output_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
120+
(q_proj): FakeQuantizedLinear(
121+
in_features=4096, out_features=4096, bias=False
122+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
123+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
124+
)
125+
(k_proj): FakeQuantizedLinear(
126+
in_features=4096, out_features=1024, bias=False
127+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
128+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
129+
)
130+
(v_proj): FakeQuantizedLinear(
131+
in_features=4096, out_features=1024, bias=False
132+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
133+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
134+
)
135+
(output_proj): FakeQuantizedLinear(
136+
in_features=4096, out_features=4096, bias=False
137+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
138+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
139+
)
107140
(pos_embeddings): RotaryPositionalEmbeddings()
108141
)
109142
110-
After fine-tuning, we can convert the model to get an actual quantized model.
111-
If we print the converted model, we’ll see that the QAT linears have been
112-
swapped with `Int8DynActInt4WeightLinear <https://github.com/pytorch/ao/blob/428084356ace4ea94c22a3a9b3d74cff8ee41db3/torchao/quantization/prototype/qat.py#L38>`_, which are the quantized versions
113-
of the linear layers. This quantized model can then be saved to checkpoint and
114-
used for inference or generation.
143+
After fine-tuning, we can convert the model to get an actual quantized model:
115144

116145
.. code-block:: python
117146
147+
from torchao.quantization.qat import (
148+
FromIntXQuantizationAwareTrainingConfig,
149+
)
150+
from torchao.quantization import (
151+
Int8DynamicActivationInt4WeightConfig,
152+
)
153+
118154
# Fine-tune as before
119155
train_loop(prepared_model)
120156
121-
# Convert fake quantize to actual quantize operations
122-
converted_model = quantizer.convert(prepared_model)
157+
# Convert the fake quantized model into an actual quantized model
158+
#
159+
# First, we swap `FakeQuantizedLinear` back to `torch.nn.Linear`
160+
# while keeping the QAT fine-tuned weights. Then, we perform standard
161+
# post-training quantization (PTQ), which inserts quantized activation
162+
# and weight tensor subclasses
163+
quantize_(prepared_model, FromIntXQuantizationAwareTrainingConfig())
164+
quantize_(prepared_model, Int8DynamicActivationInt4WeightConfig(group_size=32))
165+
166+
converted_model = prepared_model
167+
168+
The model is now fully quantized to int8 and int4 and ready for inference
169+
or generation. If we print the model now, we will see the linear layers
170+
are now swapped back to :code:`torch.nn.Linear`, but with quantized tensor
171+
activations and weights:
123172

124173
.. code-block:: bash
125174
126-
>>> print(converted_model.layers[0].attn)
175+
>>> converted_model.layers[0].attn
127176
MultiHeadAttention(
128-
(q_proj): Int8DynActInt4WeightLinear()
129-
(k_proj): Int8DynActInt4WeightLinear()
130-
(v_proj): Int8DynActInt4WeightLinear()
131-
(output_proj): Int8DynActInt4WeightLinear()
177+
(q_proj): Linear(in_features=4096, out_features=4096, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([4096, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
178+
(k_proj): Linear(in_features=4096, out_features=1024, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([1024, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
179+
(v_proj): Linear(in_features=4096, out_features=1024, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([1024, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
180+
(output_proj): Linear(in_features=4096, out_features=4096, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([4096, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
132181
(pos_embeddings): RotaryPositionalEmbeddings()
133182
)
134183
@@ -150,23 +199,21 @@ modifications accordingly:
150199
151200
.. code-block:: yaml
152201
153-
# Dataset
154202
dataset:
155203
_component_: torchtune.datasets.text_completion_dataset
156204
source: allenai/c4
157-
max_seq_len: 8192
158205
column: text
159206
name: en
160207
split: train
161-
seed: null
162-
shuffle: True
163208
164209
...
165210
166211
epochs: 1
167212
max_steps_per_epoch: 2000
168213
fake_quant_after_n_steps: 1000
169-
memory_efficient_fsdp_wrap: False
214+
215+
By default, this uses the :code:`torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer`,
216+
which uses the same fake quantization configurations as the example above.
170217

171218
Empirically, we observed that disabling fake quantization for the first N steps
172219
led to better results, presumably because doing so allows the weights to stabilize
@@ -213,15 +260,13 @@ copy and make the following modifications to the quantization config:
213260
214261
.. code-block:: yaml
215262
216-
# Model arguments
217263
model:
218264
_component_: torchtune.models.llama3.llama3_8b
219265
220266
checkpointer:
221267
_component_: torchtune.training.FullModelMetaCheckpointer
222268
checkpoint_dir: <your QAT checkpoint dir>
223-
checkpoint_files: [meta_model_0.pt]
224-
recipe_checkpoint: null
269+
checkpoint_files: [ft-model-00001-of-00001.bin]
225270
output_dir: <your QAT checkpoint dir>
226271
model_type: LLAMA3
227272
@@ -259,25 +304,19 @@ integrated in torchtune. First, copy the evaluation config and make the followin
259304
260305
.. code-block:: yaml
261306
262-
# Model arguments
263307
model:
264308
_component_: torchtune.models.llama3.llama3_8b
265309
266310
checkpointer:
267311
_component_: torchtune.training.FullModelTorchTuneCheckpointer
268312
checkpoint_dir: <your quantized model checkpoint dir>
269-
checkpoint_files: [meta_model_0-8da4w.pt]
270-
recipe_checkpoint: null
313+
checkpoint_files: [ft-model-00001-of-00001-8da4w.bin]
271314
output_dir: <your quantized model checkpoint dir>
272315
model_type: LLAMA3
273316
274317
...
275318
276-
# EleutherAI specific eval args
277319
tasks: ["hellaswag", "wikitext"]
278-
limit: null
279-
max_seq_length: 8192
280-
batch_size: 8
281320
282321
quantizer:
283322
_component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer

0 commit comments

Comments
 (0)