Skip to content

Commit

Permalink
Merge pull request #4 from vllm-project/fp8_support
Browse files Browse the repository at this point in the history
Add FP8 Support
  • Loading branch information
Sara Adkins authored Jun 25, 2024
2 parents 7d358a1 + fab9363 commit dcca2be
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,4 @@ nm_temp_test_logs/*
sparse_logs/*
wandb/
output_finetune/
env_log.json
49 changes: 49 additions & 0 deletions examples/quantization/llama7b_fp8_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationType,
)
from datasets import load_dataset
from transformers import AutoTokenizer

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot

model_stub = "meta-llama/Meta-Llama-3-8B-Instruct"
output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed"
num_calibration_samples = 512

tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token


def preprocess(batch):
text = tokenizer.apply_chat_template(batch["messages"], tokenize=False)
tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048)
return tokenized


ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = ds.map(preprocess, remove_columns=ds.column_names)

quant_args = QuantizationArgs(type=QuantizationType.FLOAT)
quant_scheme = QuantizationScheme(
weights=quant_args, input_activations=quant_args, targets=["Linear"]
)
recipe = QuantizationModifier(
config_groups={"group_0": quant_scheme}, ignore=["lm_head"]
)

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

oneshot(
model=model,
dataset=examples,
recipe=recipe,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)
20 changes: 12 additions & 8 deletions src/llmcompressor/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def infer_quantization_format(
return quantization_format

if save_compressed:
quant_depths = _get_quant_depths(model)
if quant_depths == [4]: # save packed if everything is int4
quant_types = _get_quant_types(model)
if quant_types == ["int4"]: # save packed if everything is int4
return CompressionFormat.pack_quantized
elif quant_types == ["float8"]:
return CompressionFormat.float_quantized

# otherwise just quantize to int8
return CompressionFormat.int_quantized
Expand All @@ -41,17 +43,19 @@ def infer_quantization_format(
return None


def _get_quant_depths(model):
def _get_quant_types(model):
"""
Gets a list of all the quantized bit depths present in model
Gets a list of all the quantized types present in model
"""
quant_depths = []
quant_info = []
for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
if weight_scheme is not None:
weight_bit_depth = weight_scheme.num_bits
if weight_bit_depth not in quant_depths:
quant_depths.append(weight_bit_depth)
weight_type = weight_scheme.type
weight_info = f"{weight_type}{weight_bit_depth}"
if weight_info not in quant_info:
quant_info.append(weight_info)

return quant_depths
return quant_info
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml"
ppl_threshold: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "commit"
test_type: "regression"
model_stub: "Xenova/llama2.c-stories15M"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml"
ppl_threshold: 21000
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: "float"
symmetric: true
strategy: channel
targets: ["Linear"]
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_quantization_reload(self):
n_scale, n_zp, n_weight = reloaded_weights[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == torch.int8
assert o_zp.dtype == n_zp.dtype
assert torch.equal(o_zp, n_zp)

# we don't expect an exact match here because o_weight still has the
Expand All @@ -119,7 +119,7 @@ def test_quantization_reload(self):
n_scale, n_zp = reloaded_inputs[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == torch.int8
assert o_zp.dtype == n_zp.dtype
assert torch.equal(o_zp, n_zp)

def _get_dataloader(self, data_args, tokenizer):
Expand Down

0 comments on commit dcca2be

Please sign in to comment.