Skip to content

Commit

Permalink
fix dataset load and calibration
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin committed Apr 16, 2024
1 parent 362bc91 commit 2badee9
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions examples/llama_1.1b/compressed_tensors_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
)
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.base import TextGenerationDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from torch.utils.data import DataLoader


config_file = "example_quant_config.json"
Expand All @@ -48,7 +49,6 @@
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_args = DataTrainingArguments(
dataset=dataset_name,
dataset_config_name="main",
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
)
Expand All @@ -61,10 +61,16 @@
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset, batch_size=1, collate_fn=DefaultDataCollator()
)

# run calibration
for _ in tqdm(num_calibration_samples(10)):
_ = model(**tokenizer("", return_tensors="pt"))
for idx, sample in tqdm(enumerate(data_loader)):
_ = model(**sample)

if idx >= num_calibration_samples:
break

# freeze params after calibration
model.apply(freeze_module_quantization)
Expand Down

0 comments on commit 2badee9

Please sign in to comment.