Skip to content

Commit bac6699

Browse files
ikawrakowKawrakow
andauthored
Quantization imrovements for k_quants (#2707)
* Improve LLaMA-2 2-, 3- and 4-bit quantization * Q3_K_S: use Q5_K for 1st 2 layers of attention.wv and feed_forward.w2 * Q4_K_S: use Q6_K for 1st 2 layers of attention.wv and feed_forward.w2 * Q2_K and Q3_K_M: use Q5_K instead of Q4_K for 1st 2 layers of attention.wv and feed_forward.w2 This leads to a slight model sized increase as follows: Q2_K : 2.684G vs 2.670G Q3_K_S: 2.775G vs 2.745G Q3_K_M: 3.071G vs 3.057G Q4_K_S: 3.592G vs 3.563G LLaMA-2 PPL for context 512 changes as follows: Q2_K : 6.6691 vs 6.8201 Q3_K_S: 6.2129 vs 6.2584 Q3_K_M: 6.0387 vs 6.1371 Q4_K_S: 5.9138 vs 6.0041 There are improvements for LLaMA-1 as well, but they are way smaller than the above. * Minor 4-bit quantization improvement For the same model size as previus commit, we get PPL = 5.9069 vs 5.9138. * Some more fine tuning * Adding make_qkx2_quants With it, we get PPL = 5.8828 for L2-7B Q4_K_S. * Another minor improvement * Q2_K improvement Smaller model, lower perplexity. 7B: file size = 2.632G, PPL = 6.3772 vs original 2.670G PPL = 6.8201 12B: file size = 5.056G, PPL = 5.4577 vs original 5.130G PPL = 5.7178 It is mostly Q3_K except for tok_embeddings, attention.wq, attention.wk, which are Q2_K * Iterating * Revert Q5_K back to make_qkx1_quants * Better Q6_K * make_qkx2_quants is better for Q5_K after all * Fix after rebasing on master * Fix for changed tensor names --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 519c981 commit bac6699

File tree

2 files changed

+130
-58
lines changed

2 files changed

+130
-58
lines changed

k_quants.c

+110-54
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
7777
}
7878
return 1/iscale;
7979
}
80+
bool return_early = false;
81+
if (rmse_type < 0) {
82+
rmse_type = -rmse_type;
83+
return_early = true;
84+
}
8085
int weight_type = rmse_type%2;
8186
float sumlx = 0;
8287
float suml2 = 0;
@@ -89,56 +94,9 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
8994
suml2 += w*l*l;
9095
}
9196
float scale = sumlx/suml2;
97+
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
9298
float best = scale * sumlx;
93-
for (int itry = 0; itry < 3; ++itry) {
94-
iscale = 1/scale;
95-
float slx = 0;
96-
float sl2 = 0;
97-
bool changed = false;
98-
for (int i = 0; i < n; ++i) {
99-
int l = nearest_int(iscale * x[i]);
100-
l = MAX(-nmax, MIN(nmax-1, l));
101-
if (l + nmax != L[i]) { changed = true; }
102-
float w = weight_type == 1 ? x[i] * x[i] : 1.f;
103-
slx += w*x[i]*l;
104-
sl2 += w*l*l;
105-
}
106-
if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
107-
for (int i = 0; i < n; ++i) {
108-
int l = nearest_int(iscale * x[i]);
109-
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
110-
}
111-
sumlx = slx; suml2 = sl2;
112-
scale = sumlx/suml2;
113-
best = scale * sumlx;
114-
}
115-
for (int itry = 0; itry < 5; ++itry) {
116-
int n_changed = 0;
117-
for (int i = 0; i < n; ++i) {
118-
float w = weight_type == 1 ? x[i]*x[i] : 1;
119-
int l = L[i] - nmax;
120-
float slx = sumlx - w*x[i]*l;
121-
if (slx > 0) {
122-
float sl2 = suml2 - w*l*l;
123-
int new_l = nearest_int(x[i] * sl2 / slx);
124-
new_l = MAX(-nmax, MIN(nmax-1, new_l));
125-
if (new_l != l) {
126-
slx += w*x[i]*new_l;
127-
sl2 += w*new_l*new_l;
128-
if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
129-
L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
130-
scale = sumlx / suml2; best = scale * sumlx;
131-
++n_changed;
132-
}
133-
}
134-
}
135-
}
136-
if (!n_changed) { break; }
137-
}
138-
if (rmse_type < 3) {
139-
return scale;
140-
}
141-
for (int is = -4; is <= 4; ++is) {
99+
for (int is = -9; is <= 9; ++is) {
142100
if (is == 0) {
143101
continue;
144102
}
@@ -221,12 +179,17 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
221179
return 1/iscale;
222180
}
223181

224-
static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
182+
static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
183+
int ntry, float alpha) {
225184
float min = x[0];
226185
float max = x[0];
186+
float sum_x = 0;
187+
float sum_x2 = 0;
227188
for (int i = 1; i < n; ++i) {
228189
if (x[i] < min) min = x[i];
229190
if (x[i] > max) max = x[i];
191+
sum_x += x[i];
192+
sum_x2 += x[i]*x[i];
230193
}
231194
if (max == min) {
232195
for (int i = 0; i < n; ++i) L[i] = 0;
@@ -254,7 +217,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
254217
for (int i = 0; i < n; ++i) {
255218
sum += x[i] - scale*L[i];
256219
}
257-
min = sum/n;
220+
min = alpha*min + (1 - alpha)*sum/n;
258221
if (min > 0) min = 0;
259222
iscale = 1/scale;
260223
if (!did_change) break;
@@ -263,6 +226,82 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
263226
return scale;
264227
}
265228

229+
static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
230+
uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
231+
float rmin, float rdelta, int nstep, bool use_mad) {
232+
float min = x[0];
233+
float max = x[0];
234+
float sum_w = weights[0];
235+
float sum_x = sum_w * x[0];
236+
for (int i = 1; i < n; ++i) {
237+
if (x[i] < min) min = x[i];
238+
if (x[i] > max) max = x[i];
239+
float w = weights[i];
240+
sum_w += w;
241+
sum_x += w * x[i];
242+
}
243+
if (min > 0) min = 0;
244+
if (max == min) {
245+
for (int i = 0; i < n; ++i) L[i] = 0;
246+
*the_min = -min;
247+
return 0.f;
248+
}
249+
float iscale = nmax/(max - min);
250+
float scale = 1/iscale;
251+
float best_mad = 0;
252+
for (int i = 0; i < n; ++i) {
253+
int l = nearest_int(iscale*(x[i] - min));
254+
L[i] = MAX(0, MIN(nmax, l));
255+
float diff = scale * L[i] + min - x[i];
256+
diff = use_mad ? fabsf(diff) : diff * diff;
257+
float w = weights[i];
258+
best_mad += w * diff;
259+
}
260+
if (nstep < 1) {
261+
*the_min = -min;
262+
return scale;
263+
}
264+
for (int is = 0; is <= nstep; ++is) {
265+
iscale = (rmin + rdelta*is + nmax)/(max - min);
266+
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
267+
for (int i = 0; i < n; ++i) {
268+
int l = nearest_int(iscale*(x[i] - min));
269+
l = MAX(0, MIN(nmax, l));
270+
Laux[i] = l;
271+
float w = weights[i];
272+
sum_l += w*l;
273+
sum_l2 += w*l*l;
274+
sum_xl += w*l*x[i];
275+
}
276+
float D = sum_w * sum_l2 - sum_l * sum_l;
277+
if (D > 0) {
278+
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
279+
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
280+
if (this_min > 0) {
281+
this_min = 0;
282+
this_scale = sum_xl / sum_l2;
283+
}
284+
float mad = 0;
285+
for (int i = 0; i < n; ++i) {
286+
float diff = this_scale * Laux[i] + this_min - x[i];
287+
diff = use_mad ? fabsf(diff) : diff * diff;
288+
float w = weights[i];
289+
mad += w * diff;
290+
}
291+
if (mad < best_mad) {
292+
for (int i = 0; i < n; ++i) {
293+
L[i] = Laux[i];
294+
}
295+
best_mad = mad;
296+
scale = this_scale;
297+
min = this_min;
298+
}
299+
}
300+
}
301+
*the_min = -min;
302+
return scale;
303+
}
304+
266305
#if QK_K == 256
267306
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
268307
if (j < 4) {
@@ -281,6 +320,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
281320
const int nb = k / QK_K;
282321

283322
uint8_t L[QK_K];
323+
uint8_t Laux[16];
324+
float weights[16];
284325
float mins[QK_K/16];
285326
float scales[QK_K/16];
286327

@@ -291,7 +332,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
291332
float max_scale = 0; // as we are deducting the min, scales are always positive
292333
float max_min = 0;
293334
for (int j = 0; j < QK_K/16; ++j) {
294-
scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
335+
for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
336+
scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
295337
float scale = scales[j];
296338
if (scale > max_scale) {
297339
max_scale = scale;
@@ -637,6 +679,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
637679
const int nb = k / QK_K;
638680

639681
uint8_t L[QK_K];
682+
uint8_t Laux[32];
683+
float weights[32];
640684
float mins[QK_K/32];
641685
float scales[QK_K/32];
642686

@@ -645,7 +689,12 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
645689
float max_scale = 0; // as we are deducting the min, scales are always positive
646690
float max_min = 0;
647691
for (int j = 0; j < QK_K/32; ++j) {
648-
scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
692+
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
693+
float sum_x2 = 0;
694+
for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
695+
float av_x = sqrtf(sum_x2/32);
696+
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
697+
scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
649698
float scale = scales[j];
650699
if (scale > max_scale) {
651700
max_scale = scale;
@@ -798,6 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
798847
uint8_t L[QK_K];
799848
float mins[QK_K/32];
800849
float scales[QK_K/32];
850+
float weights[32];
851+
uint8_t Laux[32];
801852
#else
802853
int8_t L[QK_K];
803854
float scales[QK_K/16];
@@ -810,7 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
810861
float max_scale = 0; // as we are deducting the min, scales are always positive
811862
float max_min = 0;
812863
for (int j = 0; j < QK_K/32; ++j) {
813-
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
864+
//scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
865+
float sum_x2 = 0;
866+
for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
867+
float av_x = sqrtf(sum_x2/32);
868+
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
869+
scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
814870
float scale = scales[j];
815871
if (scale > max_scale) {
816872
max_scale = scale;

llama.cpp

+20-4
Original file line numberDiff line numberDiff line change
@@ -3547,24 +3547,40 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
35473547
new_type = GGML_TYPE_Q6_K;
35483548
}
35493549
} else if (name.find("attn_v.weight") != std::string::npos) {
3550-
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
3550+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
3551+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
3552+
new_type = i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
3553+
}
35513554
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
35523555
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
35533556
use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
3557+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
35543558
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
35553559
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
35563560
++i_attention_wv;
35573561
} else if (name.find("ffn_down.weight") != std::string::npos) {
3558-
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
3562+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
3563+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
3564+
new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
3565+
}
35593566
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
35603567
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
35613568
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
3562-
//else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K;
3569+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
35633570
++i_feed_forward_w2;
35643571
} else if (name.find("attn_output.weight") != std::string::npos) {
3565-
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
3572+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
3573+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K;
35663574
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
35673575
}
3576+
else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) {
3577+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
3578+
}
3579+
// This can be used to reduce the size of the Q5_K_S model.
3580+
// The associated PPL increase is fully in line with the size reduction
3581+
//else {
3582+
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
3583+
//}
35683584
bool convert_incompatible_tensor = false;
35693585
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
35703586
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {

0 commit comments

Comments
 (0)