From 52b7a4c1cbf5aed61b5e20222b0a33318ce78e11 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 1 Apr 2024 14:32:04 -0700 Subject: [PATCH] Adding quantization support in torchtune Summary: Allows user to specify quantization_mode in generating model in full_finetune_single_device.py and inference with the quantized model in generate.py Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags: --- recipes/configs/eleuther_eval.yaml | 13 ++-- recipes/configs/generate.yaml | 1 + recipes/configs/quant_generate.yaml | 29 +++++++++ recipes/configs/quantize.yaml | 18 ++++++ recipes/eleuther_eval.py | 30 ++++++--- recipes/generate.py | 19 ++++-- recipes/quantize.py | 97 +++++++++++++++++++++++++++++ requirements.txt | 2 +- torchtune/_recipe_registry.py | 9 +++ torchtune/utils/__init__.py | 5 ++ torchtune/utils/quantization.py | 52 ++++++++++++++++ 11 files changed, 256 insertions(+), 19 deletions(-) create mode 100644 recipes/configs/quant_generate.yaml create mode 100644 recipes/configs/quantize.yaml create mode 100644 recipes/quantize.py create mode 100644 torchtune/utils/quantization.py diff --git a/recipes/configs/eleuther_eval.yaml b/recipes/configs/eleuther_eval.yaml index e0e70f2556..3d1bdd5a9a 100644 --- a/recipes/configs/eleuther_eval.yaml +++ b/recipes/configs/eleuther_eval.yaml @@ -9,15 +9,15 @@ model: checkpointer: _component_: torchtune.utils.FullModelTorchTuneCheckpointer - checkpoint_dir: /tmp/llama/ - checkpoint_files: [finetuned_model.pt] - output_dir: /tmp/llama/ + checkpoint_dir: /tmp/llama2/ + checkpoint_files: [meta_model_0.4w.pt] + output_dir: /tmp/llama2/ model_type: LLAMA2 # Tokenizer tokenizer: _component_: torchtune.models.llama2.llama2_tokenizer - path: /tmp/llama/tokenizer.model + path: /tmp/llama2/tokenizer.model # Environment device: cuda @@ -25,6 +25,9 @@ dtype: bf16 seed: 217 # EleutherAI specific eval args -tasks: ["truthfulqa_mc2"] +tasks: ["truthfulqa_mc2", "hellaswag"] limit: null max_seq_length: 4096 + +# Quantization +quantization_mode: 4w diff --git a/recipes/configs/generate.yaml b/recipes/configs/generate.yaml index 2865f33d42..aa6635768d 100644 --- a/recipes/configs/generate.yaml +++ b/recipes/configs/generate.yaml @@ -16,6 +16,7 @@ checkpointer: device: cuda dtype: bf16 +quantization_mode: null seed: 1234 diff --git a/recipes/configs/quant_generate.yaml b/recipes/configs/quant_generate.yaml new file mode 100644 index 0000000000..f1157d2bb0 --- /dev/null +++ b/recipes/configs/quant_generate.yaml @@ -0,0 +1,29 @@ + +# Model arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelTorchTuneCheckpointer + checkpoint_dir: /tmp/llama2/ + checkpoint_files: [meta_model_0.f16a4w.pt] + output_dir: /tmp/llama2/ + model_type: LLAMA2 + +device: cpu +dtype: bf16 +seed: 1234 + +# Quantization Arguments +quantization_mode: f16a4w + +# Tokenizer arguments +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model + +# Generation arguments; defaults taken from gpt-fast +prompt: "Hello, my name is" +max_new_tokens: 300 +temperature: 0.8 +top_k: 300 diff --git a/recipes/configs/quantize.yaml b/recipes/configs/quantize.yaml new file mode 100644 index 0000000000..9e82bad2c5 --- /dev/null +++ b/recipes/configs/quantize.yaml @@ -0,0 +1,18 @@ + +# Model arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/llama2/ + checkpoint_files: [meta_model_0.pt] + output_dir: /tmp/llama2/ + model_type: LLAMA2 + +device: cuda +dtype: bf16 +seed: 1234 + +# Quantization Arguments +quantization_mode: 4w diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index a08f5017a7..e2bdbf4d93 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -124,9 +124,9 @@ class EleutherEvalRecipe(EvalRecipeInterface): def __init__(self, cfg: DictConfig) -> None: self._cfg = cfg - def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: + def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]: checkpointer = config.instantiate(checkpointer_cfg) - checkpoint_dict = checkpointer.load_checkpoint() + checkpoint_dict = checkpointer.load_checkpoint(weights_only=weights_only) return checkpoint_dict def setup(self) -> None: @@ -134,10 +134,13 @@ def setup(self) -> None: self._dtype = utils.get_dtype(dtype=self._cfg.dtype) self._limit = self._cfg.limit self._tasks = list(self._cfg.tasks) + self._quantization_mode = self._cfg.quantization_mode utils.set_seed(seed=self._cfg.seed) - ckpt_dict = self.load_checkpoint(self._cfg.checkpointer) + # weights_only needs to be False when loading a quantized model + weights_only = (self._quantization_mode is None) + ckpt_dict = self.load_checkpoint(self._cfg.checkpointer, weights_only=weights_only) self._model = self._setup_model( model_cfg=self._cfg.model, model_state_dict=ckpt_dict[utils.MODEL_KEY], @@ -150,13 +153,24 @@ def _setup_model( model_cfg: DictConfig, model_state_dict: Dict[str, Any], ) -> nn.Module: - with utils.set_default_dtype(self._dtype), self._device: - model = config.instantiate(model_cfg) - - model.load_state_dict(model_state_dict) + if self._quantization_mode is not None: + with torch.device("meta"): + model = config.instantiate(model_cfg) + quantizer = utils.get_quantizer(self._quantization_mode) + model = quantizer.quantize(model) + model.load_state_dict(model_state_dict, assign=True) + utils.reset_parameters(model) + model = model.to(device=self._device, dtype=self._dtype) + breakpoint() + else: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(model_cfg) + model.load_state_dict(model_state_dict, assign=True) # Validate model was loaded in with the expected dtype. - utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + # TODO: enable dtype checking for quantization + if self._quantization_mode is None: + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) logger.info(f"Model is initialized with precision {self._dtype}.") return model diff --git a/recipes/generate.py b/recipes/generate.py index 8a3043229d..42500c0c5e 100644 --- a/recipes/generate.py +++ b/recipes/generate.py @@ -28,16 +28,19 @@ class InferenceRecipe: def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) self._dtype = utils.get_dtype(dtype=cfg.dtype) + self._quantization_mode = cfg.quantization_mode utils.set_seed(seed=cfg.seed) - def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: + def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]: checkpointer = config.instantiate(checkpointer_cfg) - checkpoint_dict = checkpointer.load_checkpoint() + checkpoint_dict = checkpointer.load_checkpoint(weights_only=weights_only) return checkpoint_dict def setup(self, cfg: DictConfig) -> None: - ckpt_dict = self.load_checkpoint(cfg.checkpointer) + # weights_only needs to be False when loading a quantized model + weights_only = (self._quantization_mode is None) + ckpt_dict = self.load_checkpoint(cfg.checkpointer, weights_only=weights_only) self._model = self._setup_model( model_cfg=cfg.model, model_state_dict=ckpt_dict[utils.MODEL_KEY], @@ -52,10 +55,16 @@ def _setup_model( with utils.set_default_dtype(self._dtype), self._device: model = config.instantiate(model_cfg) - model.load_state_dict(model_state_dict) + if self._quantization_mode is not None: + quantizer = utils.get_quantizer(self._quantization_mode) + model = quantizer.quantize(model) + + model.load_state_dict(model_state_dict, assign=True) # Validate model was loaded in with the expected dtype. - utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + # TODO: enable this for quantization as well + if self._quantization_mode is None: + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) logger.info(f"Model is initialized with precision {self._dtype}.") # Ensure the cache is setup on the right device diff --git a/recipes/quantize.py b/recipes/quantize.py new file mode 100644 index 0000000000..1b2e95fc05 --- /dev/null +++ b/recipes/quantize.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import sys +import time +from typing import Any, Dict + +import torch +from omegaconf import DictConfig + +from torch import nn + +from torchtune import config, utils + +logger = utils.get_logger("DEBUG") + + +class QuantizationRecipe: + """ + Recipe for quantizing a Transformer-based LLM. + + Supported quantization modes are: + 8w: int8 weight only per axis group quantization + 4w: int4 weight only per axis group quantization + after torch 2.3.0: + 8da4w: int8 dynamic activation quantization and int4 weight per axis group quantization + 8da4w-gptq: int8 dynamic activation quantization and int4 weight per axis group quantization with GPTQ + 4w-gptq: int4 weight only per axis group quantization with GPTQ + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(dtype=cfg.dtype) + self._quantization_mode = cfg.quantization_mode + utils.set_seed(seed=cfg.seed) + + def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: + self._checkpointer = config.instantiate(checkpointer_cfg) + checkpoint_dict = self._checkpointer.load_checkpoint() + return checkpoint_dict + + def setup(self, cfg: DictConfig) -> None: + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + self._model = self._setup_model( + model_cfg=cfg.model, + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ) + + def _setup_model( + self, + model_cfg: DictConfig, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(model_cfg) + + model.load_state_dict(model_state_dict, assign=True) + + # Validate model was loaded in with the expected dtype. + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + logger.info(f"Model is initialized with precision {self._dtype}.") + + # Ensure the cache is setup on the right device + # with self._device: + # model.setup_caches(max_batch_size=1, dtype=self._dtype) + + return model + + @torch.no_grad() + def quantize(self, cfg: DictConfig): + quantizer = utils.get_quantizer(self._quantization_mode) + t0 = time.perf_counter() + self._model = quantizer.quantize(self._model) + t = time.perf_counter() - t0 + logger.info( + f"Time for quantization: {t:.02f} sec" + ) + logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + def save_checkpoint(self, cfg: DictConfig): + ckpt_dict = self._model.state_dict() + file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0] + torch.save(ckpt_dict, cfg.checkpointer.output_dir + file_name + "." + self._quantization_mode + ".pt") + + +@config.parse +def main(cfg: DictConfig) -> None: + recipe = QuantizationRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.quantize(cfg=cfg) + recipe.save_checkpoint(cfg=cfg) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/requirements.txt b/requirements.txt index ffcecfe799..2d645039e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ tqdm omegaconf # Quantization -torchao-nightly==2024.3.29 +torchao-nightly==2024.4.2 diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 20e578e306..3e4c572b5f 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -86,6 +86,7 @@ class Recipe: file_path="generate.py", configs=[ Config(name="generate", file_path="generate.yaml"), + Config(name="quant_generate", file_path="quant_generate.yaml"), ], supports_distributed=False, ), @@ -97,6 +98,14 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="quantize", + file_path="quantize.py", + configs=[ + Config(name="quantize", file_path="quantize.yaml"), + ], + supports_distributed=False, + ), ] diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 666711a34c..b08e2fdbab 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -53,6 +53,10 @@ validate_expected_param_dtype, ) from .seed import set_seed +from .quantization import ( + get_quantizer, + reset_parameters, +) __all__ = [ "save_checkpoint", @@ -80,4 +84,5 @@ "OptimizerInBackwardWrapper", "create_optim_in_bwd_wrapper", "register_optim_in_bwd_hooks", + "get_quantizer", ] diff --git a/torchtune/utils/quantization.py b/torchtune/utils/quantization.py new file mode 100644 index 0000000000..e3011c7b21 --- /dev/null +++ b/torchtune/utils/quantization.py @@ -0,0 +1,52 @@ +from typing import Any +import torch +from torchao.quantization.quant_api import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_woqtensors, + Quantizer, + Int4WeightOnlyGPTQQuantizer, +) +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 + +class Int4WeightOnlyQuantizer(Quantizer): + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + change_linear_weights_to_int4_woqtensors(model) + return model + + +class Int8WeightOnlyQuantizer(Quantizer): + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + change_linear_weights_to_int8_woqtensors(model) + return model + + +def get_quantizer(quantization_mode, *args, **kwargs): + qmode_to_quantizer = { + # TODO: change to 4w before land + "4w": Int4WeightOnlyQuantizer, + "8w": Int8WeightOnlyQuantizer, + "4w-gptq": Int4WeightOnlyGPTQQuantizer, + } + if TORCH_VERSION_AFTER_2_3: + from torchao.quantization.quant_api import ( + Int8DynActInt4WeightQuantizer, + Int8DynActInt4WeightGPTQQuantizer, + ) + + qmode_to_quantizer |= { + "8da4w": Int8DynActInt4WeightQuantizer, + # TODO: merge into 8da4w + "8da4w-gptq": Int8DynActInt4WeightGPTQQuantizer, + } + if quantization_mode not in qmode_to_quantizer: + raise ValueError(f"Unsupported quantization mode: {quantization_mode}, supported modes are: {qmode_to_quantizer.keys()}") + return qmode_to_quantizer[quantization_mode](*args, **kwargs) + +def reset_parameters(model: torch.nn.Module): + for name, module in model.named_modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters()