@@ -77,6 +77,11 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
77
77
}
78
78
return 1 /iscale ;
79
79
}
80
+ bool return_early = false;
81
+ if (rmse_type < 0 ) {
82
+ rmse_type = - rmse_type ;
83
+ return_early = true;
84
+ }
80
85
int weight_type = rmse_type %2 ;
81
86
float sumlx = 0 ;
82
87
float suml2 = 0 ;
@@ -89,56 +94,9 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
89
94
suml2 += w * l * l ;
90
95
}
91
96
float scale = sumlx /suml2 ;
97
+ if (return_early ) return suml2 > 0 ? 0.5f * (scale + 1 /iscale ) : 1 /iscale ;
92
98
float best = scale * sumlx ;
93
- for (int itry = 0 ; itry < 3 ; ++ itry ) {
94
- iscale = 1 /scale ;
95
- float slx = 0 ;
96
- float sl2 = 0 ;
97
- bool changed = false;
98
- for (int i = 0 ; i < n ; ++ i ) {
99
- int l = nearest_int (iscale * x [i ]);
100
- l = MAX (- nmax , MIN (nmax - 1 , l ));
101
- if (l + nmax != L [i ]) { changed = true; }
102
- float w = weight_type == 1 ? x [i ] * x [i ] : 1.f ;
103
- slx += w * x [i ]* l ;
104
- sl2 += w * l * l ;
105
- }
106
- if (!changed || sl2 == 0 || slx * slx <= best * sl2 ) { break ; }
107
- for (int i = 0 ; i < n ; ++ i ) {
108
- int l = nearest_int (iscale * x [i ]);
109
- L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
110
- }
111
- sumlx = slx ; suml2 = sl2 ;
112
- scale = sumlx /suml2 ;
113
- best = scale * sumlx ;
114
- }
115
- for (int itry = 0 ; itry < 5 ; ++ itry ) {
116
- int n_changed = 0 ;
117
- for (int i = 0 ; i < n ; ++ i ) {
118
- float w = weight_type == 1 ? x [i ]* x [i ] : 1 ;
119
- int l = L [i ] - nmax ;
120
- float slx = sumlx - w * x [i ]* l ;
121
- if (slx > 0 ) {
122
- float sl2 = suml2 - w * l * l ;
123
- int new_l = nearest_int (x [i ] * sl2 / slx );
124
- new_l = MAX (- nmax , MIN (nmax - 1 , new_l ));
125
- if (new_l != l ) {
126
- slx += w * x [i ]* new_l ;
127
- sl2 += w * new_l * new_l ;
128
- if (sl2 > 0 && slx * slx * suml2 > sumlx * sumlx * sl2 ) {
129
- L [i ] = nmax + new_l ; sumlx = slx ; suml2 = sl2 ;
130
- scale = sumlx / suml2 ; best = scale * sumlx ;
131
- ++ n_changed ;
132
- }
133
- }
134
- }
135
- }
136
- if (!n_changed ) { break ; }
137
- }
138
- if (rmse_type < 3 ) {
139
- return scale ;
140
- }
141
- for (int is = -4 ; is <= 4 ; ++ is ) {
99
+ for (int is = -9 ; is <= 9 ; ++ is ) {
142
100
if (is == 0 ) {
143
101
continue ;
144
102
}
@@ -221,12 +179,17 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
221
179
return 1 /iscale ;
222
180
}
223
181
224
- static float make_qkx1_quants (int n , int nmax , const float * restrict x , uint8_t * restrict L , float * restrict the_min , int ntry ) {
182
+ static float make_qkx1_quants (int n , int nmax , const float * restrict x , uint8_t * restrict L , float * restrict the_min ,
183
+ int ntry , float alpha ) {
225
184
float min = x [0 ];
226
185
float max = x [0 ];
186
+ float sum_x = 0 ;
187
+ float sum_x2 = 0 ;
227
188
for (int i = 1 ; i < n ; ++ i ) {
228
189
if (x [i ] < min ) min = x [i ];
229
190
if (x [i ] > max ) max = x [i ];
191
+ sum_x += x [i ];
192
+ sum_x2 += x [i ]* x [i ];
230
193
}
231
194
if (max == min ) {
232
195
for (int i = 0 ; i < n ; ++ i ) L [i ] = 0 ;
@@ -254,7 +217,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
254
217
for (int i = 0 ; i < n ; ++ i ) {
255
218
sum += x [i ] - scale * L [i ];
256
219
}
257
- min = sum /n ;
220
+ min = alpha * min + ( 1 - alpha ) * sum /n ;
258
221
if (min > 0 ) min = 0 ;
259
222
iscale = 1 /scale ;
260
223
if (!did_change ) break ;
@@ -263,6 +226,82 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
263
226
return scale ;
264
227
}
265
228
229
+ static float make_qkx2_quants (int n , int nmax , const float * restrict x , const float * restrict weights ,
230
+ uint8_t * restrict L , float * restrict the_min , uint8_t * restrict Laux ,
231
+ float rmin , float rdelta , int nstep , bool use_mad ) {
232
+ float min = x [0 ];
233
+ float max = x [0 ];
234
+ float sum_w = weights [0 ];
235
+ float sum_x = sum_w * x [0 ];
236
+ for (int i = 1 ; i < n ; ++ i ) {
237
+ if (x [i ] < min ) min = x [i ];
238
+ if (x [i ] > max ) max = x [i ];
239
+ float w = weights [i ];
240
+ sum_w += w ;
241
+ sum_x += w * x [i ];
242
+ }
243
+ if (min > 0 ) min = 0 ;
244
+ if (max == min ) {
245
+ for (int i = 0 ; i < n ; ++ i ) L [i ] = 0 ;
246
+ * the_min = - min ;
247
+ return 0.f ;
248
+ }
249
+ float iscale = nmax /(max - min );
250
+ float scale = 1 /iscale ;
251
+ float best_mad = 0 ;
252
+ for (int i = 0 ; i < n ; ++ i ) {
253
+ int l = nearest_int (iscale * (x [i ] - min ));
254
+ L [i ] = MAX (0 , MIN (nmax , l ));
255
+ float diff = scale * L [i ] + min - x [i ];
256
+ diff = use_mad ? fabsf (diff ) : diff * diff ;
257
+ float w = weights [i ];
258
+ best_mad += w * diff ;
259
+ }
260
+ if (nstep < 1 ) {
261
+ * the_min = - min ;
262
+ return scale ;
263
+ }
264
+ for (int is = 0 ; is <= nstep ; ++ is ) {
265
+ iscale = (rmin + rdelta * is + nmax )/(max - min );
266
+ float sum_l = 0 , sum_l2 = 0 , sum_xl = 0 ;
267
+ for (int i = 0 ; i < n ; ++ i ) {
268
+ int l = nearest_int (iscale * (x [i ] - min ));
269
+ l = MAX (0 , MIN (nmax , l ));
270
+ Laux [i ] = l ;
271
+ float w = weights [i ];
272
+ sum_l += w * l ;
273
+ sum_l2 += w * l * l ;
274
+ sum_xl += w * l * x [i ];
275
+ }
276
+ float D = sum_w * sum_l2 - sum_l * sum_l ;
277
+ if (D > 0 ) {
278
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l )/D ;
279
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl )/D ;
280
+ if (this_min > 0 ) {
281
+ this_min = 0 ;
282
+ this_scale = sum_xl / sum_l2 ;
283
+ }
284
+ float mad = 0 ;
285
+ for (int i = 0 ; i < n ; ++ i ) {
286
+ float diff = this_scale * Laux [i ] + this_min - x [i ];
287
+ diff = use_mad ? fabsf (diff ) : diff * diff ;
288
+ float w = weights [i ];
289
+ mad += w * diff ;
290
+ }
291
+ if (mad < best_mad ) {
292
+ for (int i = 0 ; i < n ; ++ i ) {
293
+ L [i ] = Laux [i ];
294
+ }
295
+ best_mad = mad ;
296
+ scale = this_scale ;
297
+ min = this_min ;
298
+ }
299
+ }
300
+ }
301
+ * the_min = - min ;
302
+ return scale ;
303
+ }
304
+
266
305
#if QK_K == 256
267
306
static inline void get_scale_min_k4 (int j , const uint8_t * restrict q , uint8_t * restrict d , uint8_t * restrict m ) {
268
307
if (j < 4 ) {
@@ -281,6 +320,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
281
320
const int nb = k / QK_K ;
282
321
283
322
uint8_t L [QK_K ];
323
+ uint8_t Laux [16 ];
324
+ float weights [16 ];
284
325
float mins [QK_K /16 ];
285
326
float scales [QK_K /16 ];
286
327
@@ -291,7 +332,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
291
332
float max_scale = 0 ; // as we are deducting the min, scales are always positive
292
333
float max_min = 0 ;
293
334
for (int j = 0 ; j < QK_K /16 ; ++ j ) {
294
- scales [j ] = make_qkx1_quants (16 , 3 , x + 16 * j , L + 16 * j , & mins [j ], 5 );
335
+ for (int l = 0 ; l < 16 ; ++ l ) weights [l ] = fabsf (x [16 * j + l ]);
336
+ scales [j ] = make_qkx2_quants (16 , 3 , x + 16 * j , weights , L + 16 * j , & mins [j ], Laux , -0.5f , 0.1f , 15 , true);
295
337
float scale = scales [j ];
296
338
if (scale > max_scale ) {
297
339
max_scale = scale ;
@@ -637,6 +679,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
637
679
const int nb = k / QK_K ;
638
680
639
681
uint8_t L [QK_K ];
682
+ uint8_t Laux [32 ];
683
+ float weights [32 ];
640
684
float mins [QK_K /32 ];
641
685
float scales [QK_K /32 ];
642
686
@@ -645,7 +689,12 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
645
689
float max_scale = 0 ; // as we are deducting the min, scales are always positive
646
690
float max_min = 0 ;
647
691
for (int j = 0 ; j < QK_K /32 ; ++ j ) {
648
- scales [j ] = make_qkx1_quants (32 , 15 , x + 32 * j , L + 32 * j , & mins [j ], 5 );
692
+ //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
693
+ float sum_x2 = 0 ;
694
+ for (int l = 0 ; l < 32 ; ++ l ) sum_x2 += x [32 * j + l ] * x [32 * j + l ];
695
+ float av_x = sqrtf (sum_x2 /32 );
696
+ for (int l = 0 ; l < 32 ; ++ l ) weights [l ] = av_x + fabsf (x [32 * j + l ]);
697
+ scales [j ] = make_qkx2_quants (32 , 15 , x + 32 * j , weights , L + 32 * j , & mins [j ], Laux , -1.f , 0.1f , 20 , false);
649
698
float scale = scales [j ];
650
699
if (scale > max_scale ) {
651
700
max_scale = scale ;
@@ -798,6 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
798
847
uint8_t L [QK_K ];
799
848
float mins [QK_K /32 ];
800
849
float scales [QK_K /32 ];
850
+ float weights [32 ];
851
+ uint8_t Laux [32 ];
801
852
#else
802
853
int8_t L [QK_K ];
803
854
float scales [QK_K /16 ];
@@ -810,7 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
810
861
float max_scale = 0 ; // as we are deducting the min, scales are always positive
811
862
float max_min = 0 ;
812
863
for (int j = 0 ; j < QK_K /32 ; ++ j ) {
813
- scales [j ] = make_qkx1_quants (32 , 31 , x + 32 * j , L + 32 * j , & mins [j ], 5 );
864
+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
865
+ float sum_x2 = 0 ;
866
+ for (int l = 0 ; l < 32 ; ++ l ) sum_x2 += x [32 * j + l ] * x [32 * j + l ];
867
+ float av_x = sqrtf (sum_x2 /32 );
868
+ for (int l = 0 ; l < 32 ; ++ l ) weights [l ] = av_x + fabsf (x [32 * j + l ]);
869
+ scales [j ] = make_qkx2_quants (32 , 31 , x + 32 * j , weights , L + 32 * j , & mins [j ], Laux , -0.5f , 0.1f , 15 , false);
814
870
float scale = scales [j ];
815
871
if (scale > max_scale ) {
816
872
max_scale = scale ;
0 commit comments