-
Notifications
You must be signed in to change notification settings - Fork 513
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding quantization support in torchtune
Summary: Allows user to specify quantization_mode in generating model in full_finetune_single_device.py and inference with the quantized model in generate.py Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
1 parent
32d66df
commit 7137b6d
Showing
11 changed files
with
349 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,5 @@ prompt: "Hello, my name is" | |
max_new_tokens: 300 | ||
temperature: 0.8 | ||
top_k: 300 | ||
|
||
quantizer: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Config for QuantizationRecipe in quantize.py | ||
# | ||
# To launch, run the following command from root torchtune directory: | ||
# tune quantize --config quantize | ||
# | ||
# Supported quantization modes are: | ||
# 8w: | ||
# torchtune.utils.quantization.Int8WeightOnlyQuantizer | ||
# int8 weight only per axis group quantization | ||
# | ||
# 4w: | ||
# torchtune.utils.quantization.Int4WeightOnlyQuantizer | ||
# int4 weight only per axis group quantization | ||
# Args: | ||
# `groupsize` (int): a parameter of int4 weight only quantization, | ||
# it refers to the size of quantization groups which get independent quantization parameters | ||
# e.g. 32, 64, 128, 256, smaller numbers means more fine grained and higher accuracy | ||
# | ||
# 4w-gptq: | ||
# torchtune.utils.quantization.Int4WeightOnlyGPTQQuantizer | ||
# int4 weight only per axis group quantization with GPTQ | ||
# Args: | ||
# `groupsize`: see description in `4w` | ||
# `blocksize`: GPTQ is applied to a 'block' of columns at a time, | ||
# larger blocks trade off memory for perf, recommended to be a constant | ||
# multiple of groupsize. | ||
# `percdamp`: GPTQ stablization hyperparameter, recommended to be .01 | ||
# | ||
# future note: blocksize and percdamp should not have to be 'known' by users by default. | ||
# Similar to momentum constant in MovingAverageObserver, it can be tuned, | ||
# but 99% of users don't need to pay attention to it. blocksize should probably be set at | ||
# max(`groupsize`, 128) and percdamp at .01 | ||
|
||
# | ||
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelMetaCheckpointer | ||
checkpoint_dir: /tmp/llama2/ | ||
checkpoint_files: [meta_model_0.pt] | ||
output_dir: /tmp/llama2/ | ||
model_type: LLAMA2 | ||
|
||
device: cuda | ||
dtype: bf16 | ||
seed: 1234 | ||
|
||
quantizer: | ||
_component_: torchtune.utils.quantization.Int4WeightOnlyQuantizer | ||
groupsize: 256 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import sys | ||
import time | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from omegaconf import DictConfig | ||
|
||
from torch import nn | ||
|
||
from torchtune import config, utils | ||
|
||
logger = utils.get_logger("DEBUG") | ||
|
||
|
||
class QuantizationRecipe: | ||
""" | ||
Recipe for quantizing a Transformer-based LLM. | ||
Uses quantizer classes from torchao to quantize a model. | ||
Please refer to `receipes/configs/quantize.yaml` for supported quantizers and how to use them. | ||
""" | ||
|
||
def __init__(self, cfg: DictConfig) -> None: | ||
self._device = utils.get_device(device=cfg.device) | ||
self._dtype = utils.get_dtype(dtype=cfg.dtype) | ||
self._quantizer = config.instantiate(cfg.quantizer) | ||
self._quantization_mode = utils.get_quantizer_mode(self._quantizer) | ||
utils.set_seed(seed=cfg.seed) | ||
|
||
def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: | ||
self._checkpointer = config.instantiate(checkpointer_cfg) | ||
checkpoint_dict = self._checkpointer.load_checkpoint() | ||
return checkpoint_dict | ||
|
||
def setup(self, cfg: DictConfig) -> None: | ||
ckpt_dict = self.load_checkpoint(cfg.checkpointer) | ||
self._model = self._setup_model( | ||
model_cfg=cfg.model, | ||
model_state_dict=ckpt_dict[utils.MODEL_KEY], | ||
) | ||
|
||
def _setup_model( | ||
self, | ||
model_cfg: DictConfig, | ||
model_state_dict: Dict[str, Any], | ||
) -> nn.Module: | ||
with utils.set_default_dtype(self._dtype), self._device: | ||
model = config.instantiate(model_cfg) | ||
|
||
model.load_state_dict(model_state_dict, assign=True) | ||
|
||
# Validate model was loaded in with the expected dtype. | ||
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) | ||
logger.info(f"Model is initialized with precision {self._dtype}.") | ||
return model | ||
|
||
@torch.no_grad() | ||
def quantize(self, cfg: DictConfig): | ||
t0 = time.perf_counter() | ||
self._model = self._quantizer.quantize(self._model) | ||
t = time.perf_counter() - t0 | ||
logger.info( | ||
f"Time for quantization: {t:.02f} sec" | ||
) | ||
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") | ||
|
||
def save_checkpoint(self, cfg: DictConfig): | ||
ckpt_dict = self._model.state_dict() | ||
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0] | ||
quantized_file_name = cfg.checkpointer.output_dir + file_name + "." + self._quantization_mode + ".pt" | ||
torch.save(ckpt_dict, quantized_file_name) | ||
logger.info(f"Saved quantized model to {quantized_file_name}") | ||
|
||
|
||
@config.parse | ||
def main(cfg: DictConfig) -> None: | ||
recipe = QuantizationRecipe(cfg=cfg) | ||
recipe.setup(cfg=cfg) | ||
recipe.quantize(cfg=cfg) | ||
recipe.save_checkpoint(cfg=cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,4 @@ tqdm | |
omegaconf | ||
|
||
# Quantization | ||
torchao-nightly==2024.3.29 | ||
torchao==0.1 |
Oops, something went wrong.