Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Example PTQ flow with quant lifecycle #10

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions examples/llama_1.1b/compressed_tensors_ptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tqdm import tqdm

from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
QuantizationConfig,
QuantizationStatus,
)
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from torch.utils.data import DataLoader


config_file = "example_quant_config.json"
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
dataset_name = "open_platypus"
split = "train"
num_calibration_samples = 10
max_seq_length = 1024
pad_to_max_length = False


model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval() # no grad or updates needed for base model
config = QuantizationConfig.parse_file(config_file)

# set status to calibration
config.quantization_status = QuantizationStatus.CALIBRATION

# initialize quantization
apply_quantization_config(model, config)

# create dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
dataset=dataset_name,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=split,
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator()
)

# run calibration
for idx, sample in tqdm(enumerate(data_loader)):
_ = model(**sample)

if idx >= num_calibration_samples:
break

# freeze params after calibration
model.apply(freeze_module_quantization)

# TODO: save
12 changes: 3 additions & 9 deletions examples/llama_1.1b/example_quant_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "frozen",
"global_compression_ratio": null,
"config_groups": {
"group_1": {
Expand All @@ -14,7 +13,7 @@
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": true,
"symmetric": false,
"strategy": "tensor"
},
"targets": ["Linear"]
Expand All @@ -23,17 +22,12 @@
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": false,
"symmetric": true,
"strategy": "tensor"
},
"input_activations": null,
"targets": ["Embedding"]
}
},
"ignore": [
"LlamaRotaryEmbedding", "LlamaRMSNorm", "SiLUActivation",
"model.layers.1.mlp.down_proj", "MatMulLeftInput_QK", "MatMulRightInput_QK",
"MatMulOutput_QK", "MatMulLeftInput_PV", "MatMulRightInput_PV",
"MatMulOutput_PV"
]
"ignore": ["LlamaRotaryEmbedding"]
}
Loading