@@ -101,7 +101,13 @@ def create_weights(self, layer: torch.nn.Module,
101101 output_partition_sizes : list [int ], input_size : int ,
102102 output_size : int , params_dtype : torch .dtype ,
103103 ** extra_weight_attrs ):
104- if input_size_per_partition % self .quant_config .group_size != 0 :
104+ # Normalize group_size
105+ if self .quant_config .group_size != - 1 :
106+ group_size = self .quant_config .group_size
107+ else :
108+ group_size = input_size
109+
110+ if input_size_per_partition % group_size != 0 :
105111 raise ValueError (
106112 "The input size is not aligned with the quantized "
107113 "weight shape. This can be caused by too large "
@@ -127,9 +133,11 @@ def create_weights(self, layer: torch.nn.Module,
127133 packed_factor = self .quant_config .pack_factor ,
128134 weight_loader = weight_loader )
129135
136+ num_groups = input_size_per_partition // group_size
137+
130138 qzeros = PackedvLLMParameter (
131139 data = torch .empty (
132- input_size_per_partition // self . quant_config . group_size ,
140+ num_groups ,
133141 output_size_per_partition // self .quant_config .pack_factor ,
134142 dtype = torch .int32 ,
135143 ),
@@ -140,7 +148,7 @@ def create_weights(self, layer: torch.nn.Module,
140148 weight_loader = weight_loader )
141149
142150 scales = GroupQuantScaleParameter (data = torch .empty (
143- input_size_per_partition // self . quant_config . group_size ,
151+ num_groups ,
144152 output_size_per_partition ,
145153 dtype = params_dtype ,
146154 ),
0 commit comments