Skip to content

Commit

Permalink
fix(qbits): correct stride when packing
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Jul 25, 2024
1 parent 67490c9 commit 1fb0aac
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions optimum/quanto/tensor/qbits/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,14 @@ def dequantize(self):

@staticmethod
def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride):
data_size = size if group_size is None else grouped_shape(size, axis, group_size)
# In row major, inner dimension (stride 1) is the last one
data_stride = (data_size[1], 1)
if group_size is None:
data_size = size
data_stride = stride
else:
data_size = grouped_shape(size, axis, group_size)
assert len(data_size) == 2
# In row major, inner dimension (stride 1) is the last one
data_stride = (data_size[1], 1)
inner_tensors_dict = {
"_data": PackedTensor.load_from_state_dict(
state_dict, prefix + "_data.", qtype.bits, data_size, data_stride
Expand Down

0 comments on commit 1fb0aac

Please sign in to comment.