From 258cd8b9379f7f4e454c9dde59f9c2ebe52e8393 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 1 Nov 2024 16:05:51 -0700 Subject: [PATCH] Update QAT: add grad clipping, torch.compile, collate fn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary:** Update the qat_distributed recipe to match the full_finetune_distributed recipe. This commit adds features to QAT like gradient clipping, torch.compile, and user configurable collate function for data pre-processing. Mirrors all changes in full_finetune_distributed as of 506e099. Helpful commands for quick review: ``` diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py diff --color recipes/configs/llama2/7B_full.yaml recipes/configs/llama2/7B_qat_full.yaml diff --color recipes/configs/llama3/8B_full.yaml recipes/configs/llama3/8B_qat_full.yaml ``` **Test Plan:** Fine-tune on alpaca dataset for 1 epoch with and without QAT: ``` CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 tune run --nnodes 1 --nproc_per_node 6 qat_distributed --config llama3/8B_qat_full \ epochs=1 \ checkpointer.output_dir="$LOG_DIR" \ metric_logger.output_dir="${LOG_DIR}/metrics" \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/quantization.yaml \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=[meta_model_0.pt] \ checkpointer.model_type=LLAMA3 \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \ tasks=[wikitext] \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=[meta_model_0-8da4w.pt] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer ``` With QAT: ``` | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.9821|± | N/A| | | |none |None |byte_perplexity|↓ | 1.9754|± | N/A| | | |none |None |word_perplexity|↓ |38.1039|± | N/A| ``` Without QAT: ``` | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|--------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 2.2017|± | N/A| | | |none |None |byte_perplexity|↓ | 4.6003|± | N/A| | | |none |None |word_perplexity|↓ |3501.1122|± | N/A| ``` --- recipes/configs/llama2/7B_qat_full.yaml | 2 +- recipes/configs/llama3/8B_qat_full.yaml | 9 +- recipes/qat_distributed.py | 306 ++++++++++++++++++------ 3 files changed, 240 insertions(+), 77 deletions(-) diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml index 0cbf6c7b7a..e404b0c4dc 100644 --- a/recipes/configs/llama2/7B_qat_full.yaml +++ b/recipes/configs/llama2/7B_qat_full.yaml @@ -67,7 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True # True reduces memory -memory_efficient_fsdp_wrap: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index ce409d1bbb..2b08cbb10f 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -44,8 +44,6 @@ resume_from_checkpoint: False # Fine-tuning arguments batch_size: 2 epochs: 3 -compile: False # pytorch compile, set to true for better perf/memory -optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # QAT arguments quantizer: @@ -60,13 +58,16 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management enable_activation_checkpointing: True # True reduces memory -memory_efficient_fsdp_wrap: True +enable_activation_offloading: False # True reduces memory +custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Reduced precision dtype: bf16 @@ -75,7 +76,7 @@ dtype: bf16 metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} -output_dir: /tmp/alpaca-llama3-finetune +output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index f09ffc1c7b..b1040880d0 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time @@ -21,11 +20,13 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.training.lr_schedulers import get_lr from tqdm import tqdm @@ -50,7 +51,7 @@ class QATRecipeDistributed(FTRecipeInterface): to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``. - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states - is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. @@ -62,6 +63,18 @@ class QATRecipeDistributed(FTRecipeInterface): come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -93,6 +106,10 @@ class QATRecipeDistributed(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -102,6 +119,9 @@ class QATRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -141,12 +161,50 @@ def __init__(self, cfg: DictConfig) -> None: # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[ - cfg.get("fsdp_sharding_strategy", "FULL_SHARD") - ] + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -223,10 +281,11 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - self._model_compile = cfg.get("compile", False) + self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -239,6 +298,7 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, opt_state_dict=( checkpoint_dict[training.OPT_KEY] if self._resume_from_checkpoint @@ -248,30 +308,25 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - # For CEWithChunkedOutputLoss, if we compile the entire class - # we lose the benefits from the chunked loss. - # Therefore, we only compile the cross entropy function + upcasting - self._loss_fn.compute_cross_entropy = torch.compile( - self._loss_fn.compute_cross_entropy, backend=backend - ) - else: - if self._model_compile: - log.info("Compiling loss with torch.compile...") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) - log.info("Loss is initialized.") + + if self._is_rank_zero: + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, + collate_fn=collate_name, ) # Finally update the recipe state which can only be correctly set after all of the @@ -371,6 +426,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, model_state_dict: Dict[str, Any], @@ -396,6 +452,9 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls # the older version of AC and this behavior is unchanged @@ -451,7 +510,17 @@ def _setup_model( # This method will convert the full model state dict into a sharded state # dict and load into the model training.load_from_full_model_state_dict( - model, model_state_dict, self._device, self._is_rank_zero, strict=True + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading ) # Ensure no params and buffers are on meta device @@ -470,25 +539,64 @@ def _setup_model( return model def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None - ) -> Optimizer: - optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - training.load_from_full_optimizer_state_dict( - optimizer, - opt_state_dict, - self._device, + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + if self._is_rank_zero: + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) - if self._is_rank_zero: - log.info("Optimizer is initialized.") - return optimizer + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer def _setup_data( self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int, + collate_fn: str, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -499,15 +607,20 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) @@ -519,14 +632,12 @@ def _setup_data( drop_last=True, collate_fn=( partial( - padded_collate_sft, + collate_fn, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ) + else padded_collate_packed ), ) @@ -553,25 +664,54 @@ def save_checkpoint( checkpoint_dict = {} intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 cpu_state_dict = training.get_full_model_state_dict( self._model, self._is_rank_zero, + device=self._device, ) - if intermediate_checkpoint: - opt_state_dict = training.get_full_optimizer_state_dict( - self._optimizer, - self._is_rank_zero, + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" ) + + if intermediate_checkpoint: + start = time.perf_counter() + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) else: opt_state_dict = None # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file - if self._is_rank_zero: + if self._is_rank_zero: + start = time.perf_counter() checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state @@ -592,6 +732,9 @@ def save_checkpoint( epoch=epoch, intermediate_checkpoint=intermediate_checkpoint, ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() def train(self) -> None: """ @@ -599,10 +742,15 @@ def train(self) -> None: """ # clean up before training begins training.cleanup_before_training() + world_size, rank = training.get_world_size_and_rank() # zero out the gradients before starting training - self._optimizer.zero_grad() + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + else: + for opt in self._optim_ckpt_wrapper.optim_map.values(): + opt.zero_grad() # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() @@ -612,7 +760,6 @@ def train(self) -> None: self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): - # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True self._sampler.set_epoch(curr_epoch) @@ -635,13 +782,6 @@ def train(self) -> None: ): torch.cuda.memory._record_memory_history() - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - # Optionally wait N steps before enabling fake quant if self._fake_quant_after_n_steps is not None: if self.global_step == 0: @@ -663,20 +803,20 @@ def train(self) -> None: ) self._model.apply(enable_fq) - tokens = tokens.to(self._device) + utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step - - utils.batch_to_device(batch, self._device) - current_num_tokens = ( batch["labels"] != self._loss_fn.ignore_index ).sum() num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - logits = self._model(**batch) + with self.activations_handling_ctx: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -689,25 +829,40 @@ def train(self) -> None: logits = logits.reshape(-1, logits.size(-1)) # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits running_loss += current_loss - current_loss.backward() - # Step with optimizer - if (idx + 1) % self._gradient_accumulation_steps == 0: - # Get total number of tokens across all ranks to normalize gradients + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) - # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + current_loss = current_loss / num_tokens + + current_loss.backward() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated self.global_step += 1 @@ -726,15 +881,22 @@ def train(self) -> None: time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, - "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": ( - num_tokens / time_per_step * world_size + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), } if self._log_peak_memory_stats: log_dict.update( training.get_memory_stats(device=self._device) ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -784,7 +946,7 @@ def recipe_main(cfg: DictConfig) -> None: """ if not training.is_distributed(): raise RuntimeError( - "Distributed QAT recipe should be run via a distributed launcher." + "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")