Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Examples ] E2E Examples #5

Merged
merged 30 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a8c3ad8
added examples
robertgshaw2-neuralmagic Jun 24, 2024
539d31a
updated examples
robertgshaw2-neuralmagic Jun 24, 2024
6f298a7
set to 32 samples for testing
robertgshaw2-neuralmagic Jun 24, 2024
cfc1ec0
fix
robertgshaw2-neuralmagic Jun 25, 2024
82e8910
Update llama7b_quantize_sparse_cnn.py
robertgshaw2-neuralmagic Jun 25, 2024
62f8011
Merge branch 'main' into rs/examples
robertgshaw2-neuralmagic Jun 25, 2024
af0be23
tweak W8A8
robertgshaw2-neuralmagic Jun 25, 2024
931c504
firx w4a16
robertgshaw2-neuralmagic Jun 26, 2024
e12b65e
added example
robertgshaw2-neuralmagic Jun 27, 2024
982e3ee
tweak fp8 example
Jun 27, 2024
5971dce
remove changes
Jun 27, 2024
438b01e
fix
Jun 27, 2024
8822f3c
update examples to use tokenized data
Jun 27, 2024
a6bcb90
save
Jun 27, 2024
466cdb6
Merge branch 'main' into rs/examples
robertgshaw2-neuralmagic Jul 2, 2024
f430e43
fp8 example end to end
robertgshaw2-neuralmagic Jul 2, 2024
b0eaf12
tweak README
robertgshaw2-neuralmagic Jul 2, 2024
a020ebe
rename title
robertgshaw2-neuralmagic Jul 2, 2024
7c58ff4
update title
robertgshaw2-neuralmagic Jul 2, 2024
556eca2
finished example
robertgshaw2-neuralmagic Jul 2, 2024
39f2ef0
refactored directory structure
robertgshaw2-neuralmagic Jul 2, 2024
284a0f0
nits
robertgshaw2-neuralmagic Jul 2, 2024
2da06f9
restructure w4a16
robertgshaw2-neuralmagic Jul 2, 2024
367fb0f
fixed w4a16
robertgshaw2-neuralmagic Jul 2, 2024
956e1a4
added w8a8-int8 example
robertgshaw2-neuralmagic Jul 2, 2024
5911c45
finalized example
robertgshaw2-neuralmagic Jul 2, 2024
3d4d03b
added back example
robertgshaw2-neuralmagic Jul 2, 2024
d600009
stash
robertgshaw2-neuralmagic Jul 2, 2024
708f288
format
robertgshaw2-neuralmagic Jul 2, 2024
59ea79e
Update examples/quantization_w4a16/README.md
robertgshaw2-neuralmagic Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/quantization/example-w4a16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from datasets import load_dataset
from transformers import AutoTokenizer
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot


# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES=512
MAX_SEQUENCE_LENGTH=2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {"text": tokenizer.apply_chat_template(
example["messages"], tokenize=False,
)}
ds = ds.map(preprocess)

# Configure algorithms. In this case, we:
# * apply SmoothQuant to make the activations easier to quantize
# * quantize the weights to 8 bit with GPTQ with a static per channel strategy
# * quantize the activations to 8 bit with a dynamic per token strategy
Satrat marked this conversation as resolved.
Show resolved Hide resolved
recipe = """
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "group"
group_size: 128
targets: ["Linear"]
"""

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
82 changes: 82 additions & 0 deletions examples/quantization/example-w8a8-int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from datasets import load_dataset
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from transformers import AutoTokenizer

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype='auto',
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES=512
MAX_SEQUENCE_LENGTH=2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {"text": tokenizer.apply_chat_template(
example["messages"], tokenize=False,
)}
ds = ds.map(preprocess)

# Configure algorithms. In this case, we:
# * apply SmoothQuant to make the activations easier to quantize
# * quantize the weights to 8 bit with GPTQ with a static per channel strategy
# * quantize the activations to 8 bit with a dynamic per token strategy
recipe = """
quant_stage:
quant_modifiers:
SmoothQuantModifier:
smoothing_strength: 0.8
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
]
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: "int"
symmetric: true
strategy: "channel"
input_activations:
num_bits: 8
type: "int"
symmetric: true
dynamic: true
strategy: "token"
targets: ["Linear"]
"""

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
181 changes: 0 additions & 181 deletions examples/quantization/llama7b_w4a16_quantization.ipynb

This file was deleted.

52 changes: 0 additions & 52 deletions examples/quantization/llama7b_w4a16_quantization.py

This file was deleted.

Loading
Loading