@@ -9048,8 +9048,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9048
9048
int8_t L [32 ];
9049
9049
int8_t Laux [32 ];
9050
9050
float waux [32 ];
9051
- bool is_on_grid [4 ];
9052
- bool is_on_grid_aux [4 ];
9053
9051
uint8_t block_signs [4 ];
9054
9052
uint32_t q2 [2 * (QK_K /32 )];
9055
9053
@@ -9099,10 +9097,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9099
9097
memset (L , 0 , 32 );
9100
9098
continue ;
9101
9099
}
9100
+ float scale = make_qp_quants (32 , kMaxQ + 1 , xval , (uint8_t * )L , weight );
9101
+ float eff_max = scale * kMaxQ ;
9102
9102
float best = 0 ;
9103
- float scale = max /(2 * kMaxQ - 1 );
9104
- for (int is = -9 ; is <= 9 ; ++ is ) {
9105
- float id = (2 * kMaxQ - 1 + is * 0.1f )/max ;
9103
+ for (int is = -6 ; is <= 6 ; ++ is ) {
9104
+ float id = (2 * kMaxQ - 1 + is * 0.1f )/eff_max ;
9106
9105
float this_scale = 1 /id ;
9107
9106
for (int k = 0 ; k < 4 ; ++ k ) {
9108
9107
for (int i = 0 ; i < 8 ; ++ i ) {
@@ -9112,9 +9111,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9112
9111
uint16_t u = 0 ;
9113
9112
for (int i = 0 ; i < 8 ; ++ i ) u |= (Laux [8 * k + i ] << 2 * i );
9114
9113
int grid_index = kmap_q2xs [u ];
9115
- is_on_grid_aux [k ] = true;
9116
9114
if (grid_index < 0 ) {
9117
- is_on_grid_aux [k ] = false;
9118
9115
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs [u ] - 1 ;
9119
9116
grid_index = iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval + 8 * k , waux + 8 * k , this_scale , Laux + 8 * k );
9120
9117
}
@@ -9128,16 +9125,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9128
9125
}
9129
9126
if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
9130
9127
scale = sumqx /sumq2 ; best = scale * sumqx ;
9131
- for (int i = 0 ; i < 32 ; ++ i ) L [i ] = Laux [i ];
9132
- for (int k = 0 ; k < 4 ; ++ k ) is_on_grid [k ] = is_on_grid_aux [k ];
9128
+ memcpy (L , Laux , 32 );
9133
9129
}
9134
9130
}
9135
- int n_not_ongrid = 0 ;
9136
- for (int k = 0 ; k < 4 ; ++ k ) if (!is_on_grid [k ]) ++ n_not_ongrid ;
9137
- if (n_not_ongrid > 0 && scale > 0 ) {
9131
+ if (scale > 0 ) {
9138
9132
float id = 1 /scale ;
9139
9133
for (int k = 0 ; k < 4 ; ++ k ) {
9140
- if (is_on_grid [k ]) continue ;
9141
9134
uint16_t u = 0 ;
9142
9135
for (int i = 0 ; i < 8 ; ++ i ) {
9143
9136
int l = nearest_int (0.5f * (id * xval [8 * k + i ]- 1 ));
@@ -9193,49 +9186,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9193
9186
float d = max_scale /31 ;
9194
9187
y [ibl ].d = GGML_FP32_TO_FP16 (d );
9195
9188
float id = 1 /d ;
9196
- float sumqx = 0 , sumq2 = 0 ;
9197
9189
for (int ib = 0 ; ib < QK_K /32 ; ++ ib ) {
9198
9190
int l = nearest_int (0.5f * (id * scales [ib ]- 1 ));
9199
9191
l = MAX (0 , MIN (15 , l ));
9200
9192
q2 [2 * ib + 1 ] |= ((uint32_t )l << 28 );
9201
- const float * xb = xbl + 32 * ib ;
9202
- const float * qw = quant_weights + QK_K * ibl + 32 * ib ;
9203
- for (int i = 0 ; i < 32 ; ++ i ) weight [i ] = qw [i ] * sqrtf (sigma2 + xb [i ]* xb [i ]);
9204
- const uint8_t * aux8 = (const uint8_t * )(q2 + 2 * ib );
9205
- const float db = d * (1 + 2 * l );
9206
- uint32_t u = 0 ;
9207
- for (int k = 0 ; k < 4 ; ++ k ) {
9208
- const int8_t * signs = keven_signs_q2xs + 8 * ((q2 [2 * ib + 1 ] >> 7 * k ) & 127 );
9209
- const float * xk = xb + 8 * k ;
9210
- const float * wk = weight + 8 * k ;
9211
- const uint8_t * grid = (const uint8_t * )(kgrid_q2xs + aux8 [k ]);
9212
- float best_mse = 0 ; int best_index = aux8 [k ];
9213
- for (int j = 0 ; j < 8 ; ++ j ) {
9214
- float diff = db * grid [j ] * signs [j ] - xk [j ];
9215
- best_mse += wk [j ] * diff * diff ;
9216
- }
9217
- for (int idx = 0 ; idx < 256 ; ++ idx ) {
9218
- grid = (const uint8_t * )(kgrid_q2xs + idx );
9219
- float mse = 0 ;
9220
- for (int j = 0 ; j < 8 ; ++ j ) {
9221
- float diff = db * grid [j ] * signs [j ] - xk [j ];
9222
- mse += wk [j ] * diff * diff ;
9223
- }
9224
- if (mse < best_mse ) {
9225
- best_mse = mse ; best_index = idx ;
9226
- }
9227
- }
9228
- u |= (best_index << 8 * k );
9229
- grid = (const uint8_t * )(kgrid_q2xs + best_index );
9230
- //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9231
- for (int j = 0 ; j < 8 ; ++ j ) {
9232
- float q = db * grid [j ] * signs [j ];
9233
- sumqx += wk [j ] * q * xk [j ];
9234
- sumq2 += wk [j ] * q * q ;
9235
- }
9236
- }
9237
- q2 [2 * ib ] = u ;
9238
- if (sumq2 > 0 ) y [ibl ].d = GGML_FP32_TO_FP16 (d * sumqx /sumq2 );
9239
9193
}
9240
9194
memcpy (y [ibl ].qs , q2 , QK_K /4 );
9241
9195
}
0 commit comments