diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 32aff922c..a04847c6d 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -83,6 +83,11 @@ output_dir: /tmp/full-llama3.1-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index e2aa1c110..d445eedbe 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -79,6 +79,11 @@ output_dir: /tmp/full-llama3.2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index 51173f162..ed8978772 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -83,6 +83,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 98d34b5f9..c6ec73b9e 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -279,6 +279,7 @@ def setup(self, cfg: DictConfig) -> None: model_state_dict=checkpoint_dict[training.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), + mixed_precision_cfg=cfg.get("mixed_precision", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -419,6 +420,7 @@ def _setup_model( custom_sharded_layers: Optional[List[str]] = None, ac_mode: Optional[str] = None, ac_option: Optional[int] = None, + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: """ Model initialization has some important considerations: @@ -459,6 +461,15 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if mixed_precision_cfg is not None and mixed_precision_cfg.get( + "enabled", False + ): + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = mixed_precision_cfg.copy() + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) + model = quantizer.prepare(model) + # For FSDP sharding fsdp_shard_conditions = [ partial( diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 0ab6ff3e6..6c84d1496 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -271,6 +271,7 @@ def setup(self, cfg: DictConfig) -> None: enable_activation_offloading=self._enable_activation_offloading, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], + mixed_precision_cfg=cfg.get("mixed_precision", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -414,6 +415,7 @@ def _setup_model( enable_activation_offloading: bool, compile_model: bool, model_state_dict: Dict[str, Any], + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: """ Set up the model including enabling activation checkpointing. @@ -429,6 +431,15 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if mixed_precision_cfg is not None and mixed_precision_cfg.get( + "enabled", False + ): + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = mixed_precision_cfg.copy() + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) + model = quantizer.prepare(model) + model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index fcdb3e4ea..4475f4898 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -276,6 +276,7 @@ def setup(self, cfg: DictConfig) -> None: if self._resume_from_checkpoint else None ), + mixed_precision_cfg=cfg.get("mixed_precision", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -420,6 +421,7 @@ def _setup_model( compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) @@ -441,6 +443,15 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if mixed_precision_cfg is not None and mixed_precision_cfg.get( + "enabled", False + ): + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = mixed_precision_cfg.copy() + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) + model = quantizer.prepare(model) + base_missing, base_unexpected = model.load_state_dict( base_model_state_dict, strict=False ) diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py new file mode 100644 index 000000000..2b8a5afea --- /dev/null +++ b/tests/torchtune/training/test_quantization.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import gpu_test +from torch import nn +from torchtune.training.quantization import ( + _SUPPORTS_INT8_MIXED_PRECISION_TRAINING, + Int8MixedPrecisionTrainingQuantizer, +) + + +@gpu_test(gpu_count=1) +@pytest.mark.skipif( + not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING, + reason="INT8 mixed-precision training is not supported", +) +def test_int8_mixed_precision_training_quantizer(): + quantizer = Int8MixedPrecisionTrainingQuantizer() + model = nn.Sequential( + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + ).cuda() + quantizer.prepare(model) + + # make sure class is changed + assert model[0].__class__ != nn.Linear + assert model[2].__class__ != nn.Linear + + # smoke test forward and backward + model(torch.randn(2, 32).cuda()).sum().backward() + for p in model.parameters(): + assert p.grad is not None + + # state dict is plain tensor + state_dict = model.state_dict() + for v in state_dict.values(): + assert v.__class__ == torch.Tensor diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 7ff9315f4..2bf31c870 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -7,6 +7,11 @@ from typing import Callable, Optional from warnings import warn +import torch +import torchao +from packaging.version import Version +from torch import nn + try: # torchao 0.7+ from torchao.dtypes import TensorCoreTiledLayout @@ -43,6 +48,22 @@ Int8DynActInt4WeightQATQuantizer, ) +from torchtune.utils._version import torch_version_ge + + +_SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( + torch_version_ge("2.4.0") + and Version(torchao.__version__) >= Version("0.7.0.dev") + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) +) + +if _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: + from torchao.prototype.quantized_training import ( + int8_mixed_precision_training, + Int8MixedPrecisionTrainingConfig, + ) + __all__ = [ "get_quantizer_mode", @@ -52,6 +73,7 @@ "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightQATQuantizer", "Int8DynActInt4WeightQATQuantizerModuleSwap", + "Int8MixedPrecisionTrainingQuantizer", ] @@ -144,6 +166,86 @@ def quantize(self, model): ] = enable_8da4w_fake_quant_module_swap +class Int8MixedPrecisionTrainingQuantizer: + """Apply INT8 mixed-precision training. This only affects weights of ``nn.Linear`` + modules. During training, weights and activations are dynamically quantized to INT8 + to utilize fast matrix multiplication with INT8 tensor cores. This is also done in + the backward pass. + + The expected end2end speedup is 40% on a single A100 and 70% on a single 4090, with + minimal accuracy loss. If convergence is an issue, please refer to torchao + documentation below. + + For more details, as well as details about arguments of this quantizer, please refer to + https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision + + Args: + output (bool): whether to apply INT8 mixed-precision for calculating output. Default: True + grad_input (bool): whether to apply INT8 mixed-precision for calculating grad_input. Default: True + grad_weight (bool): whether to apply INT8 mixed-precision for calculating grad_weight. Default: True + + Raises: + RuntimeError: If runtime requirements for INT8 mixed-precision training are not met. + + NOTE: Due to the limitations of the current implementation, the following + requirements must be satisfied to enjoy the expected speedup: + + 1. Must use ``torch.compile()`` (set ``compile=True``). + 2. Inputs to the model must not be too dynamic. For example, when input tokens + length changes for every batch, you won't see the expected speedup. + + To satisfy (2), you can use :class:`~torchtune.datasets.PackedDataset` (set + ``dataset.packed=True`` and ``tokenizer.max_seq_len`` to a desired value.), which + ensures input tokens always have fixed length. + """ + + def __init__( + self, + output: bool = True, + grad_input: bool = True, + grad_weight: bool = True, + ) -> None: + if not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: + raise RuntimeError( + "INT8 mixed-precision training requires torch>=2.4, torchao>=0.7, and" + " a CUDA-capable device with compute capability >= 8.0" + ) + + self._config = Int8MixedPrecisionTrainingConfig( + output=output, + grad_input=grad_input, + grad_weight=grad_weight, + ) + + def prepare(self, model: nn.Module) -> nn.Module: + # we use module-swap implementation so that the state_dict remains plain tensors, + # as well as better FSDP compatibility in torchtune. + quantize_fn = int8_mixed_precision_training(self._config, module_swap=True) + + # custom filter_fn to work with torchtune's peft + def filter_fn(module: nn.Module, name: str) -> bool: + if isinstance(module, nn.Linear): + # skip LoRA adapters since they are too small, so the speedup will not + # outweight quantization overhead. + # also skip LM head since end2end speedup is slightly worse. + # there are also possible issues with tied word embeddings. + if ( + name.endswith(".lora_a") + or name.endswith(".lora_b") + or module.weight.shape[0] >= 32_000 + ): + return False + else: + return True + + return False + + # don't set inductor config, otherwise compile will be very slow + # (it will affect global torch.compile() config) + quantize_(model, quantize_fn, filter_fn=filter_fn, set_inductor_config=False) + return model + + def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: """Given a quantizer object, returns a string that specifies the type of quantization.