Skip to content

Commit 918a847

Browse files
ElizaWszoladsikka
authored andcommitted
Channelwise fix
Signed-off-by: ElizaWszola <ewszola@redhat.com>
1 parent 9610d4a commit 918a847

File tree

1 file changed

+4
-2
lines changed
  • vllm/model_executor/layers/quantization/kernels/mixed_precision

1 file changed

+4
-2
lines changed

vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,14 @@ def transform_w_s(x):
9393

9494
if c.zero_points:
9595
# TODO figure out a more efficient way to do it
96+
grouped_k = (c.partition_weight_shape[0] //
97+
c.group_size if c.group_size != -1 else 1)
9698
self._transform_param(layer, self.w_zp_name, lambda x: \
9799
marlin_zero_points(
98100
unpack_cols(x.t(), c.weight_type.size_bits,
99-
c.partition_weight_shape[0] // c.group_size,
101+
grouped_k,
100102
c.partition_weight_shape[1]),
101-
size_k=c.partition_weight_shape[0] // c.group_size,
103+
size_k=grouped_k,
102104
size_n=c.partition_weight_shape[1],
103105
num_bits=c.weight_type.size_bits))
104106
else:

0 commit comments

Comments
 (0)