Skip to content

Commit 8b55ce7

Browse files
committed
Update QAT: add grad clipping, torch.compile, collate fn
**Summary:** Update the qat_distributed recipe to match the full_finetune_distributed recipe. This commit adds features to QAT like gradient clipping, torch.compile, and user configurable collate function for data pre-processing. **Test Plan:** TBD
1 parent 7d29c21 commit 8b55ce7

File tree

2 files changed

+56
-45
lines changed

2 files changed

+56
-45
lines changed

recipes/configs/llama3/8B_qat_full.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ device: cuda
6363

6464
# Memory management
6565
enable_activation_checkpointing: True
66-
memory_efficient_fsdp_wrap: True
66+
custom_sharded_layers: ['tok_embeddings', 'output']
6767

6868
# Reduced precision
6969
dtype: bf16
@@ -72,6 +72,6 @@ dtype: bf16
7272
metric_logger:
7373
_component_: torchtune.training.metric_logging.DiskLogger
7474
log_dir: ${output_dir}
75-
output_dir: /tmp/alpaca-llama3-finetune
75+
output_dir: /tmp/full-llama3-finetune
7676
log_every_n_steps: 1
7777
log_peak_memory_stats: False

recipes/qat_distributed.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
87
import sys
98
import time
109

@@ -21,7 +20,8 @@
2120
from torch.optim import Optimizer
2221
from torch.utils.data import DataLoader, DistributedSampler
2322
from torchtune import config, modules, training, utils
24-
from torchtune.data import padded_collate_packed, padded_collate_sft
23+
from torchtune.config._utils import _get_component_from_path
24+
from torchtune.data import padded_collate_packed
2525
from torchtune.datasets import ConcatDataset
2626
from torchtune.recipe_interfaces import FTRecipeInterface
2727
from torchtune.training import DummyProfiler, PROFILER_KEY
@@ -50,7 +50,7 @@ class QATRecipeDistributed(FTRecipeInterface):
5050
to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``.
5151
5252
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
53-
is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
53+
is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
5454
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
5555
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
5656
DDP is currently not supported. Training on CPU is not supported.
@@ -93,6 +93,10 @@ class QATRecipeDistributed(FTRecipeInterface):
9393
9494
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
9595
96+
- Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
97+
``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
98+
``clip_grad_norm='inf'``.
99+
96100
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
97101
has example commands for how to kick-off training.
98102
@@ -102,6 +106,7 @@ class QATRecipeDistributed(FTRecipeInterface):
102106
Raises:
103107
ValueError: If ``dtype`` is set to fp16.
104108
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
109+
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
105110
"""
106111

107112
def __init__(self, cfg: DictConfig) -> None:
@@ -135,9 +140,6 @@ def __init__(self, cfg: DictConfig) -> None:
135140
# Training cfg
136141
self._resume_from_checkpoint = cfg.resume_from_checkpoint
137142
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
138-
self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[
139-
cfg.get("fsdp_sharding_strategy", "FULL_SHARD")
140-
]
141143
self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None)
142144
self._quantizer_mode = None
143145

@@ -148,6 +150,7 @@ def __init__(self, cfg: DictConfig) -> None:
148150
self.total_epochs = cfg.epochs
149151
self.max_steps_per_epoch = cfg.max_steps_per_epoch
150152
self.global_step = 0
153+
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
151154

152155
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
153156
"""
@@ -217,7 +220,7 @@ def setup(self, cfg: DictConfig) -> None:
217220

218221
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
219222

