diff --git a/CMakeLists.txt b/CMakeLists.txt index d88c5b101..54d18b003 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF) # sanitizers if (GGML_SANITIZE_THREAD) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread") endif() diff --git a/examples/gpt-2/CMakeLists.txt b/examples/gpt-2/CMakeLists.txt index 9960cfe81..3b7ab5efe 100644 --- a/examples/gpt-2/CMakeLists.txt +++ b/examples/gpt-2/CMakeLists.txt @@ -4,3 +4,10 @@ set(TEST_TARGET gpt-2) add_executable(${TEST_TARGET} main.cpp) target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) + +# +# gpt-2-quantize + +set(TEST_TARGET gpt-2-quantize) +add_executable(${TEST_TARGET} quantize.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) diff --git a/examples/gpt-2/README.md b/examples/gpt-2/README.md index 60fea55dc..ed61f861f 100644 --- a/examples/gpt-2/README.md +++ b/examples/gpt-2/README.md @@ -94,7 +94,7 @@ Done! Model '117M' saved in 'models/gpt-2-117M/' Run the convert-ckpt-to-ggml.py script to convert the model to ggml format. - python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/ + python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/ 1 ``` diff --git a/examples/gpt-2/convert-ckpt-to-ggml.py b/examples/gpt-2/convert-ckpt-to-ggml.py index 7ae438013..60cd963d2 100644 --- a/examples/gpt-2/convert-ckpt-to-ggml.py +++ b/examples/gpt-2/convert-ckpt-to-ggml.py @@ -45,8 +45,18 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) -if len(sys.argv) < 2: - print("Usage: convert-ckpt-to-ggml.py dir-model [use-f32]\n") +# helper method to convert a numpy array to different float types +def convert_to_ftype(data, ftype): + # fp16 + if ftype == 1: + return data.astype(np.float16) + + assert False, "Invalid ftype: " + str(ftype) + +if len(sys.argv) < 3: + print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") sys.exit(1) # output in the same directory as the model @@ -59,11 +69,20 @@ def bytes_to_unicode(): with open(dir_model + "/hparams.json", "r") as f: hparams = json.load(f) -# use 16-bit or 32-bit floats -use_f16 = True +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 if len(sys.argv) > 2: - use_f16 = False - fname_out = sys.argv[1] + "/ggml-model-f32.bin" + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + sys.exit(1) + fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" list_vars = tf.train.list_variables(dir_model) @@ -75,7 +94,7 @@ def bytes_to_unicode(): fout.write(struct.pack("i", hparams["n_embd"])) fout.write(struct.pack("i", hparams["n_head"])) fout.write(struct.pack("i", hparams["n_layer"])) -fout.write(struct.pack("i", use_f16)) +fout.write(struct.pack("i", ftype)) byte_encoder = bytes_to_unicode() byte_decoder = {v:k for k, v in byte_encoder.items()} @@ -93,9 +112,22 @@ def bytes_to_unicode(): data = tf.train.load_variable(dir_model, name).squeeze() n_dims = len(data.shape); - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype = 0; - if use_f16: + # for efficiency - transpose the projection matrices + # "model/h.*/attn/c_attn/w" + # "model/h.*/attn/c_proj/w" + # "model/h.*/mlp/c_fc/w" + # "model/h.*/mlp/c_proj/w" + if name[-14:] == "/attn/c_attn/w" or \ + name[-14:] == "/attn/c_proj/w" or \ + name[-11:] == "/mlp/c_fc/w" or \ + name[-13:] == "/mlp/c_proj/w": + print(" Transposing") + data = data.transpose() + + dshape = data.shape + + ftype_cur = 0 + if ftype != 0: # match name: # "model/wte" # "model/h.*/attn/c_attn/w" @@ -103,24 +135,19 @@ def bytes_to_unicode(): # "model/h.*/mlp/c_fc/w" # "model/h.*/mlp/c_proj/w" if name == "model/wte" or name[-2:] == "/w": - print(" Converting to float16") - data = data.astype(np.float16) - ftype = 1 + print(" Converting to " + ftype_str[ftype]) + data = convert_to_ftype(data, ftype) + ftype_cur = ftype else: print(" Converting to float32") data = data.astype(np.float32) - ftype = 0 - - # for efficiency - transpose the projection matrices - if name[-13:] == "/mlp/c_proj/w": - print(" Transposing") - data = data.transpose() + ftype_cur = 0 # header str = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(str), ftype)) + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) for i in range(n_dims): - fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(struct.pack("i", dshape[n_dims - 1 - i])) fout.write(str); # data diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 5a0ab01fe..f371bc667 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -42,7 +42,7 @@ struct gpt2_layer { struct ggml_tensor * c_mlp_fc_w; struct ggml_tensor * c_mlp_fc_b; - struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency + struct ggml_tensor * c_mlp_proj_w; struct ggml_tensor * c_mlp_proj_b; }; @@ -130,9 +130,23 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & } } - // for the big tensors, we have the option to store the data in 16-bit floats + // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + ggml_type wtype = GGML_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + case 2: wtype = GGML_TYPE_Q4_0; break; + case 3: wtype = GGML_TYPE_Q4_1; break; + default: + { + fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", + __func__, fname.c_str(), model.hparams.f16); + return false; + } + } + + const ggml_type wtype2 = GGML_TYPE_F32; auto & ctx = model.ctx; @@ -146,32 +160,32 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b - ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte - ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe + ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte + ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b - ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w - ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b + ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w + ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w - ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w - ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w - ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v ctx_size += (6 + 12*n_layer)*256; // object overhead @@ -219,23 +233,23 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & for (int i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; - layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd); - layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd); + layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd); + layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd); - layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); - layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); + layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); + layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); - layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); - layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); + layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // map by name model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g; @@ -253,7 +267,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b; - model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans; + model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b; } } @@ -321,9 +335,26 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & return false; } - const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); + } - if (nelements*bpe != ggml_nbytes(tensor)) { + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; @@ -331,7 +362,6 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); - //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); } @@ -433,7 +463,7 @@ bool gpt2_eval( // [2304, N] { cur = ggml_mul_mat(ctx0, - ggml_transpose(ctx0, model.layers[il].c_attn_attn_w), + model.layers[il].c_attn_attn_w, cur); cur = ggml_add(ctx0, @@ -509,11 +539,13 @@ bool gpt2_eval( // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // [n_past + N, 64, 12] struct ggml_tensor * V_trans = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), - n_embd/n_head, n_head, n_past + N), - 1, 2, 0, 3); + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + n_embd/n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head)); // KQV = transpose(V) * KQ_soft_max // [64, N, 12] @@ -540,7 +572,7 @@ bool gpt2_eval( // [768, N] { cur = ggml_mul_mat(ctx0, - ggml_transpose(ctx0, model.layers[il].c_attn_proj_w), + model.layers[il].c_attn_proj_w, cur); cur = ggml_add(ctx0, @@ -577,7 +609,7 @@ bool gpt2_eval( // cur = fc_w*cur + fc_b // [3072, N] cur = ggml_mul_mat(ctx0, - ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w), + model.layers[il].c_mlp_fc_w, cur); cur = ggml_add(ctx0, @@ -597,7 +629,7 @@ bool gpt2_eval( // cur = proj_w*cur + proj_b // [768, N] cur = ggml_mul_mat(ctx0, - model.layers[il].c_mlp_proj_w_trans, + model.layers[il].c_mlp_proj_w, cur); cur = ggml_add(ctx0, @@ -714,8 +746,12 @@ int main(int argc, char ** argv) { params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); - printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); - printf("\n"); + printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size()); + for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) { + printf("%d ", embd_inp[i]); + } + printf("\n\n"); // submit the input prompt token-by-token // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning diff --git a/examples/gpt-2/quantize.cpp b/examples/gpt-2/quantize.cpp new file mode 100644 index 000000000..3cc48ea39 --- /dev/null +++ b/examples/gpt-2/quantize.cpp @@ -0,0 +1,322 @@ +#include "ggml/ggml.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// default hparams (GPT-2 117M) +struct gpt2_hparams { + int32_t n_vocab = 50257; + int32_t n_ctx = 1024; + int32_t n_embd = 768; + int32_t n_head = 12; + int32_t n_layer = 12; + int32_t f16 = 1; +}; + +// quantize a model +bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) { + ggml_type type = GGML_TYPE_Q4_1; + + switch (itype) { + case 2: type = GGML_TYPE_Q4_0; break; + case 3: type = GGML_TYPE_Q4_1; break; + default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1; + }; + + if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) { + fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type); + return false; + } + + gpt_vocab vocab; + + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); + return false; + } + + fout.write((char *) &magic, sizeof(magic)); + } + + gpt2_hparams hparams; + + // load hparams + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + finp.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + finp.read((char *) &hparams.f16, sizeof(hparams.f16)); + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: f16 = %d\n", __func__, hparams.f16); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); + fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fout.write((char *) &itype, sizeof(hparams.f16)); + } + + // load vocab + { + int32_t n_vocab = 0; + finp.read ((char *) &n_vocab, sizeof(n_vocab)); + fout.write((char *) &n_vocab, sizeof(n_vocab)); + + if (n_vocab != hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab); + return false; + } + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + finp.read ((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + word.resize(len); + finp.read ((char *) word.data(), len); + fout.write((char *) word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // load weights + { + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector work; + + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; + + std::vector hist_all(1 << 4, 0); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + finp.read(reinterpret_cast(&length), sizeof(length)); + finp.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (finp.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + finp.read (&name[0], length); + + { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%24s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]); + } + + // regexes of tensor names to be quantized + const std::vector k_names = { + "model/wte", + "model/h.*/attn/c_attn/w", + "model/h.*/attn/c_proj/w", + "model/h.*/mlp/c_fc/w", + "model/h.*/mlp/c_proj/w", + }; + + bool quantize = false; + for (const auto & s : k_names) { + if (std::regex_match(name, std::regex(s))) { + quantize = true; + break; + } + } + + if (quantize) { + if (ftype != 0 && ftype != 1) { + fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype); + return false; + } + + if (ftype == 1) { + data_f16.resize(nelements); + finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); + data_f32.resize(nelements); + for (int i = 0; i < nelements; ++i) { + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } + } else { + data_f32.resize(nelements); + finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); + } + + ftype = itype; + } else { + const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t); + + data_u8.resize(nelements*bpe); + finp.read(reinterpret_cast(data_u8.data()), nelements * bpe); + } + + fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); + fout.write(reinterpret_cast(&length), sizeof(length)); + fout.write(reinterpret_cast(&ftype), sizeof(ftype)); + for (int i = 0; i < n_dims; ++i) { + fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + fout.write(&name[0], length); + + if (quantize) { + printf("quantizing .. "); + work.resize(nelements); // for quantization + + size_t cur_size = 0; + std::vector hist_cur(1 << 4, 0); + + switch (type) { + case GGML_TYPE_Q4_0: + { + cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q4_1: + { + cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + default: + { + fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type); + return false; + } + } + + fout.write(reinterpret_cast(work.data()), cur_size); + total_size_new += cur_size; + + printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); + for (int i = 0; i < hist_cur.size(); ++i) { + hist_all[i] += hist_cur[i]; + } + + for (int i = 0; i < hist_cur.size(); ++i) { + printf("%5.3f ", hist_cur[i] / (float)nelements); + } + printf("\n"); + } else { + printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); + fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); + total_size_new += data_u8.size(); + } + + total_size_org += nelements * sizeof(float); + } + + printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + { + int64_t sum_all = 0; + for (int i = 0; i < hist_all.size(); ++i) { + sum_all += hist_all[i]; + } + + printf("%s: hist: ", __func__); + for (int i = 0; i < hist_all.size(); ++i) { + printf("%5.3f ", hist_all[i] / (float)sum_all); + } + printf("\n"); + } + } + + finp.close(); + fout.close(); + + return true; +} + +// usage: +// ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type +// +int main(int argc, char ** argv) { + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + fprintf(stderr, " type = 2 - q4_0\n"); + fprintf(stderr, " type = 3 - q4_1\n"); + return 1; + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + + const int itype = atoi(argv[3]); + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_quantize_us = 0; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!gpt2_model_quantize(fname_inp, fname_out, itype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); + return 1; + } + + t_quantize_us = ggml_time_us() - t_start_us; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n"); + printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + return 0; +} diff --git a/examples/gpt-j/CMakeLists.txt b/examples/gpt-j/CMakeLists.txt index 4199a3fae..390746d50 100644 --- a/examples/gpt-j/CMakeLists.txt +++ b/examples/gpt-j/CMakeLists.txt @@ -4,3 +4,10 @@ set(TEST_TARGET gpt-j) add_executable(${TEST_TARGET} main.cpp) target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) + +# +# gpt-j-quantize + +set(TEST_TARGET gpt-j-quantize) +add_executable(${TEST_TARGET} quantize.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) diff --git a/examples/gpt-j/convert-h5-to-ggml.py b/examples/gpt-j/convert-h5-to-ggml.py index 310e60e0a..e254f2cc3 100644 --- a/examples/gpt-j/convert-h5-to-ggml.py +++ b/examples/gpt-j/convert-h5-to-ggml.py @@ -47,8 +47,10 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) -if len(sys.argv) < 2: +if len(sys.argv) < 3: print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") sys.exit(1) # output in the same directory as the model @@ -64,11 +66,21 @@ def bytes_to_unicode(): with open(dir_model + "/config.json", "r") as f: hparams = json.load(f) -# use 16-bit or 32-bit floats -use_f16 = True +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 if len(sys.argv) > 2: - use_f16 = False - fname_out = sys.argv[1] + "/ggml-model-f32.bin" + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + sys.exit(1) + fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + model = GPTJForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True) #print (model) @@ -85,7 +97,7 @@ def bytes_to_unicode(): fout.write(struct.pack("i", hparams["n_head"])) fout.write(struct.pack("i", hparams["n_layer"])) fout.write(struct.pack("i", hparams["rotary_dim"])) -fout.write(struct.pack("i", use_f16)) +fout.write(struct.pack("i", ftype)) byte_encoder = bytes_to_unicode() byte_decoder = {v:k for k, v in byte_encoder.items()} @@ -114,34 +126,40 @@ def bytes_to_unicode(): n_dims = len(data.shape); # ftype == 0 -> float32, ftype == 1 -> float16 - ftype = 0; - if use_f16: + ftype_cur = 0; + if ftype != 0: if name[-7:] == ".weight" and n_dims == 2: print(" Converting to float16") data = data.astype(np.float16) - ftype = 1 + ftype_cur = 1 else: print(" Converting to float32") data = data.astype(np.float32) - ftype = 0 + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 # for efficiency - transpose these matrices: - # "transformer.h.*.mlp.fc_in.weight - # "transformer.h.*.attn.out_proj.weight + # (note - with latest ggml this is no longer more efficient, so disabling it) + # "transformer.h.*.mlp.fc_in.weight" + # "transformer.h.*.attn.out_proj.weight" # "transformer.h.*.attn.q_proj.weight" # "transformer.h.*.attn.k_proj.weight" # "transformer.h.*.attn.v_proj.weight" - if name.endswith(".mlp.fc_in.weight") or \ - name.endswith(".attn.out_proj.weight") or \ - name.endswith(".attn.q_proj.weight") or \ - name.endswith(".attn.k_proj.weight") or \ - name.endswith(".attn.v_proj.weight"): - print(" Transposing") - data = data.transpose() + #if name.endswith(".mlp.fc_in.weight") or \ + # name.endswith(".attn.out_proj.weight") or \ + # name.endswith(".attn.q_proj.weight") or \ + # name.endswith(".attn.k_proj.weight") or \ + # name.endswith(".attn.v_proj.weight"): + # print(" Transposing") + # data = data.transpose() # header str = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(str), ftype)) + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) fout.write(str); diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp index 5c6ce93a0..c059e1b72 100644 --- a/examples/gpt-j/main.cpp +++ b/examples/gpt-j/main.cpp @@ -40,7 +40,7 @@ struct gptj_layer { struct ggml_tensor * c_mlp_fc_w; struct ggml_tensor * c_mlp_fc_b; - struct ggml_tensor * c_mlp_proj_w_trans; + struct ggml_tensor * c_mlp_proj_w; struct ggml_tensor * c_mlp_proj_b; }; @@ -132,9 +132,23 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & } } - // for the big tensors, we have the option to store the data in 16-bit floats + // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + ggml_type wtype = GGML_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + case 2: wtype = GGML_TYPE_Q4_0; break; + case 3: wtype = GGML_TYPE_Q4_1; break; + default: + { + fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", + __func__, fname.c_str(), model.hparams.f16); + return false; + } + } + + const ggml_type wtype2 = GGML_TYPE_F32; auto & ctx = model.ctx; @@ -148,31 +162,31 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b - ctx_size += n_embd*n_vocab*ggml_type_size(wtype); // wte + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte - ctx_size += n_embd*n_vocab*ggml_type_size(wtype); // lmh_g - ctx_size += n_vocab*ggml_type_size(GGML_TYPE_F32); // lmh_b + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // lmh_g + ctx_size += n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_q_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_k_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_v_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_q_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_k_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_v_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w - ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w_trans - ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v ctx_size += (5 + 10*n_layer)*256; // object overhead @@ -224,20 +238,20 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & for (int i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; - layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.c_attn_q_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.c_attn_k_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.c_attn_v_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_q_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_k_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_v_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); - layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); + layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); + layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); - layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); - layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); + layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // map by name model.tensors["transformer.h." + std::to_string(i) + ".ln_1.weight"] = layer.ln_1_g; @@ -252,7 +266,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.weight"] = layer.c_mlp_fc_w; model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.bias"] = layer.c_mlp_fc_b; - model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"] = layer.c_mlp_proj_w_trans; + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"] = layer.c_mlp_proj_w; model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.bias"] = layer.c_mlp_proj_b; } } @@ -323,9 +337,26 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & return false; } - const size_t bpe = tensor->type == GGML_TYPE_I8 ? 1 : (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); + } - if (nelements*bpe != ggml_nbytes(tensor)) { + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; @@ -430,9 +461,9 @@ bool gptj_eval( // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_q_proj_w), cur); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_k_proj_w), cur); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, ggml_transpose(ctx0, model.layers[il].c_attn_v_proj_w), cur); + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur); // store key and value to memory if (N >= 1) { @@ -481,11 +512,13 @@ bool gptj_eval( // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() struct ggml_tensor * V_trans = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), - n_embd/n_head, n_head, n_past + N), - 1, 2, 0, 3); + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + n_embd/n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head)); // KQV = transpose(V) * KQ_soft_max struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -500,7 +533,7 @@ bool gptj_eval( // projection (no bias) cur = ggml_mul_mat(ctx0, - ggml_transpose(ctx0, model.layers[il].c_attn_proj_w), + model.layers[il].c_attn_proj_w, cur); } @@ -511,7 +544,7 @@ bool gptj_eval( { // note here we pass inpSA instead of cur cur = ggml_mul_mat(ctx0, - ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w), + model.layers[il].c_mlp_fc_w, inpSA); cur = ggml_add(ctx0, @@ -524,7 +557,7 @@ bool gptj_eval( // projection // cur = proj_w*cur + proj_b cur = ggml_mul_mat(ctx0, - model.layers[il].c_mlp_proj_w_trans, + model.layers[il].c_mlp_proj_w, cur); cur = ggml_add(ctx0, diff --git a/examples/gpt-j/quantize.cpp b/examples/gpt-j/quantize.cpp new file mode 100644 index 000000000..29b92cfe4 --- /dev/null +++ b/examples/gpt-j/quantize.cpp @@ -0,0 +1,324 @@ +#include "ggml/ggml.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// default hparams (GPT-J 6B) +struct gptj_hparams { + int32_t n_vocab = 50400; + int32_t n_ctx = 2048; + int32_t n_embd = 4096; + int32_t n_head = 16; + int32_t n_layer = 28; + int32_t n_rot = 64; + int32_t f16 = 1; +}; + +// quantize a model +bool gptj_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) { + ggml_type type = GGML_TYPE_Q4_1; + + switch (itype) { + case 2: type = GGML_TYPE_Q4_0; break; + case 3: type = GGML_TYPE_Q4_1; break; + default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1; + }; + + if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) { + fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type); + return false; + } + + gpt_vocab vocab; + + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); + return false; + } + + fout.write((char *) &magic, sizeof(magic)); + } + + gptj_hparams hparams; + + // load hparams + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + finp.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + finp.read((char *) &hparams.f16, sizeof(hparams.f16)); + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: f16 = %d\n", __func__, hparams.f16); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); + fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + fout.write((char *) &itype, sizeof(hparams.f16)); + } + + // load vocab + { + int32_t n_vocab = 0; + finp.read ((char *) &n_vocab, sizeof(n_vocab)); + fout.write((char *) &n_vocab, sizeof(n_vocab)); + + if (n_vocab != hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab); + return false; + } + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + finp.read ((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + word.resize(len); + finp.read ((char *) word.data(), len); + fout.write((char *) word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // load weights + { + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector work; + + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; + + std::vector hist_all(1 << 4, 0); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + finp.read(reinterpret_cast(&length), sizeof(length)); + finp.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (finp.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + finp.read (&name[0], length); + + { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]); + } + + // regexes of tensor names to be quantized + const std::vector k_names = { + ".*weight", + }; + + bool quantize = false; + for (const auto & s : k_names) { + if (std::regex_match(name, std::regex(s))) { + quantize = true; + break; + } + } + + // quantize only 2D tensors + quantize &= (n_dims == 2); + + if (quantize) { + if (ftype != 0 && ftype != 1) { + fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype); + return false; + } + + if (ftype == 1) { + data_f16.resize(nelements); + finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); + data_f32.resize(nelements); + for (int i = 0; i < nelements; ++i) { + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } + } else { + data_f32.resize(nelements); + finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); + } + + ftype = itype; + } else { + const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t); + + data_u8.resize(nelements*bpe); + finp.read(reinterpret_cast(data_u8.data()), nelements * bpe); + } + + fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); + fout.write(reinterpret_cast(&length), sizeof(length)); + fout.write(reinterpret_cast(&ftype), sizeof(ftype)); + for (int i = 0; i < n_dims; ++i) { + fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + fout.write(&name[0], length); + + if (quantize) { + printf("quantizing .. "); + work.resize(nelements); // for quantization + + size_t cur_size = 0; + std::vector hist_cur(1 << 4, 0); + + switch (type) { + case GGML_TYPE_Q4_0: + { + cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q4_1: + { + cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + default: + { + fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type); + return false; + } + } + + fout.write(reinterpret_cast(work.data()), cur_size); + total_size_new += cur_size; + + printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); + for (int i = 0; i < hist_cur.size(); ++i) { + hist_all[i] += hist_cur[i]; + } + + for (int i = 0; i < hist_cur.size(); ++i) { + printf("%5.3f ", hist_cur[i] / (float)nelements); + } + printf("\n"); + } else { + printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); + fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); + total_size_new += data_u8.size(); + } + + total_size_org += nelements * sizeof(float); + } + + printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + { + int64_t sum_all = 0; + for (int i = 0; i < hist_all.size(); ++i) { + sum_all += hist_all[i]; + } + + printf("%s: hist: ", __func__); + for (int i = 0; i < hist_all.size(); ++i) { + printf("%5.3f ", hist_all[i] / (float)sum_all); + } + printf("\n"); + } + } + + finp.close(); + fout.close(); + + return true; +} + +// usage: +// ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type +// +int main(int argc, char ** argv) { + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + fprintf(stderr, " type = 2 - q4_0\n"); + fprintf(stderr, " type = 3 - q4_1\n"); + return 1; + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + + const int itype = atoi(argv[3]); + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_quantize_us = 0; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!gptj_model_quantize(fname_inp, fname_out, itype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); + return 1; + } + + t_quantize_us = ggml_time_us() - t_start_us; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n"); + printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + return 0; +} diff --git a/examples/utils.h b/examples/utils.h index d091d3db1..b61173ffd 100644 --- a/examples/utils.h +++ b/examples/utils.h @@ -20,7 +20,7 @@ struct gpt_params { // sampling parameters int32_t top_k = 40; float top_p = 0.9f; - float temp = 1.0f; + float temp = 0.9f; int32_t n_batch = 8; // batch size for prompt processing @@ -81,4 +81,3 @@ gpt_vocab::id gpt_sample_top_k_top_p( double top_p, double temp, std::mt19937 & rng); - diff --git a/examples/whisper/CMakeLists.txt b/examples/whisper/CMakeLists.txt index c8fa83a83..c7f5ff54e 100644 --- a/examples/whisper/CMakeLists.txt +++ b/examples/whisper/CMakeLists.txt @@ -13,3 +13,10 @@ set(TEST_TARGET whisper) add_executable(${TEST_TARGET} main.cpp common.cpp) target_link_libraries(${TEST_TARGET} PRIVATE whisper-cpp) target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) + +# +# whisper-quantize + +set(TEST_TARGET whisper-quantize) +add_executable(${TEST_TARGET} quantize.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) diff --git a/examples/whisper/convert-pt-to-ggml.py b/examples/whisper/convert-pt-to-ggml.py index 9e9b2dceb..749f99c88 100644 --- a/examples/whisper/convert-pt-to-ggml.py +++ b/examples/whisper/convert-pt-to-ggml.py @@ -303,8 +303,9 @@ def bytes_to_unicode(): data = data.astype(np.float32) ftype = 0 else: - data = data.astype(np.float32) - ftype = 0 + if n_dims < 3 and data.dtype != np.float32: + data = data.astype(np.float32) + ftype = 0 #if name.startswith("encoder"): # if name.endswith("mlp.0.weight") or \ diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index b8366b79f..dd30ba4c4 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -73,6 +73,7 @@ struct whisper_params { bool output_srt = false; bool output_wts = false; bool output_csv = false; + bool output_jsn = false; bool print_special = false; bool print_colors = false; bool print_progress = false; @@ -80,6 +81,7 @@ struct whisper_params { std::string language = "en"; std::string prompt; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; std::string model = "models/ggml-base.en.bin"; std::vector fname_inp = {}; @@ -127,7 +129,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } @@ -174,7 +178,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false"); fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); @@ -193,7 +199,7 @@ struct whisper_print_user_data { const std::vector> * pcmf32s; }; -void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { +void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -352,28 +358,157 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); const int n_segments = whisper_full_n_segments(ctx); + fout << "start,end,text\n"; for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. - fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n"; + fout << 10 * t0 << "," << 10 * t1 << ",\"" << text << "\"\n"; } return true; } +bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) { + std::ofstream fout(fname); + int indent = 0; + + auto doindent = [&]() { + for (int i = 0; i < indent; i++) fout << "\t"; + }; + + auto start_arr = [&](const char *name) { + doindent(); + fout << "\"" << name << "\": [\n"; + indent++; + }; + + auto end_arr = [&](bool end = false) { + indent--; + doindent(); + fout << (end ? "]\n" : "},\n"); + }; + + auto start_obj = [&](const char *name = nullptr) { + doindent(); + if (name) { + fout << "\"" << name << "\": {\n"; + } else { + fout << "{\n"; + } + indent++; + }; + + auto end_obj = [&](bool end = false) { + indent--; + doindent(); + fout << (end ? "}\n" : "},\n"); + }; + + auto start_value = [&](const char *name) { + doindent(); + fout << "\"" << name << "\": "; + }; + + auto value_s = [&](const char *name, const char *val, bool end = false) { + start_value(name); + fout << "\"" << val << (end ? "\"\n" : "\",\n"); + }; + + auto end_value = [&](bool end = false) { + fout << (end ? "\n" : ",\n"); + }; + + auto value_i = [&](const char *name, const int64_t val, bool end = false) { + start_value(name); + fout << val; + end_value(end); + }; + + auto value_b = [&](const char *name, const bool val, bool end = false) { + start_value(name); + fout << (val ? "true" : "false"); + end_value(end); + }; + + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + start_obj(); + value_s("systeminfo", whisper_print_system_info()); + start_obj("model"); + value_s("type", whisper_model_type_readable(ctx)); + value_b("multilingual", whisper_is_multilingual(ctx)); + value_i("vocab", whisper_model_n_vocab(ctx)); + start_obj("audio"); + value_i("ctx", whisper_model_n_audio_ctx(ctx)); + value_i("state", whisper_model_n_audio_state(ctx)); + value_i("head", whisper_model_n_audio_head(ctx)); + value_i("layer", whisper_model_n_audio_layer(ctx), true); + end_obj(); + start_obj("text"); + value_i("ctx", whisper_model_n_text_ctx(ctx)); + value_i("state", whisper_model_n_text_state(ctx)); + value_i("head", whisper_model_n_text_head(ctx)); + value_i("leyer", whisper_model_n_text_layer(ctx), true); + end_obj(); + value_i("mels", whisper_model_n_mels(ctx)); + value_i("f16", whisper_model_f16(ctx), true); + end_obj(); + start_obj("params"); + value_s("model", params.model.c_str()); + value_s("language", params.language.c_str()); + value_b("translate", params.translate, true); + end_obj(); + start_obj("result"); + value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true); + end_obj(); + start_arr("transcription"); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + start_obj(); + start_obj("timestanps"); + value_s("from", to_timestamp(t0, true).c_str()); + value_s("to", to_timestamp(t1, true).c_str(), true); + end_obj(); + start_obj("offsets"); + value_i("from", t0 * 10); + value_i("to", t1 * 10, true); + end_obj(); + value_s("text", text, true); + end_obj(i == (n_segments - 1)); + } + + end_arr(true); + end_obj(true); + return true; +} + // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) { +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { std::ofstream fout(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); - // TODO: become parameter - static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + static const char * font = params.font_path.c_str(); + + std::ifstream fin(font); + if (!fin.is_open()) { + fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font); + return false; + } fout << "#!/bin/bash" << "\n"; fout << "\n"; @@ -607,7 +742,7 @@ int main(int argc, char ** argv) { { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { bool is_aborted = *(bool*)user_data; return !is_aborted; }; @@ -653,6 +788,12 @@ int main(int argc, char ** argv) { const auto fname_csv = fname_out + ".csv"; output_csv(ctx, fname_csv.c_str()); } + + // output to JSON file + if (params.output_jsn) { + const auto fname_jsn = fname_out + ".json"; + output_json(ctx, fname_jsn.c_str(), params); + } } } diff --git a/examples/whisper/quantize.cpp b/examples/whisper/quantize.cpp new file mode 100644 index 000000000..8042d69cd --- /dev/null +++ b/examples/whisper/quantize.cpp @@ -0,0 +1,373 @@ +#include "ggml/ggml.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// default hparams (Whisper tiny) +struct whisper_hparams { + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; + int32_t n_audio_state = 384; + int32_t n_audio_head = 6; + int32_t n_audio_layer = 4; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t f16 = 1; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// quantize a model +bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) { + ggml_type type = GGML_TYPE_Q4_1; + + switch (itype) { + case 2: type = GGML_TYPE_Q4_0; break; + case 3: type = GGML_TYPE_Q4_1; break; + default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1; + }; + + if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) { + fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type); + return false; + } + + gpt_vocab vocab; + + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); + return false; + } + + fout.write((char *) &magic, sizeof(magic)); + } + + whisper_hparams hparams; + + // load hparams + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + finp.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx)); + finp.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state)); + finp.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head)); + finp.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer)); + finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + finp.read((char *) &hparams.f16, sizeof(hparams.f16)); + + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state); + fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); + fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); + fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + fout.write((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx)); + fout.write((char *) &hparams.n_text_state, sizeof(hparams.n_text_state)); + fout.write((char *) &hparams.n_text_head, sizeof(hparams.n_text_head)); + fout.write((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer)); + fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + fout.write((char *) &itype, sizeof(hparams.f16)); + } + + // load mel filters + { + whisper_filters filters; + + finp.read ((char *) &filters.n_mel, sizeof(filters.n_mel)); + fout.write((char *) &filters.n_mel, sizeof(filters.n_mel)); + finp.read ((char *) &filters.n_fft, sizeof(filters.n_fft)); + fout.write((char *) &filters.n_fft, sizeof(filters.n_fft)); + + filters.data.resize(filters.n_mel * filters.n_fft); + finp.read ((char *) filters.data.data(), filters.data.size() * sizeof(float)); + fout.write((char *) filters.data.data(), filters.data.size() * sizeof(float)); + } + + // load vocab + { + int32_t n_vocab = 0; + finp.read ((char *) &n_vocab, sizeof(n_vocab)); + fout.write((char *) &n_vocab, sizeof(n_vocab)); + + //if (n_vocab != hparams.n_vocab) { + // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab); + // return false; + //} + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + finp.read ((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + word.resize(len); + finp.read ((char *) word.data(), len); + fout.write((char *) word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // load weights + { + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector work; + + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; + + std::vector hist_all(1 << 4, 0); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + finp.read(reinterpret_cast(&length), sizeof(length)); + finp.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (finp.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[3] = { 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + finp.read (&name[0], length); + + { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%48s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ftype_str[ftype]); + } + + // regexes of tensor names to not be quantized + const std::vector k_names = { + //"encoder.*", + "encoder.conv1.bias", + "encoder.conv2.bias", + "encoder.positional_embedding", + "decoder.positional_embedding", + }; + + bool quantize = true; + for (const auto & s : k_names) { + if (std::regex_match(name, std::regex(s))) { + quantize = false; + break; + } + } + + // quantize only 2D and 3D tensors + quantize &= (n_dims == 2); + + if (quantize) { + if (ftype != 0 && ftype != 1) { + fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype); + return false; + } + + if (ftype == 1) { + data_f16.resize(nelements); + finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); + data_f32.resize(nelements); + for (int i = 0; i < nelements; ++i) { + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } + } else { + data_f32.resize(nelements); + finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); + } + + ftype = itype; + } else { + const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t); + + data_u8.resize(nelements*bpe); + finp.read(reinterpret_cast(data_u8.data()), nelements * bpe); + } + + fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); + fout.write(reinterpret_cast(&length), sizeof(length)); + fout.write(reinterpret_cast(&ftype), sizeof(ftype)); + for (int i = 0; i < n_dims; ++i) { + fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + fout.write(&name[0], length); + + if (quantize) { + printf("quantizing .. "); + work.resize(nelements); // for quantization + + size_t cur_size = 0; + std::vector hist_cur(1 << 4, 0); + + switch (type) { + case GGML_TYPE_Q4_0: + { + cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q4_1: + { + cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + default: + { + fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type); + return false; + } + } + + fout.write(reinterpret_cast(work.data()), cur_size); + total_size_new += cur_size; + + printf("size = %8.3f MB -> %8.3f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); + for (int i = 0; i < hist_cur.size(); ++i) { + hist_all[i] += hist_cur[i]; + } + + for (int i = 0; i < hist_cur.size(); ++i) { + printf("%5.3f ", hist_cur[i] / (float)nelements); + } + printf("\n"); + } else { + printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); + fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); + total_size_new += data_u8.size(); + } + + total_size_org += nelements * sizeof(float); + } + + printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + { + int64_t sum_all = 0; + for (int i = 0; i < hist_all.size(); ++i) { + sum_all += hist_all[i]; + } + + printf("%s: hist: ", __func__); + for (int i = 0; i < hist_all.size(); ++i) { + printf("%5.3f ", hist_all[i] / (float)sum_all); + } + printf("\n"); + } + } + + finp.close(); + fout.close(); + + return true; +} + +// usage: +// ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type +// +int main(int argc, char ** argv) { + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + fprintf(stderr, " type = 2 - q4_0\n"); + fprintf(stderr, " type = 3 - q4_1\n"); + return 1; + } + + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + + const int itype = atoi(argv[3]); + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_quantize_us = 0; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!whisper_model_quantize(fname_inp, fname_out, itype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); + return 1; + } + + t_quantize_us = ggml_time_us() - t_start_us; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n"); + printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + return 0; +} diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 3a21581c6..f44e5034b 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -218,14 +218,14 @@ static const std::map> g_lang = { { "su", { 98, "sundanese", } }, }; -static const size_t MB = 1024*1024; +static const size_t MB = 1ull*1024*1024; static const std::map MEM_REQ_SCRATCH0 = { - { MODEL_TINY, 12ull*MB }, - { MODEL_BASE, 15ull*MB }, - { MODEL_SMALL, 23ull*MB }, - { MODEL_MEDIUM, 31ull*MB }, - { MODEL_LARGE, 38ull*MB }, + { MODEL_TINY, 14ull*MB }, + { MODEL_BASE, 18ull*MB }, + { MODEL_SMALL, 28ull*MB }, + { MODEL_MEDIUM, 36ull*MB }, + { MODEL_LARGE, 44ull*MB }, }; static const std::map MEM_REQ_SCRATCH1 = { @@ -547,13 +547,11 @@ struct whisper_decoder { std::vector tokens_tmp; // used for whisper_decode calls }; -struct whisper_context { - int64_t t_load_us = 0; - int64_t t_mel_us = 0; +struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; - int64_t t_start_us = 0; + int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls @@ -561,16 +559,10 @@ struct whisper_context { int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures - ggml_type wtype; // weight type (FP32 or FP16) - - whisper_mel mel; - - whisper_model model; - whisper_vocab vocab; - // cross-attention KV cache for the decoders // shared between all decoders whisper_kv_cache kv_cross; + whisper_mel mel; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; @@ -635,6 +627,19 @@ struct whisper_context { } }; +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16) + + whisper_model model; + whisper_vocab vocab; + whisper_state * state = nullptr; + + std::string path_model; // populated by whisper_init_from_file() +}; + template static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); @@ -821,32 +826,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); - fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - - wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); - wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); - wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); - wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); + // we skip initialization of the state until it is needed + // because it might be that state will always be provided externally. } // load mel filters @@ -929,17 +910,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con vocab.id_to_token[i] = word; } } - - wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - - wctx.logits_id.reserve(n_vocab); - - // TAGS: WHISPER_DECODER_INIT - wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); - - wctx.decoders[0].probs.reserve (vocab.n_vocab); - wctx.decoders[0].logits.reserve (vocab.n_vocab); - wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } size_t ctx_size = 0; @@ -1339,33 +1309,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - wctx.rng = std::mt19937(0); - wctx.t_load_us = ggml_time_us() - t_start_us; return true; } -// evaluate the encoder +// evaluate the encoder with the given state // // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder // part of the transformer model and returns the encoded features // -// - model: the model +// - wctx: the model +// - wstate: the state of the encoder // - n_threads: number of threads to use // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // -static bool whisper_encode( +static bool whisper_encode_internal( whisper_context & wctx, + whisper_state & wstate, const int mel_offset, - const int n_threads) { + const int n_threads){ + const int64_t t_start_us = ggml_time_us(); const auto & model = wctx.model; - const auto & mel_inp = wctx.mel; + const auto & mel_inp = wstate.mel; const auto & hparams = model.hparams; - const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; @@ -1374,12 +1345,12 @@ static bool whisper_encode( assert(mel_inp.n_mel == n_mels); struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); + params.mem_size = wstate.buf_compute.size(); + params.mem_buffer = wstate.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); @@ -1401,30 +1372,30 @@ static bool whisper_encode( // convolution + gelu { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); + ggml_repeat(ctx0, + model.e_conv_1_b, + cur), + cur); cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); + ggml_repeat(ctx0, + model.e_conv_2_b, + cur), + cur); cur = ggml_gelu(ctx0, cur); } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1459,54 +1430,54 @@ static bool whisper_encode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); } // self-attention { - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); + layer.attn_q_w, + cur); Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_q_b, - Qcur), - Qcur); + ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); + layer.attn_k_w, + cur); //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); + layer.attn_v_w, + cur); Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); + ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); // ------ - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = @@ -1583,29 +1554,29 @@ static bool whisper_encode( #endif struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); } // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); + layer.attn_ln_1_w, + cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.attn_ln_1_b, cur), - cur); + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpL); @@ -1616,61 +1587,61 @@ static bool whisper_encode( { // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } #ifdef WHISPER_USE_FLASH_FF - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_flash_ff(ctx0, - ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), - layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); + layer.mlp_0_w, + cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_0_b, cur), - cur); + ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); + layer.mlp_1_w, + cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_1_b, cur), - cur); + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); #endif } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -1679,21 +1650,21 @@ static bool whisper_encode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.e_ln_w, cur), - cur), - ggml_repeat(ctx0, model.e_ln_b, cur)); + ggml_mul(ctx0, + ggml_repeat(ctx0, model.e_ln_w, cur), + cur), + ggml_repeat(ctx0, model.e_ln_b, cur)); } - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); // run the computation { @@ -1701,7 +1672,7 @@ static bool whisper_encode( gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); - ggml_graph_compute (ctx0, &gf); + ggml_graph_compute(ctx0, &gf); //ggml_graph_print(&gf); } @@ -1731,34 +1702,34 @@ static bool whisper_encode( cur->src1 = nullptr; for (int il = 0; il < model.hparams.n_text_layer; ++il) { - auto & layer = model.layers_decoder[il]; + auto& layer = model.layers_decoder[il]; - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); - struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, - layer.cross_attn_k_w, - cur); + struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); - Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25))); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); - struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, - layer.cross_attn_v_w, - cur); + struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); Vcross = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_v_b, - Vcross), - Vcross); + ggml_repeat(ctx0, + layer.cross_attn_v_b, + Vcross), + Vcross); - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); - //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1772,15 +1743,15 @@ static bool whisper_encode( //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, // ggml_used_mem(ctx0)/1024.0/1024.0, - // wctx.get_buf_max_mem(0)/1024.0/1024.0, - // wctx.get_buf_max_mem(1)/1024.0/1024.0, - // wctx.get_buf_max_mem(2)/1024.0/1024.0, - // wctx.get_buf_max_mem(3)/1024.0/1024.0); + // wstate.get_buf_max_mem(0)/1024.0/1024.0, + // wstate.get_buf_max_mem(1)/1024.0/1024.0, + // wstate.get_buf_max_mem(2)/1024.0/1024.0, + // wstate.get_buf_max_mem(3)/1024.0/1024.0); ggml_free(ctx0); - wctx.t_encode_us += ggml_time_us() - t_start_us; - wctx.n_encode++; + wstate.t_encode_us += ggml_time_us() - t_start_us; + wstate.n_encode++; return true; } @@ -1795,8 +1766,9 @@ static bool whisper_encode( // - n_tokens: number of tokens in the prompt // - n_past: number of past tokens to prefix the prompt with // -static bool whisper_decode( +static bool whisper_decode_internal( whisper_context & wctx, + whisper_state & wstate, whisper_decoder & decoder, const whisper_token * tokens, const int n_tokens, @@ -1811,7 +1783,7 @@ static bool whisper_decode( WHISPER_ASSERT(!!kv_self.ctx); - auto & logits_out = wctx.logits; + auto & logits_out = wstate.logits; const int n_vocab = hparams.n_vocab; @@ -1821,13 +1793,13 @@ static bool whisper_decode( const int n_layer = hparams.n_text_layer; const int N = n_tokens; - const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); + params.mem_size = wstate.buf_compute.size(); + params.mem_buffer = wstate.buf_compute.data(); struct ggml_context * ctx0 = ggml_init(params); @@ -1842,7 +1814,7 @@ static bool whisper_decode( ((int32_t *) position->data)[i] = n_past + i; } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); // token encoding + position encoding struct ggml_tensor * cur = @@ -1857,7 +1829,7 @@ static bool whisper_decode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpL); @@ -1871,8 +1843,6 @@ static bool whisper_decode( // self-attention { - wctx.use_buf(ctx0, 1); - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); @@ -1913,7 +1883,7 @@ static bool whisper_decode( // ------ - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); struct ggml_tensor * Q = ggml_permute(ctx0, @@ -1929,13 +1899,11 @@ static bool whisper_decode( n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - wctx.use_buf(ctx0, 0); - //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, // KQ, @@ -1944,20 +1912,16 @@ static bool whisper_decode( struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - wctx.use_buf(ctx0, 1); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - wctx.use_buf(ctx0, 0); - struct ggml_tensor * V_trans = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), - n_state/n_head, n_head, n_past + N), - 1, 2, 0, 3); - - wctx.use_buf(ctx0, 1); + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), + n_state/n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); @@ -1970,32 +1934,30 @@ static bool whisper_decode( // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.attn_ln_1_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.attn_ln_1_b, cur), cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here - wctx.use_buf(ctx0, 1); - // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -2006,8 +1968,6 @@ static bool whisper_decode( // cross-attention { - wctx.use_buf(ctx0, 0); - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.cross_attn_q_w, cur); @@ -2023,20 +1983,21 @@ static bool whisper_decode( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), + ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), + ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), n_state/n_head, n_head, M); - struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); + struct ggml_tensor * V_trans = + ggml_cpy(ctx0, + ggml_permute(ctx0, Vcross, 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); // ------ - wctx.use_buf(ctx0, 1); - struct ggml_tensor * Q = ggml_permute(ctx0, ggml_cpy(ctx0, @@ -2046,8 +2007,6 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); - wctx.use_buf(ctx0, 0); - // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -2060,16 +2019,10 @@ static bool whisper_decode( // no masking for cross-attention //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - wctx.use_buf(ctx0, 1); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - wctx.use_buf(ctx0, 0); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - wctx.use_buf(ctx0, 1); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_state, N) @@ -2080,20 +2033,20 @@ static bool whisper_decode( // projection { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_mul_mat(ctx0, layer.cross_attn_ln_1_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), cur); } - wctx.use_buf(ctx0, 2); + wstate.use_buf(ctx0, 2); // add the input cur = ggml_add(ctx0, cur, inpCA); @@ -2104,11 +2057,11 @@ static bool whisper_decode( { // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, inpFF); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, @@ -2118,39 +2071,39 @@ static bool whisper_decode( ggml_repeat(ctx0, layer.mlp_ln_b, cur)); } - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // fully connected cur = ggml_mul_mat(ctx0, layer.mlp_0_w, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_0_b, cur), cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // GELU activation cur = ggml_gelu(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); // projection cur = ggml_mul_mat(ctx0, layer.mlp_1_w, cur); - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.mlp_1_b, cur), cur); } - wctx.use_buf(ctx0, 3); + wstate.use_buf(ctx0, 3); inpL = ggml_add(ctx0, cur, inpFF); } @@ -2159,11 +2112,11 @@ static bool whisper_decode( // norm { - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); cur = ggml_norm(ctx0, cur); - wctx.use_buf(ctx0, 1); + wstate.use_buf(ctx0, 1); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -2172,7 +2125,7 @@ static bool whisper_decode( ggml_repeat(ctx0, model.d_ln_b, cur)); } - wctx.use_buf(ctx0, 0); + wstate.use_buf(ctx0, 0); // compute logits only for the last token // comment this line to compute logits for all N tokens @@ -2181,7 +2134,7 @@ static bool whisper_decode( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - wctx.use_buf(ctx0, -1); + wstate.use_buf(ctx0, -1); // run the computation { @@ -2200,16 +2153,16 @@ static bool whisper_decode( if (N > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, // ggml_used_mem(ctx0)/1024.0/1024.0, - // wctx.get_buf_max_mem(0)/1024.0/1024.0, - // wctx.get_buf_max_mem(1)/1024.0/1024.0, - // wctx.get_buf_max_mem(2)/1024.0/1024.0, - // wctx.get_buf_max_mem(3)/1024.0/1024.0); + // wstate.get_buf_max_mem(0)/1024.0/1024.0, + // wstate.get_buf_max_mem(1)/1024.0/1024.0, + // wstate.get_buf_max_mem(2)/1024.0/1024.0, + // wstate.get_buf_max_mem(3)/1024.0/1024.0); } ggml_free(ctx0); - wctx.t_decode_us += ggml_time_us() - t_start_us; - wctx.n_decode++; + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; return true; } @@ -2313,7 +2266,7 @@ static void fft(const std::vector & in, std::vector & out) { // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 static bool log_mel_spectrogram( - whisper_context & wctx, + whisper_state & wstate, const float * samples, const int n_samples, const int /*sample_rate*/, @@ -2433,7 +2386,7 @@ static bool log_mel_spectrogram( mel.data[i] = (mel.data[i] + 4.0)/4.0; } - wctx.t_mel_us += ggml_time_us() - t_start_us; + wstate.t_mel_us += ggml_time_us() - t_start_us; return true; } @@ -2507,7 +2460,54 @@ static std::vector tokenize(const whisper_vocab & vocab, cons // interface implementation // -struct whisper_context * whisper_init_from_file(const char * path_model) { +struct whisper_state * whisper_init_state(whisper_context * ctx) { + whisper_state * state = new whisper_state; + + const size_t scale = ctx->model.hparams.f16 ? 1 : 2; + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); + + state->logits_id.reserve(ctx->model.hparams.n_vocab); + + // TAGS: WHISPER_DECODER_INIT + state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); + + state->decoders[0].probs.reserve(ctx->vocab.n_vocab); + state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type))); + + state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); + state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); + state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type)); + state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type)); + + state->rng = std::mt19937(0); + + return state; +} + +struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { whisper_model_loader loader = {}; fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); @@ -2535,10 +2535,16 @@ struct whisper_context * whisper_init_from_file(const char * path_model) { fin->close(); }; - return whisper_init(&loader); + auto ctx = whisper_init_no_state(&loader); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; } -struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { +struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { struct buf_context { uint8_t* buffer; size_t size; @@ -2571,10 +2577,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s loader.close = [](void * /*ctx*/) { }; - return whisper_init(&loader); + return whisper_init_no_state(&loader); } -struct whisper_context * whisper_init(struct whisper_model_loader * loader) { +struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { ggml_time_init(); whisper_context * ctx = new whisper_context; @@ -2591,6 +2597,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { return ctx; } +struct whisper_context * whisper_init_from_file(const char * path_model) { + whisper_context * ctx = whisper_init_from_file_no_state(path_model); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { + whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init(struct whisper_model_loader * loader) { + whisper_context * ctx = whisper_init_no_state(loader); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +void whisper_free_state(struct whisper_state * state) +{ + if (state) { + kv_cache_free(state->kv_cross); + + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + kv_cache_free(state->decoders[i].kv_self); + } + + delete state; + } +} + void whisper_free(struct whisper_context * ctx) { if (ctx) { if (ctx->model.ctx) { @@ -2599,20 +2663,29 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.buf) { delete ctx->model.buf; } - if (ctx->kv_cross.ctx) { - ggml_free(ctx->kv_cross.ctx); - } - for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { - if (ctx->decoders[i].kv_self.ctx) { - ggml_free(ctx->decoders[i].kv_self.ctx); - } - } + + whisper_free_state(ctx->state); + delete ctx; } } +int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2622,11 +2695,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { - fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int whisper_set_mel_with_state( + struct whisper_context * /*ctx*/, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != WHISPER_N_MEL) { + fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); return -1; } + state->mel.n_len = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + return 0; } @@ -2635,22 +2723,20 @@ int whisper_set_mel( const float * data, int n_len, int n_mel) { - if (n_mel != WHISPER_N_MEL) { - fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); + return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { + if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); return -1; } - ctx->mel.n_len = n_len; - ctx->mel.n_mel = n_mel; - - ctx->mel.data.resize(n_len*n_mel); - memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); - return 0; } int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - if (!whisper_encode(*ctx, offset, n_threads)) { + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return -1; } @@ -2658,11 +2744,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { return 0; } +int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + const int selected_decoder_id = 0; + + if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - // TODO: add selected_decoder_id to context + // TODO: add selected_decoder_id to state const int selected_decoder_id = 0; - if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + if (ctx->state == nullptr) { + fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__); + return false; + } + + + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2720,11 +2823,12 @@ const char * whisper_lang_str(int id) { return nullptr; } -int whisper_lang_auto_detect( +int whisper_lang_auto_detect_with_state( struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs) { + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs) { const int seek = offset_ms/10; if (seek < 0) { @@ -2732,30 +2836,30 @@ int whisper_lang_auto_detect( return -1; } - if (seek >= ctx->mel.n_len) { - fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + if (seek >= state->mel.n_len) { + fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10); return -2; } // run the encoder - if (whisper_encode(ctx, seek, n_threads) != 0) { + if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); return -6; } const std::vector prompt = { whisper_token_sot(ctx) }; - if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } - auto & logits_id = ctx->logits_id; + auto & logits_id = state->logits_id; logits_id.clear(); for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); - logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); + logits_id.emplace_back(state->logits[token_lang], kv.second.first); } // sort descending @@ -2794,8 +2898,85 @@ int whisper_lang_auto_detect( return logits_id[0].second; } +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); +} + +int whisper_model_n_vocab(struct whisper_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int whisper_model_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_model_n_audio_state(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int whisper_model_n_audio_head(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int whisper_model_n_audio_layer(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int whisper_model_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_model_n_text_state(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_state; +} + +int whisper_model_n_text_head(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_head; +} + +int whisper_model_n_text_layer(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_layer; +} + +int whisper_model_n_mels(struct whisper_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int whisper_model_f16(struct whisper_context * ctx) { + return ctx->model.hparams.f16; +} + +int whisper_model_type(struct whisper_context * ctx) { + return ctx->model.type; +} + +const char *whisper_model_type_readable(struct whisper_context * ctx) { + switch (ctx->model.type) { + case e_model::MODEL_TINY: + return "tiny"; + case e_model::MODEL_BASE: + return "base"; + case e_model::MODEL_SMALL: + return "small"; + case e_model::MODEL_MEDIUM: + return "medium"; + case e_model::MODEL_LARGE: + return "large"; + default: + return "unknown"; + } +} + +int whisper_n_len_from_state(struct whisper_state * state) { + return state->mel.n_len; +} + int whisper_n_len(struct whisper_context * ctx) { - return ctx->mel.n_len; + return ctx->state->mel.n_len; } int whisper_n_vocab(struct whisper_context * ctx) { @@ -2815,7 +2996,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) { } float * whisper_get_logits(struct whisper_context * ctx) { - return ctx->logits.data(); + return ctx->state->logits.data(); +} + + +float * whisper_get_logits_from_state(struct whisper_state * state) { + return state->logits.data(); } const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { @@ -2861,24 +3047,29 @@ whisper_token whisper_token_transcribe(void) { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_encode = std::max(1, ctx->n_encode); - const int32_t n_decode = std::max(1, ctx->n_decode); - fprintf(stderr, "\n"); - fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); - fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample); - fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode); - fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + + fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + } fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { - ctx->t_sample_us = 0; - ctx->t_encode_us = 0; - ctx->t_decode_us = 0; + if (ctx->state != nullptr) { + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + } } const char * whisper_print_system_info(void) { @@ -2913,7 +3104,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.duration_ms =*/ 0, /*.translate =*/ false, - /*.no_context =*/ false, + /*.no_context =*/ true, /*.single_segment =*/ false, /*.print_special =*/ false, /*.print_progress =*/ true, @@ -2942,7 +3133,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_ts =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.2f, + /*.temperature_inc =*/ 0.0f, // TODO: temporary disabled until improve performance /*.entropy_thold =*/ 2.4f, /*.logprob_thold =*/ -1.0f, /*.no_speech_thold =*/ 0.6f, @@ -2991,6 +3182,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( struct whisper_context & ctx, + struct whisper_state & state, int i_segment, float thold_pt, float thold_ptsum); @@ -3023,8 +3215,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) { // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) { - auto segment = ctx.result_all.back(); +static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { + auto segment = state.result_all.back(); int res = 1; int acc = 0; @@ -3046,24 +3238,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool trim(text); } - ctx.result_all.back().text = std::move(text); - ctx.result_all.back().t1 = token.t0; - ctx.result_all.back().tokens.resize(i); + state.result_all.back().text = std::move(text); + state.result_all.back().t1 = token.t0; + state.result_all.back().tokens.resize(i); - ctx.result_all.push_back({}); - ctx.result_all.back().t0 = token.t0; - ctx.result_all.back().t1 = segment.t1; + state.result_all.push_back({}); + state.result_all.back().t0 = token.t0; + state.result_all.back().t1 = segment.t1; // add tokens [i, end] to the new segment - ctx.result_all.back().tokens.insert( - ctx.result_all.back().tokens.end(), + state.result_all.back().tokens.insert( + state.result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end()); acc = 0; text = ""; - segment = ctx.result_all.back(); + segment = state.result_all.back(); i = -1; res++; @@ -3076,7 +3268,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool if (split_on_word) { trim(text); } - ctx.result_all.back().text = std::move(text); + state.result_all.back().text = std::move(text); return res; } @@ -3093,6 +3285,7 @@ static const std::vector non_speech_tokens = { // - computes logprobs and probs static void whisper_process_logits( struct whisper_context & ctx, + struct whisper_state & state, const struct whisper_full_params params, struct whisper_decoder & decoder, float temperature) { @@ -3111,7 +3304,7 @@ static void whisper_process_logits( auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); if (temperature > 0.0f) { for (int i = 0; i < n_logits; i++) { @@ -3149,7 +3342,7 @@ static void whisper_process_logits( logits[vocab.token_transcribe] = -INFINITY; if (params.logits_filter_callback) { - params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); + params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } // suppress non-speech tokens @@ -3310,6 +3503,7 @@ static void whisper_process_logits( static whisper_token_data whisper_sample_token( whisper_context & ctx, + whisper_state & state, const whisper_decoder & decoder, bool best) { whisper_token_data result = { @@ -3354,7 +3548,7 @@ static whisper_token_data whisper_sample_token( } else { std::discrete_distribution<> dist(probs.begin(), probs.end()); - result.id = dist(ctx.rng); + result.id = dist(state.rng); result.p = probs[result.id]; result.plog = logprobs[result.id]; } @@ -3364,13 +3558,14 @@ static whisper_token_data whisper_sample_token( result.pt = result.p; } - ctx.n_sample++; + state.n_sample++; return result; } static std::vector whisper_sample_token_topk( whisper_context & ctx, + whisper_state & state, const whisper_decoder & decoder, int k) { const auto & vocab = ctx.vocab; @@ -3381,7 +3576,7 @@ static std::vector whisper_sample_token_topk( const int n_logits = vocab.n_vocab; - auto & logits_id = ctx.logits_id; + auto & logits_id = state.logits_id; logits_id.clear(); for (int i = 0; i < n_logits; ++i) { @@ -3434,7 +3629,7 @@ static std::vector whisper_sample_token_topk( } } - ctx.n_sample++; + state.n_sample++; return result; } @@ -3488,24 +3683,25 @@ static void whisper_sequence_score( } } -int whisper_full( +int whisper_full_with_state( struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples) { + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples) { // clear old results - auto & result_all = ctx->result_all; + auto & result_all = state->result_all; result_all.clear(); // compute log mel spectrogram if (params.speed_up) { - if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); return -1; } } else { - if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); return -2; } @@ -3515,26 +3711,26 @@ int whisper_full( if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { std::vector probs(whisper_lang_max_id() + 1, 0.0f); - const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); if (lang_id < 0) { fprintf(stderr, "%s: failed to auto-detect language\n", __func__); return -3; } - ctx->lang_id = lang_id; + state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); } if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; - ctx->tid_last = 0; - ctx->energy = get_signal_energy(samples, n_samples, 32); + state->t_beg = 0; + state->t_last = 0; + state->tid_last = 0; + state->energy = get_signal_energy(samples, n_samples, 32); } const int seek_start = params.offset_ms/10; - const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10); // if length of spectrogram is less than 1s (100 samples), then return // basically don't process anything that is less than 1s @@ -3572,10 +3768,10 @@ int whisper_full( // TAGS: WHISPER_DECODER_INIT for (int j = 1; j < n_decoders; j++) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.kv_self.ctx == nullptr) { - decoder.kv_self = ctx->decoders[0].kv_self; + decoder.kv_self = state->decoders[0].kv_self; if (!kv_cache_reinit(decoder.kv_self)) { fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; @@ -3583,7 +3779,7 @@ int whisper_full( WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); - decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); @@ -3592,7 +3788,7 @@ int whisper_full( } // the accumulated text context so far - auto & prompt_past = ctx->prompt_past; + auto & prompt_past = state->prompt_past; if (params.no_context) { prompt_past.clear(); } @@ -3611,13 +3807,13 @@ int whisper_full( fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } - ctx->exp_n_audio_ctx = params.audio_ctx; + state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); - ctx->lang_id = lang_id; + state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { prompt_init.push_back(whisper_token_translate()); @@ -3669,14 +3865,14 @@ int whisper_full( } if (params.encoder_begin_callback) { - if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek - if (!whisper_encode(*ctx, seek, params.n_threads)) { + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) { fprintf(stderr, "%s: failed to encode\n", __func__); return -6; } @@ -3717,7 +3913,7 @@ int whisper_full( // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; decoder.kv_self.n = 0; @@ -3759,7 +3955,7 @@ int whisper_full( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } @@ -3767,24 +3963,24 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); + whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); - ctx->decoders[0].kv_self.n += prompt.size(); + state->decoders[0].kv_self.n += prompt.size(); for (int j = 1; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; - memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); decoder.kv_self.n += prompt.size(); - memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); - memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } @@ -3795,7 +3991,7 @@ int whisper_full( if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { kv_bufs.resize(n_decoders_cur); for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3813,7 +4009,7 @@ int whisper_full( // generate new sequence candidates for each decoder for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3823,16 +4019,16 @@ int whisper_full( case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: { if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true)); } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false)); } decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: { - const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); for (const auto & token : tokens_new) { beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); @@ -3857,7 +4053,7 @@ int whisper_full( uint32_t cur_c = 0; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3886,7 +4082,7 @@ int whisper_full( // - check if the sequence is failed // - update sliding window based on timestamp tokens for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3968,7 +4164,7 @@ int whisper_full( bool completed_all = true; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.completed || decoder.failed) { continue; @@ -3982,11 +4178,11 @@ int whisper_full( } } - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; // obtain logits for the next token for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.failed || decoder.completed) { continue; @@ -3997,7 +4193,7 @@ int whisper_full( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -8; } @@ -4005,11 +4201,11 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, params, decoder, t_cur); + whisper_process_logits(*ctx, *state, params, decoder, t_cur); ++decoder.kv_self.n; - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } } @@ -4019,7 +4215,7 @@ int whisper_full( double best_score = -INFINITY; for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + auto & decoder = state->decoders[j]; if (decoder.failed) { continue; @@ -4036,7 +4232,7 @@ int whisper_full( __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; - ctx->n_fail_h++; + state->n_fail_h++; continue; } @@ -4054,11 +4250,11 @@ int whisper_full( { bool success = true; - const auto & decoder = ctx->decoders[best_decoder_id]; + const auto & decoder = state->decoders[best_decoder_id]; if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { success = false; - ctx->n_fail_p++; + state->n_fail_p++; } if (success) { @@ -4075,7 +4271,7 @@ int whisper_full( // output results through a user-provided callback { - const auto & best_decoder = ctx->decoders[best_decoder_id]; + const auto & best_decoder = state->decoders[best_decoder_id]; const auto seek_delta = best_decoder.seek_delta; const auto result_len = best_decoder.sequence.result_len; @@ -4138,14 +4334,14 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } text = ""; @@ -4182,14 +4378,14 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } } @@ -4204,6 +4400,15 @@ int whisper_full( return 0; } + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + int whisper_full_parallel( struct whisper_context * ctx, struct whisper_full_params params, @@ -4213,40 +4418,10 @@ int whisper_full_parallel( if (n_processors == 1) { return whisper_full(ctx, params, samples, n_samples); } - int ret = 0; - // prepare separate contexts for each thread - std::vector ctxs(n_processors - 1); - - for (int i = 0; i < n_processors - 1; ++i) { - auto & ctx_p = ctxs[i]; - - ctx_p = *ctx; - - ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - - ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); - - if (!kv_cache_reinit(ctx_p.kv_cross)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); - return false; - } - - // TAGS: WHISPER_DECODER_INIT - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); - return false; - } - - ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); - - ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); - } - } + // prepare separate states for each thread + std::vector states; const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; @@ -4256,6 +4431,9 @@ int whisper_full_parallel( std::vector workers(n_processors - 1); for (int i = 0; i < n_processors - 1; ++i) { + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; @@ -4268,13 +4446,17 @@ int whisper_full_parallel( params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback_user_data = nullptr; - workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); } { auto params_cur = params; - ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); } for (int i = 0; i < n_processors - 1; ++i) { @@ -4283,45 +4465,43 @@ int whisper_full_parallel( const int64_t offset_t = (int64_t) params.offset_ms/10.0; - // combine results into ctx->result_all + // combine results into result_state->result_all from all other states for (int i = 0; i < n_processors - 1; ++i) { - auto & results_i = ctxs[i].result_all; + auto& results_i = states[i]->result_all; - for (auto & result : results_i) { + for (auto& result : results_i) { // correct the segment timestamp taking into account the offset - result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; - result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + // make sure that segments are not overlapping - if (!ctx->result_all.empty()) { - result.t0 = std::max(result.t0, ctx->result_all.back().t1); + if (!ctx->state->result_all.empty()) { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); } - ctx->result_all.push_back(std::move(result)); + ctx->state->result_all.push_back(std::move(result)); // call the new_segment_callback for each segment if (params.new_segment_callback) { - params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); } } - ctx->t_mel_us += ctxs[i].t_mel_us; - ctx->t_sample_us += ctxs[i].t_sample_us; - ctx->t_encode_us += ctxs[i].t_encode_us; - ctx->t_decode_us += ctxs[i].t_decode_us; + ctx->state->t_mel_us += states[i]->t_mel_us; - kv_cache_free(ctx->kv_cross); + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - kv_cache_free(ctx->decoders[j].kv_self); - } + whisper_free_state(states[i]); } // average the timings - ctx->t_mel_us /= n_processors; - ctx->t_sample_us /= n_processors; - ctx->t_encode_us /= n_processors; - ctx->t_decode_us /= n_processors; + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; // print information about the audio boundaries fprintf(stderr, "\n"); @@ -4334,44 +4514,84 @@ int whisper_full_parallel( return ret; } +int whisper_full_n_segments_from_state(struct whisper_state * state) { + return state->result_all.size(); +} + int whisper_full_n_segments(struct whisper_context * ctx) { - return ctx->result_all.size(); + return ctx->state->result_all.size(); +} + +int whisper_full_lang_id_from_state(struct whisper_state * state) { + return state->lang_id; } int whisper_full_lang_id(struct whisper_context * ctx) { - return ctx->lang_id; + return ctx->state->lang_id; +} + +int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t0; } int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t0; + return ctx->state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t1; } int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t1; + return ctx->state->result_all[i_segment].t1; +} + +const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); } const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].text.c_str(); + return ctx->state->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); } int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].tokens.size(); + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); } -const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; } whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].id; + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; } struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token]; + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; } float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].p; + return ctx->state->result_all[i_segment].tokens[i_token].p; } // ================================================================================================= @@ -4382,6 +4602,15 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int // WHISPER_API int whisper_bench_memcpy(int n_threads) { + fputs(whisper_bench_memcpy_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { + static std::string s; + s = ""; + char strbuf[256]; + ggml_time_init(); size_t n = 50; @@ -4411,7 +4640,8 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) { src[0] = rand(); } - fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); + snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); + s += strbuf; // needed to prevent the compile from optimizing the memcpy away { @@ -4419,16 +4649,26 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) { for (size_t i = 0; i < size; i++) sum += dst[i]; - fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); + snprintf(strbuf, sizeof(strbuf), "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); + s += strbuf; } free(src); free(dst); - return 0; + return s.c_str(); } WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { + fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { + static std::string s; + s = ""; + char strbuf[256]; + ggml_time_init(); const int n_max = 128; @@ -4504,11 +4744,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { s = ((2.0*N*N*N*n)/tsum)*1e-9; } - fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", + snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", N, N, s_fp16, n_fp16, s_fp32, n_fp32); + s += strbuf; } - return 0; + return s.c_str(); } // ================================================================================================= @@ -4583,13 +4824,14 @@ static std::vector get_signal_energy(const float * signal, int n_samples, static void whisper_exp_compute_token_level_timestamps( struct whisper_context & ctx, + struct whisper_state & state, int i_segment, float thold_pt, float thold_ptsum) { - auto & segment = ctx.result_all[i_segment]; + auto & segment = state.result_all[i_segment]; auto & tokens = segment.tokens; - const int n_samples = ctx.energy.size(); + const int n_samples = state.energy.size(); if (n_samples == 0) { fprintf(stderr, "%s: no signal data available\n", __func__); @@ -4612,9 +4854,9 @@ static void whisper_exp_compute_token_level_timestamps( return; } - auto & t_beg = ctx.t_beg; - auto & t_last = ctx.t_last; - auto & tid_last = ctx.tid_last; + auto & t_beg = state.t_beg; + auto & t_last = state.t_last; + auto & tid_last = state.tid_last; for (int j = 0; j < n; ++j) { auto & token = tokens[j]; @@ -4737,15 +4979,15 @@ static void whisper_exp_compute_token_level_timestamps( float sum = 0.0f; for (int k = ss0; k < ss1; k++) { - sum += ctx.energy[k]; + sum += state.energy[k]; } const float thold = 0.5*sum/ns; { int k = s0; - if (ctx.energy[k] > thold && j > 0) { - while (k > 0 && ctx.energy[k] > thold) { + if (state.energy[k] > thold && j > 0) { + while (k > 0 && state.energy[k] > thold) { k--; } tokens[j].t0 = sample_to_timestamp(k); @@ -4755,7 +4997,7 @@ static void whisper_exp_compute_token_level_timestamps( s0 = k; } } else { - while (ctx.energy[k] < thold && k < s1) { + while (state.energy[k] < thold && k < s1) { k++; } s0 = k; @@ -4765,8 +5007,8 @@ static void whisper_exp_compute_token_level_timestamps( { int k = s1; - if (ctx.energy[k] > thold) { - while (k < n_samples - 1 && ctx.energy[k] > thold) { + if (state.energy[k] > thold) { + while (k < n_samples - 1 && state.energy[k] > thold) { k++; } tokens[j].t1 = sample_to_timestamp(k); @@ -4776,7 +5018,7 @@ static void whisper_exp_compute_token_level_timestamps( s1 = k; } } else { - while (ctx.energy[k] < thold && k > s0) { + while (state.energy[k] < thold && k > s0) { k--; } s1 = k; diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 3eb8d0842..fc107108a 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -66,6 +66,7 @@ extern "C" { // struct whisper_context; + struct whisper_state; typedef int whisper_token; @@ -101,11 +102,20 @@ extern "C" { WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); - // Frees all memory allocated by the model. - WHISPER_API void whisper_free(struct whisper_context * ctx); + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) + WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size); + WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader); + + WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); + + // Frees all allocated memory + WHISPER_API void whisper_free (struct whisper_context * ctx); + WHISPER_API void whisper_free_state(struct whisper_state * state); // Convert RAW PCM audio to log mel spectrogram. - // The resulting spectrogram is stored inside the provided whisper context. + // The resulting spectrogram is stored inside the default state of the provided whisper context. // Returns 0 on success WHISPER_API int whisper_pcm_to_mel( struct whisper_context * ctx, @@ -113,17 +123,30 @@ extern "C" { int n_samples, int n_threads); - // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. - // The resulting spectrogram is stored inside the provided whisper context. + WHISPER_API int whisper_pcm_to_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + // The resulting spectrogram is stored inside the default state of the provided whisper context. // Returns 0 on success WHISPER_API int whisper_pcm_to_mel_phase_vocoder( - struct whisper_context* ctx, - const float* samples, - int n_samples, - int n_threads); - - - // This can be used to set a custom log mel spectrogram inside the provided whisper context. + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 // Returns 0 on success @@ -133,7 +156,14 @@ extern "C" { int n_len, int n_mel); - // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. + WHISPER_API int whisper_set_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // offset can be used to specify the offset of the first frame in the spectrogram. // Returns 0 on success @@ -142,6 +172,12 @@ extern "C" { int offset, int n_threads); + WHISPER_API int whisper_encode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset, + int n_threads); + // Run the Whisper decoder to obtain the logits and probabilities for the next token. // Make sure to call whisper_encode() first. // tokens + n_tokens is the provided context for the decoder. @@ -155,6 +191,14 @@ extern "C" { int n_past, int n_threads); + WHISPER_API int whisper_decode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens @@ -190,20 +234,44 @@ extern "C" { int n_threads, float * lang_probs); - WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length - WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); - WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); + WHISPER_API int whisper_lang_auto_detect_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx); + + WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx); + WHISPER_API int whisper_model_f16 (struct whisper_context * ctx); + WHISPER_API int whisper_model_type (struct whisper_context * ctx); // Token logits obtained from the last call to whisper_decode() // The logits for the last token are stored in the last row // Rows: n_tokens // Cols: n_vocab - WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits (struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state); // Token Id -> String. Uses the vocabulary in the provided context WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx); + // Special tokens WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); @@ -218,7 +286,7 @@ extern "C" { WHISPER_API whisper_token whisper_token_translate (void); WHISPER_API whisper_token whisper_token_transcribe(void); - // Performance information + // Performance information from the default state. WHISPER_API void whisper_print_timings(struct whisper_context * ctx); WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); @@ -236,18 +304,19 @@ extern "C" { // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted - typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); // Logits filter callback // Can be used to modify the logits before sampling // If not NULL, called after applying temperature to logits typedef void (*whisper_logits_filter_callback)( struct whisper_context * ctx, + struct whisper_state * state, const whisper_token_data * tokens, int n_tokens, float * logits, @@ -334,6 +403,7 @@ extern "C" { WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context // Uses the specified decoding strategy to obtain the text. WHISPER_API int whisper_full( struct whisper_context * ctx, @@ -341,7 +411,16 @@ extern "C" { const float * samples, int n_samples); - // Split the input audio in chunks and process each chunk separately using whisper_full() + WHISPER_API int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() + // Result is stored in the default state of the context + // Not thread safe if executed in parallel on the same context. // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. WHISPER_API int whisper_full_parallel( @@ -351,40 +430,56 @@ extern "C" { int n_samples, int n_processors); - // Number of generated text segments. + // Number of generated text segments // A segment can be a few words, a sentence, or even a paragraph. - WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state); - // Language id associated with the current context + // Language id associated with the context's default state WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); - // Get the start and end time of the specified segment. - WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); + // Language id associated with the provided state + WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); + + // Get the start and end time of the specified segment + WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment); + + WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + + // Get the text of the specified segment + WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); + WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); - // Get the text of the specified segment. - WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + // Get number of tokens in the specified segment + WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment); + WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment); - // Get number of tokens in the specified segment. - WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + // Get the token text of the specified token in the specified segment + WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token); - // Get the token text of the specified token in the specified segment. - WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token); - // Get token data for the specified token in the specified segment. + // Get token data for the specified token in the specified segment // This contains probabilities, timestamps, etc. - WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token); - // Get the probability of the specified token in the specified segment. - WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + // Get the probability of the specified token in the specified segment + WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); //////////////////////////////////////////////////////////////////////////// // Temporary helpers needed for exposing ggml interface WHISPER_API int whisper_bench_memcpy(int n_threads); + WHISPER_API const char * whisper_bench_memcpy_str(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); #ifdef __cplusplus } diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 18f317bec..335230f9f 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -198,6 +198,8 @@ struct ggml_object; struct ggml_context; enum ggml_type { + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -226,7 +228,9 @@ enum ggml_op { GGML_OP_STEP, GGML_OP_RELU, GGML_OP_GELU, + GGML_OP_SILU, GGML_OP_NORM, // normalize + GGML_OP_RMS_NORM, GGML_OP_MUL_MAT, @@ -326,7 +330,10 @@ void ggml_print_objects(const struct ggml_context * ctx); int ggml_nelements(const struct ggml_tensor * tensor); size_t ggml_nbytes (const struct ggml_tensor * tensor); -size_t ggml_type_size (enum ggml_type type); +int ggml_blck_size (enum ggml_type type); +size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block +float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float + size_t ggml_element_size(const struct ggml_tensor * tensor); struct ggml_context * ggml_init(struct ggml_init_params params); @@ -336,6 +343,9 @@ size_t ggml_used_mem(const struct ggml_context * ctx); size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); +bool ggml_mlock_supported(void); +bool ggml_mlock(struct ggml_context * ctx, char ** err_p); + struct ggml_tensor * ggml_new_tensor( struct ggml_context * ctx, enum ggml_type type, @@ -466,12 +476,20 @@ struct ggml_tensor * ggml_gelu( struct ggml_context * ctx, struct ggml_tensor * a); +struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows // TODO: eps is hardcoded to 1e-5 for now struct ggml_tensor * ggml_norm( struct ggml_context * ctx, struct ggml_tensor * a); +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a); + // A: m rows, n columns // B: p rows, n columns (i.e. we transpose it internally) // result is m columns, p rows @@ -726,6 +744,13 @@ enum ggml_opt_result ggml_opt( struct ggml_opt_params params, struct ggml_tensor * f); +// +// quantization +// + +size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); + // // system info // diff --git a/src/ggml.c b/src/ggml.c index df2235f3c..02675ee67 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1,18 +1,23 @@ +// Defines CLOCK_MONOTONIC and asprintf on Linux +#define _GNU_SOURCE + #include "ggml.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW -#elif !defined(__FreeBSD__) +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) #include #endif #include +#include #include #include #include #include #include #include +#include // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 @@ -27,12 +32,8 @@ #else // ref: https://github.com/ggerganov/whisper.cpp/issues/168 #include -#include #endif -// Need this to compile with Visual Studio 2017 -#define restrict __restrict - typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; @@ -78,13 +79,38 @@ static int sched_yield (void) { typedef void* thread_ret_t; #endif +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#ifndef __SSE3__ +#define __SSE3__ +#endif +#endif + #ifdef __HAIKU__ #define static_assert(cond, msg) _Static_assert(cond, msg) #endif +#define GGML_MLOCK_SUPPORT 0 + +#ifdef __has_include + #if __has_include() + #undef GGML_MLOCK_SUPPORT + #define GGML_MLOCK_SUPPORT 1 + #include + #endif +#endif + + /*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 +#define GGML_SILU_FP16 #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 @@ -137,10 +163,10 @@ typedef double ggml_float; // #include -#define GGML_COMPUTE_FP16_TO_FP32(x) (x) +#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) #define GGML_COMPUTE_FP32_TO_FP16(x) (x) -#define GGML_FP16_TO_FP32(x) (x) +#define GGML_FP16_TO_FP32(x) ((float) (x)) #define GGML_FP32_TO_FP16(x) (x) #else @@ -159,8 +185,46 @@ typedef double ggml_float; #ifdef __F16C__ +#ifdef _MSC_VER +#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) +#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) +#else #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) +#endif + +#elif defined(__POWER9_VECTOR__) + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) +/* the inline asm below is about 12% faster than the lookup method */ +#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + register double d; + register ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; +} #else @@ -248,6 +312,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { // precomputed gelu table for f16 (128 KB) static ggml_fp16_t table_gelu_f16[1 << 16]; +// precomputed silu table for f16 (128 KB) +static ggml_fp16_t table_silu_f16[1 << 16]; + // precomputed exp table for f16 (128 KB) static ggml_fp16_t table_exp_f16[1 << 16]; @@ -256,6 +323,7 @@ static float table_f32_f16[1 << 16]; // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { @@ -272,7 +340,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { // note: do not use these inside ggml.c // these are meant to be used via the ggml.h API float ggml_fp16_to_fp32(ggml_fp16_t x) { - return GGML_FP16_TO_FP32(x); + return (float) GGML_FP16_TO_FP32(x); } ggml_fp16_t ggml_fp32_to_fp16(float x) { @@ -351,6 +419,673 @@ int64_t ggml_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +// +// quantization +// + +#define QK 32 + +// AVX routines provided by GH user Const-me +// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 +#if __AVX2__ || __AVX512F__ +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytesFromNibbles( const uint8_t* rsi ) +{ + // Load 16 bytes from memory + __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi ); + + // Expand bytes into uint16_t values + __m256i bytes = _mm256_cvtepu8_epi16( tmp ); + + // Unpack values into individual bytes + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + __m256i high = _mm256_andnot_si256( lowMask, bytes ); + __m256i low = _mm256_and_si256( lowMask, bytes ); + high = _mm256_slli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + return bytes; +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +} +#endif + +// method 5 +// blocks of QK elements +// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) +typedef struct { + float d; // delta + uint8_t qs[QK / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding"); + +// method 4 +// blocks of QK elements +// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) +typedef struct { + float d; + float m; + uint8_t qs[QK / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); + +// reference implementation for deterministic creation of model files +static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + for (int l = 0; l < QK; l += 2) { + const float v0 = x[i*QK + l + 0]*id; + const float v1 = x[i*QK + l + 1]*id; + + const uint8_t vi0 = (int8_t)roundf(v0) + 8; + const uint8_t vi1 = (int8_t)roundf(v1) + 8; + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(y[i].qs, pp, sizeof(pp)); + } +} + +static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + block_q4_0 * restrict y = vy; + +#if defined(__POWER9_VECTOR__) + const vector float v85 = vec_splats(8.5f); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]); + //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]); + amaxv[0] = vec_max(amaxv[0], amaxv[2]); + amaxv[4] = vec_max(amaxv[4], amaxv[6]); + //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]); + amaxv[0] = vec_max(amaxv[0], amaxv[4]); + + amax = MAX( + MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + y[i].d = d; + + const vector float vid = vec_splats(id); + uint8_t * restrict pb = y[i].qs; + for (int l = 0; l < 8; l++) { + const vector float vf = vec_madd(srcv[l], vid, v85); + const vector signed int vi = vec_signed(vf); + + pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4); + pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4); + } + } +#elif __ARM_NEON + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + // absolute max + const float amax = MAX( + MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), + MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); + const int32x4_t vi = vcvtq_s32_f32(vf); + + y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + } +#elif defined(__AVX2__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 7.0f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] + const __m256i off = _mm256_set1_epi8( 8 ); + i0 = _mm256_add_epi8( i0, off ); + + // Compress the vector into 4 bit/value, and store + __m128i res = packNibbles( i0 ); + _mm_storeu_si128( ( __m128i* )y[i].qs, res ); + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]); + + amax = MAX( + MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + y[i].d = d; + + for (int l = 0; l < 8; l++) { + const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); + const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf); + + y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4); + y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); + } + } +#else + // scalar + quantize_row_q4_0_reference(x, y, k); +#endif +} + +static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + block_q4_1 * restrict y = vy; + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + y[i].m = min; + + for (int l = 0; l < QK; l += 2) { + const float v0 = (x[i*QK + l + 0] - min)*id; + const float v1 = (x[i*QK + l + 1] - min)*id; + + const uint8_t vi0 = roundf(v0); + const uint8_t vi1 = roundf(v1); + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(y[i].qs, pp, sizeof(pp)); + } +} + +static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + block_q4_1 * restrict y = vy; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max for the block + __m256 vmax; + vmax = _mm256_max_ps( v0, v1 ); + vmax = _mm256_max_ps( vmax, v2 ); + vmax = _mm256_max_ps( vmax, v3 ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Compute min for the block + __m256 vmin; + vmin = _mm256_min_ps( v0, v1 ); + vmin = _mm256_min_ps( vmin, v2 ); + vmin = _mm256_min_ps( vmin, v3 ); + + __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) ); + min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) ); + min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) ); + const float minScalar = _mm_cvtss_f32( min4 ); + + // Quantize these floats + const float d = (maxScalar - minScalar) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].m = minScalar; + y[i].d = d; + + // x = (x-min)*id + const __m256 mul = _mm256_set1_ps( id ); + const __m256 off = _mm256_set1_ps( minScalar ); + v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul ); + v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul ); + v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul ); + v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + // Compress the vector into 4 bit/value, and store + __m128i res = packNibbles( i0 ); + _mm_storeu_si128( ( __m128i* )y[i].qs, res ); + } +#elif __ARM_NEON + for (int i = 0; i < nb; i++) { + float32x4_t srcv[8]; + float32x4_t minv[8]; + float32x4_t maxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + + for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]); + for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]); + for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]); + + for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]); + for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]); + for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]); + + const float min = vminvq_f32(minv[0]); + const float max = vmaxvq_f32(maxv[0]); + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + y[i].m = min; + + const float32x4_t minv0 = vdupq_n_f32(min); + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id); + const int32x4_t vi = vcvtq_s32_f32(v); + + y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + } +#else + // scalar + quantize_row_q4_1_reference(x, vy, k); +#endif +} + +static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + const block_q4_0 * restrict x = vx; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + // scale factor + const __m256 d_v = _mm256_broadcast_ss(&x[i].d); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 32) { + // Load 32x4-bit integers into 32x8-bit integers + __m256i vx8 = bytesFromNibbles(pp+l/2); + + // Subtract 8 from the integers + vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8)); + + // Convert to 16-bit int + const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); + const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); + + // Convert to 32-bit int -> float 32 + const __m256 vf[4] = { + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) + }; + + // Scale and store + for (int j = 0; j < 4; j++) { + const __m256 result = _mm256_mul_ps(vf[j], d_v); + _mm256_storeu_ps(y + i * QK + l + j*8, result); + } + } + } +#elif defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + const float32x4_t vd = vdupq_n_f32(x[i].d); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 16) { + // Load 16x4-bit integers into 8x8-bit integers + const uint8x8_t v8 = vld1_u8(pp + l/2); + + // Expand 4-bit qs to 8-bit bytes + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v1 = vshr_n_u8(v8, 4); + + // Convert to signed 8-bit integers + const int8x8_t vs_0 = vreinterpret_s8_u8(v0); + const int8x8_t vs_1 = vreinterpret_s8_u8(v1); + + // Subtract 8 from each byte + const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8)); + const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8)); + + // Interleave and combine + const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1); + const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1); + + const int8x16_t vq = vcombine_s8(vx_0, vx_1); + + // convert to 2x int16x8_t + const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq)); + const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq)); + + // convert to 4x float32x4_t + const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0))); + const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0))); + const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1))); + const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1))); + + // Multiply by d + const float32x4_t r0 = vmulq_f32(vf_0, vd); + const float32x4_t r1 = vmulq_f32(vf_1, vd); + const float32x4_t r2 = vmulq_f32(vf_2, vd); + const float32x4_t r3 = vmulq_f32(vf_3, vd); + + // Store + vst1q_f32(y + i*QK + l + 0, r0); + vst1q_f32(y + i*QK + l + 4, r1); + vst1q_f32(y + i*QK + l + 8, r2); + vst1q_f32(y + i*QK + l + 12, r3); + } + } +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d = x[i].d; + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1); + + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + } + } +#endif +} + +static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + const block_q4_1 * restrict x = vx; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + const __m256 d_v = _mm256_broadcast_ss(&x[i].d); + const __m256 d_m = _mm256_broadcast_ss(&x[i].m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 32) { + // Load 32x4-bit integers into 32x8-bit integers + __m256i vx8 = bytesFromNibbles(pp+l/2); + + // Convert to 16-bit int + const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); + const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); + + // Convert to 32-bit int -> float 32 + const __m256 vf[4] = { + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) + }; + + // Scale, add m and store + for (int j = 0; j < 4; j++) { + const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); + _mm256_storeu_ps(y + i * QK + l + j*8, result); + } + } + } +#elif defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + const float32x4_t vd = vdupq_n_f32(x[i].d); + const float32x4_t vm = vdupq_n_f32(x[i].m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 16) { + // Load 16x4-bit integers into 8x8-bit integers + const uint8x8_t v8 = vld1_u8(pp + l/2); + + // Expand 4-bit qs to 8-bit bytes + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v1 = vshr_n_u8(v8, 4); + + // Interleave and combine + const uint8x8_t vx_0 = vzip1_u8(v0, v1); + const uint8x8_t vx_1 = vzip2_u8(v0, v1); + + const uint8x16_t vq = vcombine_u8(vx_0, vx_1); + + // convert to 2x uint16x8_t + const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); + const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); + + // convert to 4x float32x4_t + const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); + const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); + const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); + const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); + + // multiply by d and add m + const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); + const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); + const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); + const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); + + // Store + vst1q_f32(y + i*QK + l + 0, r0); + vst1q_f32(y + i*QK + l + 4, r1); + vst1q_f32(y + i*QK + l + 8, r2); + vst1q_f32(y + i*QK + l + 12, r3); + } + } +#else + for (int i = 0; i < nb; i++) { + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + } + } +#endif +} + // // simd mappings // @@ -443,7 +1178,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); } \ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ - res = vaddvq_f32(vaddq_f32(t0, t1)); \ + res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ } #define GGML_F16_VEC GGML_F16x8 @@ -538,13 +1273,36 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); #define GGML_F16_EPR 8 // F16 arithmetic is not supported by AVX, so we use F32 instead -// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32 #define GGML_F32Cx8 __m256 #define GGML_F32Cx8_ZERO _mm256_setzero_ps() #define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) + +#if defined(__F16C__) +// the _mm256_cvt intrinsics require F16C #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#else +static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) + tmp[i] = GGML_FP16_TO_FP32(x[i]); + + return _mm256_loadu_ps(tmp); +} +static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { + float arr[8]; + + _mm256_storeu_ps(arr, y); + + for (int i = 0; i < 8; i++) + x[i] = GGML_FP16_TO_FP32(arr[i]); +} +#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) +#endif + #define GGML_F32Cx8_FMA GGML_F32x8_FMA #define GGML_F32Cx8_ADD _mm256_add_ps #define GGML_F32Cx8_MUL _mm256_mul_ps @@ -856,9 +1614,8 @@ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, co inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { - ggml_float sumf = 0.0; - #ifdef GGML_SIMD + float sumf = 0.0f; const int np = (n & ~(GGML_F32_STEP - 1)); GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; @@ -884,14 +1641,46 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float } #else // scalar + ggml_float sumf = 0.0; for (int i = 0; i < n; ++i) { - sumf += x[i]*y[i]; + sumf += (ggml_float)(x[i]*y[i]); } #endif *s = sumf; } +#if __AVX512F__ && QK == 32 +static inline __m512 dot_q4_0_oneblock_avx512( + __m512 acc, + const block_q4_0 * restrict x, + const block_q4_0 * restrict y, + int i +) { + // Compute combined scale for the block + __m512 d = _mm512_set1_ps( x[i].d * y[i].d ); + + __m256i bx = bytesFromNibbles( x[i].qs ); + __m256i by = bytesFromNibbles( y[i].qs ); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + by = _mm256_sub_epi8( by, off ); + + // Sign-extend 16 signed bytes into int16_t + __m512i x32 = _mm512_cvtepi8_epi16( bx ); + __m512i y32 = _mm512_cvtepi8_epi16( by ); + // Compute products of int16_t integers, add pairwise + __m512i i64 = _mm512_madd_epi16( x32, y32 ); + + // Convert int32_t to float + __m512 p = _mm512_cvtepi32_ps( i64 ); + // Apply the scale, and accumulate + return _mm512_fmadd_ps( d, p, acc ); +} +#endif + inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { ggml_float sumf = 0.0; @@ -917,130 +1706,531 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t // leftovers for (int i = np; i < n; ++i) { - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); } #else for (int i = 0; i < n; ++i) { - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); } #endif *s = sumf; } -// compute GGML_VEC_DOT_UNROLL dot products at once -// xs - x row stride in bytes -inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { - ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; +static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK; - ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; + assert(n % QK == 0); + assert(nb % 2 == 0); - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); - } + const block_q4_0 * restrict x = vx; + const block_q4_0 * restrict y = vy; -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + ggml_float sumf = 0.0; - GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; +#if defined(__ARM_NEON) + float sum0 = 0.0f; + float sum1 = 0.0f; - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict y0 = &y[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q4_0 * restrict y1 = &y[i + 1]; - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v1_0 = vld1q_u8(y0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + const uint8x16_t v1_1 = vld1q_u8(y1->qs); - sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); - } - } - } + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); + const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b)); - // reduce sum0..sum3 to sum0 - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - GGML_F16_VEC_REDUCE(sumf[k], sum[k]); - } + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4)); - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); - } - } + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); + const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b)); + + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); + + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int16x8_t + int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls); + int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls); + + p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs); + p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs); + + // scalar +#if defined(__ARM_FEATURE_QRDMX) + sum0 += x0->d * y0->d * vaddvq_s32(p_0); + sum1 += x1->d * y1->d * vaddvq_s32(p_1); #else - for (int i = 0; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); - } - } + sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3)); + sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3)); #endif +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - s[i] = sumf[i]; - } -} + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); -inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; + const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); + const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } + const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); + const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; + // scalar +#if defined(__ARM_FEATURE_QRDMX) + sum0 += x0->d * y0->d * vaddvq_s16(p_0); + sum1 += x1->d * y1->d * vaddvq_s16(p_1); +#else + sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); + sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); +#endif +#endif + } + + sumf = (ggml_float)(sum0 + sum1); +#elif defined(__AVX512F__) + // Initialize accumulator with zeros + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + + const int superblock_size = 8; + const int superblock_count = nb / superblock_size; + const int remainder = nb % superblock_size; + + for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) { + int i = superblock_ix * superblock_size; + + acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 ); + acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 ); + acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 ); + acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 ); + acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 ); + acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 ); + acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 ); + acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 ); + } + + // Remainders + for (int i = superblock_count * superblock_size; i < nb; ++i) { + acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i ); + } + + // Horizontal sum of all lanes of the accumulator + sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 ); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + __m256i bx = bytesFromNibbles( x[i].qs ); + __m256i by = bytesFromNibbles( y[i].qs ); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + by = _mm256_sub_epi8( by, off ); + + // Sign-extend first 16 signed bytes into int16_t + __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); + __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); + // Compute products of int16_t integers, add pairwise + __m256i i32 = _mm256_madd_epi16( x16, y16 ); + + // Sign-extend last 16 signed bytes into int16_t vectors + x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); + y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); + // Accumulate products of int16_t integers + i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) ); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps( i32 ); + // Apply the scale, and accumulate + acc = _mm256_fmadd_ps( d, p, acc ); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ); +#elif defined(__wasm_simd128__) + // wasm simd + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &px[i + 0]; + const block_q4_0 * restrict y0 = &py[i + 0]; + const block_q4_0 * restrict x1 = &px[i + 1]; + const block_q4_0 * restrict y1 = &py[i + 1]; + + const v128_t m4b = wasm_u8x16_splat(0xf); + const v128_t s8b = wasm_i8x16_splat(0x8); + + const v128_t v0_0 = wasm_v128_load(x0.qs); + const v128_t v0_1 = wasm_v128_load(y0.qs); + const v128_t v1_0 = wasm_v128_load(x1.qs); + const v128_t v1_1 = wasm_v128_load(y1.qs); + + // 4-bit -> 8-bit + const v128_t v0_0l = wasm_v128_and(v0_0, m4b); + const v128_t v1_0l = wasm_v128_and(v1_0, m4b); + + const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4); + + const v128_t v0_1l = wasm_v128_and(v0_1, m4b); + const v128_t v1_1l = wasm_v128_and(v1_1, m4b); + + const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4); + + // sub 8 + const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b); + + const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b); + + const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b); + + const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b); + + // dot product into int16x8_t + const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls)); + const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls)); + + const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs)); + const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs)); + + const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls)); + const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls)); + + const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs)); + const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs)); + + const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h); + const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h); + + const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h); + const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h); + + const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0); + const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1); + + sum0 += x0->d * y0->d * ( + wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) + + wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) + + wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) + + wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7)); + sum1 += x1->d * y1->d * ( + wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) + + wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) + + wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) + + wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7)); } + + sumf = sum0 + sum1; #else // scalar - for (int i = 0; i < n; ++i) { - y[i] += x[i]*v; + for (int i = 0; i < nb; i++) { + const float d0 = x[i].d; + const float d1 = y[i].d; + + const uint8_t * restrict p0 = x[i].qs; + const uint8_t * restrict p1 = y[i].qs; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + sumf += f0*f2 + f1*f3; + } + } +#endif + + *s = sumf; +} + +static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK; + + const block_q4_1 * restrict x = vx; + const block_q4_1 * restrict y = vy; + + float sumf = 0.0; + +#if defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + // Accumulator for constant offsets + float acc_offset = 0.0f; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float * d0 = &x[i].d; + const float * d1 = &y[i].d; + + const float * m0 = &x[i].m; + const float * m1 = &y[i].m; + + const __m256 d0v = _mm256_broadcast_ss( d0 ); + const __m256 d1v = _mm256_broadcast_ss( d1 ); + const __m256 m0v = _mm256_broadcast_ss( m0 ); + const __m256 m1v = _mm256_broadcast_ss( m1 ); + + // Compute combined scale for the block + const __m256 scale_01 = _mm256_mul_ps( d0v, d1v ); + + // Compute cross scales for the block + const __m256 scale_0 = _mm256_mul_ps( d0v, m1v ); + const __m256 scale_1 = _mm256_mul_ps( m0v, d1v ); + const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + __m256i bx = bytesFromNibbles( x[i].qs ); + __m256i by = bytesFromNibbles( y[i].qs ); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. + + // Sign-extend first 16 signed bytes into int16_t + __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); + __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); + // Compute products of int16_t integers, add pairwise + __m256i i32 = _mm256_madd_epi16( x16, y16 ); + + // Sign-extend last 16 signed bytes into int16_t vectors + __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); + __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); + // Accumulate products of int16_t integers + i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); + + // compute sums of unsigned bytes in bx, by in blocks of 8. + // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, + // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. + // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] + __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); + __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); + __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); + __m256 sums = _mm256_cvtepi32_ps( sumsi ); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps( i32 ); + // Apply the scale, and accumulate + // acc += d0*d1*x*y + d0*m1*x + d1*m0*y + acc = _mm256_fmadd_ps( scale_01, p, acc ); + acc = _mm256_fmadd_ps( cross_scales, sums, acc ); + // acc_offset += m0*m1 (for each entry in the block) + acc_offset += (*m0)*(*m1); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ) + acc_offset * QK; +#elif defined(__ARM_NEON) + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + + for (int i = 0; i < nb; ++i) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict y0 = &y[i + 0]; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v1_0 = vld1q_u8(y0->qs); + + // and with 0xf + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + + // dot product into uint16x8_t + const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); + const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); + + const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); + const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); + + const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); + const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); + + sum00 += x0->m*y0->m; + sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); + sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); + sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0)); + } + + sumf = QK*sum00 + sum01 + sum10 + sum11; +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = x[i].d; + const float d1 = y[i].d; + + const float m0 = x[i].m; + const float m1 = y[i].m; + + const uint8_t * restrict p0 = x[i].qs; + const uint8_t * restrict p1 = y[i].qs; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*(v0 & 0xf) + m0; + const float f1 = d0*(v0 >> 4) + m0; + + const float f2 = d1*(v1 & 0xf) + m1; + const float f3 = d1*(v1 >> 4) + m1; + + sumf += f0*f2 + f1*f3; + } } #endif + + *s = sumf; } -inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { +// compute GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + + ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); + } + #if defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; GGML_F16_VEC ax[GGML_F16_ARR]; GGML_F16_VEC ay[GGML_F16_ARR]; for (int i = 0; i < np; i += GGML_F16_STEP) { for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#endif + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } } // leftovers for (int i = np; i < n; ++i) { - GGML_ASSERT(false); - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + y[i] += x[i]*v; } #else + // scalar for (int i = 0; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + y[i] += x[i]*v; } #endif } @@ -1075,19 +2265,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } -inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } -inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -static const ggml_float GELU_COEF_A = 0.044715; -static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { - return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { @@ -1114,11 +2304,40 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { } #endif +// Sigmoid Linear Unit (SiLU) function +inline static float ggml_silu_f32(float x) { + return x/(1.0f + expf(-x)); +} + +inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_silu_f16[i16[i]]; + } +} + +#ifdef GGML_SILU_FP16 +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]); + } +} +#else +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]); + } +} +#endif + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; for (int i = 0; i < n; ++i) { - sum += x[i]; + sum += (ggml_float)x[i]; } *s = sum; #else @@ -1128,7 +2347,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE - ggml_float max = -INFINITY; + float max = -INFINITY; for (int i = 0; i < n; ++i) { max = MAX(max, x[i]); } @@ -1138,7 +2357,10 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #endif } -inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { + ggml_vec_norm_f32(n, s, x); + *s = 1.f/(*s); +} // // logging @@ -1168,7 +2390,21 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x // data types // +static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { + QK, + QK, + 1, + 1, + 1, + 1, + 1, +}; + +static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); + static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { + sizeof(block_q4_0), + sizeof(block_q4_1), sizeof(int8_t ), sizeof(int16_t), sizeof(int32_t), @@ -1176,6 +2412,9 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { sizeof(float ), }; +// don't forget to update the array above when adding new types +static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); + static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -1195,7 +2434,9 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "STEP", "RELU", "GELU", + "SILU", "NORM", + "RMS_NORM", "MUL_MAT", @@ -1216,6 +2457,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; +static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); + static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1235,7 +2478,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "step(x)", "relu(x)", "gelu(x)", + "silu(x)", "norm(x)", + "rms_norm(x)", "X*Y", @@ -1256,6 +2501,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; +static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); + // // ggml object // @@ -1282,6 +2529,7 @@ struct ggml_context { size_t mem_size; void * mem_buffer; bool mem_buffer_owned; + bool mem_buffer_mlocked; int n_objects; @@ -1383,13 +2631,21 @@ int ggml_nrows(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type]; + return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]; +} + +int ggml_blck_size(enum ggml_type type) { + return GGML_BLCK_SIZE[type]; } size_t ggml_type_size(enum ggml_type type) { return GGML_TYPE_SIZE[type]; } +float ggml_type_sizef(enum ggml_type type) { + return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type]; +} + size_t ggml_element_size(const struct ggml_tensor * tensor) { return GGML_TYPE_SIZE[tensor->type]; } @@ -1416,9 +2672,13 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return - (t0->ne[0] == t1->ne[0]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + (t0->ne[0] == t1->ne[0]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) { + return tensor->nb[0] > tensor->nb[1]; } static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { @@ -1426,7 +2686,7 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { return tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && - tensor->nb[1] == tensor->nb[0]*tensor->ne[0] && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_BLCK_SIZE[tensor->type] && tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } @@ -1477,7 +2737,7 @@ static inline int ggml_up(int n, int m) { // assert that pointer is aligned to GGML_MEM_ALIGN #define ggml_assert_aligned(ptr) \ - assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) + GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) //////////////////////////////////////////////////////////////////////////////// @@ -1491,7 +2751,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize time system (required on Windows) ggml_time_init(); - // initialize GELU, EXP and F32 tables + // initialize GELU, SILU and EXP F32 tables { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); @@ -1501,12 +2761,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { memcpy(&ii, &ui, sizeof(ii)); const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); - table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); + table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); + table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } // initialize g_state @@ -1551,16 +2812,19 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { } *ctx = (struct ggml_context) { - /*.mem_size =*/ params.mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), - /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, - /*.n_objects =*/ 0, - /*.objects_begin =*/ NULL, - /*.objects_end =*/ NULL, - /*.scratch =*/ { 0, 0, NULL, }, - /*.scratch_save =*/ { 0, 0, NULL, }, + /*.mem_size =*/ params.mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.mem_buffer_mlocked =*/ false, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, }; + GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure + ggml_assert_aligned(ctx->mem_buffer); GGML_PRINT_DEBUG("%s: context initialized\n", __func__); @@ -1583,6 +2847,14 @@ void ggml_free(struct ggml_context * ctx) { GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size); +#if GGML_MLOCK_SUPPORT + if (ctx->mem_buffer_mlocked) { + if (munlock(ctx->mem_buffer, ctx->mem_size)) { + fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno)); + } + } +#endif + if (ctx->mem_buffer_owned) { free(ctx->mem_buffer); } @@ -1611,6 +2883,37 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) return result; } +bool ggml_mlock_supported(void) { + return GGML_MLOCK_SUPPORT; +} + +#if GGML_MLOCK_SUPPORT +#ifdef __APPLE__ + #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l)." +#else + #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)." +#endif +bool ggml_mlock(struct ggml_context * ctx, char ** err_p) { + if (ctx->mem_buffer_mlocked) { + return true; + } + if (mlock(ctx->mem_buffer, ctx->mem_size)) { + int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION, + ctx->mem_size, strerror(errno)); + GGML_ASSERT(ret >= 0); + return false; + } + ctx->mem_buffer_mlocked = true; + return true; +} +#else // GGML_MLOCK_SUPPORT +bool ggml_mlock(struct ggml_context * ctx, char ** err_p) { + *err_p = strdup("can't mlock because it's not supported on this system"); + return false; +} +#endif // GGML_MLOCK_SUPPORT + //////////////////////////////////////////////////////////////////////////////// struct ggml_tensor * ggml_new_tensor_impl( @@ -1629,8 +2932,8 @@ struct ggml_tensor * ggml_new_tensor_impl( size_t size_needed = 0; if (data == NULL) { - size_needed += GGML_TYPE_SIZE[type]; - for (int i = 0; i < n_dims; i++) { + size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { size_needed *= ne[i]; } // align to GGML_MEM_ALIGN @@ -1723,7 +3026,8 @@ struct ggml_tensor * ggml_new_tensor_impl( } result->nb[0] = GGML_TYPE_SIZE[type]; - for (int i = 1; i < GGML_MAX_DIMS; i++) { + result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]); + for (int i = 2; i < GGML_MAX_DIMS; i++) { result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; } @@ -1820,6 +3124,14 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { char * const data = tensor->data; switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -1857,7 +3169,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { } break; case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } @@ -1872,6 +3184,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { char * const data = tensor->data; switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -1909,7 +3229,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { } break; case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } @@ -1918,6 +3238,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -1954,6 +3282,14 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -1988,6 +3324,14 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -2024,6 +3368,14 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { switch (tensor->type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(false); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -2114,7 +3466,7 @@ struct ggml_tensor * ggml_add_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2153,7 +3505,7 @@ struct ggml_tensor * ggml_sub_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2192,7 +3544,7 @@ struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2201,7 +3553,7 @@ struct ggml_tensor * ggml_mul_impl( } if (inplace) { - assert(is_node == false); + GGML_ASSERT(is_node == false); } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -2235,7 +3587,7 @@ struct ggml_tensor * ggml_div_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); bool is_node = false; @@ -2244,7 +3596,7 @@ struct ggml_tensor * ggml_div_impl( } if (inplace) { - assert(is_node == false); + GGML_ASSERT(is_node == false); } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -2368,7 +3720,7 @@ struct ggml_tensor * ggml_mean( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement + GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -2389,7 +3741,7 @@ struct ggml_tensor * ggml_repeat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_can_repeat(a, b)); + GGML_ASSERT(ggml_can_repeat(a, b)); bool is_node = false; @@ -2616,80 +3968,148 @@ struct ggml_tensor * ggml_gelu_inplace( return ggml_gelu_impl(ctx, a, true); } -// ggml_norm +// ggml_silu -struct ggml_tensor * ggml_norm_impl( +struct ggml_tensor * ggml_silu_impl( struct ggml_context * ctx, - struct ggml_tensor * a, + struct ggml_tensor * a, bool inplace) { bool is_node = false; if (!inplace && (a->grad)) { - assert(false); // TODO: implement backward is_node = true; } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_NORM; + result->op = GGML_OP_SILU; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; - result->src1 = NULL; // TODO: maybe store epsilon here? + result->src1 = NULL; return result; } -struct ggml_tensor * ggml_norm( +struct ggml_tensor * ggml_silu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, false); + return ggml_silu_impl(ctx, a, false); } -struct ggml_tensor * ggml_norm_inplace( +struct ggml_tensor * ggml_silu_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, true); + return ggml_silu_impl(ctx, a, true); } -// ggml_mul_mat +// ggml_norm -struct ggml_tensor * ggml_mul_mat( +struct ggml_tensor * ggml_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - assert(ggml_can_mul_mat(a, b)); - + bool inplace) { bool is_node = false; - if (a->grad || b->grad) { + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward is_node = true; } - const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_MUL_MAT; + result->op = GGML_OP_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; - result->src1 = b; + result->src1 = NULL; // TODO: maybe store epsilon here? return result; } -// ggml_scale - -struct ggml_tensor * ggml_scale_impl( +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, true); +} + +struct ggml_tensor * ggml_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_RMS_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_rms_norm_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_rms_norm_impl(ctx, a, true); +} + +// ggml_mul_mat + +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_MUL_MAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_scale + +struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_is_scalar(b)); - assert(ggml_is_padded_1d(a)); + GGML_ASSERT(ggml_is_scalar(b)); + GGML_ASSERT(ggml_is_padded_1d(a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2726,12 +4146,12 @@ struct ggml_tensor * ggml_cpy_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - assert(ggml_nelements(a) == ggml_nelements(b)); + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2766,14 +4186,14 @@ struct ggml_tensor * ggml_reshape( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_contiguous(a)); - assert(ggml_is_contiguous(b)); - assert(ggml_nelements(a) == ggml_nelements(b)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2792,13 +4212,13 @@ struct ggml_tensor * ggml_reshape_2d( struct ggml_tensor * a, int ne0, int ne1) { - assert(ggml_is_contiguous(a)); - assert(ggml_nelements(a) == ne0*ne1); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2819,13 +4239,13 @@ struct ggml_tensor * ggml_reshape_3d( int ne0, int ne1, int ne2) { - assert(ggml_is_contiguous(a)); - assert(ggml_nelements(a) == ne0*ne1*ne2); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2848,7 +4268,7 @@ struct ggml_tensor * ggml_view_1d( int ne0, size_t offset) { if (a->grad) { - assert(false); // gradient propagation is not supported + GGML_ASSERT(false); // gradient propagation is not supported } struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); @@ -2871,7 +4291,7 @@ struct ggml_tensor * ggml_view_2d( size_t nb1, size_t offset) { if (a->grad) { - assert(false); // gradient propagation is not supported + GGML_ASSERT(false); // gradient propagation is not supported } const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; @@ -2899,22 +4319,22 @@ struct ggml_tensor * ggml_permute( int axis1, int axis2, int axis3) { - assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS); - assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS); - assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS); - assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS); - - assert(axis0 != axis1); - assert(axis0 != axis2); - assert(axis0 != axis3); - assert(axis1 != axis2); - assert(axis1 != axis3); - assert(axis2 != axis3); + GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + GGML_ASSERT(axis0 != axis1); + GGML_ASSERT(axis0 != axis2); + GGML_ASSERT(axis0 != axis3); + GGML_ASSERT(axis1 != axis2); + GGML_ASSERT(axis1 != axis3); + GGML_ASSERT(axis2 != axis3); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2959,7 +4379,7 @@ struct ggml_tensor * ggml_transpose( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2985,12 +4405,12 @@ struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3015,7 +4435,7 @@ struct ggml_tensor * ggml_diag_mask_inf( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3040,7 +4460,7 @@ struct ggml_tensor * ggml_soft_max( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3064,11 +4484,11 @@ struct ggml_tensor * ggml_rope( int n_past, int n_dims, int mode) { - assert(n_past >= 0); + GGML_ASSERT(n_past >= 0); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3095,13 +4515,13 @@ struct ggml_tensor * ggml_conv_1d_1s( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3122,13 +4542,13 @@ struct ggml_tensor * ggml_conv_1d_2s( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - assert(ggml_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3151,7 +4571,7 @@ struct ggml_tensor * ggml_flash_attn( struct ggml_tensor * k, struct ggml_tensor * v, bool masked) { - assert(ggml_can_mul_mat(k, q)); + GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) bool is_node = false; @@ -3183,7 +4603,7 @@ struct ggml_tensor * ggml_flash_ff( struct ggml_tensor * b1, struct ggml_tensor * c0, struct ggml_tensor * c1) { - assert(ggml_can_mul_mat(b0, a)); + GGML_ASSERT(ggml_can_mul_mat(b0, a)); // TODO: more checks bool is_node = false; @@ -3214,7 +4634,7 @@ void ggml_set_param( struct ggml_tensor * tensor) { tensor->is_param = true; - assert(tensor->grad == NULL); + GGML_ASSERT(tensor->grad == NULL); tensor->grad = ggml_dup_tensor(ctx, tensor); } @@ -3224,9 +4644,9 @@ static void ggml_compute_forward_dup_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_is_contiguous(dst)); - assert(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -3249,7 +4669,7 @@ static void ggml_compute_forward_dup_f16( if (src0->nb[0] == sizeof(ggml_fp16_t)) { if (dst->type == GGML_TYPE_F16) { - int id = 0; + size_t id = 0; const size_t rs = ne00*nb00; for (int i03 = 0; i03 < ne03; i03++) { @@ -3265,7 +4685,7 @@ static void ggml_compute_forward_dup_f16( } } } else if (dst->type == GGML_TYPE_F32) { - int id = 0; + size_t id = 0; float * dst_ptr = (float *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -3287,7 +4707,7 @@ static void ggml_compute_forward_dup_f16( //printf("%s: this is not optimal - fix me\n", __func__); if (dst->type == GGML_TYPE_F32) { - int id = 0; + size_t id = 0; float * dst_ptr = (float *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -3303,7 +4723,7 @@ static void ggml_compute_forward_dup_f16( } } } else if (dst->type == GGML_TYPE_F16) { - int id = 0; + size_t id = 0; ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -3353,7 +4773,7 @@ static void ggml_compute_forward_dup_f32( if (src0->nb[0] == sizeof(float)) { if (dst->type == GGML_TYPE_F32) { - int id = 0; + size_t id = 0; const size_t rs = ne00*nb00; for (int i03 = 0; i03 < ne03; i03++) { @@ -3369,7 +4789,7 @@ static void ggml_compute_forward_dup_f32( } } } else if (dst->type == GGML_TYPE_F16) { - int id = 0; + size_t id = 0; ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -3391,7 +4811,7 @@ static void ggml_compute_forward_dup_f32( //printf("%s: this is not optimal - fix me\n", __func__); if (dst->type == GGML_TYPE_F32) { - int id = 0; + size_t id = 0; float * dst_ptr = (float *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -3407,7 +4827,7 @@ static void ggml_compute_forward_dup_f32( } } } else if (dst->type == GGML_TYPE_F16) { - int id = 0; + size_t id = 0; ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -3441,6 +4861,8 @@ static void ggml_compute_forward_dup( { ggml_compute_forward_dup_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -3516,13 +4938,15 @@ static void ggml_compute_forward_add( { ggml_compute_forward_add_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3566,13 +4990,15 @@ static void ggml_compute_forward_sub( { ggml_compute_forward_sub_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3616,13 +5042,15 @@ static void ggml_compute_forward_mul( { ggml_compute_forward_mul_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3666,13 +5094,15 @@ static void ggml_compute_forward_div( { ggml_compute_forward_div_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3712,13 +5142,15 @@ static void ggml_compute_forward_sqr( { ggml_compute_forward_sqr_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3758,13 +5190,15 @@ static void ggml_compute_forward_sqrt( { ggml_compute_forward_sqrt_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3814,13 +5248,15 @@ static void ggml_compute_forward_sum( { ggml_compute_forward_sum_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3889,13 +5325,15 @@ static void ggml_compute_forward_mean( { ggml_compute_forward_mean_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3951,13 +5389,15 @@ static void ggml_compute_forward_repeat( { ggml_compute_forward_repeat_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -3997,13 +5437,15 @@ static void ggml_compute_forward_abs( { ggml_compute_forward_abs_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4043,13 +5485,15 @@ static void ggml_compute_forward_sgn( { ggml_compute_forward_sgn_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4089,13 +5533,15 @@ static void ggml_compute_forward_neg( { ggml_compute_forward_neg_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4135,13 +5581,15 @@ static void ggml_compute_forward_step( { ggml_compute_forward_step_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4181,13 +5629,15 @@ static void ggml_compute_forward_relu( { ggml_compute_forward_relu_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -4244,17 +5694,87 @@ static void ggml_compute_forward_gelu( { ggml_compute_forward_gelu_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } + + //printf("XXXXXXXX gelu\n"); +} + +// ggml_compute_forward_silu + +static void ggml_compute_forward_silu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_f32(params, src0, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -4285,7 +5805,7 @@ static void ggml_compute_forward_norm_f32( const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - const ggml_float eps = 1e-5f; // TODO: make this a parameter + const float eps = 1e-5f; // TODO: make this a parameter // TODO: optimize for (int i03 = 0; i03 < ne03; i03++) { @@ -4293,23 +5813,24 @@ static void ggml_compute_forward_norm_f32( for (int i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float mean = 0.0; + ggml_float sum = 0.0; for (int i00 = 0; i00 < ne00; i00++) { - mean += x[i00]; + sum += (ggml_float)x[i00]; } - mean /= ne00; + float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); ggml_float sum2 = 0.0; for (int i00 = 0; i00 < ne00; i00++) { - ggml_float v = x[i00] - mean; + float v = x[i00] - mean; y[i00] = v; - sum2 += v*v; + sum2 += (ggml_float)(v*v); } - const float scale = 1.0/sqrt(sum2/ne00 + eps); + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); ggml_vec_scale_f32(ne00, y, scale); } @@ -4326,17 +5847,100 @@ static void ggml_compute_forward_norm( { ggml_compute_forward_norm_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_rms_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const float eps = 1e-6f; // TODO: make this a parameter + + // TODO: optimize + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0f/sqrtf(mean + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_rms_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, src0, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } + // ggml_compute_forward_mul_mat #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) @@ -4346,7 +5950,8 @@ static bool ggml_compute_forward_mul_mat_use_blas( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - UNUSED(src0); + //const int ne00 = src0->ne[0]; + //const int ne01 = src0->ne[1]; const int ne10 = src1->ne[0]; @@ -4354,10 +5959,10 @@ static bool ggml_compute_forward_mul_mat_use_blas( const int ne1 = dst->ne[1]; // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ( - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) - )) { - //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { + + /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ return true; } @@ -4378,8 +5983,11 @@ static void ggml_compute_forward_mul_mat_f32( const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) const int ne10 = src1->ne[0]; +#endif const int ne11 = src1->ne[1]; +#ifndef NDEBUG const int ne12 = src1->ne[2]; const int ne13 = src1->ne[3]; @@ -4387,14 +5995,16 @@ static void ggml_compute_forward_mul_mat_f32( const int ne1 = dst->ne[1]; const int ne2 = dst->ne[2]; const int ne3 = dst->ne[3]; - const int ne = ne0*ne1*ne2*ne3; const int nb00 = src0->nb[0]; +#endif const int nb01 = src0->nb[1]; const int nb02 = src0->nb[2]; const int nb03 = src0->nb[3]; +#ifndef NDEBUG const int nb10 = src1->nb[0]; +#endif const int nb11 = src1->nb[1]; const int nb12 = src1->nb[2]; const int nb13 = src1->nb[3]; @@ -4412,8 +6022,9 @@ static void ggml_compute_forward_mul_mat_f32( assert(ne2 == ne12); assert(ne3 == ne13); - // TODO: we don't support permuted src0 - assert(nb00 == sizeof(float) || nb01 == sizeof(float)); + // we don't support permuted src0 or src1 + assert(nb00 == sizeof(float)); + assert(nb10 == sizeof(float)); // dst cannot be transposed or permuted assert(nb0 == sizeof(float)); @@ -4428,14 +6039,9 @@ static void ggml_compute_forward_mul_mat_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - GGML_ASSERT(nb10 == sizeof(float)); - if (params->ith != 0) { return; } @@ -4450,19 +6056,17 @@ static void ggml_compute_forward_mul_mat_f32( for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { - const float * x = (float *) (src0->data); + const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // zT = y * xT - { - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); - } + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); } } @@ -4473,130 +6077,242 @@ static void ggml_compute_forward_mul_mat_f32( #endif if (params->type == GGML_TASK_INIT) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); return; } if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { + return; + } + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + ggml_vec_dot_f32(ne00, + (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), + (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { return; } - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } float * const wdata = params->wdata; - // cols per thread - const int dc = (ne + nth - 1)/nth; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + size_t id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + for (int i00 = 0; i00 < ne00; ++i00) { + wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + } + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + + /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); + return; + } +#endif - ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); + if (params->type == GGML_TASK_INIT) { + ggml_fp16_t * const wdata = params->wdata; - for (int k = 1; k < nth; k++) { - ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); + size_t id = 0; + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + } + } + } } + GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); + return; } - if (nb01 >= nb00) { - // TODO: do not support transposed src1 - assert(nb10 == sizeof(float)); + if (params->type == GGML_TASK_FINALIZE) { + return; + } - // parallelize by src0 rows using ggml_vec_dot_f32 + // fp16 -> half the size, so divide by 2 + // TODO: do not support transposed src1 + assert(nb10/2 == sizeof(ggml_fp16_t)); - // total rows in src0 - const int nr = ne01*ne02*ne03; + // parallelize by src0 rows using ggml_vec_dot_f16 - // rows per thread - const int dr = (nr + nth - 1)/nth; + // total rows in src0 + const int nr = ne01*ne02*ne03; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + // rows per thread + const int dr = (nr + nth - 1)/nth; - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); - for (int ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; + ggml_fp16_t * wdata = params->wdata; - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - ggml_vec_dot_f32(ne00, - (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), - (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); - } - } - } else { - // parallelize by src1 columns using ggml_vec_mad_f32 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - // total columns in src1 - const int nc = ne10; + const int i13 = i03; + const int i12 = i02; - // columns per thread - const int dc = (nc + nth - 1)/nth; + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); + ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - float * const wdata = params->wdata; + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; - - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; - - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; - - assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); - - ggml_vec_mad_f32(ne01, - (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), - (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), - *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); - } - } - } + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); } } - //int64_t t1 = ggml_perf_time_us(); + //int64_t t1 = ggml_time_us(); //static int64_t acc = 0; //acc += t1 - t0; //if (t1 - t0 > 10) { @@ -4604,13 +6320,35 @@ static void ggml_compute_forward_mul_mat_f32( // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); //} } -static void ggml_compute_forward_mul_mat_f16_f32( +typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k); +typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k); +typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y); + +typedef struct { + dequantize_row_q_t dequantize_row_q; + quantize_row_q_t quantize_row_q; + vec_dot_q_t vec_dot_q; +} quantize_fns_t; + +static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { + [GGML_TYPE_Q4_0] = { + .dequantize_row_q = dequantize_row_q4_0, + .quantize_row_q = quantize_row_q4_0, + .vec_dot_q = ggml_vec_dot_q4_0, + }, + [GGML_TYPE_Q4_1] = { + .dequantize_row_q = dequantize_row_q4_1, + .quantize_row_q = quantize_row_q4_1, + .vec_dot_q = ggml_vec_dot_q4_1, + }, +}; + +static void ggml_compute_forward_mul_mat_q_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -4632,7 +6370,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int ne1 = dst->ne[1]; const int ne2 = dst->ne[2]; const int ne3 = dst->ne[3]; - const int ne = ne0*ne1*ne2*ne3; const int nb00 = src0->nb[0]; const int nb01 = src0->nb[1]; @@ -4657,8 +6394,13 @@ static void ggml_compute_forward_mul_mat_f16_f32( GGML_ASSERT(ne2 == ne12); GGML_ASSERT(ne3 == ne13); - // TODO: we don't support permuted src0 - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + const enum ggml_type type = src0->type; + quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; + vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); + GGML_ASSERT(nb10 == sizeof(float)); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -4673,14 +6415,9 @@ static void ggml_compute_forward_mul_mat_f16_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - GGML_ASSERT(nb10 == sizeof(float)); - if (params->ith != 0) { return; } @@ -4694,58 +6431,29 @@ static void ggml_compute_forward_mul_mat_f16_f32( } float * const wdata = params->wdata; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { { - int id = 0; + size_t id = 0; for (int i01 = 0; i01 < ne01; ++i01) { - for (int i00 = 0; i00 < ne00; ++i00) { - wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); - } + dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; } } const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - // float * z = wdata + ne00*ne01; - - // z = x * yT - //{ - // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - // ne01, ne11, ne00, - // 1.0f, x, ne00, - // y, ne00, - // 0.0f, z, ne11); - //} - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - // transpose z - //for (int j = 0; j < ne11; ++j) { - // for (int i = 0; i < ne01; ++i) { - // d[j*ne01 + i] = z[i*ne11 + j]; - // } - //} - - { -#if 1 - // zT = y * xT - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne00, - x, ne00, - 0.0f, d, ne01); -#else - // zT = (xT * y)T - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, - ne01, ne11, ne10, - 1.0f, x, ne00, - y, ne00, - 0.0f, d, ne01); -#endif - } + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); } } @@ -4756,150 +6464,62 @@ static void ggml_compute_forward_mul_mat_f16_f32( #endif if (params->type == GGML_TASK_INIT) { - if (nb01 >= nb00) { - ggml_fp16_t * const wdata = params->wdata; - - int id = 0; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int i10 = 0; i10 < ne10; ++i10) { - wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); - } - } + char * wdata = params->wdata; + const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type]; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; } } - - GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); - - return; } - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); return; } if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); - - ggml_fp16_t * const wdata = params->wdata; - - // cols per thread - const int dc = (ne + nth - 1)/nth; - - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); - - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); - } - - for (int k = 1; k < nth; k++) { - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); - } - } - return; } - if (nb01 >= nb00) { - // fp16 -> half the size, so divide by 2 - // TODO: do not support transposed src1 - assert(nb10/2 == sizeof(ggml_fp16_t)); - - // parallelize by src0 rows using ggml_vec_dot_f16 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * wdata = params->wdata; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int i13 = i03; - const int i12 = i02; - - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; - - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + // parallelize by src0 rows using ggml_vec_dot_q - assert(ne00 % 32 == 0); + // total rows in src0 + const int nr = ne01*ne02*ne03; - for (int ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); - } - } - } else { - // parallelize by src1 columns using ggml_vec_mad_f16 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst - - // total columns in src1 - const int nc = ne10; - - // columns per thread - const int dc = (nc + nth - 1)/nth; + // rows per thread + const int dr = (nr + nth - 1)/nth; - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - ggml_fp16_t * const wdata = params->wdata; + void * wdata = params->wdata; + const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type]; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + const int i13 = i03; + const int i12 = i02; - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); - assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); - float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + assert(ne00 % 32 == 0); - ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); - } - } - } + for (int ic = 0; ic < ne11; ++ic) { + vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); } } @@ -4922,6 +6542,11 @@ static void ggml_compute_forward_mul_mat( const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + { + ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); @@ -4935,9 +6560,37 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } + +#if 0 + if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) { + static int first = 8; + printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); + printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); + printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + if (first) { + --first; + } else { + for (int k = 0; k < dst->ne[1]; ++k) { + for (int j = 0; j < dst->ne[0]/16; ++j) { + for (int i = 0; i < 16; ++i) { + printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + } + printf("\n"); + } + printf("\n"); + } + printf("\n"); + exit(0); + } + } else { + printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); + printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); + printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + } +#endif } // ggml_compute_forward_scale @@ -4987,13 +6640,15 @@ static void ggml_compute_forward_scale( { ggml_compute_forward_scale_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5051,6 +6706,35 @@ static void ggml_compute_forward_transpose( // ggml_compute_forward_get_rows +static void ggml_compute_forward_get_rows_q( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + const enum ggml_type type = src0->type; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == GGML_TYPE_SIZE[type]); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + dequantize_row_q( + (const void *) ((char *) src0->data + r*src0->nb[1]), + (float *) ((char *) dst->data + i*dst->nb[1]), nc); + } +} + static void ggml_compute_forward_get_rows_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -5112,6 +6796,11 @@ static void ggml_compute_forward_get_rows( const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + { + ggml_compute_forward_get_rows_q(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_get_rows_f16(params, src0, src1, dst); @@ -5125,9 +6814,27 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} } // ggml_compute_forward_diag_mask_inf @@ -5178,13 +6885,15 @@ static void ggml_compute_forward_diag_mask_inf( { ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5223,6 +6932,7 @@ static void ggml_compute_forward_soft_max_f32( #ifndef NDEBUG for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); assert(!isnan(p[i])); } #endif @@ -5241,12 +6951,12 @@ static void ggml_compute_forward_soft_max_f32( ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); memcpy(&scvt, &s, sizeof(scvt)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); - sum += val; + sum += (ggml_float)val; p[i] = val; } } - assert(sum > 0.0f); + assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(nc, p, sum); @@ -5269,13 +6979,15 @@ static void ggml_compute_forward_soft_max( { ggml_compute_forward_soft_max_f32(params, src0, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5320,16 +7032,16 @@ static void ggml_compute_forward_rope_f32( const int p = (mode == 0 ? n_past + i2 : i2); for (int i1 = 0; i1 < ne1; i1++) { for (int i0 = 0; i0 < n_dims; i0 += 2) { - const double theta = pow(10000.0, ((double)-i0)/n_dims); + const float theta = powf(10000.0, ((float)-i0)/n_dims); - const double cos_theta = cos(p*theta); - const double sin_theta = sin(p*theta); + const float cos_theta = cosf(p*theta); + const float sin_theta = sinf(p*theta); const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - double x0 = src[0]; - double x1 = src[1]; + const float x0 = src[0]; + const float x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -5339,23 +7051,84 @@ static void ggml_compute_forward_rope_f32( } } +static void ggml_compute_forward_rope_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 3); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + //const int ne0 = src0->ne[0]; + const int ne1 = src0->ne[1]; + const int ne2 = src0->ne[2]; + const int ne3 = src0->ne[3]; + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(ggml_fp16_t)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = (mode == 0 ? n_past + i2 : i2); + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < n_dims; i0 += 2) { + const float theta = powf(10000.0, ((float)-i0)/n_dims); + + const float cos_theta = cosf(p*theta); + const float sin_theta = sinf(p*theta); + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = ggml_fp16_to_fp32(src[0]); + const float x1 = ggml_fp16_to_fp32(src[1]); + + dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta); + dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta); + } + } + } + } +} + static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_f16(params, src0, src1, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: - case GGML_TYPE_F16: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -5616,6 +7389,8 @@ static void ggml_compute_forward_conv_1d_1s( { ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5882,6 +7657,8 @@ static void ggml_compute_forward_conv_1d_2s( { ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5996,7 +7773,7 @@ static void ggml_compute_forward_flash_attn_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - const float scale = 1.0/sqrt((double) D); + const float scale = 1.0f/sqrtf(D); //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); @@ -6043,7 +7820,7 @@ static void ggml_compute_forward_flash_attn_f32( float max = -INFINITY; ggml_vec_max_f32(M, &max, S); - float sum = 0.0f; + ggml_float sum = 0.0; { #ifdef GGML_SOFT_MAX_ACCELERATE max = -max; @@ -6064,7 +7841,7 @@ static void ggml_compute_forward_flash_attn_f32( ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += val; + sump[j] += (ggml_float)val; SS[j] = val; } } @@ -6076,7 +7853,7 @@ static void ggml_compute_forward_flash_attn_f32( #endif } - assert(sum > 0.0f); + assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(M, S, sum); @@ -6205,7 +7982,7 @@ static void ggml_compute_forward_flash_attn_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - const float scale = 1.0/sqrt((double) D); + const float scale = 1.0f/sqrtf(D); //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); @@ -6269,7 +8046,7 @@ static void ggml_compute_forward_flash_attn_f16( float max = -INFINITY; ggml_vec_max_f32(M, &max, S); - float sum = 0.0f; + ggml_float sum = 0.0; { #ifdef GGML_SOFT_MAX_ACCELERATE max = -max; @@ -6290,7 +8067,7 @@ static void ggml_compute_forward_flash_attn_f16( ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += val; + sump[j] += (ggml_float)val; SS[j] = val; } } @@ -6302,7 +8079,7 @@ static void ggml_compute_forward_flash_attn_f16( #endif } - assert(sum > 0.0f); + assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(M, S, sum); @@ -6365,12 +8142,14 @@ static void ggml_compute_forward_flash_attn( { ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -6574,12 +8353,14 @@ static void ggml_compute_forward_flash_ff( { GGML_ASSERT(false); // TODO } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -6587,7 +8368,7 @@ static void ggml_compute_forward_flash_ff( ///////////////////////////////// static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { - assert(params); + GGML_ASSERT(params); switch (tensor->op) { case GGML_OP_DUP: @@ -6654,10 +8435,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_gelu(params, tensor->src0, tensor); } break; + case GGML_OP_SILU: + { + ggml_compute_forward_silu(params, tensor->src0, tensor); + } break; case GGML_OP_NORM: { ggml_compute_forward_norm(params, tensor->src0, tensor); } break; + case GGML_OP_RMS_NORM: + { + ggml_compute_forward_rms_norm(params, tensor->src0, tensor); + } break; case GGML_OP_MUL_MAT: { ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); @@ -6835,7 +8624,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_MEAN: { - assert(false); // TODO: implement + GGML_ASSERT(false); // TODO: implement } break; case GGML_OP_REPEAT: { @@ -6890,17 +8679,25 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_GELU: { - assert(false); // TODO: not implemented + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_SILU: + { + GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_NORM: { - assert(false); // TODO: not implemented + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_MUL_MAT: { if (src0->grad) { // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); - assert(false); + GGML_ASSERT(false); } if (src1->grad) { src1->grad = @@ -7016,12 +8813,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * if (node->op == GGML_OP_NONE && node->grad == NULL) { // reached a leaf node, not part of the gradient graph (e.g. a constant) - assert(cgraph->n_leafs < GGML_MAX_NODES); + GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES); cgraph->leafs[cgraph->n_leafs] = node; cgraph->n_leafs++; } else { - assert(cgraph->n_nodes < GGML_MAX_NODES); + GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES); cgraph->nodes[cgraph->n_nodes] = node; cgraph->grads[cgraph->n_nodes] = node->grad; @@ -7045,7 +8842,7 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten if (n_new > 0) { // the last added node should always be starting point - assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); } } @@ -7076,7 +8873,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { struct ggml_cgraph result = *gf; - assert(gf->n_nodes > 0); + GGML_ASSERT(gf->n_nodes > 0); // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph if (keep) { @@ -7239,10 +9036,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { - if (cgraph->n_threads <= 0) { - cgraph->n_threads = 8; - } - const int n_threads = cgraph->n_threads; struct ggml_compute_state_shared state_shared = { @@ -7275,7 +9068,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) }; int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - assert(rc == 0); + GGML_ASSERT(rc == 0); UNUSED(rc); } } @@ -7317,7 +9110,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = n_threads; } break; + case GGML_OP_SILU: + { + node->n_tasks = n_threads; + } break; case GGML_OP_NORM: + case GGML_OP_RMS_NORM: { node->n_tasks = n_threads; } break; @@ -7334,32 +9132,35 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; - // TODO: better way to determine if the matrix is transposed - if (node->src0->nb[1] < node->src0->nb[0]) { - cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) - } else { - if (node->src0->type == GGML_TYPE_F16 && - node->src1->type == GGML_TYPE_F32) { + if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); - //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); - //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); - //printf("cur = %zu\n", cur); - } else { - cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); - } + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); + //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); + //printf("cur = %zu\n", cur); + } else { + cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); + } #else - cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); + cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); #endif - } else if (node->src0->type == GGML_TYPE_F32 && - node->src1->type == GGML_TYPE_F32) { - cur = 0; - } else { - GGML_ASSERT(false); + } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { + cur = 0; + } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else +#endif + { + cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type]; } + } else { + GGML_ASSERT(false); } work_size = MAX(work_size, cur); @@ -7460,13 +9261,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } if (cgraph->work != NULL && work_size > cgraph->work_size) { - assert(false); // TODO: better handling + GGML_ASSERT(false); // TODO: better handling } if (work_size > 0 && cgraph->work == NULL) { @@ -7632,7 +9433,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) for (int j = 0; j < n_threads - 1; j++) { int rc = ggml_thread_join(workers[j].thrd, NULL); - assert(rc == 0); + GGML_ASSERT(rc == 0); UNUSED(rc); } @@ -7739,7 +9540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph char color[16]; FILE * fp = fopen(filename, "w"); - assert(fp); + GGML_ASSERT(fp); fprintf(fp, "digraph G {\n"); fprintf(fp, " newrank = true;\n"); @@ -7787,7 +9588,7 @@ label=\"%d [%d, %d] | %s", fprintf(fp, " \"%p\" [ \ style = filled; fillcolor = %s; shape = record; \ label=\"%.1e\"; ]\n", - (void *) node, color, ggml_get_f32_1d(node, 0)); + (void *) node, color, (double)ggml_get_f32_1d(node, 0)); } else { fprintf(fp, " \"%p\" [ \ style = filled; fillcolor = %s; shape = record; \ @@ -7897,7 +9698,7 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_tensor * f, struct ggml_cgraph * gf, struct ggml_cgraph * gb) { - assert(ggml_is_scalar(f)); + GGML_ASSERT(ggml_is_scalar(f)); gf->n_threads = params.n_threads; gb->n_threads = params.n_threads; @@ -7911,7 +9712,7 @@ static enum ggml_opt_result ggml_opt_adam( if (gf->nodes[i]->is_param) { GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - assert(np < GGML_MAX_PARAMS); + GGML_ASSERT(np < GGML_MAX_PARAMS); ps[np++] = gf->nodes[i]; nx += ggml_nelements(gf->nodes[i]); @@ -8025,7 +9826,7 @@ static enum ggml_opt_result ggml_opt_adam( if (params.past <= t) { const float rate = (pf[t%params.past] - fx)/fx; - if (fabs(rate) < params.delta) { + if (fabsf(rate) < params.delta) { return GGML_OPT_OK; } } @@ -8104,7 +9905,7 @@ static enum ggml_opt_result linesearch_backtracking( const float dec = 0.5f; const float inc = 2.1f; - if (*step <= 0.) { + if (*step <= 0.f) { return GGML_LINESEARCH_INVALID_PARAMETERS; } @@ -8192,7 +9993,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( struct ggml_cgraph * gb) { if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { - if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { return GGML_OPT_INVALID_WOLFE; } } @@ -8211,7 +10012,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( if (gf->nodes[i]->is_param) { GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - assert(np < GGML_MAX_PARAMS); + GGML_ASSERT(np < GGML_MAX_PARAMS); ps[np++] = gf->nodes[i]; nx += ggml_nelements(gf->nodes[i]); @@ -8313,8 +10114,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - if (xnorm < 1.0) { - xnorm = 1.0; + if (xnorm < 1.0f) { + xnorm = 1.0f; } if (gnorm/xnorm <= params.lbfgs.eps) { // converged @@ -8327,7 +10128,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( if (params.past <= k) { const float rate = (pf[k%params.past] - fx)/fx; - if (fabs(rate) < params.delta) { + if (fabsf(rate) < params.delta) { return GGML_OPT_OK; } } @@ -8523,6 +10324,54 @@ enum ggml_opt_result ggml_opt( //////////////////////////////////////////////////////////////////////////////// +size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK == 0); + const int nb = k / QK; + + for (int j = 0; j < n; j += k) { + block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK; + + quantize_row_q4_0_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK; l += 2) { + const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi1 = y[i].qs[l/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK*sizeof(block_q4_0)); +} + +size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK == 0); + const int nb = k / QK; + + for (int j = 0; j < n; j += k) { + block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK; + + quantize_row_q4_1_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK; l += 2) { + const uint8_t vi0 = y[i].qs[l/2] & 0xF; + const uint8_t vi1 = y[i].qs[l/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK*sizeof(block_q4_1)); +} + +//////////////////////////////////////////////////////////////////////////////// + int ggml_cpu_has_avx(void) { #if defined(__AVX__) return 1; diff --git a/tests/test-mul-mat2.c b/tests/test-mul-mat2.c index bb7dd8d8f..be7b038df 100644 --- a/tests/test-mul-mat2.c +++ b/tests/test-mul-mat2.c @@ -13,8 +13,10 @@ #include -#ifdef __ARM_NEON +#if defined(__ARM_NEON) #include "arm_neon.h" +#elif defined(__AVX__) || defined(__AVX2__) +#include "immintrin.h" #endif #ifndef MIN @@ -26,8 +28,12 @@ const int M = 1280; const int N = 1536; const int K = 1280; -const int QK = 64; -#define QB 7 +//const int M = 64; +//const int N = 64; +//const int K = 64; + +#define QK 64 +#define QB 4 //#define GGML_GQ_USE_FP16_SCALE @@ -41,8 +47,12 @@ const int QK = 64; #define GGML_GQ_TO_FP32(x) (x) #endif -#define gq_quant_t uint64_t #define gq_t_bits 64 +#define gq_quant_t uint64_t + +float frand() { + return (float) rand() / (float) RAND_MAX; +} uint64_t get_time_us() { struct timeval tv; @@ -50,6 +60,47 @@ uint64_t get_time_us() { return tv.tv_sec * 1000000 + tv.tv_usec; } +#if defined(__AVX2__) +// horizontally reduce 8 32-bit integers +static inline uint32_t _mm256_hadd_epi32_gg(__m256i v) { + __m128i v0 = _mm256_extractf128_si256(v, 0); + __m128i v1 = _mm256_extractf128_si256(v, 1); + + v0 = _mm_add_epi32(v0, v1); + + v1 = _mm_shuffle_epi32(v0, 0x0e); + v0 = _mm_add_epi32(v0, v1); + + v1 = _mm_shuffle_epi32(v0, 0x01); + v0 = _mm_add_epi32(v0, v1); + + return _mm_cvtsi128_si32(v0); +} + +//static inline float _mm256_hadd_epi32_gg(__m256i v) { +// const __m256 v0 = _mm256_cvtepi32_ps(v); +// const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v0), _mm256_extractf128_ps(v0, 1)); +// const __m128 t1 = _mm_hadd_ps(t0, t0); +// +// return _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); +//} + +// horizontally reduce 32 8-bit integers +static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) { + __m256i v1 = _mm256_maddubs_epi16(v0, _mm256_set1_epi8(1)); + __m256i v2 = _mm256_madd_epi16 (v1, _mm256_set1_epi16(1)); + + return _mm256_hadd_epi32_gg(v2); +} + +static inline float _mm256_hadd_ps_gg(__m256 v) { + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + const __m128 t1 = _mm_hadd_ps(t0, t0); + + return _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); +} +#endif + // // naive implementation // @@ -74,6 +125,21 @@ void mul_mat_f32_naive( // method 1 // +static inline int quantize_1_blocks_per_row(int k) { + return k/QK; +} + +static inline int quantize_1_quants_per_block() { + return QK/gq_t_bits; +} + +static inline int quantize_1_row_size(int k) { + const int nb = quantize_1_blocks_per_row(k); + const int nq = quantize_1_quants_per_block(); + + return nb*(2*sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t)); +} + void quantize_1(const float * src, void * dst, int n, int k) { char * p0 = dst; @@ -215,6 +281,7 @@ void mul_mat_gq_1( // // method 2 +// n-bit quantization (2nd attempt) // static inline int quantize_2_blocks_per_row(int k) { @@ -244,15 +311,41 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) { gq_quant_t pp[QB]; + static const int32_t sh[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + }; + for (int i = 0; i < nb; i++) { float min = FLT_MAX; float max = -FLT_MAX; - for (int l = 0; l < QK; l++) { - const float v = src[i*QK + l]; - if (v < min) min = v; - if (v > max) max = v; +#ifdef __ARM_NEON + { + float32x4_t minv = vdupq_n_f32(FLT_MAX); + float32x4_t maxv = vdupq_n_f32(-FLT_MAX); + + for (int l = 0; l < QK; l += 4) { + float32x4_t v = vld1q_f32(src + i*QK + l); + minv = vminq_f32(minv, v); + maxv = vmaxq_f32(maxv, v); + } + + float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv)); + float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv)); + + min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1)); + max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1)); } +#else + { + for (int l = 0; l < QK; l++) { + const float v = src[i*QK + l]; + if (v < min) min = v; + if (v > max) max = v; + } + } +#endif const float d = (max - min) / ((1 << QB) - 1); const float id = d ? 1.0/d : 0.0; @@ -263,18 +356,150 @@ void quantize_2_row(const float * restrict src, void * restrict dst, int k) { for (int s = 0; s < nq; ++s) { memset(pp, 0, sizeof(pp)); +#if 1 for (int l = 0; l < gq_t_bits; l++) { const float v = src[i*QK + s*gq_t_bits + l]; - const uint8_t q = (v - min)*id; + const uint8_t q = (v - min)*id + frand(); for (int b = 0; b < QB; b++) { pp[b] |= q & (1 << b) ? (1ULL << l) : 0; } } +#elif defined(__ARM_NEON) +#if 1 + { + uint32_t ppt[2*4*QB]; + + float32x4_t minv = vdupq_n_f32(min); + float32x4_t idv = vdupq_n_f32(id); + + assert(gq_t_bits % 16 == 0); + + uint32x4_t p0[QB] = { vdupq_n_u32(0) }; + uint32x4_t p1[QB] = { vdupq_n_u32(0) }; + + for (int l = 0; l < gq_t_bits; l += 16) { + float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0); + float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4); + float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8); + float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12); + + v0 = vsubq_f32(v0, minv); + v1 = vsubq_f32(v1, minv); + v2 = vsubq_f32(v2, minv); + v3 = vsubq_f32(v3, minv); + + v0 = vmulq_f32(v0, idv); + v1 = vmulq_f32(v1, idv); + v2 = vmulq_f32(v2, idv); + v3 = vmulq_f32(v3, idv); + +#if 1 + v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand(); + v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand(); + v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand(); + v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand(); +#endif + + uint32x4_t q0 = vcvtq_u32_f32(v0); + uint32x4_t q1 = vcvtq_u32_f32(v1); + uint32x4_t q2 = vcvtq_u32_f32(v2); + uint32x4_t q3 = vcvtq_u32_f32(v3); + + for (int b = 0; b < QB; ++b) { + uint32x4_t m = vdupq_n_u32(1 << b); + uint32x4_t r = vdupq_n_u32(-b); + + if (l < 32) { + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0))); + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4))); + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8))); + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12))); + } else { + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32))); + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28))); + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24))); + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20))); + } + } + } + +#if QB == 4 + vst1q_u32((uint32_t *) ppt + 0, p0[0]); + vst1q_u32((uint32_t *) ppt + 4, p1[0]); + vst1q_u32((uint32_t *) ppt + 8, p0[1]); + vst1q_u32((uint32_t *) ppt + 12, p1[1]); + vst1q_u32((uint32_t *) ppt + 16, p0[2]); + vst1q_u32((uint32_t *) ppt + 20, p1[2]); + vst1q_u32((uint32_t *) ppt + 24, p0[3]); + vst1q_u32((uint32_t *) ppt + 28, p1[3]); + + pp[0] = (ppt[0] | ppt[1] | ppt[2] | ppt[3] ) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7]) ) << 32; + pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32; + pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32; + pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32; +#else + for (int b = 0; b < QB; ++b) { + vst1q_u32((uint32_t *) ppt + 0, p0[b]); + vst1q_u32((uint32_t *) ppt + 4, p1[b]); + + pp[b] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32; + } +#endif + } +#else + // less optimal SIMD + { + float32x4_t minv = vdupq_n_f32(min); + float32x4_t idv = vdupq_n_f32(id); + + assert(gq_t_bits == 64); + uint8_t qq[gq_t_bits]; + + for (int l = 0; l < gq_t_bits; l += 16) { + float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0); + float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4); + float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8); + float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12); + + v0 = vsubq_f32(v0, minv); + v1 = vsubq_f32(v1, minv); + v2 = vsubq_f32(v2, minv); + v3 = vsubq_f32(v3, minv); + + v0 = vmulq_f32(v0, idv); + v1 = vmulq_f32(v1, idv); + v2 = vmulq_f32(v2, idv); + v3 = vmulq_f32(v3, idv); + +#if 0 + v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand(); + v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand(); + v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand(); + v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand(); +#endif + + uint32x4_t q0 = vcvtq_u32_f32(v0); + uint32x4_t q1 = vcvtq_u32_f32(v1); + uint32x4_t q2 = vcvtq_u32_f32(v2); + uint32x4_t q3 = vcvtq_u32_f32(v3); + + // store in qq as uint8_t + vst1_u8(qq + l + 0, vmovn_u16(vcombine_u16(vmovn_u32(q0), vmovn_u32(q1)))); + vst1_u8(qq + l + 8, vmovn_u16(vcombine_u16(vmovn_u32(q2), vmovn_u32(q3)))); + } - for (int b = 0; b < QB; b++) { - pb[i*nq*QB + s*QB + b] = pp[b]; + for (int l = 0; l < gq_t_bits; l++) { + for (int b = 0; b < QB; b++) { + const uint64_t ql = qq[l]; + /*pp[b] |= qq[l] & (1 << b) ? (1ULL << l) : 0;*/ + pp[b] |= ((ql & (1 << b)) >> b) << l; + } + } } +#endif +#endif + memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp)); } } } @@ -290,9 +515,6 @@ void quantize_2(const float * restrict src, char * restrict dst, int n, int k) { } void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - float sumf[(QB + 1)*(QB + 1)]; - memset(sumf, 0, sizeof(sumf)); - const int nb = quantize_2_blocks_per_row(n); const int nq = quantize_2_quants_per_block(); @@ -305,10 +527,9 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb); const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb); -#if 1 - float s0[QB + 1]; - float s1[QB + 1]; + float sumf = 0.0; +#if 1 for (int i = 0; i < nb; i++) { const float m0 = GGML_GQ_TO_FP32(pm0[i]); const float d0 = GGML_GQ_TO_FP32(pd0[i]); @@ -316,6 +537,99 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons const float m1 = GGML_GQ_TO_FP32(pm1[i]); const float d1 = GGML_GQ_TO_FP32(pd1[i]); +#if QB == 4 + int isum01 = 0; + int isum10 = 0; + int isum11 = 0; + + for (int s = 0; s < nq; ++s) { + const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB; + const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB; + +#define bpcnt(x) __builtin_popcountll(x) + isum01 += (1 << 0)*(bpcnt(mm1[0])); + isum01 += (1 << 1)*(bpcnt(mm1[1])); + isum01 += (1 << 2)*(bpcnt(mm1[2])); + isum01 += (1 << 3)*(bpcnt(mm1[3])); + + isum10 += (1 << 0)*(bpcnt(mm0[0])); + isum10 += (1 << 1)*(bpcnt(mm0[1])); + isum10 += (1 << 2)*(bpcnt(mm0[2])); + isum10 += (1 << 3)*(bpcnt(mm0[3])); + + isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0])); + isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0])); + isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0])); + isum11 += (1 << 3)*(bpcnt(mm0[0] & mm1[3]) + bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1]) + bpcnt(mm0[3] & mm1[0])); + isum11 += (1 << 4)*(bpcnt(mm0[1] & mm1[3]) + bpcnt(mm0[2] & mm1[2]) + bpcnt(mm0[3] & mm1[1])); + isum11 += (1 << 5)*(bpcnt(mm0[2] & mm1[3]) + bpcnt(mm0[3] & mm1[2])); + isum11 += (1 << 6)*(bpcnt(mm0[3] & mm1[3])); +#undef bpcnt + } + + sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1); +#elif QB == 3 + int isum01 = 0; + int isum10 = 0; + int isum11 = 0; + + for (int s = 0; s < nq; ++s) { + const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB; + const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB; + +#if gq_t_bits == 32 +#define bpcnt(x) __builtin_popcount(x) +#else +#define bpcnt(x) __builtin_popcountll(x) +#endif + isum01 += (1 << 0)*(bpcnt(mm1[0])); + isum01 += (1 << 1)*(bpcnt(mm1[1])); + isum01 += (1 << 2)*(bpcnt(mm1[2])); + + isum10 += (1 << 0)*(bpcnt(mm0[0])); + isum10 += (1 << 1)*(bpcnt(mm0[1])); + isum10 += (1 << 2)*(bpcnt(mm0[2])); + + isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0])); + isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0])); + isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0])); + isum11 += (1 << 3)*(bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1])); + isum11 += (1 << 4)*(bpcnt(mm0[2] & mm1[2])); +#undef bpcnt + } + + sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1); +#elif QB == 2 + int isum01 = 0; + int isum10 = 0; + int isum11 = 0; + + for (int s = 0; s < nq; ++s) { + const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB; + const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB; + +#if gq_t_bits == 32 +#define bpcnt(x) __builtin_popcount(x) +#else +#define bpcnt(x) __builtin_popcountll(x) +#endif + isum01 += (1 << 0)*(bpcnt(mm1[0])); + isum01 += (1 << 1)*(bpcnt(mm1[1])); + + isum10 += (1 << 0)*(bpcnt(mm0[0])); + isum10 += (1 << 1)*(bpcnt(mm0[1])); + + isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0])); + isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0])); + isum11 += (1 << 2)*(bpcnt(mm0[1] & mm1[1])); +#undef bpcnt + } + + sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1); +#else + float s0[QB + 1]; + float s1[QB + 1]; + s0[0] = m0; s1[0] = m1; @@ -329,36 +643,17 @@ void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, cons const gq_quant_t mm0 = q0 ? pb0[i*nq*QB + s*QB + q0 - 1] : -1ULL; for (int q1 = 0; q1 < QB + 1; q1++) { const gq_quant_t mm1 = q1 ? pb1[i*nq*QB + s*QB + q1 - 1] : -1ULL; - sumf[q0*(QB + 1) + q1] += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1); + sumf += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1); } } } +#endif } #else - // SIMD-ify with the assumptions: - // - nb is a multiple of 4 - // - gq_scale_t is float - // - gq_quant_t is uint64_t - // - QB == 7 - assert(nb % 4 == 0); - -#ifdef __ARM_NEON -#else - // TODO -#endif - +#error "not implemented" #endif - for (int q0 = 0; q0 < QB + 1; q0++) { - for (int q1 = 1; q1 < QB + 1; q1++) { - sumf[q0*(QB + 1)] += sumf[q0*(QB + 1) + q1]; - } - } - - *s = sumf[0]; - for (int q0 = 1; q0 < QB + 1; q0++) { - *s += sumf[q0*(QB + 1)]; - } + *s = sumf; } // use vec_dot_gq_2 to compute the dot product of two rows @@ -384,83 +679,1904 @@ void mul_mat_gq_2( } } -int main(int argc, const char ** argv) { - assert(sizeof(gq_quant_t)*8 == gq_t_bits); +// +// method 3 +// (does not work) +// - float * src0 = (float *)malloc(sizeof(float)*M*K); - float * src1 = (float *)malloc(sizeof(float)*N*K); - float * dst = (float *)malloc(sizeof(float)*M*N); +static inline int quantize_3_blocks_per_row(int k) { + return k/QK; +} - for (int i = 0; i < M*K; i++) { - src0[i] = rand() / (float)RAND_MAX; - } +static inline int quantize_3_quants_per_block() { + return QK/gq_t_bits; +} - for (int i = 0; i < N*K; i++) { - src1[i] = rand() / (float)RAND_MAX; - } +static inline int quantize_3_row_size(int k) { + const int nb = quantize_3_blocks_per_row(k); + const int nq = quantize_3_quants_per_block(); + + return nb*(sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t)); +} - void * src0_gq = calloc(1, quantize_2_row_size(K)*M); - void * src1_gq = calloc(1, quantize_2_row_size(K)*N); +void quantize_3_row(const float * restrict src, void * restrict dst, int k) { + assert(k % QK == 0); - const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K; - const size_t sizegq = quantize_2_row_size(K)*M + quantize_2_row_size(K)*N; + const int nb = quantize_3_blocks_per_row(k); + const int nq = quantize_3_quants_per_block(); - printf("compression: %f\n", (float)sizegq/sizef16); + gq_scale_t * restrict pd = (gq_scale_t *) (dst); + gq_quant_t * restrict pb = (gq_quant_t *) (pd + nb); - int method = 0; - if (argc > 1) { - method = atoi(argv[1]); - } + gq_quant_t pp[QB]; - // convert fp32 -> gq - { - const uint64_t t_start = get_time_us(); + static const int32_t sh[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + }; - if (method == 1) { - quantize_1(src0, src0_gq, M, K); - quantize_1(src1, src1_gq, N, K); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // abs max + +#ifdef __ARM_NEON + { + // min / max + //float32x4_t minv = vdupq_n_f32(FLT_MAX); + //float32x4_t maxv = vdupq_n_f32(-FLT_MAX); + + //for (int l = 0; l < QK; l += 4) { + // float32x4_t v = vld1q_f32(src + i*QK + l); + // minv = vminq_f32(minv, v); + // maxv = vmaxq_f32(maxv, v); + //} + + //float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv)); + //float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv)); + + //min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1)); + //max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1)); + + // abs max + float32x4_t amaxv = vdupq_n_f32(0.0f); + + for (int l = 0; l < QK; l += 4) { + float32x4_t v = vld1q_f32(src + i*QK + l); + amaxv = vmaxq_f32(amaxv, vabsq_f32(v)); + } + + float32x2_t amaxv32 = vpmax_f32(vget_low_f32(amaxv), vget_high_f32(amaxv)); + + amax = MAX(vget_lane_f32(amaxv32, 0), vget_lane_f32(amaxv32, 1)); + } +#else + { + for (int l = 0; l < QK; l++) { + const float v = src[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } } +#endif - if (method == 2) { - quantize_2(src0, src0_gq, M, K); - quantize_2(src1, src1_gq, N, K); + const float d = amax / ((1 << (QB - 1)) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + for (int s = 0; s < nq; ++s) { + memset(pp, 0, sizeof(pp)); + +#if 0 + for (int l = 0; l < gq_t_bits; l++) { + const float v = src[i*QK + s*gq_t_bits + l]; + const uint8_t q = v*id + frand(); + + for (int b = 0; b < QB; b++) { + pp[b] |= q & (1 << b) ? (1ULL << l) : 0; + } + } +#elif defined(__ARM_NEON) + { + uint32_t ppt[2*4*QB]; + + float32x4_t idv = vdupq_n_f32(id); + + assert(gq_t_bits == 64); + + uint32x4_t p0[QB] = { vdupq_n_u32(0) }; + uint32x4_t p1[QB] = { vdupq_n_u32(0) }; + + for (int l = 0; l < gq_t_bits; l += 16) { + float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0); + float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4); + float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8); + float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12); + + v0 = vmulq_f32(v0, idv); + v1 = vmulq_f32(v1, idv); + v2 = vmulq_f32(v2, idv); + v3 = vmulq_f32(v3, idv); + +#if 1 + v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand(); + v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand(); + v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand(); + v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand(); +#endif + + uint32x4_t q0 = vcvtq_u32_f32(v0); + uint32x4_t q1 = vcvtq_u32_f32(v1); + uint32x4_t q2 = vcvtq_u32_f32(v2); + uint32x4_t q3 = vcvtq_u32_f32(v3); + + for (int b = 0; b < QB; ++b) { + uint32x4_t m = vdupq_n_u32(1 << b); + uint32x4_t r = vdupq_n_u32(-b); + + if (l < 32) { + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0))); + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4))); + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8))); + p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12))); + } else { + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32))); + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28))); + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24))); + p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20))); + } + } + } + +#if QB == 4 + vst1q_u32((uint32_t *) ppt + 0, p0[0]); + vst1q_u32((uint32_t *) ppt + 4, p1[0]); + vst1q_u32((uint32_t *) ppt + 8, p0[1]); + vst1q_u32((uint32_t *) ppt + 12, p1[1]); + vst1q_u32((uint32_t *) ppt + 16, p0[2]); + vst1q_u32((uint32_t *) ppt + 20, p1[2]); + vst1q_u32((uint32_t *) ppt + 24, p0[3]); + vst1q_u32((uint32_t *) ppt + 28, p1[3]); + + pp[0] = (ppt[0] | ppt[1] | ppt[2] | ppt[3] ) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7]) ) << 32; + pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32; + pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32; + pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32; +#else + for (int q = 0; q < QB; ++q) { + vst1q_u32((uint32_t *) ppt + 0, p0[q]); + vst1q_u32((uint32_t *) ppt + 4, p1[q]); + + pp[q] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32; + } +#endif + } +#endif + memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp)); } + } +} - const uint64_t t_end = get_time_us(); - printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method); +// reimplementation of quantize_3 using quantize_3_row +void quantize_3(const float * restrict src, char * restrict dst, int n, int k) { + assert(k % QK == 0); + + for (int j = 0; j < n; j++) { + quantize_3_row(src + j*k, dst, k); + dst = (char *) dst + quantize_3_row_size(k); } +} - const int nIter = 1; +void vec_dot_gq_3(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + float sumf = 0.0f; - const clock_t start = clock(); - const uint64_t start_us = get_time_us(); + const int nb = quantize_3_blocks_per_row(n); + const int nq = quantize_3_quants_per_block(); - double iM = 1.0/M; - double sum = 0.0f; - for (int i = 0; i < nIter; i++) { - if (method == 0) { - mul_mat_f32_naive(src0, src1, dst, M, N, K); - } + const gq_scale_t * restrict pd0 = (const gq_scale_t *) x; + const gq_scale_t * restrict pd1 = (const gq_scale_t *) y; - if (method == 1) { - mul_mat_gq_1(src0_gq, src1_gq, dst, M, N, K); - } + const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb); + const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb); - if (method == 2) { - mul_mat_gq_2(src0_gq, src1_gq, dst, M, N, K); +#if 1 + for (int i = 0; i < nb; i++) { + int isum = 0; + +#if QB == 4 + for (int s = 0; s < nq; ++s) { + const gq_quant_t * restrict m0 = pb0 + i*nq*QB + s*QB; + const gq_quant_t * restrict m1 = pb1 + i*nq*QB + s*QB; + + isum += (1 << 0)*(__builtin_popcountll(m0[0] & m1[0])); + isum += (1 << 1)*(__builtin_popcountll(m0[0] & m1[1]) + __builtin_popcountll(m0[1] & m1[0])); + isum += (1 << 2)*(__builtin_popcountll(m0[0] & m1[2]) + __builtin_popcountll(m0[1] & m1[1]) + __builtin_popcountll(m0[2] & m1[0])); + isum += (1 << 3)*(__builtin_popcountll(m0[0] & m1[3]) + __builtin_popcountll(m0[1] & m1[2]) + __builtin_popcountll(m0[2] & m1[1]) + __builtin_popcountll(m0[3] & m1[0])); + isum += (1 << 4)*(__builtin_popcountll(m0[1] & m1[3]) + __builtin_popcountll(m0[2] & m1[2]) + __builtin_popcountll(m0[3] & m1[1])); + isum += (1 << 5)*(__builtin_popcountll(m0[2] & m1[3]) + __builtin_popcountll(m0[3] & m1[2])); + isum += (1 << 6)*(__builtin_popcountll(m0[3] & m1[3])); } - } +#else + for (int s = 0; s < nq; ++s) { + for (int q0 = 0; q0 < QB; q0++) { + const gq_quant_t mm0 = pb0[i*nq*QB + s*QB + q0]; + for (int q1 = 0; q1 < QB; q1++) { + const gq_quant_t mm1 = pb1[i*nq*QB + s*QB + q1]; + isum += (1 << (q0 + q1))*(__builtin_popcountll(mm0 & mm1)); + } + } + } +#endif - for (int i = 0; i < N; i++) { - sum += dst[i]*iM; + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + sumf += d0*d1*isum; } +#else +#ifdef __ARM_NEON + // gq_quant_t == uint64_t + for (int i = 0; i < nb; i += 4) { + int isum[4] = {0, 0, 0, 0}; + + for (int k = 0; k < 4; ++k) { + for (int s = 0; s < nq; ++s) { + const gq_quant_t * restrict m0 = pb0 + (i+k)*nq*QB + s*QB; + const gq_quant_t * restrict m1 = pb1 + (i+k)*nq*QB + s*QB; + +#if QB == 4 +#define bpcnt(x) __builtin_popcountll(x) + //isum[k] += (1ULL << 0)*(bpcnt(m0[0] & m1[0])) + + // (1ULL << 1)*(bpcnt(m0[0] & m1[1]) + bpcnt(m0[1] & m1[0])) + + // (1ULL << 2)*(bpcnt(m0[0] & m1[2]) + bpcnt(m0[1] & m1[1]) + bpcnt(m0[2] & m1[0])) + + // (1ULL << 3)*(bpcnt(m0[0] & m1[3]) + bpcnt(m0[1] & m1[2]) + bpcnt(m0[2] & m1[1]) + bpcnt(m0[3] & m1[0])) + + // (1ULL << 4)*(bpcnt(m0[1] & m1[3]) + bpcnt(m0[2] & m1[2]) + bpcnt(m0[3] & m1[1])) + + // (1ULL << 5)*(bpcnt(m0[2] & m1[3]) + bpcnt(m0[3] & m1[2])) + + // (1ULL << 6)*(bpcnt(m0[3] & m1[3])); +#undef bpcnt + + const uint8x8_t m00 = vld1_u8((const uint8_t *) (m0 + 0)); + const uint8x8_t m01 = vld1_u8((const uint8_t *) (m0 + 1)); + const uint8x8_t m02 = vld1_u8((const uint8_t *) (m0 + 2)); + const uint8x8_t m03 = vld1_u8((const uint8_t *) (m0 + 3)); + + const uint8x8_t m10 = vld1_u8((const uint8_t *) (m1 + 0)); + const uint8x8_t m11 = vld1_u8((const uint8_t *) (m1 + 1)); + const uint8x8_t m12 = vld1_u8((const uint8_t *) (m1 + 2)); + const uint8x8_t m13 = vld1_u8((const uint8_t *) (m1 + 3)); + + const uint8x8_t m00m10 = vand_u8(m00, m10); + + const uint8x8_t m00m11 = vand_u8(m00, m11); + const uint8x8_t m01m10 = vand_u8(m01, m10); + + const uint8x8_t m00m12 = vand_u8(m00, m12); + const uint8x8_t m01m11 = vand_u8(m01, m11); + const uint8x8_t m02m10 = vand_u8(m02, m10); + + const uint8x8_t m00m13 = vand_u8(m00, m13); + const uint8x8_t m01m12 = vand_u8(m01, m12); + const uint8x8_t m02m11 = vand_u8(m02, m11); + const uint8x8_t m03m10 = vand_u8(m03, m10); + + const uint8x8_t m01m13 = vand_u8(m01, m13); + const uint8x8_t m02m12 = vand_u8(m02, m12); + const uint8x8_t m03m11 = vand_u8(m03, m11); + + const uint8x8_t m02m13 = vand_u8(m02, m13); + const uint8x8_t m03m12 = vand_u8(m03, m12); + + const uint8x8_t m03m13 = vand_u8(m03, m13); + +#define bpcnt(x) vaddv_u8(vcnt_u8(x)) + isum[k] += (1ULL << 0)*(bpcnt(m00m10)) + + (1ULL << 1)*(bpcnt(m00m11) + bpcnt(m01m10)) + + (1ULL << 2)*(bpcnt(m00m12) + bpcnt(m01m11) + bpcnt(m02m10)) + + (1ULL << 3)*(bpcnt(m00m13) + bpcnt(m01m12) + bpcnt(m02m11) + bpcnt(m03m10)) + + (1ULL << 4)*(bpcnt(m01m13) + bpcnt(m02m12) + bpcnt(m03m11)) + + (1ULL << 5)*(bpcnt(m02m13) + bpcnt(m03m12)) + + (1ULL << 6)*(bpcnt(m03m13)); +#undef bpcnt +#else + for (int q0 = 0; q0 < QB; q0++) { + const gq_quant_t mm0 = m0[q0]; + for (int q1 = 0; q1 < QB; q1++) { + const gq_quant_t mm1 = m1[q1]; + isum[k] += (1ULL << (q0 + q1))*(__builtin_popcountll(mm0 & mm1)); + } + } +#endif + } + } - { - const clock_t end = clock(); - const uint64_t end_us = get_time_us(); - printf("%s: elapsed ticks: %ld\n", __func__, end - start); - printf("%s: elapsed us: %d / %f ms\n", __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter); + int32x4_t isumv = vld1q_s32(isum); + + float32x4_t d0v = vld1q_f32(pd0 + i); + float32x4_t d1v = vld1q_f32(pd1 + i); + + float32x4_t sumfv = vmulq_f32(d0v, d1v); + + sumfv = vmulq_f32(sumfv, vcvtq_f32_s32(isumv)); + sumf += vaddvq_f32(sumfv); } +#else +#error "not implemented" +#endif + +#endif + *s = sumf; +} + +// use vec_dot_gq_3 to compute the dot product of two rows +void mul_mat_gq_3( + const void * src0, + const void * src1, // transposed + float * dst, + int m, int n, int k) { + assert(k % QK == 0); + + const int nb = quantize_3_blocks_per_row(k); + const int nq = quantize_3_quants_per_block(); + + for (int ir0 = 0; ir0 < m; ir0++) { + for (int ir1 = 0; ir1 < n; ir1++) { + vec_dot_gq_3(k, dst + ir1, src0, src1); + src1 = (const char *) src1 + quantize_3_row_size(k); + } + src0 = (const char *) src0 + quantize_3_row_size(k); + src1 = (const char *) src1 - n*quantize_3_row_size(k); + + dst = (float *) dst + n; + } +} + +// +// method 4 +// 4-bit quantization +// + +static inline int quantize_4_blocks_per_row(int k) { + return k/QK; +} + +static inline int quantize_4_row_size(int k) { + const int nb = quantize_4_blocks_per_row(k); + + return nb*(2*sizeof(gq_scale_t) + QK/2); +} + +void quantize_4_row(const float * restrict src, void * restrict dst, int k) { + assert(k % QK == 0); + assert(QB == 4); + + const int nb = quantize_4_blocks_per_row(k); + + gq_scale_t * restrict pm = (gq_scale_t *) (dst); + gq_scale_t * restrict pd = (gq_scale_t *) (pm + nb); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + memset(pp, 0, sizeof(pp)); + + float min = FLT_MAX; + float max = -FLT_MAX; + +#if defined(__AVX2__) + { + assert(QK == 64); + const int QK8 = QK/8; + + __m256 srcv[QK8]; + __m256 minv[QK8]; + __m256 maxv[QK8]; + + for (int l = 0; l < QK8; l++) { + srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l); + } + + for (int l = 0; l < QK8/2; l++) { + minv[2*l] = _mm256_min_ps(srcv[2*l], srcv[2*l+1]); + maxv[2*l] = _mm256_max_ps(srcv[2*l], srcv[2*l+1]); + } + + for (int l = 0; l < QK8/4; l++) { + minv[4*l] = _mm256_min_ps(minv[4*l], minv[4*l+2]); + maxv[4*l] = _mm256_max_ps(maxv[4*l], maxv[4*l+2]); + } + + for (int l = 0; l < QK8/8; l++) { + minv[8*l] = _mm256_min_ps(minv[8*l], minv[8*l+4]); + maxv[8*l] = _mm256_max_ps(maxv[8*l], maxv[8*l+4]); + } + + //min = MIN(minv[0][0], MIN(minv[0][1], MIN(minv[0][2], MIN(minv[0][3], MIN(minv[0][4], MIN(minv[0][5], MIN(minv[0][6], minv[0][7]))))))); + //max = MAX(maxv[0][0], MAX(maxv[0][1], MAX(maxv[0][2], MAX(maxv[0][3], MAX(maxv[0][4], MAX(maxv[0][5], MAX(maxv[0][6], maxv[0][7]))))))); + + const __m256 minv0_0 = _mm256_permute2f128_ps(minv[0], minv[0], 3); + const __m256 minv0_1 = _mm256_min_ps(minv[0], minv0_0); + const __m256 minv0_2 = _mm256_permute_ps(minv0_1, 0x4e); + const __m256 minv0_3 = _mm256_min_ps(minv0_1, minv0_2); + const __m256 minv0_4 = _mm256_permute_ps(minv0_3, 0xb1); + const __m256 minv0_5 = _mm256_min_ps(minv0_3, minv0_4); + + const __m256 maxv0_0 = _mm256_permute2f128_ps(maxv[0], maxv[0], 3); + const __m256 maxv0_1 = _mm256_max_ps(maxv[0], maxv0_0); + const __m256 maxv0_2 = _mm256_permute_ps(maxv0_1, 0x4e); + const __m256 maxv0_3 = _mm256_max_ps(maxv0_1, maxv0_2); + const __m256 maxv0_4 = _mm256_permute_ps(maxv0_3, 0xb1); + const __m256 maxv0_5 = _mm256_max_ps(maxv0_3, maxv0_4); + + min = _mm256_cvtss_f32(minv0_5); + max = _mm256_cvtss_f32(maxv0_5); + + const float d = (max - min) / ((1 << QB) - 2); + const float id = d ? 1.0/d : 0.0; + + pm[i] = GGML_FP32_TO_GQ(min); + pd[i] = GGML_FP32_TO_GQ(d); + + const __m256 idv = _mm256_set1_ps(id); + + for (int l = 0; l < QK/8; l++) { + __m256 v = _mm256_mul_ps(_mm256_sub_ps(srcv[l], _mm256_set1_ps(min)), idv); +#if 0 + v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand(); + v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand(); +#endif + + // convert to uint8 + __m256i vi = _mm256_cvtps_epi32(v); + + uint32_t vi_0 = _mm256_extract_epi32(vi, 0); + uint32_t vi_1 = _mm256_extract_epi32(vi, 1); + uint32_t vi_2 = _mm256_extract_epi32(vi, 2); + uint32_t vi_3 = _mm256_extract_epi32(vi, 3); + + uint32_t vi_4 = _mm256_extract_epi32(vi, 4); + uint32_t vi_5 = _mm256_extract_epi32(vi, 5); + uint32_t vi_6 = _mm256_extract_epi32(vi, 6); + uint32_t vi_7 = _mm256_extract_epi32(vi, 7); + + // convert to 4-bit, 2 consecutive packed into 1 byte + pp[4*l + 0] = vi_0 | (vi_1 << 4); + pp[4*l + 1] = vi_2 | (vi_3 << 4); + pp[4*l + 2] = vi_4 | (vi_5 << 4); + pp[4*l + 3] = vi_6 | (vi_7 << 4); + + //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); + //printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#elif defined(__ARM_NEON) && 0 + { + // TODO + } +#else + { + for (int l = 0; l < QK; l++) { + const float v = src[i*QK + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << QB) - 1); + const float id = d ? 1.0/d : 0.0; + + pm[i] = GGML_FP32_TO_GQ(min); + pd[i] = GGML_FP32_TO_GQ(d); + + for (int l = 0; l < QK; l++) { + const float v = (src[i*QK + l] - min) * id; + const uint8_t vi = (uint8_t) (v + frand()); + pp[l/2] |= (vi & 0xf) << (4*(l & 1)); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#endif + //printf("min %f max %f\n", min, max); + } +} + +// reimplementation of quantize_4 using quantize_4_row +void quantize_4(const float * restrict src, char * restrict dst, int n, int k) { + assert(k % QK == 0); + + for (int j = 0; j < n; j++) { + quantize_4_row(src + j*k, dst, k); + dst = (char *) dst + quantize_4_row_size(k); + } +} + +void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = quantize_4_blocks_per_row(n); + + const gq_scale_t * restrict pm0 = (const gq_scale_t *) x; + const gq_scale_t * restrict pm1 = (const gq_scale_t *) y; + + const gq_scale_t * restrict pd0 = pm0 + nb; + const gq_scale_t * restrict pd1 = pm1 + nb; + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#if 0 + // scalar + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*(v0 & 0xf) + m0; + const float f1 = d0*(v0 >> 4) + m0; + + const float f2 = d1*(v1 & 0xf) + m1; + const float f3 = d1*(v1 >> 4) + m1; + + sumf += f0*f2 + f1*f3; + } + } +#else +#if defined(__AVX2__) +#if QK == 64 && 0 + __m256 sumv0 = _mm256_setzero_ps(); + __m256 sumv1 = _mm256_setzero_ps(); + + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const __m256 m0v = _mm256_set1_ps(m0); + const __m256 d0v = _mm256_set1_ps(d0); + + const __m256 m1v = _mm256_set1_ps(m1); + const __m256 d1v = _mm256_set1_ps(d1); + + const __m256i m4b = _mm256_set1_epi8(0xf); + + __m256i v0 = _mm256_loadu_si256((__m256i *) p0); + + //_mm_prefetch((const char *) (p0 + 32), _MM_HINT_T0); + //_mm_prefetch((const char *) (p1 + 32), _MM_HINT_T0); + //_mm_prefetch((const char *) (pm0 + i + 1), _MM_HINT_T0); + //_mm_prefetch((const char *) (pm1 + i + 1), _MM_HINT_T0); + //_mm_prefetch((const char *) (pd0 + i + 1), _MM_HINT_T0); + //_mm_prefetch((const char *) (pd1 + i + 1), _MM_HINT_T0); + + __m256i v00 = _mm256_and_si256(v0, _mm256_set1_epi32(0x000000FF)); + __m256i v01 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x0000FFFF)), 8); + __m256i v02 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x00FFFFFF)), 16); + __m256i v03 = _mm256_srli_epi32(v0, 24); + + ////////////////////// + + //{ + // uint32_t vi_0 = _mm256_extract_epi32(v00, 0); + // uint32_t vi_1 = _mm256_extract_epi32(v00, 1); + // uint32_t vi_2 = _mm256_extract_epi32(v00, 2); + // uint32_t vi_3 = _mm256_extract_epi32(v00, 3); + // uint32_t vi_4 = _mm256_extract_epi32(v00, 4); + // uint32_t vi_5 = _mm256_extract_epi32(v00, 5); + // uint32_t vi_6 = _mm256_extract_epi32(v00, 6); + // uint32_t vi_7 = _mm256_extract_epi32(v00, 7); + // printf("v0: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); + // printf("p0: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[0], p0[4], p0[8], p0[12], p0[16], p0[20], p0[24], p0[28]); + // printf("p1: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[1], p0[5], p0[9], p0[13], p0[17], p0[21], p0[25], p0[29]); + // printf("p2: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[2], p0[6], p0[10], p0[14], p0[18], p0[22], p0[26], p0[30]); + // printf("p3: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[3], p0[7], p0[11], p0[15], p0[19], p0[23], p0[27], p0[31]); + //} + + // compute 32 x 4-bit values (low and high) + __m256i v00l = _mm256_and_si256(v00, m4b); + __m256i v01l = _mm256_and_si256(v01, m4b); + __m256i v02l = _mm256_and_si256(v02, m4b); + __m256i v03l = _mm256_and_si256(v03, m4b); + + __m256i v00h = _mm256_srli_epi32(v00, 4); + __m256i v01h = _mm256_srli_epi32(v01, 4); + __m256i v02h = _mm256_srli_epi32(v02, 4); + __m256i v03h = _mm256_srli_epi32(v03, 4); + + //{ + // uint32_t vi_0 = _mm256_extract_epi32(v00l, 0); + // uint32_t vi_1 = _mm256_extract_epi32(v00l, 1); + // uint32_t vi_2 = _mm256_extract_epi32(v00l, 2); + // uint32_t vi_3 = _mm256_extract_epi32(v00l, 3); + // uint32_t vi_4 = _mm256_extract_epi32(v00l, 4); + // uint32_t vi_5 = _mm256_extract_epi32(v00l, 5); + // uint32_t vi_6 = _mm256_extract_epi32(v00l, 6); + // uint32_t vi_7 = _mm256_extract_epi32(v00l, 7); + + // printf("v0l: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); + + // vi_0 = _mm256_extract_epi32(v00h, 0); + // vi_1 = _mm256_extract_epi32(v00h, 1); + // vi_2 = _mm256_extract_epi32(v00h, 2); + // vi_3 = _mm256_extract_epi32(v00h, 3); + // vi_4 = _mm256_extract_epi32(v00h, 4); + // vi_5 = _mm256_extract_epi32(v00h, 5); + // vi_6 = _mm256_extract_epi32(v00h, 6); + // vi_7 = _mm256_extract_epi32(v00h, 7); + + // printf("v0h: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); + //} + + // convert to float + __m256 vf00l = _mm256_cvtepi32_ps(v00l); + __m256 vf01l = _mm256_cvtepi32_ps(v01l); + __m256 vf02l = _mm256_cvtepi32_ps(v02l); + __m256 vf03l = _mm256_cvtepi32_ps(v03l); + + __m256 vf00h = _mm256_cvtepi32_ps(v00h); + __m256 vf01h = _mm256_cvtepi32_ps(v01h); + __m256 vf02h = _mm256_cvtepi32_ps(v02h); + __m256 vf03h = _mm256_cvtepi32_ps(v03h); + + //{ + // printf("vf00l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf00l[0], vf00l[1], vf00l[2], vf00l[3], vf00l[4], vf00l[5], vf00l[6], vf00l[7]); + // printf("vf01l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf01l[0], vf01l[1], vf01l[2], vf01l[3], vf01l[4], vf01l[5], vf01l[6], vf01l[7]); + // printf("vf02l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf02l[0], vf02l[1], vf02l[2], vf02l[3], vf02l[4], vf02l[5], vf02l[6], vf02l[7]); + // printf("vf03l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf03l[0], vf03l[1], vf03l[2], vf03l[3], vf03l[4], vf03l[5], vf03l[6], vf03l[7]); + //} + + // multiply by scale and add offset + vf00l = _mm256_fmadd_ps(vf00l, d0v, m0v); + vf01l = _mm256_fmadd_ps(vf01l, d0v, m0v); + vf02l = _mm256_fmadd_ps(vf02l, d0v, m0v); + vf03l = _mm256_fmadd_ps(vf03l, d0v, m0v); + + vf00h = _mm256_fmadd_ps(vf00h, d0v, m0v); + vf01h = _mm256_fmadd_ps(vf01h, d0v, m0v); + vf02h = _mm256_fmadd_ps(vf02h, d0v, m0v); + vf03h = _mm256_fmadd_ps(vf03h, d0v, m0v); + + __m256i v1 = _mm256_loadu_si256((__m256i *) p1); + + __m256i v10 = _mm256_and_si256(v1, _mm256_set1_epi32(0x000000FF)); + __m256i v11 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x0000FFFF)), 8); + __m256i v12 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x00FFFFFF)), 16); + __m256i v13 = _mm256_srli_epi32(v1, 24); + + __m256i v10l = _mm256_and_si256(v10, m4b); + __m256i v11l = _mm256_and_si256(v11, m4b); + __m256i v12l = _mm256_and_si256(v12, m4b); + __m256i v13l = _mm256_and_si256(v13, m4b); + + __m256i v10h = _mm256_srli_epi32(v10, 4); + __m256i v11h = _mm256_srli_epi32(v11, 4); + __m256i v12h = _mm256_srli_epi32(v12, 4); + __m256i v13h = _mm256_srli_epi32(v13, 4); + + __m256 vf10l = _mm256_cvtepi32_ps(v10l); + __m256 vf11l = _mm256_cvtepi32_ps(v11l); + __m256 vf12l = _mm256_cvtepi32_ps(v12l); + __m256 vf13l = _mm256_cvtepi32_ps(v13l); + + __m256 vf10h = _mm256_cvtepi32_ps(v10h); + __m256 vf11h = _mm256_cvtepi32_ps(v11h); + __m256 vf12h = _mm256_cvtepi32_ps(v12h); + __m256 vf13h = _mm256_cvtepi32_ps(v13h); + + vf10l = _mm256_fmadd_ps(vf10l, d1v, m1v); + vf11l = _mm256_fmadd_ps(vf11l, d1v, m1v); + vf12l = _mm256_fmadd_ps(vf12l, d1v, m1v); + vf13l = _mm256_fmadd_ps(vf13l, d1v, m1v); + + vf10h = _mm256_fmadd_ps(vf10h, d1v, m1v); + vf11h = _mm256_fmadd_ps(vf11h, d1v, m1v); + vf12h = _mm256_fmadd_ps(vf12h, d1v, m1v); + vf13h = _mm256_fmadd_ps(vf13h, d1v, m1v); + + // compute dot product + sumv0 = _mm256_fmadd_ps(vf00l, vf10l, sumv0); + sumv0 = _mm256_fmadd_ps(vf01l, vf11l, sumv0); + sumv0 = _mm256_fmadd_ps(vf02l, vf12l, sumv0); + sumv0 = _mm256_fmadd_ps(vf03l, vf13l, sumv0); + + sumv1 = _mm256_fmadd_ps(vf00h, vf10h, sumv1); + sumv1 = _mm256_fmadd_ps(vf01h, vf11h, sumv1); + sumv1 = _mm256_fmadd_ps(vf02h, vf12h, sumv1); + sumv1 = _mm256_fmadd_ps(vf03h, vf13h, sumv1); + } + + // accumulate (horizontal sum) + const __m256 vdot = _mm256_add_ps(sumv0, sumv1); + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(vdot), _mm256_extractf128_ps(vdot, 1)); + const __m128 t1 = _mm_hadd_ps(t0, t0); + + sumf += _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); +#elif QK == 64 && 0 + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + + const __m256i m4b = _mm256_set1_epi8(0xf); + + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + // 64 x 4 + const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); + const __m256i v1 = _mm256_loadu_si256((__m256i *) p1); + + // 32 x 8 + const __m256i v0l = _mm256_and_si256(v0, m4b); + const __m256i v1l = _mm256_and_si256(v1, m4b); + + const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b); + const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b); + + const __m256i pl = _mm256_maddubs_epi16(v0l, v1l); + const __m256i ph = _mm256_maddubs_epi16(v0h, v1h); + + const __m256i p16 = _mm256_add_epi16(ph, pl); + const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16); + + sum00 += m0*m1; + sum01 += m1*d0*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v0l, v0h))); + sum10 += m0*d1*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v1l, v1h))); + sum11 += d0*d1*(_mm256_hadd_epi32_gg(p)); + } + + sumf = 64.0*sum00 + sum01 + sum10 + sum11; +#elif QK == 64 && 1 // this is the best when using min + d + float sum00 = 0.0f; + + __m256 sum01 = _mm256_setzero_ps(); + __m256 sum10 = _mm256_setzero_ps(); + __m256 sum11 = _mm256_setzero_ps(); + + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const __m256 m0v = _mm256_set1_ps(m0); + const __m256 d0v = _mm256_set1_ps(d0); + + const __m256 m1v = _mm256_set1_ps(m1); + const __m256 d1v = _mm256_set1_ps(d1); + + const __m256 m1d0v = _mm256_mul_ps(m1v, d0v); + const __m256 m0d1v = _mm256_mul_ps(m0v, d1v); + const __m256 d0d1v = _mm256_mul_ps(d0v, d1v); + + const __m256i m4b = _mm256_set1_epi8(0xf); + + // 64 x 4 + const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); + const __m256i v1 = _mm256_loadu_si256((__m256i *) p1); + + // 32 x 8 + const __m256i v0l = _mm256_and_si256(v0, m4b); + const __m256i v1l = _mm256_and_si256(v1, m4b); + + const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b); + const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b); + + const __m256i v0a = _mm256_add_epi8(v0l, v0h); + const __m256i v1a = _mm256_add_epi8(v1l, v1h); + + const __m128i v0al = _mm256_extracti128_si256(v0a, 0); + const __m128i v0ah = _mm256_extracti128_si256(v0a, 1); + + const __m128i v1al = _mm256_extracti128_si256(v1a, 0); + const __m128i v1ah = _mm256_extracti128_si256(v1a, 1); + + const __m128i v0as = _mm_add_epi8(v0al, v0ah); + const __m128i v1as = _mm_add_epi8(v1al, v1ah); + + const __m256i v0as_0 = _mm256_cvtepu8_epi32(v0as); + const __m256i v0as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v0as, 8)); + + const __m256i v1as_0 = _mm256_cvtepu8_epi32(v1as); + const __m256i v1as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v1as, 8)); + + const __m256i v0ass = _mm256_add_epi32(v0as_0, v0as_1); + const __m256i v1ass = _mm256_add_epi32(v1as_0, v1as_1); + + const __m256 v0f = _mm256_cvtepi32_ps(v0ass); + const __m256 v1f = _mm256_cvtepi32_ps(v1ass); + + const __m256i pl = _mm256_maddubs_epi16(v0l, v1l); + const __m256i ph = _mm256_maddubs_epi16(v0h, v1h); + + const __m256i p16 = _mm256_add_epi16(ph, pl); + const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16); + + sum00 += m0*m1; + sum01 = _mm256_fmadd_ps(m1d0v, v0f, sum01); + sum10 = _mm256_fmadd_ps(m0d1v, v1f, sum10); + sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11); + } + + sumf = 64.0*sum00 + _mm256_hadd_ps_gg(sum01) + _mm256_hadd_ps_gg(sum10) + _mm256_hadd_ps_gg(sum11); +#endif +#elif defined (__ARM_NEON) + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + + for (int i = 0; i < nb; i++) { + const float m0 = GGML_GQ_TO_FP32(pm0[i]); + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + + const float m1 = GGML_GQ_TO_FP32(pm1[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // and with 0xf + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + + // dot product into uint16x8_t + const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); + const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); + const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); + const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); + + const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); + const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); + const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); + const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); + + const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); + const uint16x8_t pl1 = vaddq_u16(pl1l, pl1h); + const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); + const uint16x8_t ph1 = vaddq_u16(ph1l, ph1h); + + const uint16x8_t pl = vaddq_u16(pl0, pl1); + const uint16x8_t ph = vaddq_u16(ph0, ph1); + + sum00 += m0*m1; + sum01 += m1*d0*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h) + vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h)); + sum10 += m0*d1*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h) + vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h)); + //sum11 += d0*d1*( + // vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) + + // vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h)))); + sum11 += d0*d1*vaddvq_u16(vaddq_u16(pl, ph)); + } + + sumf = 64.0*sum00 + sum01 + sum10 + sum11; +#endif +#endif + + *s = sumf; +} + +// use vec_dot_gq_4 to compute the dot product of two rows +void mul_mat_gq_4( + const void * src0, + const void * src1, // transposed + float * dst, + int m, int n, int k) { + assert(k % QK == 0); + + const int nb = quantize_4_blocks_per_row(k); + + for (int ir0 = 0; ir0 < m; ir0++) { + for (int ir1 = 0; ir1 < n; ir1++) { + vec_dot_gq_4(k, dst + ir1, src0, src1); + src1 = (const char *) src1 + quantize_4_row_size(k); + } + src0 = (const char *) src0 + quantize_4_row_size(k); + src1 = (const char *) src1 - n*quantize_4_row_size(k); + + dst = (float *) dst + n; + } +} + +// +// method 5 +// 4-bit quantization (without min, only delta) +// + +static inline int quantize_5_blocks_per_row(int k) { + return k/QK; +} + +static inline int quantize_5_row_size(int k) { + const int nb = quantize_5_blocks_per_row(k); + + return nb*(sizeof(gq_scale_t) + QK/2); +} + +void quantize_5_row(const float * restrict src, void * restrict dst, int k) { + assert(k % QK == 0); + assert(QB == 4); + + const int nb = quantize_5_blocks_per_row(k); + + gq_scale_t * restrict pd = (gq_scale_t *) (dst); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + memset(pp, 0, sizeof(pp)); + + float amax = 0.0f; // absolute max + +#if defined(__AVX2__) + { + assert(QK == 64); + const int QK8 = QK/8; + + __m256 srcv [QK8]; + __m256 asrcv[QK8]; + __m256 amaxv[QK8]; + + for (int l = 0; l < QK8; l++) { + srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l); + } + + for (int l = 0; l < QK8; l++) { + asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff)); + } + + + for (int l = 0; l < QK8/2; l++) { + amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]); + } + + for (int l = 0; l < QK8/4; l++) { + amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]); + } + + for (int l = 0; l < QK8/8; l++) { + amaxv[8*l] = _mm256_max_ps(amaxv[8*l], amaxv[8*l+4]); + } + + //amax = MAX(amaxv[0][0], MAX(amaxv[0][1], MAX(amaxv[0][2], MAX(amaxv[0][3], MAX(amaxv[0][4], MAX(amaxv[0][5], MAX(amaxv[0][6], amaxv[0][7]))))))); + + const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3); + const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0); + const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e); + const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2); + const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1); + const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4); + + amax = _mm256_cvtss_f32(amaxv0_5); + + //printf("amax = %f\n", amax); + + const float d = amax / ((1 << (QB - 1)) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + const __m256 idv = _mm256_set1_ps(id); + + for (int l = 0; l < QK/8; l++) { + __m256 v = _mm256_mul_ps(srcv[l], idv); +#if 0 + v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand(); + v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand(); +#endif + + // convert to int8 + __m256i vi = _mm256_cvtps_epi32(v); + vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8)); + + int32_t vi_0 = _mm256_extract_epi32(vi, 0); + int32_t vi_1 = _mm256_extract_epi32(vi, 1); + int32_t vi_2 = _mm256_extract_epi32(vi, 2); + int32_t vi_3 = _mm256_extract_epi32(vi, 3); + + int32_t vi_4 = _mm256_extract_epi32(vi, 4); + int32_t vi_5 = _mm256_extract_epi32(vi, 5); + int32_t vi_6 = _mm256_extract_epi32(vi, 6); + int32_t vi_7 = _mm256_extract_epi32(vi, 7); + + // convert to 4-bit, 2 consecutive packed into 1 byte + pp[4*l + 0] = vi_0 | (vi_1 << 4); + pp[4*l + 1] = vi_2 | (vi_3 << 4); + pp[4*l + 2] = vi_4 | (vi_5 << 4); + pp[4*l + 3] = vi_6 | (vi_7 << 4); + + //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7); + ////printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); + + assert(vi_0 >= 0 && vi_0 < 16); + assert(vi_1 >= 0 && vi_1 < 16); + assert(vi_2 >= 0 && vi_2 < 16); + assert(vi_3 >= 0 && vi_3 < 16); + + assert(vi_4 >= 0 && vi_4 < 16); + assert(vi_5 >= 0 && vi_5 < 16); + assert(vi_6 >= 0 && vi_6 < 16); + assert(vi_7 >= 0 && vi_7 < 16); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#elif defined(__ARM_NEON) && 0 + { + // TODO + } +#else + { + for (int l = 0; l < QK; l++) { + const float v = src[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << (QB - 1)) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + for (int l = 0; l < QK; l++) { + const float v = src[i*QK + l]*id; + const int8_t vi = ((int8_t) (round(v))) + 8; + assert(vi >= 0 && vi < 16); + pp[l/2] |= (vi & 0xf) << (4*(l & 1)); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#endif + //printf("min %f max %f\n", min, max); + } +} + +// reimplementation of quantize_5 using quantize_5_row +void quantize_5(const float * restrict src, char * restrict dst, int n, int k) { + assert(k % QK == 0); + + for (int j = 0; j < n; j++) { + quantize_5_row(src + j*k, dst, k); + dst = (char *) dst + quantize_5_row_size(k); + } +} + +void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = quantize_5_blocks_per_row(n); + + const gq_scale_t * restrict pd0 = (const gq_scale_t *) x; + const gq_scale_t * restrict pd1 = (const gq_scale_t *) y; + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#if 0 + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + sumf += f0*f2 + f1*f3; + } + } +#else +#if defined(__AVX2__) +#if QK == 64 && 1 + __m256 sum11 = _mm256_setzero_ps(); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const __m256 d0v = _mm256_set1_ps(d0); + const __m256 d1v = _mm256_set1_ps(d1); + + const __m256 d0d1v = _mm256_mul_ps(d0v, d1v); + + const __m256i m4b = _mm256_set1_epi8(0xf); + + // 64 x 4 + const __m256i v0 = _mm256_loadu_si256((__m256i *) p0); + const __m256i v1 = _mm256_loadu_si256((__m256i *) p1); + + // 32 x 8 + __m256i v0l = _mm256_and_si256(v0, m4b); + __m256i v1l = _mm256_and_si256(v1, m4b); + + __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b); + __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b); + + // sub 8 + v0l = _mm256_sub_epi8(v0l, _mm256_set1_epi8(8)); + v0h = _mm256_sub_epi8(v0h, _mm256_set1_epi8(8)); + + v1l = _mm256_sub_epi8(v1l, _mm256_set1_epi8(8)); + v1h = _mm256_sub_epi8(v1h, _mm256_set1_epi8(8)); + + // abs + const __m256i v0la = _mm256_sign_epi8(v0l, v0l); + const __m256i v0ha = _mm256_sign_epi8(v0h, v0h); + + // sign + const __m256i v1ls = _mm256_sign_epi8(v1l, v0l); + const __m256i v1hs = _mm256_sign_epi8(v1h, v0h); + + const __m256i pl = _mm256_maddubs_epi16(v0la, v1ls); + const __m256i ph = _mm256_maddubs_epi16(v0ha, v1hs); + + const __m256i p16 = _mm256_add_epi16(ph, pl); + const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16); + + sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11); + } + + sumf = _mm256_hadd_ps_gg(sum11); +#endif +#elif defined (__ARM_NEON) + float sum11 = 0.0f; + + //float32x4_t sum_0 = vdupq_n_f32(0.0f); + //float32x4_t sum_1 = vdupq_n_f32(0.0f); + + //float16x8_t sum_0 = vdupq_n_f16(0.0f); + //float16x8_t sum_1 = vdupq_n_f16(0.0f); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + //float32x4_t d0d1v = vdupq_n_f32(d0*d1); + //float16x8_t d0d1v = vdupq_n_f16(d0*d1); + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // 4-bit -> 8-bit + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); + + // dot product into int16x8_t + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl0 = vaddq_s16(pl0l, pl0h); + const int16x8_t pl1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph0 = vaddq_s16(ph0l, ph0h); + const int16x8_t ph1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t pl = vaddq_s16(pl0, pl1); + const int16x8_t ph = vaddq_s16(ph0, ph1); + + //const int8x16_t pl0 = vmulq_s8(v0_0ls, v1_0ls); + //const int8x16_t pl1 = vmulq_s8(v0_1ls, v1_1ls); + //const int8x16_t ph0 = vmulq_s8(v0_0hs, v1_0hs); + //const int8x16_t ph1 = vmulq_s8(v0_1hs, v1_1hs); + + //const int16x8_t pll = vaddl_s8(vget_low_s8(pl0), vget_low_s8(pl1)); + //const int16x8_t plh = vaddl_s8(vget_high_s8(pl0), vget_high_s8(pl1)); + //const int16x8_t phl = vaddl_s8(vget_low_s8(ph0), vget_low_s8(ph1)); + //const int16x8_t phh = vaddl_s8(vget_high_s8(ph0), vget_high_s8(ph1)); + + //const int16x8_t pl = vaddq_s16(pll, plh); + //const int16x8_t ph = vaddq_s16(phl, phh); + + const int16x8_t p = vaddq_s16(pl, ph); + + // convert to float + //const float32x4_t pf0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (p))); + //const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p))); + + // scalar + sum11 += d0*d1*vaddvq_s16(p); + //sum11 += d0*d1*(vaddvq_s16(pl) + vaddvq_s16(ph)); + //sum11 += d0*d1*vaddvq_s16(vaddq_s16(pl, ph)); + //sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1)); + //sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh)); + + //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(p)); + //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(pl)); + //sum_1 = vfmaq_f16(sum_1, d0d1v, vcvtq_f16_s16(ph)); + + // vectorize + //sum_0 = vmlaq_f32(sum_0, d0d1v, pf0); + //sum_1 = vmlaq_f32(sum_1, d0d1v, pf1); + } + + sumf = sum11; + //sumf = vaddvq_f32(sum_0) + vaddvq_f32(sum_1); + //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7]; + //sum_0 = vaddq_f16(sum_0, sum_1); + //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7]; +#endif +#endif + + *s = sumf; +} + +// use vec_dot_gq_5 to compute the dot product of two rows +void mul_mat_gq_5( + const void * src0, + const void * src1, // transposed + float * dst, + int m, int n, int k) { + assert(k % QK == 0); + + const int nb = quantize_5_blocks_per_row(k); + + for (int ir0 = 0; ir0 < m; ir0++) { + for (int ir1 = 0; ir1 < n; ir1++) { + vec_dot_gq_5(k, dst + ir1, src0, src1); + src1 = (const char *) src1 + quantize_5_row_size(k); + } + src0 = (const char *) src0 + quantize_5_row_size(k); + src1 = (const char *) src1 - n*quantize_5_row_size(k); + + dst = (float *) dst + n; + } +} + +// +// method 6 +// same as 5 but with 32 element blocks +// + +static inline int quantize_6_blocks_per_row(int k) { + return k/32; +} + +static inline int quantize_6_row_size(int k) { + const int nb = quantize_6_blocks_per_row(k); + + return nb*(sizeof(gq_scale_t) + 16); +} + +void quantize_6_row(const float * restrict src, void * restrict dst, int k) { + assert(k % 32 == 0); + assert(QB == 4); + + const int nb = quantize_6_blocks_per_row(k); + + gq_scale_t * restrict pd = (gq_scale_t *) (dst); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[16]; + + for (int i = 0; i < nb; i++) { + memset(pp, 0, sizeof(pp)); + + float amax = 0.0f; // absolute max + +#if defined(__AVX2__) + { + const int QK8 = 4; + + __m256 srcv [QK8]; + __m256 asrcv[QK8]; + __m256 amaxv[QK8]; + + for (int l = 0; l < QK8; l++) { + srcv[l] = _mm256_loadu_ps(src + i*32 + 8*l); + } + + for (int l = 0; l < QK8; l++) { + asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff)); + } + + for (int l = 0; l < QK8/2; l++) { + amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]); + } + + for (int l = 0; l < QK8/4; l++) { + amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]); + } + + const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3); + const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0); + const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e); + const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2); + const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1); + const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4); + + amax = _mm256_cvtss_f32(amaxv0_5); + + const float d = amax / ((1 << (QB - 1)) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + const __m256 idv = _mm256_set1_ps(id); + + for (int l = 0; l < 4; l++) { + __m256 v = _mm256_mul_ps(srcv[l], idv); + + // convert to int8 + __m256i vi = _mm256_cvtps_epi32(v); + vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8)); + + int32_t vi_0 = _mm256_extract_epi32(vi, 0); + int32_t vi_1 = _mm256_extract_epi32(vi, 1); + int32_t vi_2 = _mm256_extract_epi32(vi, 2); + int32_t vi_3 = _mm256_extract_epi32(vi, 3); + + int32_t vi_4 = _mm256_extract_epi32(vi, 4); + int32_t vi_5 = _mm256_extract_epi32(vi, 5); + int32_t vi_6 = _mm256_extract_epi32(vi, 6); + int32_t vi_7 = _mm256_extract_epi32(vi, 7); + + // convert to 4-bit, 2 consecutive packed into 1 byte + pp[4*l + 0] = vi_0 | (vi_1 << 4); + pp[4*l + 1] = vi_2 | (vi_3 << 4); + pp[4*l + 2] = vi_4 | (vi_5 << 4); + pp[4*l + 3] = vi_6 | (vi_7 << 4); + + assert(vi_0 >= 0 && vi_0 < 16); + assert(vi_1 >= 0 && vi_1 < 16); + assert(vi_2 >= 0 && vi_2 < 16); + assert(vi_3 >= 0 && vi_3 < 16); + + assert(vi_4 >= 0 && vi_4 < 16); + assert(vi_5 >= 0 && vi_5 < 16); + assert(vi_6 >= 0 && vi_6 < 16); + assert(vi_7 >= 0 && vi_7 < 16); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#elif defined(__ARM_NEON) + { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(src + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + amax = MAX( + MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), + MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); + const int32x4_t vi = vcvtq_s32_f32(vf); + + pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#else + { + for (int l = 0; l < 32; l++) { + const float v = src[i*32 + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << (QB - 1)) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = GGML_FP32_TO_GQ(d); + + for (int l = 0; l < 32; l++) { + const float v = src[i*32 + l]*id; + const int8_t vi = ((int8_t) (round(v))) + 8; + assert(vi >= 0 && vi < 16); + pp[l/2] |= (vi & 0xf) << (4*(l & 1)); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#endif + //printf("amax = %f\n", amax); + } +} + +// reimplementation of quantize__6using quantize_6_row +void quantize_6(const float * restrict src, char * restrict dst, int n, int k) { + assert(k % 32 == 0); + + for (int j = 0; j < n; j++) { + quantize_6_row(src + j*k, dst, k); + dst = (char *) dst + quantize_6_row_size(k); + } +} + +void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = quantize_6_blocks_per_row(n); + + const gq_scale_t * restrict pd0 = (const gq_scale_t *) x; + const gq_scale_t * restrict pd1 = (const gq_scale_t *) y; + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#if 0 + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; + + for (int j = 0; j < 16; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + sumf += f0*f2 + f1*f3; + } + } +#else +#if defined(__AVX2__) + // TODO +#elif defined (__ARM_NEON) +#if 0 + float sum0 = 0.0f; + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_GQ_TO_FP32(pd0[i]); + const float d1 = GGML_GQ_TO_FP32(pd1[i]); + + //float32x4_t d0d1v = vdupq_n_f32(d0*d1); + //float16x8_t d0d1v = vdupq_n_f16(d0*d1); + + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v1_0 = vld1q_u8(p1); + + // 4-bit -> 8-bit + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + + // dot product into int16x8_t + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + + const int16x8_t pl = vaddq_s16(pl0l, pl0h); + const int16x8_t ph = vaddq_s16(ph0l, ph0h); + + const int16x8_t p = vaddq_s16(pl, ph); + + // scalar + sum0 += d0*d1*vaddvq_s16(p); + } + + sumf = sum0; +#elif 1 // this is a bit faster than the above + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int i = 0; i < nb; i += 2) { + const float d0_0 = GGML_GQ_TO_FP32(pd0[i + 0]); + const float d1_0 = GGML_GQ_TO_FP32(pd1[i + 0]); + const float d0_1 = GGML_GQ_TO_FP32(pd0[i + 1]); + const float d1_1 = GGML_GQ_TO_FP32(pd1[i + 1]); + + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // 4-bit -> 8-bit + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + + const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); + const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + + const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); + const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); + + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); + + // dot product into int16x8_t + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); + const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); + + const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); + const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); + + // scalar + sum0 += d0_0*d1_0*vaddvq_s16(p_0); + sum1 += d0_1*d1_1*vaddvq_s16(p_1); + } + + sumf = sum0 + sum1; +#endif +#endif +#endif + + *s = sumf; +} + +// use vec_dot_gq_6 to compute the dot product of two rows +void mul_mat_gq_6( + const void * src0, + const void * src1, // transposed + float * dst, + int m, int n, int k) { + assert(k % 32 == 0); + + const int nb = quantize_6_blocks_per_row(k); + + for (int ir0 = 0; ir0 < m; ir0++) { + for (int ir1 = 0; ir1 < n; ir1++) { + vec_dot_gq_6(k, dst + ir1, src0, src1); + src1 = (const char *) src1 + quantize_6_row_size(k); + } + src0 = (const char *) src0 + quantize_6_row_size(k); + src1 = (const char *) src1 - n*quantize_6_row_size(k); + + dst = (float *) dst + n; + } +} + +int main(int argc, const char ** argv) { + assert(sizeof(gq_quant_t)*8 == gq_t_bits); + + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + int method = 0; + if (argc > 1) { + method = atoi(argv[1]); + } + + float * src0 = (float *)malloc(sizeof(float)*M*K); + float * src1 = (float *)malloc(sizeof(float)*N*K); + float * dst = (float *)malloc(sizeof(float)*M*N); + + // allocate aligned memory + //float * src0 = (float *)aligned_alloc(32, sizeof(float)*M*K); + //float * src1 = (float *)aligned_alloc(32, sizeof(float)*N*K); + //float * dst = (float *)aligned_alloc(32, sizeof(float)*M*N); + + for (int i = 0; i < M*K; i++) { + src0[i] = 0.8 - rand() / (float)RAND_MAX; + /*src0[i] = rand() / (float)RAND_MAX;*/ + /*src0[i] = i % 2;*/ + } + + for (int i = 0; i < N*K; i++) { + src1[i] = 0.8 - rand() / (float)RAND_MAX; + /*src1[i] = rand() / (float)RAND_MAX;*/ + /*src1[i] = i % 3;*/ + } + + void * src0_gq = NULL; + void * src1_gq = NULL; + + size_t sizegq = 0; + + { + if (method == 1) { + src0_gq = calloc(1, quantize_1_row_size(K)*M); + src1_gq = calloc(1, quantize_1_row_size(K)*N); + + sizegq = quantize_1_row_size(K)*M + quantize_1_row_size(K)*N; + } + + if (method == 2) { + src0_gq = calloc(1, quantize_2_row_size(K)*M); + src1_gq = calloc(1, quantize_2_row_size(K)*N); + + sizegq = quantize_2_row_size(K)*M + quantize_2_row_size(K)*N; + } + + if (method == 3) { + src0_gq = calloc(1, quantize_3_row_size(K)*M); + src1_gq = calloc(1, quantize_3_row_size(K)*N); + + sizegq = quantize_3_row_size(K)*M + quantize_3_row_size(K)*N; + } + + if (method == 4) { + src0_gq = calloc(1, quantize_4_row_size(K)*M); + src1_gq = calloc(1, quantize_4_row_size(K)*N); + + sizegq = quantize_4_row_size(K)*M + quantize_4_row_size(K)*N; + } + + if (method == 5) { + src0_gq = calloc(1, quantize_5_row_size(K)*M); + src1_gq = calloc(1, quantize_5_row_size(K)*N); + + sizegq = quantize_5_row_size(K)*M + quantize_5_row_size(K)*N; + } + + if (method == 6) { + src0_gq = calloc(1, quantize_6_row_size(K)*M); + src1_gq = calloc(1, quantize_6_row_size(K)*N); + + sizegq = quantize_6_row_size(K)*M + quantize_6_row_size(K)*N; + } + } + + const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K; + + printf("compression: %f\n", (float)sizegq/sizef16); + + // convert fp32 -> gq + { + const uint64_t t_start = get_time_us(); + + if (method == 1) { + quantize_1(src0, src0_gq, M, K); + quantize_1(src1, src1_gq, N, K); + } + + if (method == 2) { + quantize_2(src0, src0_gq, M, K); + quantize_2(src1, src1_gq, N, K); + } + + if (method == 3) { + quantize_3(src0, src0_gq, M, K); + quantize_3(src1, src1_gq, N, K); + } + + if (method == 4) { + quantize_4(src0, src0_gq, M, K); + quantize_4(src1, src1_gq, N, K); + } + + if (method == 5) { + quantize_5(src0, src0_gq, M, K); + quantize_5(src1, src1_gq, N, K); + } + + if (method == 6) { + quantize_6(src0, src0_gq, M, K); + quantize_6(src1, src1_gq, N, K); + } + + const uint64_t t_end = get_time_us(); + printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method); + } + + for (int i = 0; i < 16; ++i) { + printf("%f %f\n", src0[i], src1[i]); + } + + const int nIter = 1; + + const clock_t start = clock(); + const uint64_t start_us = get_time_us(); + + double iM = 1.0/M; + double sum = 0.0f; + for (int i = 0; i < nIter; i++) { + if (method == 0) { + mul_mat_f32_naive(src0, src1, dst, M, N, K); + } + + if (method == 1) { + mul_mat_gq_1(src0_gq, src1_gq, dst, M, N, K); + } + + if (method == 2) { + mul_mat_gq_2(src0_gq, src1_gq, dst, M, N, K); + } + + if (method == 3) { + mul_mat_gq_3(src0_gq, src1_gq, dst, M, N, K); + } + + if (method == 4) { + mul_mat_gq_4(src0_gq, src1_gq, dst, M, N, K); + } + + if (method == 5) { + mul_mat_gq_5(src0_gq, src1_gq, dst, M, N, K); + } + + if (method == 6) { + mul_mat_gq_6(src0_gq, src1_gq, dst, M, N, K); + } + } + + for (int i = 0; i < N; i++) { + sum += dst[i]*iM; + } + + { + const clock_t end = clock(); + const uint64_t end_us = get_time_us(); + printf("%s: elapsed ticks: %ld\n", __func__, end - start); + printf("%s: elapsed us: %d / %f ms\n", __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter); + } + +#if 0 + // print src0 + printf("src0:\n"); + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + printf("%4.1f ", src0[i*K+j]); + } + printf("\n"); + } + + // print src1 + printf("src1:\n"); + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + printf("%4.1f ", src1[i*K+j]); + } + printf("\n"); + } + + printf("dst:\n"); + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + printf("%4.1f ", dst[i*N+j]); + } + printf("\n"); + } +#endif printf("%f\n", sum); @@ -468,8 +2584,8 @@ int main(int argc, const char ** argv) { free(src1); free(dst); - free(src0_gq); - free(src1_gq); + if (src0_gq) free(src0_gq); + if (src1_gq) free(src1_gq); return 0; }