44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import os
87import sys
98import time
109
2120from torch .optim import Optimizer
2221from torch .utils .data import DataLoader , DistributedSampler
2322from torchtune import config , modules , training , utils
24- from torchtune .data import padded_collate_packed , padded_collate_sft
23+ from torchtune .config ._utils import _get_component_from_path
24+ from torchtune .data import padded_collate_packed
2525from torchtune .datasets import ConcatDataset
2626from torchtune .recipe_interfaces import FTRecipeInterface
2727from torchtune .training import DummyProfiler , PROFILER_KEY
@@ -50,7 +50,7 @@ class QATRecipeDistributed(FTRecipeInterface):
5050 to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``.
5151
5252 - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
53- is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
53+ is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
5454 done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
5555 ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
5656 DDP is currently not supported. Training on CPU is not supported.
@@ -93,6 +93,10 @@ class QATRecipeDistributed(FTRecipeInterface):
9393
9494 - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
9595
96+ - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
97+ ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
98+ ``clip_grad_norm='inf'``.
99+
96100 For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
97101 has example commands for how to kick-off training.
98102
@@ -102,6 +106,7 @@ class QATRecipeDistributed(FTRecipeInterface):
102106 Raises:
103107 ValueError: If ``dtype`` is set to fp16.
104108 RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
109+ RuntimeError: If ``left_pad_sequence`` is set as the data collator.
105110 """
106111
107112 def __init__ (self , cfg : DictConfig ) -> None :
@@ -135,9 +140,6 @@ def __init__(self, cfg: DictConfig) -> None:
135140 # Training cfg
136141 self ._resume_from_checkpoint = cfg .resume_from_checkpoint
137142 self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
138- self ._fsdp_sharding_strategy = torch .distributed .fsdp .ShardingStrategy [
139- cfg .get ("fsdp_sharding_strategy" , "FULL_SHARD" )
140- ]
141143 self ._fake_quant_after_n_steps = cfg .get ("fake_quant_after_n_steps" , None )
142144 self ._quantizer_mode = None
143145
@@ -148,6 +150,7 @@ def __init__(self, cfg: DictConfig) -> None:
148150 self .total_epochs = cfg .epochs
149151 self .max_steps_per_epoch = cfg .max_steps_per_epoch
150152 self .global_step = 0
153+ self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
151154
152155 def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
153156 """
@@ -217,7 +220,7 @@ def setup(self, cfg: DictConfig) -> None:
217220
218221 checkpoint_dict = self .load_checkpoint (cfg_checkpointer = cfg .checkpointer )
219222
220- self ._model_compile = cfg .get ("compile" , False )
223+ self ._compile = cfg .get ("compile" , False )
221224 self ._model = self ._setup_model (
222225 cfg_model = cfg .model ,
223226 enable_activation_checkpointing = cfg .enable_activation_checkpointing ,
@@ -240,30 +243,25 @@ def setup(self, cfg: DictConfig) -> None:
240243
241244 # initialize loss
242245 self ._loss_fn = config .instantiate (cfg .loss )
243- backend = os .environ .get ("TORCH_COMPILE_BACKEND" , "inductor" )
246+
247+ if self ._compile :
248+ training .compile_loss (self ._loss_fn , verbose = self ._is_rank_zero )
249+
244250 if self ._loss_fn .__class__ .__name__ == "CEWithChunkedOutputLoss" :
245251 # set num_output_chunks for model
246252 self ._model .set_num_output_chunks (self ._loss_fn .num_output_chunks )
247- if self ._model_compile :
248- log .info ("Compiling loss with torch.compile..." )
249- # For CEWithChunkedOutputLoss, if we compile the entire class
250- # we lose the benefits from the chunked loss.
251- # Therefore, we only compile the cross entropy function + upcasting
252- self ._loss_fn .compute_cross_entropy = torch .compile (
253- self ._loss_fn .compute_cross_entropy , backend = backend
254- )
255- else :
256- if self ._model_compile :
257- log .info ("Compiling loss with torch.compile..." )
258- self ._loss_fn = torch .compile (self ._loss_fn , backend = backend )
259- log .info ("Loss is initialized." )
253+
254+ if self ._is_rank_zero :
255+ log .info ("Loss is initialized." )
260256
261257 # sampler and dataloader depend on the tokenizer and loss_fn and should be
262258 # setup after both of these are initialized
259+ collate_name = cfg .get ("collate_fn" , "torchtune.data.padded_collate_sft" )
263260 self ._sampler , self ._dataloader = self ._setup_data (
264261 cfg_dataset = cfg .dataset ,
265262 shuffle = cfg .shuffle ,
266263 batch_size = cfg .batch_size ,
264+ collate_fn = collate_name ,
267265 )
268266
269267 # Finally update the recipe state which can only be correctly set after all of the
@@ -388,6 +386,9 @@ def _setup_model(
388386 with training .set_default_dtype (self ._dtype ), torch .device ("meta" ):
389387 model = config .instantiate (cfg_model )
390388
389+ if self ._compile :
390+ training .compile_model (model , verbose = self ._is_rank_zero )
391+
391392 # We currently have two versions of activation checkpointing in this recipe
392393 # for testing and BC purposes. ``enable_activation_checkpointing`` controls
393394 # the older version of AC and this behavior is unchanged
@@ -459,7 +460,12 @@ def _is_layer_fqn(s: str) -> bool:
459460 # This method will convert the full model state dict into a sharded state
460461 # dict and load into the model
461462 training .load_from_full_model_state_dict (
462- model , model_state_dict , self ._device , self ._is_rank_zero , strict = True
463+ model ,
464+ model_state_dict ,
465+ self ._device ,
466+ self ._is_rank_zero ,
467+ strict = True ,
468+ cpu_offload = fsdp_cpu_offload ,
463469 )
464470
465471 # Ensure no params and buffers are on meta device
@@ -497,6 +503,7 @@ def _setup_data(
497503 cfg_dataset : DictConfig ,
498504 shuffle : bool ,
499505 batch_size : int ,
506+ collate_fn : str ,
500507 ) -> Tuple [DistributedSampler , DataLoader ]:
501508 """
502509 All data related setup happens here. Currently this recipe only supports the
@@ -507,15 +514,20 @@ def _setup_data(
507514
508515 if isinstance (cfg_dataset , ListConfig ):
509516 datasets = [
510- config .instantiate (single_cfg_dataset , tokenizer = self ._tokenizer )
517+ config .instantiate (single_cfg_dataset , self ._tokenizer )
511518 for single_cfg_dataset in cfg_dataset
512519 ]
513520 ds = ConcatDataset (datasets = datasets )
514521 packed = False
515522 else :
516- ds = config .instantiate (cfg_dataset , tokenizer = self ._tokenizer )
523+ ds = config .instantiate (cfg_dataset , self ._tokenizer )
517524 packed = cfg_dataset .get ("packed" , False )
518525
526+ # Instantiate collate_fn
527+ if "left_pad_sequence" in collate_fn :
528+ raise RuntimeError ("left_pad_sequence collator is only for inference." )
529+ collate_fn = _get_component_from_path (collate_fn )
530+
519531 sampler = DistributedSampler (
520532 ds , num_replicas = world_size , rank = rank , shuffle = shuffle , seed = 0
521533 )
@@ -526,14 +538,12 @@ def _setup_data(
526538 # dropping last avoids shape issues with compile + flex attention
527539 drop_last = True ,
528540 collate_fn = partial (
529- padded_collate_sft ,
541+ collate_fn ,
530542 padding_idx = self ._tokenizer .pad_id ,
531543 ignore_idx = self ._loss_fn .ignore_index ,
532544 )
533545 if not packed
534- else partial (
535- padded_collate_packed ,
536- ),
546+ else padded_collate_packed ,
537547 )
538548
539549 if self ._is_rank_zero :
@@ -564,12 +574,14 @@ def save_checkpoint(
564574 cpu_state_dict = training .get_full_model_state_dict (
565575 self ._model ,
566576 self ._is_rank_zero ,
577+ device = self ._device ,
567578 )
568579
569580 if intermediate_checkpoint :
570581 opt_state_dict = training .get_full_optimizer_state_dict (
571582 self ._optimizer ,
572583 self ._is_rank_zero ,
584+ device = self ._device ,
573585 )
574586 else :
575587 opt_state_dict = None
@@ -642,13 +654,6 @@ def train(self) -> None:
642654 ):
643655 torch .cuda .memory ._record_memory_history ()
644656
645- # Both are shape [b, s]
646- tokens , labels = batch ["tokens" ], batch ["labels" ]
647- # Get the attention mask and position ids from the dataset if they
648- # exist. Currently, only sample packing in PackedDataset returns these
649- mask = batch .get ("mask" , None ) # shape [b, s, s]
650- input_pos = batch .get ("input_pos" , None ) # shape [b, s]
651-
652657 # Optionally wait N steps before enabling fake quant
653658 if self ._fake_quant_after_n_steps is not None :
654659 if self .global_step == 0 :
@@ -670,15 +675,13 @@ def train(self) -> None:
670675 )
671676 self ._model .apply (enable_fq )
672677
673- tokens = tokens .to (self ._device )
674- num_tokens += tokens .numel ()
675- labels = labels .to (self ._device )
676- mask = mask .to (self ._device ) if mask is not None else None
677- input_pos = (
678- input_pos .to (self ._device ) if input_pos is not None else None
679- )
678+ utils .batch_to_device (batch , self ._device )
679+ num_tokens += batch ["tokens" ].numel ()
680+
681+ # Shape [b, s], needed for the loss not the model
682+ labels = batch .pop ("labels" )
680683
681- logits = self ._model (tokens , mask = mask , input_pos = input_pos )
684+ logits = self ._model (** batch )
682685
683686 # Shift labels to compute loss
684687 # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
@@ -692,6 +695,7 @@ def train(self) -> None:
692695
693696 # Compute loss
694697 loss = self ._loss_fn (logits , labels )
698+
695699 # free logits otherwise it peaks backward memory
696700 del logits
697701
@@ -701,6 +705,11 @@ def train(self) -> None:
701705
702706 # Step with optimizer
703707 if (idx + 1 ) % self ._gradient_accumulation_steps == 0 :
708+ if self ._clip_grad_norm is not None :
709+ grad_norm = torch .nn .utils .clip_grad_norm_ (
710+ self ._model .parameters (),
711+ max_norm = float (self ._clip_grad_norm ),
712+ )
704713 self ._optimizer .step ()
705714 self ._optimizer .zero_grad (set_to_none = True )
706715
@@ -728,6 +737,8 @@ def train(self) -> None:
728737 log_dict .update (
729738 training .get_memory_stats (device = self ._device )
730739 )
740+ if self ._clip_grad_norm is not None :
741+ log_dict .update ({"grad_norm" : grad_norm })
731742 self ._metric_logger .log_dict (
732743 log_dict ,
733744 step = self .global_step ,
0 commit comments