@@ -11992,6 +11992,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
11992
11992
float sumqx[4], sumq2[4];
11993
11993
11994
11994
iq1m_scale_t s;
11995
+ const float * xx;
11995
11996
11996
11997
for (int ibl = 0; ibl < nbl; ++ibl) {
11997
11998
@@ -12126,7 +12127,6 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
12126
12127
scale = -scale;
12127
12128
best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0;
12128
12129
}
12129
- const float * xx;
12130
12130
bool all_on_grid = true;
12131
12131
for (int k = 0; k < block_size/8; ++k) {
12132
12132
if (k == 0) xx = best_k < 2 ? x_p : x_m;
@@ -12173,13 +12173,33 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
12173
12173
uint16_t * sc = (uint16_t *)y[ibl].scales;
12174
12174
float d = max_scale/15;
12175
12175
float id = 1/d;
12176
+ float sumqx_f = 0, sumq2_f = 0;
12176
12177
for (int ib = 0; ib < QK_K/block_size; ++ib) {
12177
12178
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
12178
12179
l = MAX(0, MIN(7, l));
12179
12180
sc[ib/4] |= (l << 3*(ib%4));
12180
12181
y[ibl].qh[ib] |= masks[shifts[ib]];
12182
+ const float * xb = xbl + block_size*ib;
12183
+ if (quant_weights) {
12184
+ const float * qw = quant_weights + QK_K*ibl + block_size*ib;
12185
+ for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
12186
+ } else {
12187
+ for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
12188
+ }
12189
+ for (int k = 0; k < block_size/8; ++k) {
12190
+ if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m;
12191
+ else xx = shifts[ib]%2 == 0 ? x_p : x_m;
12192
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700));
12193
+ for (int j = 0; j < 8; ++j) {
12194
+ float w = weight[8*k + j];
12195
+ float q = xx[(pg[j] - 1)/2]*(2*l+1);
12196
+ sumqx_f += w*q*xb[8*k+j];
12197
+ sumq2_f += w*q*q;
12198
+ }
12199
+ }
12181
12200
}
12182
- s.fp16 = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed.
12201
+ if (sumq2_f > 0) d = sumqx_f/sumq2_f;
12202
+ s.fp16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
12183
12203
sc[0] |= ((s.u16 & 0x000f) << 12);
12184
12204
sc[1] |= ((s.u16 & 0x00f0) << 8);
12185
12205
sc[2] |= ((s.u16 & 0x0f00) << 4);
0 commit comments