Skip to content

Commit 3e6e70d

Browse files
authored
Add enum llama_ftype, sync ggml_type to model files (#709)
1 parent 2663d2c commit 3e6e70d

File tree

5 files changed

+74
-57
lines changed

5 files changed

+74
-57
lines changed

examples/quantize/quantize.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
#include <string>
66

77
// usage:
8-
// ./llama-quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type
8+
// ./quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type
99
//
1010
int main(int argc, char ** argv) {
1111
ggml_time_init();
1212

1313
if (argc != 4) {
1414
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
15-
fprintf(stderr, " type = 2 - q4_0\n");
16-
fprintf(stderr, " type = 3 - q4_1\n");
15+
fprintf(stderr, " type = %d - q4_0\n", LLAMA_FTYPE_MOSTLY_Q4_0);
16+
fprintf(stderr, " type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1);
1717
return 1;
1818
}
1919

@@ -27,7 +27,7 @@ int main(int argc, char ** argv) {
2727
const std::string fname_inp = argv[1];
2828
const std::string fname_out = argv[2];
2929

30-
const int itype = atoi(argv[3]);
30+
const enum llama_ftype ftype = (enum llama_ftype)atoi(argv[3]);
3131

3232
const int64_t t_main_start_us = ggml_time_us();
3333

@@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
3737
{
3838
const int64_t t_start_us = ggml_time_us();
3939

40-
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) {
40+
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype)) {
4141
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
4242
return 1;
4343
}

ggml.c

+16-19
Original file line numberDiff line numberDiff line change
@@ -2560,29 +2560,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
25602560
//
25612561

25622562
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2563-
QK,
2564-
QK,
2565-
1,
2566-
1,
2567-
1,
2568-
1,
2569-
1,
2563+
[GGML_TYPE_F32] = 1,
2564+
[GGML_TYPE_F16] = 1,
2565+
[GGML_TYPE_Q4_0] = QK,
2566+
[GGML_TYPE_Q4_1] = QK,
2567+
[GGML_TYPE_I8] = 1,
2568+
[GGML_TYPE_I16] = 1,
2569+
[GGML_TYPE_I32] = 1,
25702570
};
2571-
2572-
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2571+
static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
25732572

25742573
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2575-
sizeof(block_q4_0),
2576-
sizeof(block_q4_1),
2577-
sizeof(int8_t ),
2578-
sizeof(int16_t),
2579-
sizeof(int32_t),
2580-
sizeof(ggml_fp16_t),
2581-
sizeof(float ),
2574+
[GGML_TYPE_F32] = sizeof(float),
2575+
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2576+
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2577+
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
2578+
[GGML_TYPE_I8] = sizeof(int8_t),
2579+
[GGML_TYPE_I16] = sizeof(int16_t),
2580+
[GGML_TYPE_I32] = sizeof(int32_t),
25822581
};
2583-
2584-
// don't forget to update the array above when adding new types
2585-
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
2582+
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
25862583

25872584
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
25882585
"NONE",

ggml.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,14 @@ struct ggml_object;
198198
struct ggml_context;
199199

200200
enum ggml_type {
201-
GGML_TYPE_Q4_0,
202-
GGML_TYPE_Q4_1,
201+
// explicitly numbered values are used in llama.cpp files
202+
GGML_TYPE_F32 = 0,
203+
GGML_TYPE_F16 = 1,
204+
GGML_TYPE_Q4_0 = 2,
205+
GGML_TYPE_Q4_1 = 3,
203206
GGML_TYPE_I8,
204207
GGML_TYPE_I16,
205208
GGML_TYPE_I32,
206-
GGML_TYPE_F16,
207-
GGML_TYPE_F32,
208209
GGML_TYPE_COUNT,
209210
};
210211

llama.cpp

+39-28
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct llama_hparams {
8282
uint32_t n_head = 32;
8383
uint32_t n_layer = 32;
8484
uint32_t n_rot = 64;
85-
uint32_t f16 = 1;
85+
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
8686

8787
bool operator!=(const llama_hparams & other) const {
8888
return memcmp(this, &other, sizeof(llama_hparams));
@@ -432,7 +432,7 @@ struct llama_file_loader {
432432
hparams.n_head = file.read_u32();
433433
hparams.n_layer = file.read_u32();
434434
hparams.n_rot = file.read_u32();
435-
hparams.f16 = file.read_u32();
435+
hparams.ftype = (enum llama_ftype) file.read_u32();
436436
}
437437
void read_vocab() {
438438
vocab.id_to_token.resize(hparams.n_vocab);
@@ -458,20 +458,21 @@ struct llama_file_loader {
458458
llama_load_tensor_shard shard;
459459
uint32_t n_dims = file.read_u32();
460460
uint32_t name_len = file.read_u32();
461-
uint32_t ftype = file.read_u32();
461+
shard.type = (enum ggml_type) file.read_u32();
462462
shard.ne.resize(n_dims);
463463
file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims);
464464
std::string name = file.read_string(name_len);
465465
if (n_dims < 1 || n_dims > 2) {
466466
throw format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims);
467467
}
468-
switch (ftype) {
469-
case 0: shard.type = GGML_TYPE_F32; break;
470-
case 1: shard.type = GGML_TYPE_F16; break;
471-
case 2: shard.type = GGML_TYPE_Q4_0; break;
472-
case 3: shard.type = GGML_TYPE_Q4_1; break;
468+
switch (shard.type) {
469+
case GGML_TYPE_F32:
470+
case GGML_TYPE_F16:
471+
case GGML_TYPE_Q4_0:
472+
case GGML_TYPE_Q4_1:
473+
break;
473474
default: {
474-
throw format("unrecognized ftype %u\n", ftype);
475+
throw format("unrecognized tensor type %u\n", shard.type);
475476
}
476477
}
477478

@@ -502,26 +503,26 @@ struct llama_file_loader {
502503
struct llama_file_saver {
503504
llama_file file;
504505
llama_file_loader * any_file_loader;
505-
llama_file_saver(const char * fname, llama_file_loader * any_file_loader, uint32_t new_f16)
506+
llama_file_saver(const char * fname, llama_file_loader * any_file_loader, enum llama_ftype new_ftype)
506507
: file(fname, "wb"), any_file_loader(any_file_loader) {
507508
fprintf(stderr, "llama.cpp: saving model to %s\n", fname);
508509
write_magic();
509-
write_hparams(new_f16);
510+
write_hparams(new_ftype);
510511
write_vocab();
511512
}
512513
void write_magic() {
513514
file.write_u32('ggjt'); // magic
514515
file.write_u32(1); // version
515516
}
516-
void write_hparams(uint32_t new_f16) {
517+
void write_hparams(enum llama_ftype new_ftype) {
517518
const llama_hparams & hparams = any_file_loader->hparams;
518519
file.write_u32(hparams.n_vocab);
519520
file.write_u32(hparams.n_embd);
520521
file.write_u32(hparams.n_mult);
521522
file.write_u32(hparams.n_head);
522523
file.write_u32(hparams.n_layer);
523524
file.write_u32(hparams.n_rot);
524-
file.write_u32(new_f16);
525+
file.write_u32(new_ftype);
525526
}
526527
void write_vocab() {
527528
if (any_file_loader->file_version == LLAMA_FILE_VERSION_GGML) {
@@ -536,17 +537,17 @@ struct llama_file_saver {
536537
}
537538
}
538539
void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
539-
uint32_t ftype;
540540
switch (new_type) {
541-
case GGML_TYPE_F32: ftype = 0; break;
542-
case GGML_TYPE_F16: ftype = 1; break;
543-
case GGML_TYPE_Q4_0: ftype = 2; break;
544-
case GGML_TYPE_Q4_1: ftype = 3; break;
541+
case GGML_TYPE_F32:
542+
case GGML_TYPE_F16:
543+
case GGML_TYPE_Q4_0:
544+
case GGML_TYPE_Q4_1:
545+
break;
545546
default: LLAMA_ASSERT(false);
546547
}
547548
file.write_u32((uint32_t) tensor.ne.size());
548549
file.write_u32((uint32_t) tensor.name.size());
549-
file.write_u32(ftype);
550+
file.write_u32(new_type);
550551
file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size());
551552
file.write_raw(tensor.name.data(), tensor.name.size());
552553
file.seek(-file.tell() & 31, SEEK_CUR);
@@ -820,6 +821,16 @@ static const char *llama_file_version_name(llama_file_version version) {
820821
}
821822
}
822823

824+
static const char *llama_ftype_name(enum llama_ftype ftype) {
825+
switch (ftype) {
826+
case LLAMA_FTYPE_ALL_F32: return "all F32";
827+
case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16";
828+
case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0";
829+
case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1";
830+
default: LLAMA_ASSERT(false);
831+
}
832+
}
833+
823834
static const char *llama_model_type_name(e_model type) {
824835
switch (type) {
825836
case MODEL_7B: return "7B";
@@ -872,7 +883,7 @@ static void llama_model_load_internal(
872883
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
873884
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
874885
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
875-
fprintf(stderr, "%s: f16 = %u\n", __func__, hparams.f16);
886+
fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
876887
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
877888
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
878889
fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
@@ -1544,17 +1555,17 @@ static llama_vocab::id llama_sample_top_p_top_k(
15441555
// quantization
15451556
//
15461557

1547-
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) {
1558+
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype) {
15481559
ggml_type quantized_type;
1549-
switch (itype) {
1550-
case 2: quantized_type = GGML_TYPE_Q4_0; break;
1551-
case 3: quantized_type = GGML_TYPE_Q4_1; break;
1552-
default: throw format("invalid quantization type %d\n", itype);
1560+
switch (ftype) {
1561+
case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break;
1562+
case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break;
1563+
default: throw format("invalid output file type %d\n", ftype);
15531564
};
15541565

15551566
std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp.c_str(), /*use_mmap*/ false,
15561567
/*vocab_only*/ false));
1557-
llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), (uint32_t) itype);
1568+
llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype);
15581569

15591570
size_t total_size_org = 0;
15601571
size_t total_size_new = 0;
@@ -1745,9 +1756,9 @@ void llama_free(struct llama_context * ctx) {
17451756
int llama_model_quantize(
17461757
const char * fname_inp,
17471758
const char * fname_out,
1748-
int itype) {
1759+
enum llama_ftype ftype) {
17491760
try {
1750-
llama_model_quantize_internal(fname_inp, fname_out, itype);
1761+
llama_model_quantize_internal(fname_inp, fname_out, ftype);
17511762
return 0;
17521763
} catch (const std::string & err) {
17531764
fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.c_str());

llama.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ extern "C" {
6565
void * progress_callback_user_data;
6666
};
6767

68+
// model file types
69+
enum llama_ftype {
70+
LLAMA_FTYPE_ALL_F32 = 0,
71+
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
72+
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
73+
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
74+
};
75+
6876
LLAMA_API struct llama_context_params llama_context_default_params();
6977

7078
LLAMA_API bool llama_mmap_supported();
@@ -85,7 +93,7 @@ extern "C" {
8593
LLAMA_API int llama_model_quantize(
8694
const char * fname_inp,
8795
const char * fname_out,
88-
int itype);
96+
enum llama_ftype ftype);
8997

9098
// Returns the KV cache that will contain the context for the
9199
// ongoing prediction with the model.

0 commit comments

Comments
 (0)