Skip to content

Commit

Permalink
move code over from sparseml PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Jun 24, 2024
1 parent 06e3131 commit 5722531
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 8 deletions.
38 changes: 38 additions & 0 deletions examples/quantization/llama7b_fp8_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot

model_stub = "meta-llama/Meta-Llama-3-8B-Instruct"
output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed"
num_calibration_samples = 512

tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token


def preprocess(batch):
text = tokenizer.apply_chat_template(batch["messages"], tokenize=False)
tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048)
return tokenized


ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = ds.map(preprocess, remove_columns=ds.column_names)

recipe = GPTQModifier(targets=["Linear"], scheme="FP8", ignore=["lm_head"])

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

oneshot(
model=model,
dataset=examples,
recipe=recipe,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)
20 changes: 12 additions & 8 deletions src/llmcompressor/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def infer_quantization_format(
return quantization_format

if save_compressed:
quant_depths = _get_quant_depths(model)
if quant_depths == [4]: # save packed if everything is int4
quant_types = _get_quant_types(model)
if quant_types == ["int4"]: # save packed if everything is int4
return CompressionFormat.pack_quantized
elif quant_types == ["float8"]:
return CompressionFormat.float_quantized

# otherwise just quantize to int8
return CompressionFormat.int_quantized
Expand All @@ -56,17 +58,19 @@ def infer_quantization_format(
return None


def _get_quant_depths(model):
def _get_quant_types(model):
"""
Gets a list of all the quantized bit depths present in model
Gets a list of all the quantized types present in model
"""
quant_depths = []
quant_info = []
for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
if weight_scheme is not None:
weight_bit_depth = weight_scheme.num_bits
if weight_bit_depth not in quant_depths:
quant_depths.append(weight_bit_depth)
weight_type = weight_scheme.type
weight_info = f"{weight_type}{weight_bit_depth}"
if weight_info not in quant_info:
quant_info.append(weight_info)

return quant_depths
return quant_info
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml"
ppl_threshold: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "commit"
test_type: "regression"
model_stub: "Xenova/llama2.c-stories15M"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml"
ppl_threshold: 5000
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: "float"
symmetric: true
strategy: channel
input_activations:
num_bits: 8
type: "float"
symmetric: true
dynamic: true
strategy: token
targets: ["Linear"]

0 comments on commit 5722531

Please sign in to comment.