220-
self._model_compile = cfg.get("compile", False)
223+
self._compile = cfg.get("compile", False)
221224
self._model = self._setup_model(
222225
cfg_model=cfg.model,
223226
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
@@ -240,30 +243,25 @@ def setup(self, cfg: DictConfig) -> None:
240243

241244
# initialize loss
242245
self._loss_fn = config.instantiate(cfg.loss)
243-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
246+
247+
if self._compile:
248+
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
249+
244250
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
245251
# set num_output_chunks for model
246252
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
247-
if self._model_compile:
248-
log.info("Compiling loss with torch.compile...")
249-
# For CEWithChunkedOutputLoss, if we compile the entire class
250-
# we lose the benefits from the chunked loss.
251-
# Therefore, we only compile the cross entropy function + upcasting
252-
self._loss_fn.compute_cross_entropy = torch.compile(
253-
self._loss_fn.compute_cross_entropy, backend=backend
254-
)
255-
else:
256-
if self._model_compile:
257-
log.info("Compiling loss with torch.compile...")
258-
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
259-
log.info("Loss is initialized.")
253+
254+
if self._is_rank_zero:
255+
log.info("Loss is initialized.")
260256

261257
# sampler and dataloader depend on the tokenizer and loss_fn and should be
262258
# setup after both of these are initialized
259+
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
263260
self._sampler, self._dataloader = self._setup_data(
264261
cfg_dataset=cfg.dataset,
265262
shuffle=cfg.shuffle,
266263
batch_size=cfg.batch_size,
264+
collate_fn=collate_name,
267265
)
268266

269267
# Finally update the recipe state which can only be correctly set after all of the
@@ -388,6 +386,9 @@ def _setup_model(
388386
with training.set_default_dtype(self._dtype), torch.device("meta"):
389387
model = config.instantiate(cfg_model)
390388

389+
if self._compile:
390+
training.compile_model(model, verbose=self._is_rank_zero)
391+
391392
# We currently have two versions of activation checkpointing in this recipe
392393
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
393394
# the older version of AC and this behavior is unchanged
@@ -459,7 +460,12 @@ def _is_layer_fqn(s: str) -> bool:
459460
# This method will convert the full model state dict into a sharded state
460461
# dict and load into the model
461462
training.load_from_full_model_state_dict(
462-
model, model_state_dict, self._device, self._is_rank_zero, strict=True
463+
model,
464+
model_state_dict,
465+
self._device,
466+
self._is_rank_zero,
467+
strict=True,
468+
cpu_offload=fsdp_cpu_offload,
463469
)
464470

465471
# Ensure no params and buffers are on meta device
@@ -497,6 +503,7 @@ def _setup_data(
497503
cfg_dataset: DictConfig,
498504
shuffle: bool,
499505
batch_size: int,
506+
collate_fn: str,
500507
) -> Tuple[DistributedSampler, DataLoader]:
501508
"""
502509
All data related setup happens here. Currently this recipe only supports the
@@ -507,15 +514,20 @@ def _setup_data(
507514

508515
if isinstance(cfg_dataset, ListConfig):
509516
datasets = [
510-
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
517+
config.instantiate(single_cfg_dataset, self._tokenizer)
511518
for single_cfg_dataset in cfg_dataset
512519
]
513520
ds = ConcatDataset(datasets=datasets)
514521
packed = False
515522
else:
516-
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
523+
ds = config.instantiate(cfg_dataset, self._tokenizer)
517524
packed = cfg_dataset.get("packed", False)
518525

526+
# Instantiate collate_fn
527+
if "left_pad_sequence" in collate_fn:
528+
raise RuntimeError("left_pad_sequence collator is only for inference.")
529+
collate_fn = _get_component_from_path(collate_fn)
530+
519531
sampler = DistributedSampler(
520532
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
521533
)
@@ -526,14 +538,12 @@ def _setup_data(
526538
# dropping last avoids shape issues with compile + flex attention
527539
drop_last=True,
528540
collate_fn=partial(
529-
padded_collate_sft,
541+
collate_fn,
530542
padding_idx=self._tokenizer.pad_id,
531543
ignore_idx=self._loss_fn.ignore_index,
532544
)
533545
if not packed
534-
else partial(
535-
padded_collate_packed,
536-
),
546+
else padded_collate_packed,
537547
)
538548

539549
if self._is_rank_zero:
@@ -564,12 +574,14 @@ def save_checkpoint(
564574
cpu_state_dict = training.get_full_model_state_dict(
565575
self._model,
566576
self._is_rank_zero,
577+
device=self._device,
567578
)
568579

569580
if intermediate_checkpoint:
570581
opt_state_dict = training.get_full_optimizer_state_dict(
571582
self._optimizer,
572583
self._is_rank_zero,
584+
device=self._device,
573585
)
574586
else:
575587
opt_state_dict = None
@@ -642,13 +654,6 @@ def train(self) -> None:
642654
):
643655
torch.cuda.memory._record_memory_history()
644656

645-
# Both are shape [b, s]
646-
tokens, labels = batch["tokens"], batch["labels"]
647-
# Get the attention mask and position ids from the dataset if they
648-
# exist. Currently, only sample packing in PackedDataset returns these
649-
mask = batch.get("mask", None) # shape [b, s, s]
650-
input_pos = batch.get("input_pos", None) # shape [b, s]
651-
652657
# Optionally wait N steps before enabling fake quant
653658
if self._fake_quant_after_n_steps is not None:
654659
if self.global_step == 0:
@@ -670,15 +675,13 @@ def train(self) -> None:
670675
)
671676
self._model.apply(enable_fq)
672677

673-
tokens = tokens.to(self._device)
674-
num_tokens += tokens.numel()
675-
labels = labels.to(self._device)
676-
mask = mask.to(self._device) if mask is not None else None
677-
input_pos = (
678-
input_pos.to(self._device) if input_pos is not None else None
679-
)
678+
utils.batch_to_device(batch, self._device)
679+
num_tokens += batch["tokens"].numel()
680+
681+
# Shape [b, s], needed for the loss not the model
682+
labels = batch.pop("labels")
680683

681-
logits = self._model(tokens, mask=mask, input_pos=input_pos)
684+
logits = self._model(**batch)
682685

683686
# Shift labels to compute loss
684687
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
@@ -692,6 +695,7 @@ def train(self) -> None:
692695

693696
# Compute loss
694697
loss = self._loss_fn(logits, labels)
698+
695699
# free logits otherwise it peaks backward memory
696700
del logits
697701

@@ -701,6 +705,11 @@ def train(self) -> None:
701705

702706
# Step with optimizer
703707
if (idx + 1) % self._gradient_accumulation_steps == 0:
708+
if self._clip_grad_norm is not None:
709+
grad_norm = torch.nn.utils.clip_grad_norm_(
710+
self._model.parameters(),
711+
max_norm=float(self._clip_grad_norm),
712+
)
704713
self._optimizer.step()
705714
self._optimizer.zero_grad(set_to_none=True)
706715

@@ -728,6 +737,8 @@ def train(self) -> None:
728737
log_dict.update(
729738
training.get_memory_stats(device=self._device)
730739
)
740+
if self._clip_grad_norm is not None:
741+
log_dict.update({"grad_norm": grad_norm})
731742
self._metric_logger.log_dict(
732743
log_dict,
733744
step=self.global_step,

0 commit comments

Comments
 (0)