Skip to content

Commit

Permalink
Adding quantization support in torchtune
Browse files Browse the repository at this point in the history
Summary:
Allows user to specify quantization_mode in generating model in full_finetune_single_device.py
and inference with the quantized model in generate.py

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 4, 2024
1 parent 32d66df commit 7137b6d
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 16 deletions.
115 changes: 115 additions & 0 deletions recipes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,118 @@ python3 convert.py /tmp/llama2/meta_model_0.pt --ctx 4096
```

This will output a gguf file in the same precision which can be used for running inference.

### Architecture Optimization

TorchTune integrates with `torchao`(https://github.com/pytorch-labs/ao/) for architecture optimization techniques including quantization and sparsity. Currently only some quantization techniques are integrated, see `receipes/configs/quantize.yaml` for more details.

#### Quantize
To quantize a model (default is int4 weight only quantization):
```
tune run quantize --config quantize
```

#### Eval
To evaluate a quantized model, add the following to `receipes/configs/eleuther_eval.yaml`:


```
# make sure to change the checkpointer component
checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
# Quantization specific args
quantizer:
_component_: torchtune.utils.Int4WeightOnlyQuantizer
groupsize: 256
```

and run the eval command:
```
tune run eleuther_eval --config eleuther_eval
```

#### Generate
Changes in `receipes/configs/generate.yaml`
```
# Model arguments
checkpointer:
# make sure to change the checkpointer component
checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_files: [meta_model_0.4w.pt]
# Quantization Arguments
quantizer:
_component_: torchtune.utils.Int4WeightOnlyQuantizer
groupsize: 256
```

and run generate command:
```
tune run generate --config generate
```

#### GPTQ

Specifically for GPTQ, here are the changes that's needed:

`receipes/configs/quantize.yaml`

We'll publish doc pages for different quantizers in torchao a bit later. For int4 weight only gptq quantizer, here is a brief description of what each argument menas:

```
quantizer:
_component_: torchtune.utils.quantization.Int4WeightOnlyGPTQQuantizer
blocksize: 128
percdamp: 0.01
groupsize: 256
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model
```

`receipes/quantize.py`

```
def quantize(self, cfg: DictConfig):
from torchao.quantization.GPTQ import InputRecorder
tokenizer = config.instantiate(cfg.tokenizer)
calibration_seq_length = 100
calibration_tasks = ['wikitext']
inputs = InputRecorder(
tokenizer,
calibration_seq_length,
vocab_size=self._model.tok_embeddings.weight.shape[0],
device="cpu",
).record_inputs(
calibration_tasks,
5, # calibration_limit
).get_inputs()
t0 = time.perf_counter()
self._model = self._quantizer.quantize(self._model, inputs)
....
```

Run quantize
```
tune run quantize --config quantize
```

`recipes/eleuther_eval.py`

```
# to skip running through GPTQ, change model = quantizer.quantize(model) to:
model = quantizer._convert_for_runtime(model)
```

`recipes/configs/eleuther_eval.yaml`
```
quantizer:
_component_: torchtune.utils.quantization.Int4WeightOnlyGPTQQuantizer
blocksize: 128
percdamp: 0.01
groupsize: 256
```

7 changes: 6 additions & 1 deletion recipes/configs/eleuther_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ dtype: bf16
seed: 217

# EleutherAI specific eval args
tasks: ["truthfulqa_mc2"]
tasks: ["truthfulqa_mc2", "hellaswag", "wikitext"]
limit: null
max_seq_length: 4096

# Quantization specific args
quantizer: null


2 changes: 2 additions & 0 deletions recipes/configs/generate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ prompt: "Hello, my name is"
max_new_tokens: 300
temperature: 0.8
top_k: 300

quantizer: null
52 changes: 52 additions & 0 deletions recipes/configs/quantize.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Config for QuantizationRecipe in quantize.py
#
# To launch, run the following command from root torchtune directory:
# tune quantize --config quantize
#
# Supported quantization modes are:
# 8w:
# torchtune.utils.quantization.Int8WeightOnlyQuantizer
# int8 weight only per axis group quantization
#
# 4w:
# torchtune.utils.quantization.Int4WeightOnlyQuantizer
# int4 weight only per axis group quantization
# Args:
# `groupsize` (int): a parameter of int4 weight only quantization,
# it refers to the size of quantization groups which get independent quantization parameters
# e.g. 32, 64, 128, 256, smaller numbers means more fine grained and higher accuracy
#
# 4w-gptq:
# torchtune.utils.quantization.Int4WeightOnlyGPTQQuantizer
# int4 weight only per axis group quantization with GPTQ
# Args:
# `groupsize`: see description in `4w`
# `blocksize`: GPTQ is applied to a 'block' of columns at a time,
# larger blocks trade off memory for perf, recommended to be a constant
# multiple of groupsize.
# `percdamp`: GPTQ stablization hyperparameter, recommended to be .01
#
# future note: blocksize and percdamp should not have to be 'known' by users by default.
# Similar to momentum constant in MovingAverageObserver, it can be tuned,
# but 99% of users don't need to pay attention to it. blocksize should probably be set at
# max(`groupsize`, 128) and percdamp at .01

#
# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/llama2/
checkpoint_files: [meta_model_0.pt]
output_dir: /tmp/llama2/
model_type: LLAMA2

device: cuda
dtype: bf16
seed: 1234

quantizer:
_component_: torchtune.utils.quantization.Int4WeightOnlyQuantizer
groupsize: 256
30 changes: 23 additions & 7 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,28 @@ class EleutherEvalRecipe(EvalRecipeInterface):
def __init__(self, cfg: DictConfig) -> None:
self._cfg = cfg

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
def load_checkpoint(self, checkpointer_cfg: DictConfig, weights_only: bool = True) -> Dict[str, Any]:
checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = checkpointer.load_checkpoint()
checkpoint_dict = checkpointer.load_checkpoint(weights_only=weights_only)
return checkpoint_dict

def setup(self) -> None:
self._device = utils.get_device(device=self._cfg.device)
self._dtype = utils.get_dtype(dtype=self._cfg.dtype)
self._limit = self._cfg.limit
self._tasks = list(self._cfg.tasks)
self._quantizer = config.instantiate(self._cfg.quantizer)
self._quantization_mode = utils.get_quantizer_mode(self._quantizer)

utils.set_seed(seed=self._cfg.seed)

ckpt_dict = self.load_checkpoint(self._cfg.checkpointer)
checkpointer = config.instantiate(self._cfg.checkpointer)
if self._quantization_mode is None:
ckpt_dict = checkpointer.load_checkpoint()
else:
# weights_only needs to be False when loading a quantized model
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)

self._model = self._setup_model(
model_cfg=self._cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
Expand All @@ -150,13 +158,21 @@ def _setup_model(
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
if self._quantization_mode is not None:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict)
model = self._quantizer.quantize(model)
model = model.to(device=self._device, dtype=self._dtype)
model.load_state_dict(model_state_dict)
else:
print("non quant path")
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
# TODO (before landing): enable dtype checking for quantization
if self._quantization_mode is None:
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")
return model

Expand Down
23 changes: 16 additions & 7 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ class InferenceRecipe:
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._quantizer = config.instantiate(cfg.quantizer)
self._quantization_mode = utils.get_quantizer_mode(self._quantizer)

utils.set_seed(seed=cfg.seed)

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = checkpointer.load_checkpoint()
return checkpoint_dict

def setup(self, cfg: DictConfig) -> None:
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
checkpointer = config.instantiate(cfg.checkpointer)
if self._quantization_mode is None:
ckpt_dict = checkpointer.load_checkpoint()
else:
# weights_only needs to be False when loading a quantized model
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)

self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
Expand All @@ -52,10 +55,16 @@ def _setup_model(
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

if self._quantization_mode is not None:
model = self._quantizer.quantize(model)
model = model.to(device=self._device, dtype=self._dtype)

model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
# TODO: enable this for quantization as well
if self._quantization_mode is None:
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")

# Ensure the cache is setup on the right device
Expand Down
88 changes: 88 additions & 0 deletions recipes/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
from typing import Any, Dict

import torch
from omegaconf import DictConfig

from torch import nn

from torchtune import config, utils

logger = utils.get_logger("DEBUG")


class QuantizationRecipe:
"""
Recipe for quantizing a Transformer-based LLM.
Uses quantizer classes from torchao to quantize a model.
Please refer to `receipes/configs/quantize.yaml` for supported quantizers and how to use them.
"""

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._quantizer = config.instantiate(cfg.quantizer)
self._quantization_mode = utils.get_quantizer_mode(self._quantizer)
utils.set_seed(seed=cfg.seed)

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
self._checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict

def setup(self, cfg: DictConfig) -> None:
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
)

def _setup_model(
self,
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

model.load_state_dict(model_state_dict, assign=True)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
logger.info(f"Model is initialized with precision {self._dtype}.")
return model

@torch.no_grad()
def quantize(self, cfg: DictConfig):
t0 = time.perf_counter()
self._model = self._quantizer.quantize(self._model)
t = time.perf_counter() - t0
logger.info(
f"Time for quantization: {t:.02f} sec"
)
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

def save_checkpoint(self, cfg: DictConfig):
ckpt_dict = self._model.state_dict()
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0]
quantized_file_name = cfg.checkpointer.output_dir + file_name + "." + self._quantization_mode + ".pt"
torch.save(ckpt_dict, quantized_file_name)
logger.info(f"Saved quantized model to {quantized_file_name}")


@config.parse
def main(cfg: DictConfig) -> None:
recipe = QuantizationRecipe(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.quantize(cfg=cfg)
recipe.save_checkpoint(cfg=cfg)


if __name__ == "__main__":
sys.exit(main())
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ tqdm
omegaconf

# Quantization
torchao-nightly==2024.3.29
torchao==0.1
Loading

0 comments on commit 7137b6d

Please sign in to comment.