Skip to content

Commit

Permalink
fix vit quant (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Dec 16, 2024
1 parent 6953a5b commit 408ba31
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 64 deletions.
43 changes: 43 additions & 0 deletions configs/quantization/methods/Awq/awq_w_only_vit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
base:
seed: &seed 42
model:
type: Vit
path: /mnt/nvme1/yongyang/models/vit-base-patch16-224
tokenizer_mode: fast
torch_dtype: auto
calib:
name: images
download: False
path: /mnt/nvme1/yongyang/general_custom_data
n_samples: 128
bs: 1
apply_chat_template: False
seed: *seed
eval:
eval_pos: [pretrain, transformed, fake_quant]
name: imagenet
type: acc
download: False
path: /mnt/nvme1/yongyang/datasets/imagenet/val
bs: 512
quant:
method: Awq
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
special:
trans: True
# The options for "trans_version" include "v1" and "v2".
trans_version: v2
weight_clip: False
clip_sym: True
save:
save_trans: False
save_fake: False
save_path: /path/to/save/
107 changes: 57 additions & 50 deletions llmc/data/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,72 +71,79 @@ def build_calib_dataset(self):
else:
raise Exception(f'Not support {self.calib_dataset_name} dataset.')
else:
if self.calib_dataset_name == 'custom_txt' or self.calib_dataset_name == 'custom_mm':
if self.calib_dataset_name == 'custom_txt' or self.calib_dataset_name == 'custom_mm' or self.calib_dataset_name == 'images': # noqa
self.calib_dataset = self.get_cutomdata(self.calib_dataset_path)
else:
self.calib_dataset = load_from_disk(self.calib_dataset_path)

def get_calib_model_inputs(self, samples):
if not self.padding:
assert not self.calib_dataset_name == 'custom_mm'
if self.calib_dataset_name == 'custom_txt':
txts = self.batch_process(
samples,
calib_or_eval='calib',
apply_chat_template=self.apply_chat_template,
return_inputs=False
)
if self.calib_dataset_name == 'images':
calib_model_inputs = self.get_batch_process(samples)
else:
txts = self.calib_dataset
preproc = PREPROC_REGISTRY[self.preproc]
preproc_param_dict = {
'calib_dataset': txts,
'tokenizer': self.tokenizer,
'n_samples': self.n_samples,
'seq_len': self.seq_len
}
if self.preproc == 'txt_general_preproc':
preproc_param_dict['key'] = self.key
samples = preproc(**preproc_param_dict)
calib_model_inputs = []
if self.calib_bs < 0:
batch = torch.cat(samples, dim=0)
calib_model_inputs.append({'input_ids': batch})
elif self.calib_bs == 1:
for i in range(len(samples)):
calib_model_inputs.append({'input_ids': samples[i]})
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
batch = torch.cat(batch, dim=0)
assert not self.calib_dataset_name == 'custom_mm'
if self.calib_dataset_name == 'custom_txt':
txts = self.batch_process(
samples,
calib_or_eval='calib',
apply_chat_template=self.apply_chat_template,
return_inputs=False
)
else:
txts = self.calib_dataset
preproc = PREPROC_REGISTRY[self.preproc]
preproc_param_dict = {
'calib_dataset': txts,
'tokenizer': self.tokenizer,
'n_samples': self.n_samples,
'seq_len': self.seq_len
}
if self.preproc == 'txt_general_preproc':
preproc_param_dict['key'] = self.key
samples = preproc(**preproc_param_dict)
calib_model_inputs = []
if self.calib_bs < 0:
batch = torch.cat(samples, dim=0)
calib_model_inputs.append({'input_ids': batch})
elif self.calib_bs == 1:
for i in range(len(samples)):
calib_model_inputs.append({'input_ids': samples[i]})
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
batch = torch.cat(batch, dim=0)
calib_model_inputs.append({'input_ids': batch})
else:
assert self.calib_dataset_name == 'custom_txt' or self.calib_dataset_name == 'custom_mm'
calib_model_inputs = []
if self.calib_bs < 0:
calib_model_inputs = self.get_batch_process(samples)
return calib_model_inputs

def get_batch_process(self, samples):
calib_model_inputs = []
if self.calib_bs < 0:
calib_model_inputs.append(
self.batch_process(
samples,
calib_or_eval='calib',
apply_chat_template=self.apply_chat_template
)
)
elif self.calib_bs == 1:
calib_model_inputs = [self.batch_process([sample], calib_or_eval='calib', apply_chat_template=self.apply_chat_template) for sample in samples] # noqa
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
calib_model_inputs.append(
self.batch_process(
samples,
batch,
calib_or_eval='calib',
apply_chat_template=self.apply_chat_template
)
)
elif self.calib_bs == 1:
calib_model_inputs = [self.batch_process([sample], calib_or_eval='calib', apply_chat_template=self.apply_chat_template) for sample in samples] # noqa
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
calib_model_inputs.append(
self.batch_process(
batch,
calib_or_eval='calib',
apply_chat_template=self.apply_chat_template
)
)
return calib_model_inputs

def get_calib_dataset(self):
Expand Down
8 changes: 4 additions & 4 deletions llmc/eval/eval_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


class AccuracyEval:
def __init__(self, config, batch_size=256, num_workers=8):
def __init__(self, config):
self.eval_config = config.eval
self.imagenet_root = self.eval_config['path']
self.batch_size = batch_size
self.num_workers = num_workers
self.bs = self.eval_config['bs']
self.num_workers = self.eval_config.get('num_workers', 8)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_imagenet(self):
Expand All @@ -23,7 +23,7 @@ def load_imagenet(self):
val_dataset = ImageFolder(root=self.imagenet_root, transform=val_transform)
val_loader = DataLoader(
val_dataset,
batch_size=self.batch_size,
batch_size=self.bs,
shuffle=False,
num_workers=self.num_workers,
collate_fn=lambda x: x,
Expand Down
21 changes: 11 additions & 10 deletions llmc/models/vit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import inspect

import torch.nn as nn
from loguru import logger
from transformers import (AutoConfig, AutoProcessor, ViTForImageClassification,
from PIL import Image
from transformers import (AutoConfig, ViTForImageClassification,
ViTImageProcessor)

from llmc.utils.registry_factory import MODEL_REGISTRY
Expand Down Expand Up @@ -72,13 +70,16 @@ def get_softmax_in_block(self, block):
def __str__(self):
return f'\nModel: \n{str(self.model)}'

def batch_process(self, imgs):
processor = AutoProcessor.from_pretrained(self.model_path)
samples = []
def batch_process(self, imgs, calib_or_eval='eval', apply_chat_template=False, return_inputs=True): # noqa
assert calib_or_eval == 'calib' or calib_or_eval == 'eval'
assert not apply_chat_template
img_data_list = []
for img in imgs:
sample = processor(images=img, return_tensors='pt')
samples.append(sample)
return samples
path = img['image']
img_data = Image.open(path)
img_data_list.append(img_data)
inputs = self.processor(images=img_data_list, return_tensors='pt')
return inputs

def get_subsets_in_block(self, block):
return [
Expand Down

0 comments on commit 408ba31

Please sign in to comment.