Skip to content

Commit 86b06e0

Browse files
ikawrakowKawrakow
authored andcommitted
iq2_xxs: tune quantization (ggml-org#5320)
We get slightly better PPL, and we cut quantization time in nearly half. The trick is to 1st quantize without forcing points onto the E8-lattice. We can then use a narrower search range around the block scale that we got that way. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 35b93a6 commit 86b06e0

File tree

1 file changed

+6
-52
lines changed

1 file changed

+6
-52
lines changed

ggml-quants.c

+6-52
Original file line numberDiff line numberDiff line change
@@ -9048,8 +9048,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
90489048
int8_t L[32];
90499049
int8_t Laux[32];
90509050
float waux[32];
9051-
bool is_on_grid[4];
9052-
bool is_on_grid_aux[4];
90539051
uint8_t block_signs[4];
90549052
uint32_t q2[2*(QK_K/32)];
90559053

@@ -9099,10 +9097,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
90999097
memset(L, 0, 32);
91009098
continue;
91019099
}
9100+
float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
9101+
float eff_max = scale*kMaxQ;
91029102
float best = 0;
9103-
float scale = max/(2*kMaxQ-1);
9104-
for (int is = -9; is <= 9; ++is) {
9105-
float id = (2*kMaxQ-1+is*0.1f)/max;
9103+
for (int is = -6; is <= 6; ++is) {
9104+
float id = (2*kMaxQ-1+is*0.1f)/eff_max;
91069105
float this_scale = 1/id;
91079106
for (int k = 0; k < 4; ++k) {
91089107
for (int i = 0; i < 8; ++i) {
@@ -9112,9 +9111,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91129111
uint16_t u = 0;
91139112
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
91149113
int grid_index = kmap_q2xs[u];
9115-
is_on_grid_aux[k] = true;
91169114
if (grid_index < 0) {
9117-
is_on_grid_aux[k] = false;
91189115
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
91199116
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
91209117
}
@@ -9128,16 +9125,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91289125
}
91299126
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
91309127
scale = sumqx/sumq2; best = scale*sumqx;
9131-
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9132-
for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
9128+
memcpy(L, Laux, 32);
91339129
}
91349130
}
9135-
int n_not_ongrid = 0;
9136-
for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9137-
if (n_not_ongrid > 0 && scale > 0) {
9131+
if (scale > 0) {
91389132
float id = 1/scale;
91399133
for (int k = 0; k < 4; ++k) {
9140-
if (is_on_grid[k]) continue;
91419134
uint16_t u = 0;
91429135
for (int i = 0; i < 8; ++i) {
91439136
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
@@ -9193,49 +9186,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91939186
float d = max_scale/31;
91949187
y[ibl].d = GGML_FP32_TO_FP16(d);
91959188
float id = 1/d;
9196-
float sumqx = 0, sumq2 = 0;
91979189
for (int ib = 0; ib < QK_K/32; ++ib) {
91989190
int l = nearest_int(0.5f*(id*scales[ib]-1));
91999191
l = MAX(0, MIN(15, l));
92009192
q2[2*ib+1] |= ((uint32_t)l << 28);
9201-
const float * xb = xbl + 32*ib;
9202-
const float * qw = quant_weights + QK_K*ibl + 32*ib;
9203-
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9204-
const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
9205-
const float db = d * (1 + 2*l);
9206-
uint32_t u = 0;
9207-
for (int k = 0; k < 4; ++k) {
9208-
const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
9209-
const float * xk = xb + 8*k;
9210-
const float * wk = weight + 8*k;
9211-
const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9212-
float best_mse = 0; int best_index = aux8[k];
9213-
for (int j = 0; j < 8; ++j) {
9214-
float diff = db * grid[j] * signs[j] - xk[j];
9215-
best_mse += wk[j] * diff * diff;
9216-
}
9217-
for (int idx = 0; idx < 256; ++idx) {
9218-
grid = (const uint8_t *)(kgrid_q2xs + idx);
9219-
float mse = 0;
9220-
for (int j = 0; j < 8; ++j) {
9221-
float diff = db * grid[j] * signs[j] - xk[j];
9222-
mse += wk[j] * diff * diff;
9223-
}
9224-
if (mse < best_mse) {
9225-
best_mse = mse; best_index = idx;
9226-
}
9227-
}
9228-
u |= (best_index << 8*k);
9229-
grid = (const uint8_t *)(kgrid_q2xs + best_index);
9230-
//grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9231-
for (int j = 0; j < 8; ++j) {
9232-
float q = db * grid[j] * signs[j];
9233-
sumqx += wk[j] * q * xk[j];
9234-
sumq2 += wk[j] * q * q;
9235-
}
9236-
}
9237-
q2[2*ib] = u;
9238-
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
92399193
}
92409194
memcpy(y[ibl].qs, q2, QK_K/4);
92419195
}

0 commit comments

Comments
 (0)