Skip to content

Commit

Permalink
fix missing self.quantize_config.seqlen (#298)
Browse files Browse the repository at this point in the history
Co-authored-by: LRL-ModelCloud <lrl@modelcloud.ai>
  • Loading branch information
LRL-ModelCloud and LRL-ModelCloud authored Jul 25, 2024
1 parent 102b65c commit 01e4c96
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,33 +252,34 @@ def quantize(
import torch.nn.functional as F
from torch.utils.data import DataLoader

# set the nsamples/seqlen according to the actual size of the calibration_dataset.
nsamples = len(calibration_dataset)
seqlen = max_input_id_length

@torch.no_grad()
def collate_batch(batch):
input_ids_new = []
attention_mask_new = []
for text in batch:
input_ids, attention_mask = text["input_ids"][0], text["attention_mask"][0]

input_ids = input_ids[:self.quantize_config.seqlen]
input_ids = input_ids[:seqlen]
input_ids_new.append(input_ids)

attention_mask = attention_mask[:self.quantize_config.seqlen]
attention_mask = attention_mask[:seqlen]
attention_mask_new.append(attention_mask)

if len(input_ids_new) == 0:
return None

input_ids_new = [F.pad(t, (0, self.quantize_config.seqlen - t.size(0))) for t in input_ids_new]
attention_mask_new = [F.pad(t, (0, self.quantize_config.seqlen - t.size(0))) for t in attention_mask_new]
input_ids_new = [F.pad(t, (0, seqlen - t.size(0))) for t in input_ids_new]
attention_mask_new = [F.pad(t, (0, seqlen - t.size(0))) for t in attention_mask_new]

input_ids_new = torch.vstack(input_ids_new)
attention_mask_new = torch.vstack(attention_mask_new)
res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new}
return res

# set the nsamples/seqlen according to the actual size of the calibration_dataset.
nsamples = len(calibration_dataset)
seqlen = max_input_id_length
dataloader = DataLoader(calibration_dataset, collate_fn=collate_batch, shuffle=False, batch_size=nsamples)

self.autoround = AutoRound(self.model,
Expand Down

0 comments on commit 01e4c96

Please sign in to comment.