Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions examples/autoround/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,39 @@ ds = get_dataset(
### 3) Apply Quantization

With the dataset ready, we will now apply AutoRound quantization to the model.
Add `--fp8_kv` when running the script if you want to quantize the kv cache.

```python
from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier

# Configure the quantization algorithm to run.
recipe = AutoRoundModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200
)
if args.fp8_kv:
recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
AutoRoundModifier:
targets: [Linear]
scheme: W4A16
ignore: [lm_head]
iters: 200
"""
else:
recipe = """
quant_stage:
quant_modifiers:
AutoRoundModifier:
targets: [Linear]
scheme: W4A16
ignore: [lm_head]
iters: 200
"""

# Apply quantization.
oneshot(
Expand Down Expand Up @@ -116,6 +140,7 @@ Run the following to test accuracy on GSM-8K:

```bash
lm_eval --model vllm \
# If KV cache is quantized, add 'kv_cache_dtype=fp8' to the --model_args below.
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound",add_bos_token=true \
--tasks gsm8k \
--num_fewshot 5 \
Expand All @@ -125,11 +150,19 @@ lm_eval --model vllm \

We can see the resulting scores look good!

```bash
```text
w/o kv cache quantization:
| Tasks | Version | Filter | n-shot | Metric | | Value | | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | ----: | --- | -----: |
| gsm8k | 3 | flexible-extract | 5 | exact_match | ↑ | 0.737 | ± | 0.0139 |
| | | strict-match | 5 | exact_match | ↑ | 0.736 | ± | 0.0139 |

w/ kv cache quantization:

| Tasks | Version | Filter | n-shot | Metric | | Value | | Stderr |
| ----- | ------: | ---------------- | -----: | ----------- | --- | ----: | --- | -----: |
| gsm8k | 3 | flexible-extract | 5 | exact_match | ↑ | 0.740 | ± | 0.0139 |
| | | strict-match | 5 | exact_match | ↑ | 0.742 | ± | 0.0138 |
```
> Note: quantized model accuracy may vary slightly due to nondeterminism.

Expand Down
40 changes: 34 additions & 6 deletions examples/autoround/llama3_example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import argparse
from auto_round.calib_dataset import get_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from llmcompressor.utils import dispatch_for_generation


parser = argparse.ArgumentParser()
parser.add_argument("--fp8_kv", action="store_true")
args = parser.parse_args()

# Select model and load it.
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
Expand All @@ -21,13 +26,36 @@
nsamples=NUM_CALIBRATION_SAMPLES,
)


# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with AutoRound with a group size 128
recipe = AutoRoundModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200
)

# * quantize the kv cache to fp8
if args.fp8_kv:
recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
AutoRoundModifier:
targets: [Linear]
scheme: W4A16
ignore: [lm_head]
iters: 200
"""
else:
recipe = """
quant_stage:
quant_modifiers:
AutoRoundModifier:
targets: [Linear]
scheme: W4A16
ignore: [lm_head]
iters: 200
"""

# Apply algorithms.
oneshot(
Expand Down
Loading