Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ Activation Ordering #94

Merged
merged 64 commits into from
Aug 28, 2024
Merged

GPTQ Activation Ordering #94

merged 64 commits into from
Aug 28, 2024

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Aug 16, 2024

Summary

Add support for compressed-tensors models which have been quantized using activation ordering (group-wise quantization in decreasing order of activation)

Usage Script

compress_actorder.py
import os
import pickle
import datetime
from datasets import load_dataset
from transformers import AutoTokenizer
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot

def get_current_time():
    now = datetime.datetime.now()
    formatted_time = now.strftime("%Y%m%d_%H%M%S")
    return str(formatted_time)

# Select model and load it.
MODEL_ID="Qwen/Qwen2-0.5B-Instruct"
model = SparseAutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "openai/gsm8k"
DATASET_SUBSET = "main"
DATASET_SPLIT = "train"
PICKLE_FILE = "pickle.pkl"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 512

# Load dataset and preprocess.
def preprocess(example):
    return tokenizer.apply_chat_template(
        {
            "role": "user",
            "content": example["question"],
        },
        tokenize=False,
        add_generation_prompt=True
    )

# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["question"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )

# Check if the preprocessed dataset is already saved
if os.path.exists(PICKLE_FILE):
    # Load the dataset from the pickle file
    with open(PICKLE_FILE, "rb") as f:
        ds = pickle.load(f)
    print("Loaded dataset from pickle file.")
else:
    # Load and preprocess the dataset
    ds = load_dataset(DATASET_ID, DATASET_SUBSET, split=DATASET_SPLIT)
    ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
    ds = ds.map(tokenize, remove_columns=ds.column_names)

    # Save the preprocessed dataset to a pickle file
    with open(PICKLE_FILE, "wb") as f:
        pickle.dump(ds, f)
    print("Saved preprocessed dataset to pickle file.")

recipe = """
    quant_stage:
        quant_modifiers:
            GPTQModifier:
                sequential_update: false
                ignore: ["lm_head"]
                config_groups:
                    group_0:
                        weights:
                            num_bits: 4
                            type: "int"
                            symmetric: true
                            strategy: "group"
                            group_size: 128
                            actorder: True
                        targets: ["Linear"]
"""
# Apply algorithm
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# save model
SAVE_DIR = "actorder" + get_current_time()
print(SAVE_DIR)
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR, save_compressed=True)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=50)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

Evaluation

Accuracy

Full Precision

vllm (pretrained=Qwen/Qwen2-0.5B-Instruct,add_bos_token=True), gen_kwargs: (None), limit: 1000.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.387|?  |0.0154|
|     |       |strict-match    |     5|exact_match|?  |0.385|?  |0.0154|

Group Quantization Only

vllm (pretrained=/home/ksayers/llm-compressor/gwen_group,add_bos_token=True), gen_kwargs: (None), limit: 1000.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.226|?  |0.0132|
|     |       |strict-match    |     5|exact_match|?  |0.212|?  |0.0129|

Group Quantization Only on main (regression test)

vllm (pretrained=/home/ksayers/llm-compressor/gwen_regression,add_bos_token=True), gen_kwargs: (None), limit: 1000.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.226|?  |0.0132|
|     |       |strict-match    |     5|exact_match|?  |0.212|?  |0.0129|

Activation Ordering

vllm (pretrained=/home/ksayers/llm-compressor/gwen_actorder,add_bos_token=True), gen_kwargs: (None), limit: 1000.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|                                                   
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|                                                   
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.235|?  |0.0134|                                                   
|     |       |strict-match    |     5|exact_match|?  |0.231|?  |0.0133|

Latency Regression

Namespace(model='/home/ksayers/llm-compressor/gwen_actorder/', speculative_model=None, num_speculative
_tokens=None, speculative_draft_tensor_parallel_size=None, tokenizer=None, quantization=None, tensor_p
arallel_size=1, input_len=32, output_len=128, batch_size=32, n=1, use_beam_search=False, num_iters_war
mup=10, num_iters=30, trust_remote_code=False, max_model_len=None, dtype='auto', enforce_eager=False, 
kv_cache_dtype='auto', quantization_param_path=None, profile=False, profile_result_dir=None, device='a
uto', block_size=16, enable_chunked_prefill=False, enable_prefix_caching=False, use_v2_block_manager=F
alse, ray_workers_use_nsight=False, download_dir=None, output_json=None, gpu_memory_utilization=0.9, l
oad_format='auto', distributed_executor_backend=None, otlp_traces_endpoint=None)

Group Quantization Only

Avg latency: 0.8884373404396076 seconds
10% percentile latency: 0.8715801022946834 seconds
25% percentile latency: 0.8739993472117931 seconds
50% percentile latency: 0.876951577141881 seconds
75% percentile latency: 0.8830150356516242 seconds
90% percentile latency: 0.9393035409972071 seconds
99% percentile latency: 0.9404808702412992 seconds

Activation Ordering

Avg latency: 0.9159474782645702 seconds
10% percentile latency: 0.9001966264098883 seconds
25% percentile latency: 0.9010569080710411 seconds
50% percentile latency: 0.9041027296334505 seconds
75% percentile latency: 0.9064613012596965 seconds
90% percentile latency: 0.9662564094178379 seconds
99% percentile latency: 0.9761117453686893 seconds

PR Dependencies

Activation Ordering Support (neuralmagic/compressed-tensors#97)

@kylesayrs kylesayrs marked this pull request as draft August 16, 2024 22:23
@kylesayrs kylesayrs mentioned this pull request Aug 16, 2024
@kylesayrs kylesayrs changed the title Activation Ordering GPTQ Activation Ordering Aug 16, 2024
@kylesayrs kylesayrs requested a review from Satrat August 23, 2024 18:46
Copy link
Contributor

@Satrat Satrat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good with respect to the Hessian memory management. I'd still like to see an e2e test in for activation reordering that tests perplexity and reloading. You can see tests/llmcompressor/transformers/compression/test_quantization.py for an example of this. I believe it should just be a matter of adding a new recipe and config, let me know if you need help with doing that

@kylesayrs kylesayrs changed the base branch from main to gptq-cleanup August 27, 2024 20:37
@kylesayrs
Copy link
Collaborator Author

Preformed tests and got the same accuracy and latency results

Base automatically changed from gptq-cleanup to main August 28, 2024 20:19
@kylesayrs
Copy link
Collaborator Author

Using compressed_tensors main branch, I confirmed that tests/llmcompressor/modifiers, tests/llmcompressor/transformers/compression, and tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py all pass

@kylesayrs kylesayrs merged commit 6ad6e05 into main Aug 28, 2024
4 of 7 checks passed
@kylesayrs kylesayrs deleted the kylesayrs/activation-ordering branch August 28, 2024 21:18
markmc pushed a commit to markmc/llm-compressor that referenced this pull request Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants