From 5722531f70c55608a9d224f29a9a2d47f4f73105 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 24 Jun 2024 21:35:35 +0000 Subject: [PATCH] move code over from sparseml PR --- .../quantization/llama7b_fp8_quantization.py | 38 +++++++++++++++++++ .../compression/quantization_format.py | 20 ++++++---- .../compression/configs/fp8_1.1b.yaml | 5 +++ .../compression/configs/fp8_15m.yaml | 5 +++ .../compression/recipes/new_quant_fp8.yaml | 19 ++++++++++ 5 files changed, 79 insertions(+), 8 deletions(-) create mode 100644 examples/quantization/llama7b_fp8_quantization.py create mode 100644 tests/llmcompressor/transformers/compression/configs/fp8_1.1b.yaml create mode 100644 tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml create mode 100644 tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml diff --git a/examples/quantization/llama7b_fp8_quantization.py b/examples/quantization/llama7b_fp8_quantization.py new file mode 100644 index 000000000..c876ee2c1 --- /dev/null +++ b/examples/quantization/llama7b_fp8_quantization.py @@ -0,0 +1,38 @@ +import torch +from datasets import load_dataset +from transformers import AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot + +model_stub = "meta-llama/Meta-Llama-3-8B-Instruct" +output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed" +num_calibration_samples = 512 + +tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True) +tokenizer.pad_token = tokenizer.eos_token + + +def preprocess(batch): + text = tokenizer.apply_chat_template(batch["messages"], tokenize=False) + tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048) + return tokenized + + +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") +examples = ds.map(preprocess, remove_columns=ds.column_names) + +recipe = GPTQModifier(targets=["Linear"], scheme="FP8", ignore=["lm_head"]) + +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) + +oneshot( + model=model, + dataset=examples, + recipe=recipe, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + save_compressed=True, +) diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index 2f9cf409f..01427e38a 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -45,9 +45,11 @@ def infer_quantization_format( return quantization_format if save_compressed: - quant_depths = _get_quant_depths(model) - if quant_depths == [4]: # save packed if everything is int4 + quant_types = _get_quant_types(model) + if quant_types == ["int4"]: # save packed if everything is int4 return CompressionFormat.pack_quantized + elif quant_types == ["float8"]: + return CompressionFormat.float_quantized # otherwise just quantize to int8 return CompressionFormat.int_quantized @@ -56,17 +58,19 @@ def infer_quantization_format( return None -def _get_quant_depths(model): +def _get_quant_types(model): """ - Gets a list of all the quantized bit depths present in model + Gets a list of all the quantized types present in model """ - quant_depths = [] + quant_info = [] for _, submodule in iter_named_leaf_modules(model): if is_module_quantized(submodule): weight_scheme = submodule.quantization_scheme.weights if weight_scheme is not None: weight_bit_depth = weight_scheme.num_bits - if weight_bit_depth not in quant_depths: - quant_depths.append(weight_bit_depth) + weight_type = weight_scheme.type + weight_info = f"{weight_type}{weight_bit_depth}" + if weight_info not in quant_info: + quant_info.append(weight_info) - return quant_depths + return quant_info diff --git a/tests/llmcompressor/transformers/compression/configs/fp8_1.1b.yaml b/tests/llmcompressor/transformers/compression/configs/fp8_1.1b.yaml new file mode 100644 index 000000000..9bda49378 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/configs/fp8_1.1b.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml" +ppl_threshold: 20 \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml b/tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml new file mode 100644 index 000000000..181351c05 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml @@ -0,0 +1,5 @@ +cadence: "commit" +test_type: "regression" +model_stub: "Xenova/llama2.c-stories15M" +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml" +ppl_threshold: 5000 \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml b/tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml new file mode 100644 index 000000000..54b248716 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml @@ -0,0 +1,19 @@ +quant_stage: + quant_modifiers: + GPTQModifier: + sequential_update: false + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "float" + symmetric: true + strategy: channel + input_activations: + num_bits: 8 + type: "float" + symmetric: true + dynamic: true + strategy: token + targets: ["Linear"] \ No newline at end of file