diff --git a/examples/training/CMakeLists.txt b/examples/training/CMakeLists.txt index 64afe6ddc647a..08d7ab2479055 100644 --- a/examples/training/CMakeLists.txt +++ b/examples/training/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} finetune.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + +set(TARGET llama-finetune-lora) +add_executable(${TARGET} finetune-lora.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) \ No newline at end of file diff --git a/examples/training/README.md b/examples/training/README.md index df425279266e4..ed255a0e1af3d 100644 --- a/examples/training/README.md +++ b/examples/training/README.md @@ -1,5 +1,6 @@ # llama.cpp/examples/training +## finetune This directory contains examples related to language model training using llama.cpp/GGML. So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP. Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory. @@ -15,3 +16,67 @@ export model_name=llama_3.2-1b && export quantization=f32 ``` The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. + + +## finetune-lora + +LoRA (Low-Rank Adaptation) fine-tuning for efficient model training. This approach trains only a small set of additional parameters while keeping +the base model frozen, making it memory-efficient. + +### Basic Usage + +```sh +# Create new LoRA adapter with default settings (rank=8, alpha=16, attention modules) +./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 + +# Custom LoRA parameters(creates new lora adapter and trains it from scratch) +./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \ + --lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o" + +# Fine-tune existing LoRA adapter +./build/bin/llama-finetune-lora -m base_model.gguf -f dataset.txt --lora existing_adapter.gguf \ + --output-adapter improved_adapter.gguf -ngl 999 -c 512 -b 512 -ub 512 +``` + + +### Parameters + +#### LoRA Configuration +- `--lora-rank N` - LoRA rank (default: 8) + - Lower rank = smaller adapter, less capacity + - Higher rank = larger adapter, more capacity +- `--lora-alpha N` - LoRA alpha scaling factor (default: 16.0) + - Controls adaptation strength + - Common rule: alpha = 2 × rank +- `--lora-modules MODULES` - Target modules as comma-separated list + - Available: `attn_q`, `attn_k`, `attn_v`, `attn_o`, `ffn_gate`, `ffn_up`, `ffn_down`, `embed`, `output`, `all` + - Default: `attn_q,attn_k,attn_v,attn_o` (attention modules) +- `--output-adapter PATH` - Output adapter filename (default: auto-generated) + +#### Standard Parameters +- `-m MODEL` - Base model file (.gguf) +- `-f FILE` - Training dataset +- `-ngl N` - GPU layers (use 999 for full GPU training) +- `-c N` - Context length (512 recommended for mobile) + + +### Using Trained Adapters + +After training, you'll get a small adapter file. Use it with the original base model: + +```sh +./build/bin/llama-cli -m base_model.gguf --lora trained_adapter.gguf -ngl 999 +``` + +### Troubleshooting + +- **Out of memory**: Reduce context length (`-c 256`), lower rank, or use fewer target modules +- **Poor quality**: Increase rank, add more target modules, or train longer +- **Large adapter**: Reduce rank or limit target modules + +### Help + +Run with `--help` or `-h` to see all available parameters: +```sh +./build/bin/llama-finetune-lora --help +``` diff --git a/examples/training/finetune-lora.cpp b/examples/training/finetune-lora.cpp new file mode 100644 index 0000000000000..8e3a1026b6c91 --- /dev/null +++ b/examples/training/finetune-lora.cpp @@ -0,0 +1,262 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + + +static uint32_t parse_lora_modules(const std::string& modules_str) { + if (modules_str.empty()) { + return LLAMA_LORA_TARGET_ATTN_Q | LLAMA_LORA_TARGET_ATTN_K | LLAMA_LORA_TARGET_ATTN_V | LLAMA_LORA_TARGET_ATTN_O; + } + + static const std::map module_map = { + {"attn_q", LLAMA_LORA_TARGET_ATTN_Q}, + {"attn_k", LLAMA_LORA_TARGET_ATTN_K}, + {"attn_v", LLAMA_LORA_TARGET_ATTN_V}, + {"attn_o", LLAMA_LORA_TARGET_ATTN_O}, + {"ffn_gate", LLAMA_LORA_TARGET_FFN_GATE}, + {"ffn_up", LLAMA_LORA_TARGET_FFN_UP}, + {"ffn_down", LLAMA_LORA_TARGET_FFN_DOWN}, + {"output", LLAMA_LORA_TARGET_OUTPUT}, + {"all", LLAMA_LORA_TARGET_ALL} + }; + + uint32_t target_modules = 0; + std::stringstream ss(modules_str); + std::string module; + + while (std::getline(ss, module, ',')) { + module.erase(0, module.find_first_not_of(" \t")); + module.erase(module.find_last_not_of(" \t") + 1); + + auto it = module_map.find(module); + if (it != module_map.end()) { + target_modules |= it->second; + LOG_INF("Added target module: %s\n", module.c_str()); + } else { + LOG_ERR("Unknown LoRA target module: %s\n", module.c_str()); + LOG_ERR("Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n"); + return 0; + } + } + + return target_modules; +} + +static void print_lora_usage() { + printf("\nLoRA Fine-tuning Parameters:\n"); + printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n"); + printf(" --lora-alpha N LoRA alpha scaling factor (default: 16.0, range: 0.1-1000.0)\n"); + printf(" --lora-modules MODULES Target modules as comma-separated list (default: attn_q,attn_k,attn_v,attn_o)\n"); + printf(" Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n"); + printf(" Examples: \"attn_q,attn_v\" or \"all\" or \"attn_q,attn_k,attn_v,attn_o,ffn_gate,ffn_up,ffn_down\"\n"); + printf(" --output-adapter PATH Output path for trained adapter (default: auto-generated)\n"); + printf("\nExamples:\n"); + printf(" # Train with rank=16, alpha=32, all attention modules\n"); + printf(" %s -m model.gguf -f dataset.txt --lora-rank 16 --lora-alpha 32 --lora-modules attn_q,attn_k,attn_v,attn_o\n", "finetune-lora"); + printf("\n # Fine-tune existing adapter with all modules\n"); + printf(" %s -m model.gguf -f dataset.txt --lora existing.gguf --output-adapter improved.gguf\n", "finetune-lora"); + printf("\n"); +} + +int main(int argc, char ** argv) { + common_params params; + + int32_t lora_rank = 8; + float lora_alpha = 16.0f; + std::string lora_modules_str; + std::string output_adapter_path; + + params.escape = false; + + auto remove_arg_pair = [&](int i) { + for (int j = i; j < argc - 2; j++) { + argv[j] = argv[j + 2]; + } + argc -= 2; + }; + + for (int i = 1; i < argc - 1; i++) { + if (strcmp(argv[i], "--lora-rank") == 0) { + lora_rank = std::atoi(argv[i + 1]); + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--lora-alpha") == 0) { + lora_alpha = std::atof(argv[i + 1]); + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--lora-modules") == 0) { + lora_modules_str = argv[i + 1]; + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--output-adapter") == 0) { + output_adapter_path = argv[i + 1]; + remove_arg_pair(i); + i--; + } + } + + LOG_INF("Using LoRA parameters: rank=%d, alpha=%.1f\n", lora_rank, lora_alpha); + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + print_lora_usage(); + } + } + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { + print_lora_usage(); + return 1; + } + + if (params.use_mmap) { + LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); + params.use_mmap = false; + } + if (params.cache_type_k != GGML_TYPE_F32) { + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_k = GGML_TYPE_F32; + } + if (params.cache_type_v != GGML_TYPE_F32) { + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_v = GGML_TYPE_F32; + } + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + + common_init_result llama_init = common_init_from_params(params); + llama_model_ptr & model = llama_init.model; + llama_context_ptr & ctx = llama_init.context; + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + uint32_t target_modules = parse_lora_modules(lora_modules_str); + if (target_modules == 0) { + return 1; + } + + struct llama_lora_training_params lora_params = { + /*target_modules =*/ target_modules, + /*rank =*/ lora_rank, + /*alpha =*/ lora_alpha, + /*dropout =*/ 0.0f, + /*init_std =*/ 0.02f, + }; + + bool has_existing_lora = !params.lora_adapters.empty(); + struct llama_adapter_lora * trained_adapter = nullptr; + + if (has_existing_lora) { + LOG_INF("Finetuning existing LoRA adapters\n"); + LOG_INF("Found %zu existing LoRA adapters to train\n", params.lora_adapters.size());\ + trained_adapter = params.lora_adapters[0].ptr; + if (!trained_adapter) { + LOG_ERR("Existing LoRA adapter is null\n"); + return 1; + } + } else { + LOG_INF("Target modules: Q=%s, K=%s, V=%s, O=%s, GATE=%s, UP=%s, DOWN=%s, OUTPUT=%s\n", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_Q) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_K) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_V) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_O) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_GATE) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_UP) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_DOWN) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_OUTPUT) ? "yes" : "no"); + + LOG_INF("LoRA configuration: rank=%d, alpha=%.1f (scaling=%.3f)\n", + lora_params.rank, lora_params.alpha, lora_params.alpha / lora_params.rank); + + trained_adapter = llama_lora_training_init(ctx.get(), model.get(), &lora_params); + if (!trained_adapter) { + LOG_ERR("%s: LoRA training initialization failed\n", __func__); + return 1; + } + } + + constexpr float val_split = 0.05f; + + std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); + + struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); + optimizer_params.adamw.alpha = 1e-5f; // learning rate + + struct llama_opt_params lopt_params { + /*n_ctx_train =*/ 0, + /*param_filter =*/ llama_opt_param_filter_lora, + /*param_filter_ud =*/ nullptr, + /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, + /*get_opt_pars_ud =*/ &optimizer_params, + }; + llama_opt_init(ctx.get(), model.get(), lopt_params); + + const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_eval = ggml_opt_result_init(); + + for (int epoch = 0; epoch < 2; ++epoch) { + llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, + ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); + fprintf(stderr, "\n"); + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_eval); + } + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_eval); + + std::string adapter_filename; + if (!output_adapter_path.empty()) { + adapter_filename = output_adapter_path; + } else if (has_existing_lora) { + adapter_filename = "finetuned-lora-adapter.gguf"; + LOG_INF("Finetuned existing lora adapter, saving as: %s\n", adapter_filename.c_str()); + } else { + adapter_filename = "trained-lora-adapter.gguf"; + LOG_INF("Saving new lora adapter: %s\n", adapter_filename.c_str()); + } + + if (trained_adapter) { + if (llama_lora_save_adapter(trained_adapter, adapter_filename.c_str(), model.get())) { + std::ifstream adapter_file(adapter_filename, std::ios::binary | std::ios::ate); + if (adapter_file.is_open()) { + std::streamsize adapter_size = adapter_file.tellg(); + LOG_INF("LoRA adapter saved: %s (%.2f MB)\n", + adapter_filename.c_str(), adapter_size / (1024.0 * 1024.0)); + adapter_file.close(); + } + } else { + LOG_ERR("Failed to save LoRA adapter\n"); + } + } else { + LOG_ERR("No trained adapter available for saving\n"); + } + + llama_backend_free(); + + return 0; +} diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8a8775be36583..3ae9d7c4bf0bd 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -457,6 +457,7 @@ extern "C" { GGML_OP_REPEAT_BACK, GGML_OP_CONCAT, GGML_OP_SILU_BACK, + GGML_OP_GEGLU_BACK, GGML_OP_NORM, // normalize GGML_OP_RMS_NORM, GGML_OP_RMS_NORM_BACK, @@ -1097,6 +1098,12 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_geglu_back( + struct ggml_context * ctx, + struct ggml_tensor * grad, + struct ggml_tensor * x, + struct ggml_tensor * g); + // hardswish(x) = x * relu6(x + 3) / 6 GGML_API struct ggml_tensor * ggml_hardswish( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c5271b7757228..f5270179d210c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1746,6 +1746,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_silu_back(params, tensor); } break; + case GGML_OP_GEGLU_BACK: + { + ggml_compute_forward_geglu_back(params, tensor); + } break; case GGML_OP_NORM: { ggml_compute_forward_norm(params, tensor); @@ -2182,6 +2186,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_SILU_BACK: + case GGML_OP_GEGLU_BACK: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_NORM: diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index c9daa4c39e83e..91b1004b5cf3c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -442,7 +442,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st case GGML_OP_GET_ROWS_BACK: return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16; case GGML_OP_OUT_PROD: - return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return true; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6581d27adde2e..a92f46e151b01 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3185,6 +3185,71 @@ void ggml_compute_forward_silu_back( } } +static void ggml_compute_forward_geglu_back_f32( + const ggml_compute_params * params, + const struct ggml_tensor * grad, + const struct ggml_tensor * x, + const struct ggml_tensor * g, + struct ggml_tensor * dst) { + + GGML_ASSERT(ggml_can_repeat(grad, dst)); + GGML_ASSERT(ggml_are_same_shape(x, g)); + GGML_ASSERT(grad->type == GGML_TYPE_F32); + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + + GGML_ASSERT(nc % 2 == 0); + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + for (int i2 = 0; i2 < dst->ne[2]; i2++) { + for (int i1 = ith; i1 < dst->ne[1]; i1 += nth) { + float * dst_ptr = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + const float * grad_ptr = (const float *)((char *) grad->data + i3*grad->nb[3] + i2*grad->nb[2] + i1*grad->nb[1]); + const float * x_ptr = (const float *)((char *) x->data + i3*x->nb[3] + i2*x->nb[2] + i1*x->nb[1]); + const float * g_ptr = (const float *)((char *) g->data + i3*g->nb[3] + i2*g->nb[2] + i1*g->nb[1]); + + const int half = nc / 2; + ggml_vec_gelu_f32(half, dst_ptr, g_ptr); + ggml_vec_mul_f32(half, dst_ptr, dst_ptr, grad_ptr); + float * temp = (float *)alloca(half * sizeof(float)); + ggml_vec_gelu_backward_f32(half, temp, g_ptr, grad_ptr); + ggml_vec_mul_f32(half, dst_ptr + half, temp, x_ptr); + } + } + } +} + +void ggml_compute_forward_geglu_back( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const struct ggml_tensor * grad = dst->src[0]; + const struct ggml_tensor * x = dst->src[1]; + const struct ggml_tensor * g = dst->src[2]; + + switch (dst->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_back_f32(params, grad, x, g, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + + // ggml_compute_forward_reglu static void ggml_compute_forward_reglu_f32( @@ -4498,6 +4563,107 @@ static void ggml_compute_forward_out_prod_f32( } } +static void ggml_compute_forward_out_prod_f16_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + GGML_ASSERT(ne2 % ne02 == 0); + GGML_ASSERT(ne3 % ne03 == 0); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (ith == 0) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0); + } + ggml_barrier(params->threadpool); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + ggml_fp16_t * s0 = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + for (int i = 0; i < ne0; ++i) { + d[i] += GGML_CPU_FP16_TO_FP32(s0[i])*(*s1); + } + } + } + } + } +} + static void ggml_compute_forward_out_prod_q_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -4620,9 +4786,8 @@ void ggml_compute_forward_out_prod( } break; case GGML_TYPE_F16: { - GGML_ABORT("fatal error"); // todo - // ggml_compute_forward_out_prod_f16_f32(params, dst); - } + ggml_compute_forward_out_prod_f16_f32(params, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_out_prod_f32(params, dst); diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 3a32ec20dba2b..d1f88ba53eca0 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -40,6 +40,7 @@ void ggml_compute_forward_repeat(const struct ggml_compute_params * params, stru void ggml_compute_forward_repeat_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_concat(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_geglu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index d18783a00a1a5..7ad6d44ea8c50 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -944,6 +944,32 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con } } +inline static float ggml_gelu_backward_f32(float x, float dy) { + const float tanh_arg = SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x); + const float tanh_val = tanhf(tanh_arg); + const float sech2_val = 1.0f - tanh_val * tanh_val; + const float dtanh_dx = SQRT_2_OVER_PI * (1.0f + 3.0f * GELU_COEF_A * x * x) * sech2_val; + return dy * 0.5f * (1.0f + tanh_val + x * dtanh_dx); +} + +inline static void ggml_vec_gelu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + dx[i] = ggml_gelu_backward_f32(x[i], dy[i]); + } +} + +inline static void ggml_vec_gelu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) { + for (int i = 0; i < n; ++i) { + float xi = GGML_CPU_FP16_TO_FP32(x[i]); + float tanh_arg = SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi); + float tanh_val = tanhf(tanh_arg); + float sech2_val = 1.0f - tanh_val * tanh_val; + float dtanh_dx = SQRT_2_OVER_PI * (1.0f + 3.0f * GELU_COEF_A * xi * xi) * sech2_val; + + dx[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(dy[i]) * 0.5f * (1.0f + tanh_val + xi * dtanh_dx)); + } +} + inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) { for (int i = 0; i < n; ++i) { y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 50a977c30762c..b0039b93a933a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3202,7 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } } break; case GGML_OP_OUT_PROD: - return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + return op->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index c9b2b699c6a55..be73bcb89c9dd 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -1,4 +1,5 @@ #include "out-prod.cuh" +#include "convert.cuh" #include @@ -8,10 +9,56 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + const bool src0_is_quantized = (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16); + const bool src1_is_quantized = (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + cudaStream_t stream = ctx.stream(); + ggml_cuda_pool & pool = ctx.pool(); + + // temp buffers + float * src0_f32 = nullptr; + float * src1_f32 = nullptr; + bool allocated_src0 = false; + bool allocated_src1 = false; + ggml_cuda_pool_alloc src0_alloc(pool); + ggml_cuda_pool_alloc src1_alloc(pool); + + if (src0_is_quantized) { + const size_t src0_size = ggml_nelements(src0); + src0_alloc.alloc(src0_size); + src0_f32 = src0_alloc.ptr; + allocated_src0 = true; + + // Dequantize + auto dequantize_fn = ggml_get_to_fp32_cuda(src0->type); + if (dequantize_fn) { + dequantize_fn(src0->data, src0_f32, ggml_nelements(src0), stream); + } else { + GGML_ABORT("Unsupported quant type for src0"); + } + } else { + src0_f32 = (float *) src0->data; + } + + if (src1_is_quantized) { + const size_t src1_size = ggml_nelements(src1); + src1_alloc.alloc(src1_size); + src1_f32 = src1_alloc.ptr; + allocated_src1 = true; + + auto dequantize_fn = ggml_get_to_fp32_cuda(src1->type); + if (dequantize_fn) { + dequantize_fn(src1->data, src1_f32, ggml_nelements(src1), stream); + } else { + GGML_ABORT("Unsupported quant type for src1"); + } + } else { + src1_f32 = (float *) src1->data; + } + + GGML_ASSERT(ne01 == ne11); GGML_ASSERT(ne0 == ne00); GGML_ASSERT(ne1 == ne10); @@ -22,11 +69,11 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ne2 == src1->ne[2]); GGML_ASSERT(ne3 == src1->ne[3]); - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; + // Use dequantized data + const float * src0_d = src0_f32; + const float * src1_d = src1_f32; float * dst_d = (float *) dst->data; - cudaStream_t stream = ctx.stream(); cublasHandle_t handle = ctx.cublas_handle(); const float alpha = 1.0f; @@ -34,19 +81,25 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK(cublasSetStream(handle, stream)); - const int64_t lda = nb01 / sizeof(float); + const int64_t lda = allocated_src0 ? ne00 : (nb01 / sizeof(float)); const int64_t ldc = nb1 / sizeof(float); const bool src1_T = ggml_is_transposed(src1); const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); - GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); + const int64_t ldb = allocated_src1 ? + (src1_T ? ne10 : ne11) : + ((src1_T ? nb10 : nb11) / sizeof(float)); + + // Only assert for non dequantized src1 + if (!allocated_src1) { + GGML_ASSERT((src1_T ? nb11 : nb10) == sizeof(float)); + } // data strides in dimensions 2/3 - const size_t s02 = nb02 / sizeof(float); - const size_t s03 = nb03 / sizeof(float); - const size_t s12 = nb12 / sizeof(float); - const size_t s13 = nb13 / sizeof(float); + const size_t s02 = allocated_src0 ? (ne00 * ne01) : nb02 / sizeof(float); + const size_t s03 = allocated_src0 ? (ne00 * ne01 * ne02): nb03 / sizeof(float); + const size_t s12 = allocated_src1 ? (ne10 * ne11) : nb12 / sizeof(float); + const size_t s13 = allocated_src1 ? (ne10 * ne11 * ne12) : nb13 / sizeof(float); const size_t s2 = nb2 / sizeof(float); const size_t s3 = nb3 / sizeof(float); @@ -65,4 +118,5 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { &beta, dst_d + i3 *s3 + i2 *s2, ldc)); } } + } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3019a545d58ed..9d8091feeca9f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -239,6 +239,17 @@ enum FaHeadSizes { FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED, }; +// XXX: Use value queried from the driver +#if 1 +const uint64_t MAX_ADDRESS_SPACE_SIZE = 1 << 27; +const uint64_t MAX_ADDRESS_SPACE_SIZE_MUL_MAT = 1 << 27; +const uint64_t MAX_ADDRESS_SPACE_SIZE_OUT_PROD = 1 << 27; +#else +const uint64_t MAX_ADDRESS_SPACE_SIZE = 1 << 26; +const uint64_t MAX_ADDRESS_SPACE_SIZE_MUL_MAT = 1 << 26; +const uint64_t MAX_ADDRESS_SPACE_SIZE_OUT_PROD = 1 << 26; +#endif + static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { vk::PhysicalDeviceProperties props = device.getProperties(); @@ -463,7 +474,9 @@ struct vk_device_struct { vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; + vk_pipeline pipeline_geglu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; + vk_pipeline pipeline_cross_entropy_loss_back_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; @@ -473,6 +486,10 @@ struct vk_device_struct { vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_out_prod_f32; + vk_pipeline pipeline_out_prod_f16_f32; + vk_pipeline pipeline_out_prod_q4_0; + vk_pipeline pipeline_out_prod_q8_0; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; @@ -569,8 +586,8 @@ struct vk_buffer_struct { } VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); - device->device.freeMemory(device_memory); device->device.destroyBuffer(buffer); + device->device.freeMemory(device_memory); } }; @@ -1020,7 +1037,10 @@ struct ggml_backend_vk_context { size_t semaphore_idx, event_idx; ggml_vk_garbage_collector gc; size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; + size_t prealloc_size_tile; vk_buffer prealloc_x, prealloc_y, prealloc_split_k; + vk_buffer prealloc_tile; + vk_buffer prealloc_tile_debug; vk::Fence fence, almost_ready_fence; bool almost_ready_fence_pending {}; @@ -1237,11 +1257,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin vk::PipelineRobustnessCreateInfoEXT rci; +#if 1 if (device->pipeline_robustness && disable_robustness) { rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; compute_pipeline_create_info.setPNext(&rci); } +#endif try { pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; @@ -1535,7 +1557,8 @@ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()"); // Requires command buffers to be done - device->device.resetCommandPool(p.pool); + device->device.resetCommandPool(p.pool, vk::CommandPoolResetFlagBits::eReleaseResources); + p.cmd_buffer_idx = 0; } @@ -1659,6 +1682,7 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk: } static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { + VK_LOG_MEMORY("ggml_vk_create_buffer_device(" << size << ")"); vk_buffer buf; try { if (device->prefer_host_memory) { @@ -2442,6 +2466,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -2629,6 +2654,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_tq2_0_f32_f32_len, mul_mat_vec_tq2_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); @@ -2698,6 +2724,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ2_0], "dequant_tq2_0", dequant_tq2_0_len, dequant_tq2_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); @@ -2909,8 +2936,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_geglu_back_f32, "geglu_back_f32", geglu_back_f32_len, geglu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -2934,6 +2965,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } + // TODO: should we have device->subgroup_size here or 0? + ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -3539,10 +3577,10 @@ static vk_device ggml_vk_get_device(size_t idx) { break; #endif default: - device->mul_mat_l[i] = true; + device->mul_mat_l[i] = false; device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; - device->mul_mat_id_l[i] = true; + device->mul_mat_id_l[i] = false; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; @@ -3953,6 +3991,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; ctx->prealloc_size_split_k = 0; + ctx->prealloc_size_tile = 0; ctx->fence = ctx->device->device.createFence({}); ctx->almost_ready_fence = ctx->device->device.createFence({}); @@ -3977,6 +4016,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4047,6 +4087,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4090,6 +4131,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4144,6 +4186,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4179,6 +4222,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4371,11 +4415,13 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); - VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; +#if 0 + std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; for (auto& buffer : descriptor_buffer_infos) { std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } - std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl; +#endif GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); @@ -4727,7 +4773,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ } static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { - VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); + VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << "dst=" << dst << ", dst_offset=" << dst_offset << ", dst_size=" << dst->size << ", src=" << src << ", src_offset=" << src_offset << ", src_size=" << src->size << ", copy_size=" << size << ")"); // Make sure both buffers are on same device GGML_ASSERT(src->device == dst->device); @@ -4737,6 +4783,10 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds } static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { +#if 0 + std::cerr << "ggml_vk_buffer(" << dst << "dst=" << dst << ", dst_offset=" << dst_offset << ", src=" + << src << ", src_offset=" << src_offset << ", size=" << size << ")" << std::endl; +#endif if (src->device == dst->device) { std::lock_guard guard(src->device->mutex); VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); @@ -4852,7 +4902,7 @@ static void ggml_vk_matmul( uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t padded_n) { - VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); ggml_vk_sync_buffers(subctx); if (split_k == 1) { const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; @@ -5066,10 +5116,146 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); } +// XXX: leave only this one, we dont need the other two functions +static void ggml_vk_copy_2d_to_2d(vk_context& subctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t width, size_t height, size_t spitch, size_t dpitch) { +#if 1 + VK_LOG_DEBUG("ggml_vk_copy_2d_to_2d(dst=" << dst << ", dst_offset=" << dst_offset + << ", src=" << src << ", src_offset=" << src_offset + << ", width=" << width << ", height=" << height + << ", spitch=" << spitch + << ", dpitch=" << dpitch << ")"); +#endif + + // XXX put this back in + std::lock_guard guard(src->device->mutex); + +#if 1 // TEST + ggml_vk_ctx_end(subctx); + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk wait"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_command_pool_cleanup(src->device, *subctx->p); + ggml_vk_ctx_begin(src->device, subctx); +#endif + + // Copy within the device + + VkMemoryBarrier memoryBarrier = {}; + memoryBarrier.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + memoryBarrier.srcAccessMask = VK_ACCESS_MEMORY_WRITE_BIT; + memoryBarrier.dstAccessMask = VK_ACCESS_MEMORY_READ_BIT; + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + + //int group_size = 1024; + int group_size = height; + int groups = CEIL_DIV(height, group_size); + // TODO: measure this + for (int i = 0; i < groups; i++) { + int group_height = std::min(group_size, (int) height - i * group_size); + std::vector copy_regions; + copy_regions.reserve(group_height); + // TODO: measure this + for (size_t row = i * group_size; row < i * group_size + group_height; ++row) { + size_t row_src_offset = src_offset + row * spitch; + size_t row_dst_offset = dst_offset + row * dpitch; + + // Make sure both buffers are on same device + GGML_ASSERT(src->device == dst->device); + + VkBufferCopy bc{ row_src_offset, row_dst_offset, width }; + copy_regions.push_back(bc); + } + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, copy_regions.size(), copy_regions.data()); + } + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + +#if 1 // TEST + ggml_vk_ctx_end(subctx); + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk wait"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_command_pool_cleanup(src->device, *subctx->p); + ggml_vk_ctx_begin(src->device, subctx); +#endif +} + +// XXX: either specify that this is only for {a, b, d} ops, or make it work for ops with different # of srcs +static uint64_t sum_buffer_sizes(uint64_t m, uint64_t n, uint64_t k, enum ggml_type m_type, enum ggml_type n_type, enum ggml_type d_type) { + uint64_t a_size = CEIL_DIV(m*k, ggml_blck_size(m_type)) * ggml_type_size(m_type); + uint64_t b_size = CEIL_DIV(n*k, ggml_blck_size(n_type)) * ggml_type_size(n_type); + uint64_t d_size = CEIL_DIV(m*n, ggml_blck_size(d_type)) * ggml_type_size(d_type); + + return a_size + b_size + d_size; +} + +// XXX: either specify that this is only for {a, b, d} ops, or make it work for ops with different # of srcs +static void calculate_tile_dims(uint64_t m, uint64_t n, uint64_t k, uint64_t *tile_m, uint64_t *tile_n, uint64_t *m_tiles, uint64_t *n_tiles, uint64_t *num_dispatches, enum ggml_type m_type = GGML_TYPE_F32, enum ggml_type n_type = GGML_TYPE_F32, enum ggml_type d_type = GGML_TYPE_F32) { + // XXX minimum tile size + const uint64_t step = 32; + + // Set starting tile size + uint64_t mt = std::min(step, m); + uint64_t nt = std::min(step, n); + + // XXX The minimum tile size might be already too large (if it is, it's likely because of `k`) + GGML_ASSERT(sum_buffer_sizes(mt, nt, k, m_type, n_type, d_type) < MAX_ADDRESS_SPACE_SIZE); + + while (nt < n || mt < m) { + bool nt_stopped = false; + bool mt_stopped = false; + + if (nt < n) { + uint64_t next_nt = std::min(n, nt + step); + if (sum_buffer_sizes(mt, next_nt, k, m_type, n_type, d_type) < MAX_ADDRESS_SPACE_SIZE) { + nt = next_nt; + } else { + nt_stopped = true; + } + } + + if (mt < m) { + uint64_t next_mt = std::min(m, mt + step); + if (sum_buffer_sizes(next_mt, nt, k, m_type, n_type, d_type) < MAX_ADDRESS_SPACE_SIZE) { + mt = next_mt; + } else { + mt_stopped = true; + } + } + + if ((nt_stopped || nt >= n) && (mt_stopped || mt >= m)) { + break; + } + } + + *tile_m = mt; + *tile_n = nt; + *m_tiles = CEIL_DIV(m, mt); + *n_tiles = CEIL_DIV(n, nt); + *num_dispatches = (*m_tiles) * (*n_tiles); +} + static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; - std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT @@ -5086,6 +5272,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; const uint64_t r2 = ne12 / ne02; const uint64_t r3 = ne13 / ne03; @@ -5121,7 +5309,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + // TODO: understand what this means + bool quantize_y = false; //ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; // Check for mmq first vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; @@ -5141,7 +5330,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; @@ -5162,6 +5351,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t d_sz = sizeof(float) * d_ne; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; vk_pipeline to_q8_1 = nullptr; @@ -5183,10 +5374,122 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); } + // XXX: Cleanup + auto getenv_u32 = [](const char *name, uint32_t defval) -> uint32_t { + const char *v = getenv(name); + if (!v || !*v) return defval; + char *endp = nullptr; + unsigned long x = strtoul(v, &endp, 10); + if (endp == v) return defval; + if (x > 0xfffffffful) x = 0xfffffffful; + return (uint32_t)x; + }; + + // XXX: Cleanup. + uint32_t a_bytes_per_block, a_bytes_per_unit, a_elems_per_block; + if (qx_needs_dequant) { + a_bytes_per_block = (uint32_t) ggml_type_size(f16_type); + a_elems_per_block = (uint32_t) ggml_blck_size(f16_type); + a_bytes_per_unit = (uint32_t) a_bytes_per_block / a_elems_per_block; + } else { + a_bytes_per_block = (uint32_t) ggml_type_size(src0->type); + a_elems_per_block = (uint32_t) ggml_blck_size(src0->type); + a_bytes_per_unit = (uint32_t) a_bytes_per_block / a_elems_per_block; + } + + uint32_t b_bytes_per_block, b_bytes_per_unit, b_elems_per_block; + if (quantize_y) { + b_bytes_per_block = ggml_type_size(GGML_TYPE_Q8_1); + b_elems_per_block = ggml_blck_size(GGML_TYPE_Q8_1); + b_bytes_per_unit = b_bytes_per_block / b_elems_per_block; + } else if (y_f32_kernel) { + b_bytes_per_block = ggml_type_size(GGML_TYPE_F32); + b_elems_per_block = ggml_blck_size(GGML_TYPE_F32); + b_bytes_per_unit = b_bytes_per_block / b_elems_per_block; + } else if (qy_needs_dequant) { + b_bytes_per_block = ggml_type_size(f16_type); + b_elems_per_block = ggml_blck_size(f16_type); + b_bytes_per_unit = b_bytes_per_block / b_elems_per_block; + } else { + b_bytes_per_block = ggml_type_size(src1->type); + b_elems_per_block = ggml_blck_size(src1->type); + b_bytes_per_unit = b_bytes_per_block / b_elems_per_block; + } + + uint32_t d_bytes_per_block, d_bytes_per_unit, d_elems_per_block; + d_bytes_per_block = ggml_type_size(GGML_TYPE_F32); + d_elems_per_block = ggml_blck_size(GGML_TYPE_F32); + d_bytes_per_unit = d_bytes_per_block / d_elems_per_block; + + // XXX + const uint64_t tiling_threshold = MAX_ADDRESS_SPACE_SIZE_MUL_MAT; + + // XXX: Cleanup. + const bool tiling_debug = getenv_u32("GGML_TILING_DEBUG", 0); + const bool tiling_enabled = getenv_u32("GGML_TILING_ENABLE", 0); + + // XXX: Cleanup. + bool do_tiling = + ne02 == 1 && ne03 == 1 && ne12 == 1 && ne13 == 1 && // XXX: DEBUG + tiling_enabled && +#if 0 + ((vk_tensor_offset(src0) + src0->view_offs + x_sz * ne02 * ne03) >= tiling_threshold || + (vk_tensor_offset(src1) + src1->view_offs + y_sz * ne12 * ne13) >= tiling_threshold || + (d_buf_offset + d_sz * ne12 * ne13) >= tiling_threshold); +#else + (x_sz * ne02 * ne03 + y_sz * ne12 * ne13 + + d_sz * ne12 * ne13 >= tiling_threshold); +#endif + + if (tiling_debug) { + fprintf(stderr, "tiling enabled ? %d (%lu > %lu ?)\n", do_tiling, x_sz * ne02 * ne03 + y_sz * ne12 * ne13 + d_sz * ne12 * ne13, tiling_threshold); + } + + // XXX + bool do_splitting = false; +#if 0 + if (do_tiling && (x_sz * ne02 * ne03) + (y_sz * ne12 * ne13) + (d_sz * ne12 * ne13) < tiling_threshold) { + do_splitting = true; + do_tiling = false; + } +#endif + if (do_splitting && tiling_debug) { + fprintf(stderr, "[VK] [MUL_MAT] [SPLITTING] (total sum %lu + %lu + %lu) >= %lu\n", (x_sz * ne02 * ne03), (y_sz * ne12 * ne13), (d_sz * ne12 * ne13), tiling_threshold); + } + + uint64_t tile_m = 0, tile_n = 0; + uint64_t m_tiles = 0, n_tiles = 0; + uint64_t num_dispatches = 0; + if (do_tiling) { + // XXX + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne12 == 1); + GGML_ASSERT(ne13 == 1); + + calculate_tile_dims(ne01, ne11, ne00, &tile_m, &tile_n, &m_tiles, &n_tiles, &num_dispatches); + if (tiling_debug) { + fprintf(stderr, "[VK] [TILING] tile_m=%lu, tile_n=%lu, m_tiles=%lu, n_tiles=%lu, num_dispatches=%lu\n", tile_m, tile_n, m_tiles, n_tiles, num_dispatches); + } + } + if (dryrun) { + // XXX: Cleanup + // Allocate buffers for tiling (for buffers that are too large to bind fully) + if (do_splitting || do_tiling) { + ctx->prealloc_size_tile = MAX_ADDRESS_SPACE_SIZE; + } + + if (do_tiling) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, num_dispatches); + } else { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + } + const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; + GGML_ASSERT(split_k <= 1); if ( (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || @@ -5203,8 +5506,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub ctx->prealloc_size_split_k = split_k_size; } - // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } @@ -5214,6 +5515,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } + // XXX: TODO: test split_k + // XXX: TODO: maybe add !do_splitting here as well if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); } @@ -5221,7 +5524,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } vk_buffer d_D = dst_buf_ctx->dev_buffer; - const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; GGML_ASSERT(d_D != nullptr); GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); vk_buffer d_X; @@ -5283,15 +5585,232 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } - // compute - ggml_vk_matmul( - ctx, subctx, pipeline, - { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, - { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, - ne01, ne11, ne10, - ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, - split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n - ); // NOLINT + if (!do_tiling && !do_splitting) { + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, + ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, + split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n + ); + } else if (do_splitting) { + if (tiling_debug) { + std::cerr << "[VK] splitting!!" << std::endl; + } + GGML_ASSERT(false); + + const uint64_t a_size = x_sz * ne02 * ne03; + const uint64_t b_size = y_sz * ne12 * ne13; + const uint64_t d_size = d_sz * ne12 * ne13; + const uint64_t a_off = 0; + const uint64_t b_off = a_size; + const uint64_t d_off = a_size + b_size; + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, 0, d_X, x_buf_offset, a_size, 1, 1, 1); + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, b_off, d_Y, y_buf_offset, b_size, 1, 1, 1); + + // XXX TODO: account for split_k in the total buffer size sum?? + GGML_ASSERT(split_k <= 1); + +#if 0 + vk_context splitctx = ggml_vk_create_temporary_context(ctx->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(ctx->device, splitctx); + + ggml_vk_begin_submission(ctx->device, *splitctx->p); + + ggml_vk_matmul( + ctx, splitctx, pipeline, + { ctx->prealloc_tile, 0, a_size }, { ctx->prealloc_tile, b_off, b_size }, + { ctx->prealloc_tile, d_off, d_size }, { nullptr, 0, 0 }, + ne01, ne11, ne10, + ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, + 1u, ne12*ne13, ne02, ne12, r2, r3, padded_n + ); + + vk::SubmitInfo submit_info{}; + submit_info.commandBufferCount = 1; + submit_info.pCommandBuffers = &splitctx->s->buffer; + + ggml_vk_ctx_end(splitctx); + + splitctx->p->q->queue.submit(submit_info, ctx->matmul_fence); + + VK_CHECK(ctx->device->device.waitForFences({ ctx->matmul_fence }, true, UINT64_MAX), "matmul_fence"); + ctx->device->device.resetFences({ ctx->matmul_fence }); + + ggml_vk_copy_2d_to_2d(subctx, d_D, d_buf_offset, ctx->prealloc_tile, d_off, d_size, 1, 1, 1); +#endif + } else { // tiling + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne12 == 1); + GGML_ASSERT(ne13 == 1); + + // XXX TODO: account for split_k in the total buffer size sum?? + GGML_ASSERT(split_k <= 1); + + // XXX which value to use? is 1 good enough? + const uint32_t align = 1; +#if 0 + const uint32_t off_align = 256; +#else + const uint32_t off_align = 1; +#endif + + if (tiling_debug) { + std::cerr << "[VK] mul_mat tiled: M=" << ne01 << " N=" << ne11 << " K=" << ne10 << " tile_m=" << tile_m << " tile_n=" << tile_n << " batches=" << (ne12*ne13) << " split_k=1" << std::endl; + } + + // XXX make it work for n12 > 1 + for (uint32_t n0 = 0; n0 < ne11; n0 += tile_n) { + const uint32_t nt = (uint32_t) std::min(tile_n, ne11 - n0); + // XXX I think these shouldn't have CEIL_DIV + const uint64_t b_off_bytes = y_buf_offset + CEIL_DIV(((uint64_t)n0 * (uint64_t)ne10), b_elems_per_block) * (uint64_t)b_bytes_per_block; + const uint64_t d_off_bytes_n = d_buf_offset + CEIL_DIV(((uint64_t)n0 * (uint64_t)ne20), d_elems_per_block) * (uint64_t)d_bytes_per_block; + + // How many are safe to read within the current n-tile. + // XXX +#if 1 + const uint32_t padded_n_tile = 0; +#else + const uint32_t padded_n_tile = nt; +#endif + + // XXX make it work for n02 > 1 + for (uint32_t m0 = 0; m0 < ne01; m0 += tile_m) { + const uint32_t mt = (uint32_t) std::min(tile_m, ne01 - m0); + // XXX I think these shouldn't have CEIL_DIV + const uint64_t a_off_bytes = x_buf_offset + CEIL_DIV(((uint64_t)m0 * (uint64_t)ne00), a_elems_per_block) * (uint64_t)a_bytes_per_block; + const uint64_t d_off_bytes = d_off_bytes_n + CEIL_DIV((uint64_t)m0, d_elems_per_block) * (uint64_t)d_bytes_per_block; + + if (tiling_debug) { + std::cerr << "[VK] tile m=[" << m0 << "," << (m0+mt) << ") n=[" << n0 << "," << (n0+nt) << ") padded_n_tile=" << padded_n_tile << std::endl; + } + + const uint64_t a_k_bytes = CEIL_DIV(ne00, a_elems_per_block) * a_bytes_per_block; + const uint64_t b_k_bytes = CEIL_DIV(ne10, b_elems_per_block) * b_bytes_per_block; + const uint64_t orig_stride_a_bytes = a_k_bytes; + const uint64_t orig_stride_b_bytes = b_k_bytes; + const uint64_t orig_stride_d_bytes = CEIL_DIV(ne20, d_elems_per_block) * d_bytes_per_block; + const uint64_t tile_stride_a_bytes = ggml_vk_align_size(orig_stride_a_bytes, align); + const uint64_t tile_stride_b_bytes = ggml_vk_align_size(orig_stride_b_bytes, align); + const uint64_t tile_stride_d_bytes = ggml_vk_align_size(CEIL_DIV(mt, d_elems_per_block) * d_bytes_per_block, align); + + const uint32_t tile_stride_a_elems = (tile_stride_a_bytes / a_bytes_per_block) * a_elems_per_block; + const uint32_t tile_stride_b_elems = (tile_stride_b_bytes / b_bytes_per_block) * b_elems_per_block; + const uint32_t tile_stride_d_elems = (tile_stride_d_bytes / d_bytes_per_block) * d_elems_per_block; + const uint64_t dst_copy_row_size = CEIL_DIV(mt, d_elems_per_block) * d_bytes_per_block; + + const uint64_t a_size_elems = tile_stride_a_elems * mt; + const uint64_t b_size_elems = tile_stride_b_elems * nt; + const uint64_t d_size_elems = tile_stride_d_elems * nt; + + const uint64_t a_size_bytes = tile_stride_a_bytes * mt; + const uint64_t b_size_bytes = tile_stride_b_bytes * nt; + const uint64_t d_size_bytes = tile_stride_d_bytes * nt; +#if 0 + const uint64_t a_off = ggml_vk_align_size(a_off, off_align); + const uint64_t b_off = ggml_vk_align_size(a_off + a_size_bytes, off_align); + const uint64_t d_off = ggml_vk_align_size(b_off + b_size_bytes, off_align); +#else + const uint64_t a_off = 0; + const uint64_t b_off = a_off + a_size_bytes; + const uint64_t d_off = b_off + b_size_bytes; +#endif + + GGML_ASSERT(a_size_bytes + b_size_bytes + d_size_bytes < MAX_ADDRESS_SPACE_SIZE_MUL_MAT); + GGML_ASSERT(d_off + d_size_bytes < MAX_ADDRESS_SPACE_SIZE_MUL_MAT); + + //ggml_vk_sync_buffers(subctx); + + // XXX + VkMemoryBarrier memoryBarrier = {}; + memoryBarrier.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + memoryBarrier.srcAccessMask = VK_ACCESS_MEMORY_WRITE_BIT; + memoryBarrier.dstAccessMask = VK_ACCESS_MEMORY_READ_BIT; + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + + // Copy data to tile buffers + ggml_vk_sync_buffers(subctx); + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, a_off, d_X, a_off_bytes, a_k_bytes, mt, orig_stride_a_bytes, tile_stride_a_bytes); + ggml_vk_sync_buffers(subctx); + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, b_off, d_Y, b_off_bytes, b_k_bytes, nt, orig_stride_b_bytes, tile_stride_b_bytes); + ggml_vk_sync_buffers(subctx); + + if (tiling_debug) { + std::cerr << "[VK] " + << " tile_stride_a_elems = " << tile_stride_a_elems << " | tile_stride_a_bytes = " << tile_stride_a_bytes << "\n" + << " tile_stride_b_elems = " << tile_stride_b_elems << " | tile_stride_b_bytes = " << tile_stride_b_bytes << "\n" + << " tile_stride_d_elems = " << tile_stride_d_elems << " | tile_stride_d_bytes = " << tile_stride_d_bytes << "\n" + << " m = " << ne01 << " | n = " << ne11 << " | k = " << ne00 << "\n" + << " mt = " << mt << " | nt = " << nt << " | kt = " << ne00 << "\n" + << " a_orig_off = " << a_off_bytes << " | b_orig_off = " << b_off_bytes << " | d_orig_off = " << d_off_bytes << "\n" + << " a_tile_off = " << a_off << " | b_tile_off = " << b_off << " | d_tile_off = " << d_off << "\n" + << " a_tile_size = " << a_size_bytes << " | b_tile_size = " << b_size_bytes << " | d_tile_size = " << d_size_bytes << "\n" + << " a_k_bytes = " << a_k_bytes << " | b_k_bytes = " << b_k_bytes << "\n" + << " dst_copy_row_size = " << dst_copy_row_size << "\n" + << std::endl; + } + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + + // XXX + ggml_vk_matmul( + ctx, subctx, pipeline, + { ctx->prealloc_tile, a_off, a_size_bytes }, { ctx->prealloc_tile, b_off, b_size_bytes }, + { ctx->prealloc_tile, d_off, d_size_bytes }, { nullptr, 0, 0 }, + mt, nt, (uint32_t)ne00, + tile_stride_a_elems, tile_stride_b_elems, tile_stride_d_elems, + tile_stride_a_elems*mt, tile_stride_b_elems*nt, tile_stride_d_elems*nt, + 1u, (uint32_t)(ne12*ne13), (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, padded_n_tile + ); + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + + ggml_vk_sync_buffers(subctx); + + // Copy results back to dst buffer + ggml_vk_copy_2d_to_2d(subctx, d_D, d_off_bytes, ctx->prealloc_tile, d_off, dst_copy_row_size, nt, tile_stride_d_bytes, orig_stride_d_bytes); + + ggml_vk_sync_buffers(subctx); + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + } + } + } } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5353,7 +5872,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT const uint64_t x_ne = ne01 * ne00; const uint64_t y_ne = ne11 * ne10; @@ -5467,7 +5986,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ne01 > max_groups_x) { groups_z = 64; + //groups_z = 96; groups_x = CEIL_DIV(groups_x, groups_z); + GGML_ASSERT(max_groups_x > groups_x); } // compute @@ -6613,6 +7134,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_silu_back_f32; } return nullptr; + case GGML_OP_GEGLU_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_geglu_back_f32; + } + return nullptr; case GGML_OP_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_norm_f32; @@ -6691,6 +7217,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_diag_mask_inf_f32; } return nullptr; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cross_entropy_loss_back_f32; + } + return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); @@ -6745,6 +7276,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; } + case GGML_OP_OUT_PROD: + if (dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32; + if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0; + if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_out_prod_f16_f32; + } + return nullptr; case GGML_OP_ARGSORT: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { return ctx->device->pipeline_argsort_f32; @@ -6829,6 +7370,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { switch (op) { case GGML_OP_CPY: case GGML_OP_GET_ROWS: + case GGML_OP_OUT_PROD: case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -6915,7 +7457,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || op == GGML_OP_OUT_PROD || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; @@ -6945,6 +7487,19 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint64_t ned3 = dst->ne[3]; const uint64_t ned = ned0 * ned1; + + // XXX: Cleanup + auto getenv_u32 = [](const char *name, uint32_t defval) -> uint32_t { + const char *v = getenv(name); + if (!v || !*v) return defval; + char *endp = nullptr; + unsigned long x = strtoul(v, &endp, 10); + if (endp == v) return defval; + if (x > 0xfffffffful) x = 0xfffffffful; + return (uint32_t)x; + }; + + init_pushconst_fastdiv(pc); vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); @@ -6959,6 +7514,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } if (dryrun) { + // XXX should've been short-circuited in build_graph() + GGML_ASSERT(op != GGML_OP_OUT_PROD); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -7030,6 +7588,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + // XXX: should we worry about incontiguous? if (op_supports_incontiguous) { x_sz = ggml_nbytes(src0); y_sz = use_src1 ? ggml_nbytes(src1) : 0; @@ -7055,6 +7614,72 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co // Single call if dimension 2 is contiguous GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); + + // XXX: Cleanup. + // XXX: TODO: be aware of integer division, and where this is used + // XXX: TODO: make this work for use_src2 ops and for !use_src1 ops + uint32_t a_bytes_per_block, a_bytes_per_unit, a_elems_per_block; + a_bytes_per_block = (uint32_t) ggml_type_size(src0->type); + a_elems_per_block = (uint32_t) ggml_blck_size(src0->type); + a_bytes_per_unit = (uint32_t) a_bytes_per_block / a_elems_per_block; + + uint32_t b_bytes_per_block = 0, b_bytes_per_unit = 0, b_elems_per_block = 0; + if (use_src1) { + b_bytes_per_block = ggml_type_size(src1->type); + b_elems_per_block = ggml_blck_size(src1->type); + b_bytes_per_unit = b_bytes_per_block / b_elems_per_block; + } + + uint32_t d_bytes_per_block, d_bytes_per_unit, d_elems_per_block; + d_bytes_per_block = ggml_type_size(dst->type); + d_elems_per_block = ggml_blck_size(dst->type); + d_bytes_per_unit = d_bytes_per_block / d_elems_per_block; + + const uint64_t tiling_threshold = MAX_ADDRESS_SPACE_SIZE_OUT_PROD; + + // XXX: Cleanup. + const bool tiling_debug = getenv_u32("GGML_TILING_DEBUG", 0); + const bool tiling_enabled = getenv_u32("GGML_TILING_ENABLE", 0); + + // XXX: Cleanup. + bool do_tiling = + (op == GGML_OP_OUT_PROD && tiling_enabled && use_src1 && !use_src2) && +#if 0 + ((vk_tensor_offset(src0) + src0->view_offs + x_sz * ne02 * ne03) >= tiling_threshold || + (vk_tensor_offset(src1) + src1->view_offs + y_sz * ne12 * ne13) >= tiling_threshold || + (d_buf_offset + d_sz * ne12 * ne13) >= tiling_threshold); +#else + ((x_sz * ne02 * ne03) + (y_sz * ne12 * ne13) + + (d_sz * ne12 * ne13) >= tiling_threshold); +#endif + + // XXX + bool do_splitting = false; +#if 0 + if (use_src1 && !use_src2 && do_tiling && (x_sz * ne02 * ne03) + (y_sz * ne12 * ne13) + (d_sz * ne12 * ne13) < tiling_threshold) { + do_splitting = true; + do_tiling = false; + } +#endif + if (do_splitting && tiling_debug) { + fprintf(stderr, "[VK] [%s] [SPLITTING] (total sum %lu) >= %lu\n", ggml_op_name(op), (x_sz * ne02 * ne03) + (y_sz * ne12 * ne13) + (d_sz * ne12 * ne13), tiling_threshold); + } + + // XXX: TODO: enable this for other operators + do_tiling = do_tiling && op == GGML_OP_OUT_PROD; + do_splitting = do_splitting && op == GGML_OP_OUT_PROD; + + // XXX: Cleanup. + uint64_t tile_m = 0, tile_n = 0; + uint64_t m_tiles = 0, n_tiles = 0; + uint64_t num_dispatches = 0; + if (do_tiling) { + calculate_tile_dims(ne00, ne10, ne01, &tile_m, &tile_n, &m_tiles, &n_tiles, &num_dispatches, src0->type); + if (tiling_debug) { + fprintf(stderr, "[VK] %s [TILING] [tile_and_dispatch] tile_m=%lu, tile_n=%lu, m_tiles=%lu, n_tiles=%lu, num_dispatches=%lu\n", ggml_op_name(op), tile_m, tile_n, m_tiles, n_tiles, num_dispatches); + } + } + switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM_BACK: @@ -7073,6 +7698,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + // For cross entropy loss back, we need one workgroup per row of logits (src1) + const uint32_t nr = ggml_nrows(src1); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; case GGML_OP_RMS_NORM: elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; break; @@ -7149,6 +7786,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_UPSCALE: case GGML_OP_UNARY: case GGML_OP_GLU: + case GGML_OP_OUT_PROD: case GGML_OP_CONV_2D_DW: { uint32_t ne = ggml_nelements(dst); @@ -7214,7 +7852,257 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) { + // XXX which value to use? + const uint32_t align = 1; + + // XXX: TODO: enable this for other ops + if (do_splitting && op == GGML_OP_OUT_PROD) { + if (tiling_debug) { + std::cerr << "[VK] " << ggml_op_name(op) << " splitting!!" << std::endl; + } + + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne12 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ned2 == 1); + GGML_ASSERT(ned3 == 1); + + GGML_ASSERT(false); + + // XXX +#if 0 + // XXX: check this for ne02/ne03/... > 1 + const uint64_t a_size = CEIL_DIV(ne0, ggml_blck_size(src0->type)) * ggml_type_size(src0->type); + const uint64_t b_size = ggml_type_size(src1->type) * ne1; + const uint64_t d_size = ggml_type_size(dst->type) * ned; + const uint64_t a_off = 0; + const uint64_t b_off = a_size; + const uint64_t d_off = a_size + b_size; + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, 0, d_X, x_buf_offset, a_size, 1, 1, 1); + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, b_off, d_Y, y_buf_offset, b_size, 1, 1, 1); + + vk_context splitctx = ggml_vk_create_temporary_context(ctx->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(ctx->device, splitctx); + + ggml_vk_begin_submission(ctx->device, *splitctx->p); + + ggml_vk_dispatch_pipeline(ctx, splitctx, pipeline, + { vk_subbuffer{ ctx->prealloc_tile, 0, a_size }, + vk_subbuffer{ ctx->prealloc_tile, b_off, b_size }, + vk_subbuffer{ ctx->prealloc_tile, d_off, d_size } }, pc, elements); + + vk::SubmitInfo submit_info{}; + submit_info.commandBufferCount = 1; + submit_info.pCommandBuffers = &splitctx->s->buffer; + + ggml_vk_ctx_end(splitctx); + + splitctx->p->q->queue.submit(submit_info, ctx->split_fence); + + VK_CHECK(ctx->device->device.waitForFences({ ctx->split_fence }, true, UINT64_MAX), "split_fence"); + ctx->device->device.resetFences({ ctx->split_fence }); + + ggml_vk_copy_2d_to_2d(subctx, d_D, d_buf_offset, ctx->prealloc_tile, d_off, d_size, 1, 1, 1); +#endif + // XXX: TODO: enable this for other ops + } else if (do_tiling && op == GGML_OP_OUT_PROD) { + if constexpr (std::is_same_v) { + // XXX + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne12 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ned2 == 1); + GGML_ASSERT(ned3 == 1); + + // XXX make it work for n12 > 1 + for (uint32_t n0 = 0; n0 < ne10; n0 += tile_n) { + const uint32_t nt = (uint32_t) std::min(tile_n, ne10 - n0); + + uint64_t b_off_bytes = y_buf_offset + ((uint64_t)n0 / b_elems_per_block) * (uint64_t)b_bytes_per_block; + const uint64_t d_off_bytes_n = d_buf_offset + CEIL_DIV(((uint64_t)n0 * (uint64_t)ned0), d_elems_per_block) * (uint64_t)d_bytes_per_block; + + // XXX make it work for n02 > 1 + for (uint32_t m0 = 0; m0 < ne00; m0 += tile_m) { + const uint32_t mt = (uint32_t) std::min(tile_m, ne00 - m0); + uint64_t a_off_bytes = x_buf_offset + ((uint64_t)m0 / a_elems_per_block) * (uint64_t)a_bytes_per_block; + uint64_t d_off_bytes = d_off_bytes_n + CEIL_DIV((uint64_t)m0, d_elems_per_block) * (uint64_t)d_bytes_per_block; + + if (tiling_debug) { + std::cerr << "[VK] tile m=[" << m0 << "," << (m0+mt) << ") n=[" << n0 << "," << (n0+nt) << ")" << std::endl; + } + + const uint64_t orig_stride_a_bytes = CEIL_DIV(ne00, a_elems_per_block) * a_bytes_per_block; + const uint64_t orig_stride_b_bytes = CEIL_DIV(ne10, b_elems_per_block) * b_bytes_per_block; + const uint64_t orig_stride_d_bytes = CEIL_DIV(ned0, d_elems_per_block) * d_bytes_per_block; + + const uint64_t tile_stride_a_bytes = ggml_vk_align_size(CEIL_DIV(mt, a_elems_per_block) * a_bytes_per_block, align); + const uint64_t tile_stride_b_bytes = ggml_vk_align_size(CEIL_DIV(nt, b_elems_per_block) * b_bytes_per_block, align); + const uint64_t tile_stride_d_bytes = ggml_vk_align_size(CEIL_DIV(mt, d_elems_per_block) * d_bytes_per_block, align); + + const uint64_t a_mt_bytes = CEIL_DIV(mt, a_elems_per_block) * a_bytes_per_block; + const uint64_t b_nt_bytes = CEIL_DIV(nt, b_elems_per_block) * b_bytes_per_block; + + const uint64_t a_size = tile_stride_a_bytes * ne01; + const uint64_t b_size = tile_stride_b_bytes * ne11; + const uint64_t d_size = tile_stride_d_bytes * nt; + const uint64_t a_off = 0; + const uint64_t b_off = a_size; + const uint64_t d_off = a_size + b_size; + + const uint32_t tile_stride_a_elems = (tile_stride_a_bytes / a_bytes_per_block) * a_elems_per_block; + const uint32_t tile_stride_b_elems = (tile_stride_b_bytes / b_bytes_per_block) * b_elems_per_block; + const uint32_t tile_stride_d_elems = (tile_stride_d_bytes / d_bytes_per_block) * d_elems_per_block; + const uint64_t dst_copy_row_size = CEIL_DIV(mt, d_elems_per_block) * d_bytes_per_block; + + GGML_ASSERT(a_size + b_size + d_size < MAX_ADDRESS_SPACE_SIZE_OUT_PROD); + GGML_ASSERT(d_off + d_size < MAX_ADDRESS_SPACE_SIZE_OUT_PROD); + + GGML_ASSERT(ne01 == ne11); + + // XXX + VkMemoryBarrier memoryBarrier = {}; + memoryBarrier.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + memoryBarrier.srcAccessMask = VK_ACCESS_MEMORY_WRITE_BIT; + memoryBarrier.dstAccessMask = VK_ACCESS_MEMORY_READ_BIT; + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + + if (tiling_debug) { + std::cerr << "[VK] " + << " tile_stride_a_elems = " << tile_stride_a_elems << " | tile_stride_a_bytes = " << tile_stride_a_bytes << "\n" + << " tile_stride_b_elems = " << tile_stride_b_elems << " | tile_stride_b_bytes = " << tile_stride_b_bytes << "\n" + << " tile_stride_d_elems = " << tile_stride_d_elems << " | tile_stride_d_bytes = " << tile_stride_d_bytes << "\n" + << " m = " << ne00 << " | n = " << ne10 << " | k = " << ne01 << "\n" + << " mt = " << mt << " | nt = " << nt << " | kt = " << ne11 << "\n" + << " a_off_bytes = " << a_off_bytes << " | b_off_bytes = " << b_off_bytes << " | d_off_bytes = " << d_off_bytes << "\n" + << " a_mt_bytes = " << a_mt_bytes << " | b_nt_bytes = " << b_nt_bytes << "\n" + << " orig_stride_a_bytes = " << orig_stride_a_bytes << " | orig_stride_b_bytes = " << orig_stride_b_bytes << " | orig_stride_d_bytes = " << orig_stride_d_bytes << "\n" + << " dst_copy_row_size = " << dst_copy_row_size << "\n" + << std::endl; + } + + // Copy data to tile buffers + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, a_off, d_X, a_off_bytes, a_mt_bytes, ne01, orig_stride_a_bytes, tile_stride_a_bytes); + ggml_vk_copy_2d_to_2d(subctx, ctx->prealloc_tile, b_off, d_Y, b_off_bytes, b_nt_bytes, ne11, orig_stride_b_bytes, tile_stride_b_bytes); + + + // XXX: TODO: account for batch sizes? + uint32_t ne = std::min((uint32_t) ggml_nelements(dst), mt * nt); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne12 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne01 == ne11); + GGML_ASSERT(pc.param2 >= 0.9f && pc.param2 <= 1.1f); + GGML_ASSERT(pc.param3 == 1); + + uint32_t a_row_bytes = CEIL_DIV(mt, a_elems_per_block) * a_bytes_per_block; + // XXX? + uint32_t a_row_elems = a_row_bytes / a_bytes_per_unit; + + // XXX: account for batches + pc.ne = mt * nt; + + // XXX: account for batches + pc.ne00 = mt; pc.ne01 = ne01; +#if 0 + pc.nb00 = a_bytes_per_unit; + pc.nb01 = a_row_bytes; +#else + pc.nb00 = pc.nb00; + //pc.nb01 = a_row_elems; + //pc.nb01 = mt; + pc.nb01 = mt / a_elems_per_block; +#endif + pc.nb02 = pc.nb01 * pc.ne01; + pc.nb03 = pc.nb02 * pc.ne02; + + // XXX: account for batches + pc.ne10 = nt; pc.ne11 = ne11; + pc.nb10 = pc.nb10; + pc.nb11 = pc.nb10 * pc.ne10; + pc.nb12 = pc.nb11 * pc.ne11; + pc.nb13 = pc.nb12 * pc.ne12; + + // XXX: account for non-src2 dst? + // XXX: account for batches + // XXX: do we need to divide nt by b_elems_per_block? + pc.ne20 = mt; pc.ne21 = nt; + pc.nb20 = pc.nb20; + pc.nb21 = pc.nb20 * pc.ne20; + pc.nb22 = pc.nb21 * pc.ne21; + pc.nb23 = pc.nb22 * pc.ne22; + + ggml_vk_sync_buffers(subctx); + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + + vk_subbuffer a = { ctx->prealloc_tile, 0, a_size }; + vk_subbuffer b = { ctx->prealloc_tile, b_off, b_size }; + vk_subbuffer d = { ctx->prealloc_tile, d_off, d_size }; + if (tiling_debug) { + fprintf(stderr, "[VK] [DEBUG] a = (0x%p, %lu, %lu)\n", ctx->prealloc_tile->buffer, 0, a_size); + fprintf(stderr, " b = (0x%p, %lu, %lu)\n", ctx->prealloc_tile->buffer, b_off, b_size); + fprintf(stderr, " d = (0x%p, %lu, %lu)\n", ctx->prealloc_tile->buffer, d_off, d_size); + } + ggml_vk_dispatch_pipeline( + ctx, subctx, pipeline, { a, b, d }, pc, elements + ); + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + + ggml_vk_sync_buffers(subctx); + + // Copy results back to dst buffer + ggml_vk_copy_2d_to_2d(subctx, d_D, d_off_bytes, ctx->prealloc_tile, d_off, dst_copy_row_size, nt, tile_stride_d_bytes, orig_stride_d_bytes); + + // XXX + vkCmdPipelineBarrier(subctx->s->buffer, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, &memoryBarrier, + 0, NULL, + 0, NULL + ); + } + } + } + } else if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) { // Empty src1 is possible in soft_max, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -7717,6 +8605,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_geglu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GEGLU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; @@ -7787,6 +8679,18 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); } +static void ggml_vk_cross_entropy_loss_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const int64_t nclasses = src1->ne[0]; + const int64_t nrows = ggml_nrows(src1); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_CROSS_ENTROPY_LOSS_BACK, { + (uint32_t)nclasses, + (uint32_t)nrows, + 0.0f, + 0.0f + }, dryrun); +} + static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; @@ -7894,6 +8798,24 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_out_prod(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + const int64_t r2 = src1->ne[2] / src0->ne[2]; + const int64_t r3 = src1->ne[3] / src0->ne[3]; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_OUT_PROD, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, (float) r2, (int32_t) r3 + }, dryrun); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -8962,6 +9884,22 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); } + if (ctx->prealloc_tile_debug == nullptr || (ctx->prealloc_size_tile > 0 && ctx->prealloc_tile_debug->size < ctx->prealloc_size_tile)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(debug_tile_size: " << ctx->prealloc_size_tile * 4 << ")"); + // Resize buffer + if (ctx->prealloc_tile_debug != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_tile_debug); + } + ctx->prealloc_tile_debug = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_tile * 4); + } + if (ctx->prealloc_tile == nullptr || (ctx->prealloc_size_tile > 0 && ctx->prealloc_tile->size < ctx->prealloc_size_tile)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(tile_size: " << ctx->prealloc_size_tile << ")"); + // Resize buffer + if (ctx->prealloc_tile != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_tile); + } + ctx->prealloc_tile = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_tile); + } } static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); @@ -8981,6 +9919,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr const ggml_tensor * src1 = node->src[1]; const ggml_tensor * src2 = node->src[2]; const ggml_tensor * src3 = node->src[3]; + const ggml_tensor * dst = node; switch (node->op) { // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor @@ -9038,18 +9977,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_SILU_BACK: + case GGML_OP_GEGLU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_OUT_PROD: case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -9105,6 +10047,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_SILU_BACK: + case GGML_OP_GEGLU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: @@ -9117,6 +10060,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: + case GGML_OP_OUT_PROD: case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -9131,8 +10075,112 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr { // These operations all go through ggml_vk_op_f32, so short-circuit and // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + + const bool use_src1 = src1 != nullptr; + const bool use_src2 = src2 != nullptr; + + // XXX: currently not accounting for use_src2 ops + // XXX: currently only really enabled for GGML_OP_OUT_PROD + if (use_src1 && !use_src2 && node->op == GGML_OP_OUT_PROD) { + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + const uint64_t ne0 = ne00 * ne01; + + const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; + const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; + const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; + const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; + const uint64_t ne1 = ne10 * ne11; + + const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; + const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; + const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; + const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; + const uint64_t ne2 = ne20 * ne21; + + const uint64_t ned0 = dst->ne[0]; + const uint64_t ned1 = dst->ne[1]; + const uint64_t ned2 = dst->ne[2]; + const uint64_t ned3 = dst->ne[3]; + const uint64_t ned = ned0 * ned1; + + uint64_t x_sz = CEIL_DIV(ne0, ggml_blck_size(src0->type)) * ggml_type_size(src0->type); + uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; + uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; + uint64_t d_sz = ggml_type_size(dst->type) * ned; + + const uint64_t tiling_threshold = MAX_ADDRESS_SPACE_SIZE_OUT_PROD; + + // XXX: Cleanup + auto getenv_u32 = [](const char *name, uint32_t defval) -> uint32_t { + const char *v = getenv(name); + if (!v || !*v) return defval; + char *endp = nullptr; + unsigned long x = strtoul(v, &endp, 10); + if (endp == v) return defval; + if (x > 0xfffffffful) x = 0xfffffffful; + return (uint32_t)x; + }; + + // XXX: Cleanup. + const bool tiling_debug = getenv_u32("GGML_TILING_DEBUG", 0); + const bool tiling_enabled = getenv_u32("GGML_TILING_ENABLE", 0); + + // XXX: TODO: use a better way to detect high buffer range? + bool do_tiling = + tiling_enabled && +#if 0 + ((vk_tensor_offset(src0) + src0->view_offs + x_sz * ne02 * ne03) >= tiling_threshold || + (vk_tensor_offset(src1) + src1->view_offs + y_sz * ne12 * ne13) >= tiling_threshold || + (vk_tensor_offset(dst) + dst->view_offs + d_sz * ne12 * ne13) >= tiling_threshold); +#else + (x_sz * ne02 * ne03 + y_sz * ne12 * ne13 + + d_sz * ne12 * ne13 >= tiling_threshold); +#endif + + bool do_splitting = false; +#if 0 + if (do_tiling && (x_sz * ne02 * ne03) + (y_sz * ne12 * ne13) + (d_sz * ne12 * ne13) < tiling_threshold) { + do_splitting = true; + do_tiling = false; + } +#endif + if (do_splitting && tiling_debug) { + fprintf(stderr, "[VK] [%s] [SPLITTING] (total sum %lu) >= %lu\n", ggml_op_name(node->op), (x_sz * ne02 * ne03) + (y_sz * ne12 * ne13) + (d_sz * ne12 * ne13), tiling_threshold); + } + + uint64_t tile_m = 0, tile_n = 0; + uint64_t m_tiles = 0, n_tiles = 0; + uint64_t num_dispatches = 0; + if (do_tiling) { + calculate_tile_dims(ne00, ne10, ne01, &tile_m, &tile_n, &m_tiles, &n_tiles, &num_dispatches, src0->type); + if (tiling_debug) { + fprintf(stderr, "[VK] %s [TILING] [request_descriptor_sets] tile_m=%lu, tile_n=%lu, m_tiles=%lu, n_tiles=%lu, num_dispatches=%lu\n", ggml_op_name(node->op), tile_m, tile_n, m_tiles, n_tiles, num_dispatches); + } + } + + // XXX: Cleanup + // Allocate buffers for tiling (for buffers that are too large to bind fully) + if (node->op == GGML_OP_OUT_PROD) { + if (do_splitting) { + ctx->prealloc_size_tile = MAX_ADDRESS_SPACE_SIZE; + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return false; + } else if (do_tiling) { + ctx->prealloc_size_tile = MAX_ADDRESS_SPACE_SIZE; + ggml_pipeline_request_descriptor_sets(ctx, pipeline, num_dispatches); + return false; + } + } + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return false; } default: @@ -9156,6 +10204,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_GET_ROWS: ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_OUT_PROD: + ggml_vk_out_prod(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_ADD: ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9222,6 +10274,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SILU_BACK: ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_GEGLU_BACK: + ggml_vk_geglu_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_NORM: ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); @@ -9280,6 +10336,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_DIAG_MASK_INF: ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); + break; + + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + ggml_vk_cross_entropy_loss_back(ctx, compute_ctx, src0, src1, src2, node, dryrun); + break; case GGML_OP_SOFT_MAX: ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9438,12 +10499,14 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_SILU_BACK: + case GGML_OP_GEGLU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: @@ -9457,6 +10520,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGMAX: + case GGML_OP_OUT_PROD: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: @@ -9604,6 +10668,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->prealloc_x); ggml_vk_destroy_buffer(ctx->prealloc_y); ggml_vk_destroy_buffer(ctx->prealloc_split_k); + ggml_vk_destroy_buffer(ctx->prealloc_tile); + ggml_vk_destroy_buffer(ctx->prealloc_tile_debug); for (auto& buffer : ctx->buffer_pool) { ggml_vk_destroy_buffer(buffer); @@ -10366,6 +11432,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -10480,6 +11547,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: @@ -10505,6 +11573,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ4_NL: return true; default: @@ -10580,14 +11649,32 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_GROUP_NORM: case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_OUT_PROD: { + const ggml_type t0 = op->src[0]->type; + const ggml_type t1 = op->src[1]->type; + const ggml_type td = op->type; + if (td != GGML_TYPE_F32 || t1 != GGML_TYPE_F32) { + return false; + } + switch (t0) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + return true; + default: + return false; + } + } case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SILU_BACK: + case GGML_OP_GEGLU_BACK: case GGML_OP_RMS_NORM_BACK: case GGML_OP_SQR: case GGML_OP_SIN: @@ -10619,6 +11706,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return false; } @@ -10892,7 +11981,8 @@ size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { ggml_tensor * tensor = cgraph->nodes[tensor_idx]; - if (tensor->op == GGML_OP_TRANSPOSE) { + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { + fprintf(stderr, "[WARNING] ggml_vk_check_results_0 unimplemented op %s\n", ggml_op_name(tensor->op)); return; } @@ -11004,6 +12094,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else { tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); } + } else if (tensor->op == GGML_OP_CROSS_ENTROPY_LOSS_BACK) { + tensor_clone = ggml_cross_entropy_loss_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); } else if (tensor->op == GGML_OP_DIV) { tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { @@ -11012,7 +12104,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; - tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); + tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { @@ -11030,6 +12122,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_ADD) { tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_OUT_PROD) { + tensor_clone = ggml_out_prod(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { @@ -11123,8 +12217,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else { tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); } - } else if (tensor->op == GGML_OP_SET_ROWS) { - tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONT) { tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_RESHAPE) { @@ -11224,7 +12316,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { ggml_tensor * tensor = cgraph->nodes[tensor_idx]; - if (tensor->op == GGML_OP_TRANSPOSE) { + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { + fprintf(stderr, "[WARNING] ggml_vk_check_results_1 unimplemented op %s\n", ggml_op_name(tensor->op)); return; } bool fused_rms_norm_mul = false; @@ -11284,6 +12377,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->type == GGML_TYPE_F16) { correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_BF16) { + correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); } else if (tensor->type == GGML_TYPE_I32) { correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); @@ -11323,6 +12419,27 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * ggml_vk_print_graph_origin(tensor, done); GGML_ABORT("fatal error"); } + + // XXX just for debugging, on release builds this might not be an actual error + if (std::isnan(result)) { + std::cerr << std::endl << "[VK] [ERROR] NAN Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } + + // XXX just for debugging, on release builds this might not be an actual error + if (std::isinf(result)) { + std::cerr << std::endl << "[VK] [ERROR] INF Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } + const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f; if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) { first_error[0] = i0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp new file mode 100644 index 0000000000000..920279aee314f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp @@ -0,0 +1,92 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.comp" +#include "types.comp" + +#define FLOAT_TYPE float + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // Grad(scalar) +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // logits => raw model outputs(unnormalized scored) +layout (binding = 2) readonly buffer C {C_TYPE data_c[];}; // true labels(one hot encoded) +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; // output gradients + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +void main() { + const uint nclasses = p.KX; + const uint nrows = p.KY; + + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + if (row >= nrows) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint warp_size = gl_WorkGroupSize.x; + + const uint logits_offset = row * nclasses; + const uint labels_offset = row * nclasses; + const uint dst_offset = row * nclasses; + + // Gradient scaling (grad / batch_size) + const FLOAT_TYPE d_by_nrows = FLOAT_TYPE(data_a[0]) / FLOAT_TYPE(nrows); + + // Get max value per thread + FLOAT_TYPE thread_max = FLOAT_TYPE(uintBitsToFloat(0xFF800000)); // -INFINITY + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE val = FLOAT_TYPE(data_b[logits_offset + i]); + thread_max = max(thread_max, val); + } + + vals[tid] = thread_max; + barrier(); + + // Get global maximum for the row(batch) + [[unroll]] + for (uint s = warp_size / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + const FLOAT_TYPE row_max = vals[0]; + barrier(); + + // Compute sum of exp(logits - max) for softmax normalization + FLOAT_TYPE thread_sum = FLOAT_TYPE(0.0); + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE val = FLOAT_TYPE(data_b[logits_offset + i]); + thread_sum += exp(val - row_max); + } + + vals[tid] = thread_sum; + barrier(); + + [[unroll]] + for (uint s = warp_size / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE row_sum = vals[0]; + const FLOAT_TYPE sm_scale = FLOAT_TYPE(1.0) / row_sum; + barrier(); + + // Compute final gradients: (softmax - labels) * d_by_nrows + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE logit = FLOAT_TYPE(data_b[logits_offset + i]); + FLOAT_TYPE softmax_val = exp(logit - row_max) * sm_scale; + + FLOAT_TYPE label = FLOAT_TYPE(data_c[labels_offset + i]); + + data_d[dst_offset + i] = D_TYPE((softmax_val - label) * d_by_nrows); + } +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 0d9739d40609a..3bd7144c1dda7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -35,8 +35,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return (vec2(vui & 0xF, vui >> 4) - 8.0f); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); - return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); + const vec2 v01 = dequantize(ib, iqs, a_offset); + const vec2 v23 = dequantize(ib, iqs + 1, a_offset); + return vec4(v01.x, v01.y, v23.x, v23.y); } #endif @@ -46,8 +47,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(vui & 0xF, vui >> 4); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); - return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); + const vec2 v01 = dequantize(ib, iqs, a_offset); + const vec2 v23 = dequantize(ib, iqs + 1, a_offset); + return vec4(v01.x, v01.y, v23.x, v23.y); } #endif @@ -88,9 +90,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; - return vec4(v0.x, v0.y, v1.x, v1.y); + const vec2 v01 = dequantize(ib, iqs, a_offset); + const vec2 v23 = dequantize(ib, iqs + 2, a_offset); + return vec4(v01.x, v01.y, v23.x, v23.y); } #endif @@ -434,6 +436,30 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_TQ2_0) +// TQ2_0 ternary dequantization: {0,1,2} -> {-1,0,+1} via (q-1) mapping +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + const uint c0 = (vui >> 0) & 3; + const uint c1 = (vui >> 2) & 3; + const float q0 = float(c0) - 1.0f; + const float q1 = float(c1) - 1.0f; + return vec2(q0, q1); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + const uint c0 = (vui >> 0) & 3; + const uint c1 = (vui >> 2) & 3; + const uint c2 = (vui >> 4) & 3; + const uint c3 = (vui >> 6) & 3; + const float q0 = float(c0) - 1.0f; + const float q1 = float(c1) - 1.0f; + const float q2 = float(c2) - 1.0f; + const float q3 = float(c3) - 1.0f; + return vec4(q0, q1, q2, q3); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -449,7 +475,7 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 9cb7da2daab5d..47b03babb07d7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor } #endif +#if defined(DATA_A_TQ2_0) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2_0 { + block_tq2_0 block; +}; + +float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx / 4; + const uint iqs_offset = idx % 4; + const uint vui = uint(bl.block.qs[iqs]); + const uint c = (vui >> (2 * iqs_offset)) & 3; + const float q = float(c) - 1.0f; + float16_t ret = d * float16_t(q); + return ret; +} +#endif + #if defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 #elif defined(DATA_A_Q4_1) @@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncIQ4_XS #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#elif defined(DATA_A_TQ2_0) +#define dequantFuncA dequantFuncTQ2_0 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp new file mode 100644 index 0000000000000..f2fafcb3d49e1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp @@ -0,0 +1,36 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +layout (push_constant) uniform parameter { + uint ne; +} p; + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i = gl_GlobalInvocationID.x * 4; + + if (i >= p.ne) { + return; + } + + const uint ib = i / QUANT_K; // block index + const uint iqs = (i % QUANT_K) / 4; // quant index within block (byte index) + const uint bit_pos_base = (i % 4) * 2; // bit position within byte + + const float d = float(data_a[ib].d); + + for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) { + const uint local_iqs = ((i + j) % QUANT_K) / 4; // byte index for this element + const uint bit_pos = ((i + j) % 4) * 2; // bit position for this element + const uint vui = uint(data_a[ib].qs[local_iqs]); + const uint q = (vui >> bit_pos) & 3; + data_b[i + j] = D_TYPE(d * (float(q) - 1.0f)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index 7defe72b403b5..fc6767214f3be 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -62,26 +62,79 @@ layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[ #endif #if defined(DATA_A_Q4_0) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE data[];} kv_packed[2]; #define BLOCK_BYTE_SIZE 18 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint v00 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 0]); + uint v01 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 1]); + uint v10 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 2]); + uint v11 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 3]); + uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; + v00 >>= shift; + v01 >>= shift; + v10 >>= shift; + v11 >>= shift; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + v00 = v00 & 0xF; + v01 = v01 & 0xF; + v10 = v10 & 0xF; + v11 = v11 & 0xF; + + return float(kv_packed[binding_idx].data[a_offset + ib].d) * (vec4(v00, v01, v10, v11) - 8.0f); +} +#endif + +#if defined(DATA_A_Q4_1) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE data[];} kv_packed[2]; +#define BLOCK_BYTE_SIZE 20 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint v00 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 0]); + uint v01 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 1]); + uint v10 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 2]); + uint v11 = + uint(kv_packed[binding_idx].data[a_offset + ib].qs[(iqs & 0xF) + 3]); + + uint shift = (iqs & 0x10) >> 2; + v00 >>= shift; + v01 >>= shift; + v10 >>= shift; + v11 >>= shift; + + v00 = v00 & 0xF; + v01 = v01 & 0xF; + v10 = v10 & 0xF; + v11 = v11 & 0xF; + + return float(kv_packed[binding_idx].data[a_offset + ib].d) * (vec4(v00, v01, v10, v11) - 8.0f); } #endif #if defined(DATA_A_Q8_0) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE data[];} kv_packed[2]; #define BLOCK_BYTE_SIZE 34 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + const vec2 v0 = vec2(int(kv_packed[binding_idx].data[a_offset + ib].qs[iqs]), + int(kv_packed[binding_idx].data[a_offset + ib].qs[iqs + 1])); + const vec2 v1 = vec2(int(kv_packed[binding_idx].data[a_offset + ib].qs[iqs + 2]), + int(kv_packed[binding_idx].data[a_offset + ib].qs[iqs + 3])); - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return float(kv_packed[binding_idx].data[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_back.comp new file mode 100644 index 0000000000000..3c301596c2279 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_back.comp @@ -0,0 +1,53 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer GRAD {A_TYPE data_grad[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_x[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +float gelu(float x) { + const float c = 0.797884560802865; // sqrt(2/pi) + const float a = 0.044715; + const float inner = c * (x + a * x * x * x); + return 0.5 * x * (1.0 + tanh(inner)); +} + +float gelu_derivative(float x) { + const float c = 0.797884560802865; // sqrt(2/pi) + const float a = 0.044715; + const float x_squared = x * x; + const float x_cubed = x_squared * x; + const float inner = c * (x + a * x_cubed); + const float tanh_val = tanh(inner); + const float sech2_val = 1.0 - tanh_val * tanh_val; + const float dtanh_dx = c * (1.0 + 3.0 * a * x_squared) * sech2_val; + return 0.5 * (1.0 + tanh_val + x * dtanh_dx); +} + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const uint half_size = p.KX / 2; + + if (i < half_size) { + const float grad_val = float(data_grad[i]); + const float g_val = float(data_x[i + half_size]); + data_d[i] = D_TYPE(grad_val * gelu(g_val)); + } else { + const uint idx = i - half_size; + const float grad_val = float(data_grad[idx]); + const float x_val = float(data_x[idx]); + const float g_val = float(data_x[i]); + data_d[i] = D_TYPE(grad_val * x_val * gelu_derivative(g_val)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index bb429dd594588..621191d4e754d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -24,13 +24,39 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #if K_PER_ITER == 8 #if QUANT_R == 2 - const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); - const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); - const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); - const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); + // Replicate the original data_b_v4 indexing with /4 rounding + uint idx1 = (j*p.batch_stride_b + b_offset + iybs + iqs); + uint idx2 = (j*p.batch_stride_b + b_offset + iybs + iqs + y_offset); + uint base1 = (idx1 / 4) * 4; // Round down to nearest multiple of 4 + uint base2 = (idx2 / 4) * 4; // Round down to nearest multiple of 4 + + const FLOAT_TYPE bv02_x = FLOAT_TYPE(data_b[base1 + 0]); + const FLOAT_TYPE bv02_y = FLOAT_TYPE(data_b[base1 + 1]); + const FLOAT_TYPE bv02_z = FLOAT_TYPE(data_b[base1 + 2]); + const FLOAT_TYPE bv02_w = FLOAT_TYPE(data_b[base1 + 3]); + const FLOAT_TYPE bv13_x = FLOAT_TYPE(data_b[base2 + 0]); + const FLOAT_TYPE bv13_y = FLOAT_TYPE(data_b[base2 + 1]); + const FLOAT_TYPE bv13_z = FLOAT_TYPE(data_b[base2 + 2]); + const FLOAT_TYPE bv13_w = FLOAT_TYPE(data_b[base2 + 3]); + // XXX this is not guaranteed to be used for Q4, so make sure it works for everything else +#if 1 + const vec4 bv0 = vec4(bv02_x, bv13_x, bv02_y, bv13_y); + const vec4 bv1 = vec4(bv02_z, bv13_z, bv02_w, bv13_w); #else - const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); - const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); + const vec4 bv0 = vec4(1.0, 1.0, 1.0, 1.0); + const vec4 bv1 = vec4(1.0, 1.0, 1.0, 1.0); +#endif +#else + const FLOAT_TYPE bv00 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) ]); + const FLOAT_TYPE bv01 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 1]); + const FLOAT_TYPE bv02 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 2]); + const FLOAT_TYPE bv03 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 3]); + const FLOAT_TYPE bv10 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 4]); + const FLOAT_TYPE bv11 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 5]); + const FLOAT_TYPE bv12 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 6]); + const FLOAT_TYPE bv13 = FLOAT_TYPE(data_b[(j*p.batch_stride_b + b_offset + iybs + iqs) + 7]); + const vec4 bv0 = vec4(bv00, bv01, bv02, bv03); + const vec4 bv1 = vec4(bv10, bv11, bv12, bv13); #endif #else // Check if the second of the pair of elements is OOB, and don't fetch B or diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index 903753c7e2ec5..da83e0dd3c893 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -1,4 +1,5 @@ #extension GL_EXT_control_flow_attributes : enable +//#extension GL_EXT_integer_dot_product : require #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require @@ -10,8 +11,10 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +#if !defined(DATA_A_Q8_0) && !defined(DATA_A_Q4_0) && !defined(DATA_A_Q4_1) && !defined(DATA_A_Q6_K) layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#endif layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -92,15 +95,15 @@ shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { // sum up partial sums and write back result - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { + for (uint j = 0; j < NUM_COLS; ++j) { + for (uint n = 0; n < num_rows; ++n) { tmpsh[j][n][tid] = temp[j][n]; } } barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { if (tid < s) { - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; } @@ -109,8 +112,8 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32 barrier(); } if (tid == 0) { - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { + for (uint j = 0; j < NUM_COLS; ++j) { + for (uint n = 0; n < num_rows; ++n) { data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index d53d9ee0a2723..590504b799cd9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -14,7 +14,7 @@ uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { + for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; csel ^= 1; @@ -27,15 +27,39 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, continue; } - const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); - const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); +#if 0 + const uint32_t ql0_u32 = + uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | + (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = + uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | + (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + const uint32_t qh_u32 = + uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | + (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); +#else + const uint32_t ql0_u32 = + uint32_t(data_a[ib0 + i].ql[ql_offset]) | + (uint32_t(data_a[ib0 + i].ql[ql_offset + 1]) << 8) | + (uint32_t(data_a[ib0 + i].ql[ql_offset + 2]) << 16) | + (uint32_t(data_a[ib0 + i].ql[ql_offset + 3]) << 24); + const uint32_t ql32_u32 = + uint32_t(data_a[ib0 + i].ql[ql_offset + 32]) | + (uint32_t(data_a[ib0 + i].ql[ql_offset + 33]) << 8) | + (uint32_t(data_a[ib0 + i].ql[ql_offset + 34]) << 16) | + (uint32_t(data_a[ib0 + i].ql[ql_offset + 35]) << 24); + const uint32_t qh_u32 = + uint32_t(data_a[ib0 + i].qh[qh_offset + 0]) | + (uint32_t(data_a[ib0 + i].qh[qh_offset + 1]) << 8) | + (uint32_t(data_a[ib0 + i].qh[qh_offset + 2]) << 16) | + (uint32_t(data_a[ib0 + i].qh[qh_offset + 3]) << 24); +#endif const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; - const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; const uint32_t qh4_u32 = (qh_u32 & 0x30303030); @@ -46,10 +70,17 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; +#if 0 const vec4 q0 = vec4(unpack8(q0_u32)) - 32; const vec4 q1 = vec4(unpack8(q1_u32)) - 32; const vec4 q2 = vec4(unpack8(q2_u32)) - 32; const vec4 q3 = vec4(unpack8(q3_u32)) - 32; +#else + const vec4 q0 = vec4(float(q0_u32 & 0xFF), float((q0_u32 >> 8) & 0xFF), float((q0_u32 >> 16) & 0xFF), float(q0_u32 >> 24)) - 32; + const vec4 q1 = vec4(float(q1_u32 & 0xFF), float((q1_u32 >> 8) & 0xFF), float((q1_u32 >> 16) & 0xFF), float(q1_u32 >> 24)) - 32; + const vec4 q2 = vec4(float(q2_u32 & 0xFF), float((q2_u32 >> 8) & 0xFF), float((q2_u32 >> 16) & 0xFF), float(q2_u32 >> 24)) - 32; + const vec4 q3 = vec4(float(q3_u32 & 0xFF), float((q3_u32 >> 8) & 0xFF), float((q3_u32 >> 16) & 0xFF), float(q3_u32 >> 24)) - 32; +#endif if (all_threads) { sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); @@ -58,14 +89,38 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + for (uint j = 0; j < NUM_COLS; ++j) { + +#if 0 vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); +#else + vec4 by0 = + vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 0], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 1], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 2], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 3]); + vec4 by32 = + vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8 + 1], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8 + 2], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 8 + 3]); + vec4 by64 = + vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16 + 1], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16 + 2], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 16 + 3]); + vec4 by96 = + vec4(data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24 + 1], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24 + 2], + data_b[(j*p.batch_stride_b + b_offset + y_idx) + 4 * 24 + 3]); +#endif FLOAT_TYPE sum[4] = {0, 0, 0, 0}; - [[unroll]] for (uint l = 0; l < 4; ++l) { + for (uint l = 0; l < 4; ++l) { sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); @@ -99,8 +154,8 @@ void compute_outputs(const uint first_row, const uint num_rows) { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + for (uint j = 0; j < NUM_COLS; ++j) { + for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } @@ -108,7 +163,7 @@ void compute_outputs(const uint first_row, const uint num_rows) { const uint nbr_par_th = num_blocks_per_row%it_size; const uint nbr_all_th = num_blocks_per_row - nbr_par_th; uint i0 = 0; - [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + for (; i0 < nbr_all_th; i0 += it_size) calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp new file mode 100644 index 0000000000000..e49f8f3139b91 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp @@ -0,0 +1,66 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + const uint tid = gl_LocalInvocationID.x; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = tid; i < num_blocks_per_row; i += gl_WorkGroupSize.x) { + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row; + const float d = float(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < 64; j += 32) { + [[unroll]] for (uint l = 0; l < 4; ++l) { + [[unroll]] for (uint k = 0; k < 32; ++k) { + // Extract quantized value: ((x[i].qs[j + k] >> (l*2)) & 3) - 1 + const uint q_byte = uint(data_a[ib0 + i].qs[j + k]); + const uint shift = l * 2; + const uint q = (q_byte >> shift) & 3; + const FLOAT_TYPE dequant_val = FLOAT_TYPE(d * (float(q) - 1.0f)); // CPU kernel: (q-1)*d + + // y-data access pattern: y[i].qs[j*4 + l*32 + k] + const uint b_idx = i * QUANT_K + j * 4 + l * 32 + k; + if (b_idx < p.ncols) { + [[unroll]] for (uint jcol = 0; jcol < NUM_COLS; ++jcol) { + temp[jcol][n] += dequant_val * FLOAT_TYPE(data_b[jcol * p.batch_stride_b + b_offset + b_idx]); + } + } + } + } + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index f481549911b92..a610d7dcef946 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -117,6 +117,41 @@ shared uint _ne1_sh; shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif +#if defined(DATA_A_Q8_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v01 = dequantize(ib, iqs, a_offset); + const vec2 v23 = dequantize(ib, iqs + 2, a_offset); + return vec4(v01.x, v01.y, v23.x, v23.y); +} +#endif + +#if defined(DATA_A_Q4_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v01 = dequantize(ib, iqs, a_offset); + const vec2 v23 = dequantize(ib, iqs + 1, a_offset); + return vec4(v01.x, v01.y, v23.x, v23.y); +} +#endif + +#if defined(DATA_A_Q4_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v01 = dequantize(ib, iqs, a_offset); + const vec2 v23 = dequantize(ib, iqs + 1, a_offset); + return vec4(v01.x, v01.y, v23.x, v23.y); +} +#endif + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -324,6 +359,7 @@ void main() { const uint ib = idx / 4; const uint iqs = idx & 0x03; +#if 0 const float d = float(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; @@ -337,6 +373,20 @@ void main() { buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); +#else + const float d = float(data_a[ib].d); + const vec4 vxy = dequantize4(ib, 4*iqs, 0) * d; + const vec4 vzw = dequantize4(ib, 4*iqs + 2, 0) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(vxy.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(vxy.z); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(vzw.x); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(vzw.z); + buf_a[buf_idx + 16] = FLOAT_TYPE(vxy.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(vxy.w); + buf_a[buf_idx + 18] = FLOAT_TYPE(vzw.y); + buf_a[buf_idx + 19] = FLOAT_TYPE(vzw.w); +#endif #elif defined(DATA_A_Q4_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; @@ -344,6 +394,7 @@ void main() { const uint ib = idx / 4; const uint iqs = idx & 0x03; +#if 0 const float d = float(data_a_packed16[ib].d); const float m = float(data_a_packed16[ib].m); const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); @@ -358,6 +409,21 @@ void main() { buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); +#else + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const vec4 vxy = dequantize4(ib, 4*iqs, 0) * d + m; + const vec4 vzw = dequantize4(ib, 4*iqs + 2, 0) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(vxy.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(vxy.z); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(vzw.x); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(vzw.z); + buf_a[buf_idx + 16] = FLOAT_TYPE(vxy.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(vxy.w); + buf_a[buf_idx + 18] = FLOAT_TYPE(vzw.y); + buf_a[buf_idx + 19] = FLOAT_TYPE(vzw.w); +#endif #elif defined(DATA_A_Q5_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; @@ -404,10 +470,17 @@ void main() { const uint ib = idx / 8; const uint iqs = idx & 0x07; +#if 0 const float d = float(data_a_packed16[ib].d); const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; +#else + const float d = float(data_a[ib].d); + const vec2 v0 = dequantize(ib, 2*(2*iqs), 0); + const vec2 v1 = dequantize(ib, 2*(2*iqs + 1), 0); + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; +#endif buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 83de90eb7e0f2..80e2f63f6c9f9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -24,7 +24,11 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#if defined(DATA_A_Q8_0) || defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#else layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#endif #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp index 63b15471bd3aa..25707e869d641 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -8,12 +8,32 @@ #if defined(DATA_A_Q4_0) i32vec2 repack(uint ib, uint iqs) { +#if 0 // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], data_a[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); +#else + int32_t u0 = int32_t(uint(data_a[ib].qs[iqs * 4])); + int32_t u1 = int32_t(uint(data_a[ib].qs[iqs * 4 + 1])); + int32_t u2 = int32_t(uint(data_a[ib].qs[iqs * 4 + 2])); + int32_t u3 = int32_t(uint(data_a[ib].qs[iqs * 4 + 3])); + + int32_t v0 = int32_t( + (u0 & 0xF) | + ((u1 & 0xF) << 8) | + ((u2 & 0xF) << 16) | + ((u3 & 0xF) << 24)); + int32_t v1 = int32_t( + ((u0 >> 4) & 0xF) | + (((u1 >> 4) & 0xF) << 8) | + (((u2 >> 4) & 0xF) << 16) | + (((u3 >> 4) & 0xF) << 24)); + + return i32vec2(v0, v1); +#endif } ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { @@ -23,10 +43,30 @@ ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { #if defined(DATA_A_Q4_1) i32vec2 repack(uint ib, uint iqs) { +#if 0 // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 const uint32_t vui = data_a_packed32[ib].qs[iqs]; return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); +#else + int32_t u0 = int32_t(uint(data_a[ib].qs[iqs * 4])); + int32_t u1 = int32_t(uint(data_a[ib].qs[iqs * 4 + 1])); + int32_t u2 = int32_t(uint(data_a[ib].qs[iqs * 4 + 2])); + int32_t u3 = int32_t(uint(data_a[ib].qs[iqs * 4 + 3])); + + int32_t v0 = int32_t( + (u0 & 0xF) | + ((u1 & 0xF) << 8) | + ((u2 & 0xF) << 16) | + ((u3 & 0xF) << 24)); + int32_t v1 = int32_t( + ((u0 >> 4) & 0xF) | + (((u1 >> 4) & 0xF) << 8) | + (((u2 >> 4) & 0xF) << 16) | + (((u3 >> 4) & 0xF) << 24)); + + return i32vec2(v0, v1); +#endif } ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { @@ -76,9 +116,20 @@ ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { #if defined(DATA_A_Q8_0) int32_t repack(uint ib, uint iqs) { +#if 0 // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], data_a[ib].qs[iqs * 2 + 1])); +#else + int32_t u0 = int32_t(uint(data_a[ib].qs[iqs * 2]) & 0xFFu); + int32_t u1 = int32_t(uint(data_a[ib].qs[iqs * 2 + 1]) & 0xFFu); + int32_t u2 = int32_t(uint(data_a[ib].qs[iqs * 2 + 2]) & 0xFFu); + int32_t u3 = int32_t(uint(data_a[ib].qs[iqs * 2 + 3]) & 0xFFu); + + int32_t packed32 = (u3 << 24) | (u2 << 16) | (u1 << 8) | u0; + + return packed32; +#endif } ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { @@ -92,8 +143,15 @@ FLOAT_TYPE get_d(uint ib) { } #endif -#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +#if defined(DATA_A_Q4_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a[ib].d, data_a[ib].m); +} +#endif + +#if defined(DATA_A_Q5_1) FLOAT_TYPE_VEC2 get_dm(uint ib) { return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); } #endif + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/out_prod.comp b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod.comp new file mode 100644 index 0000000000000..31a7f40db694a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod.comp @@ -0,0 +1,56 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "generic_binary_head.comp" +#include "types.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { + i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); + const uint i22_offset = i22*p.ne21*p.ne20; + i21 = (idx - i23_offset - i22_offset) / p.ne20; + i20 = idx - i23_offset - i22_offset - i21*p.ne20; +} + +void main() { + // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx + const uint num_iter = 2; + + const uint broadcast2 = uint(p.param2); + const uint broadcast3 = p.param3; + + uint idx = get_idx(); + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + uint i0, i1, i2, i3; + get_dst_indices(idx, i0, i1, i2, i3); + + FLOAT_TYPE acc = FLOAT_TYPE(0.0); + + for (uint i01 = 0; i01 < p.ne01; ++i01) { + uint a_idx = src0_idx(i0, i01, i2 / broadcast2, i3 / broadcast3); + uint b_idx = src1_idx(i1, i01, i2, i3); + + FLOAT_TYPE a_val = FLOAT_TYPE(data_a[get_aoffset() + a_idx]); + FLOAT_TYPE b_val = FLOAT_TYPE(data_b[get_boffset() + b_idx]); + + acc += a_val * b_val; + } + + uint d_idx = dst_idx(i0, i1, i2, i3); + data_d[get_doffset() + d_idx] = D_TYPE(acc); + + idx += num_threads; + } +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q4_0.comp new file mode 100644 index 0000000000000..1d83771b1d910 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q4_0.comp @@ -0,0 +1,57 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +const uint num_threads = 256; +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { + i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); + const uint i22_offset = i22*p.ne21*p.ne20; + i21 = (idx - i23_offset - i22_offset) / p.ne20; + i20 = idx - i23_offset - i22_offset - i21*p.ne20; +} + +void main() { + // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx + const uint num_iter = 2; + + const uint broadcast2 = uint(p.param2); + const uint broadcast3 = p.param3; + + uint idx = get_idx(); + + [[unroll]] for (uint it = 0; it < num_iter; ++it) { + if (idx < p.ne) { + uint i0, i1, i2, i3; + get_dst_indices(idx, i0, i1, i2, i3); + + float acc = 0.0f; + + for (uint k = 0; k < p.ne01; k += 1) { + const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01; + const uint ib = a_block_base + (i0 / QUANT_K); + const uint iqs = i0 % (QUANT_K / QUANT_R); + const uint upper = (i0 % QUANT_K) / (QUANT_K / QUANT_R); + const uint lower = 1 - upper; + + const vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + + const float a_val = (v.x * lower + v.y * upper) * dm.x + dm.y; + + const uint b_idx = src1_idx(i1, k, i2, i3); + const float b = data_b[get_boffset() + b_idx]; + acc += a_val * b; + } + + uint d_idx = dst_idx(i0, i1, i2, i3); + data_d[get_doffset() + d_idx] = acc; + } + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp new file mode 100644 index 0000000000000..58acaae127622 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp @@ -0,0 +1,54 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +const uint num_threads = 256; +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { + i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); + const uint i22_offset = i22*p.ne21*p.ne20; + i21 = (idx - i23_offset - i22_offset) / p.ne20; + i20 = idx - i23_offset - i22_offset - i21*p.ne20; +} + +void main() { + // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx + const uint num_iter = 2; + + const uint broadcast2 = uint(p.param2); + const uint broadcast3 = p.param3; + + uint idx = get_idx(); + + [[unroll]] for (uint it = 0; it < num_iter; ++it) { + if (idx < p.ne) { + uint i0, i1, i2, i3; + get_dst_indices(idx, i0, i1, i2, i3); + + float acc = 0.0f; + + for (uint k = 0; k < p.ne01; k += 1) { + const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01; + const uint ib = a_block_base + (i0 / QUANT_K) * p.nb00; + const uint iqs = (i0 % QUANT_K) / QUANT_R; + + const vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + const float a_val = v.x * dm.x + dm.y; + + const uint b_idx = src1_idx(i1, k, i2, i3); + const float b = data_b[get_boffset() + b_idx]; + acc += a_val * b; + } + + uint d_idx = dst_idx(i0, i1, i2, i3); + data_d[get_doffset() + d_idx] = acc; + } + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 3bde717832b45..4d5f3a41550d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -65,7 +65,7 @@ struct block_q4_0_packed16 #define QUANT_R QUANT_R_Q4_0 #define QUANT_AUXF 1 #define A_TYPE block_q4_0 -#define A_TYPE_PACKED16 block_q4_0_packed16 +//#define A_TYPE_PACKED16 block_q4_0_packed16 #endif #define QUANT_K_Q4_1 32 @@ -96,8 +96,8 @@ struct block_q4_1_packed32 #define QUANT_R QUANT_R_Q4_1 #define QUANT_AUXF 2 #define A_TYPE block_q4_1 -#define A_TYPE_PACKED16 block_q4_1_packed16 -#define A_TYPE_PACKED32 block_q4_1_packed32 +//#define A_TYPE_PACKED16 block_q4_1_packed16 +//#define A_TYPE_PACKED32 block_q4_1_packed32 #endif #define QUANT_K_Q5_0 32 @@ -184,8 +184,8 @@ struct block_q8_0_packed32 #define QUANT_R QUANT_R_Q8_0 #define QUANT_AUXF 1 #define A_TYPE block_q8_0 -#define A_TYPE_PACKED16 block_q8_0_packed16 -#define A_TYPE_PACKED32 block_q8_0_packed32 +//#define A_TYPE_PACKED16 block_q8_0_packed16 +//#define A_TYPE_PACKED32 block_q8_0_packed32 #endif #define QUANT_K_Q8_1 32 @@ -347,7 +347,7 @@ struct block_q6_K_packed16 #if defined(DATA_A_Q6_K) #define QUANT_K QUANT_K_Q6_K #define A_TYPE block_q6_K -#define A_TYPE_PACKED16 block_q6_K_packed16 +//#define A_TYPE_PACKED16 block_q6_K_packed16 #endif // IQuants @@ -1337,6 +1337,22 @@ struct block_iq4_nl_packed16 #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif +// TQ2_0 +#define QUANT_K_TQ2_0 256 +#define QUANT_R_TQ2_0 4 + +struct block_tq2_0 +{ + uint8_t qs[QUANT_K_TQ2_0/QUANT_R_TQ2_0]; // 256/4 = 64 bytes + float16_t d; +}; + +#if defined(DATA_A_TQ2_0) +#define QUANT_K QUANT_K_TQ2_0 +#define QUANT_R QUANT_R_TQ2_0 +#define A_TYPE block_tq2_0 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 809c0bd9bd305..f7f2b03e50225 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -50,6 +50,7 @@ const std::vector type_names = { "q5_0", "q5_1", "q8_0", + "tq2_0", "q2_k", "q3_k", "q4_k", @@ -467,6 +468,9 @@ void process_shaders() { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + if (tname == "tq2_0") { + shader = "mul_mat_vec_tq2_0.comp"; + } string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); @@ -610,9 +614,12 @@ void process_shaders() { string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_back_f32", "geglu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cross_entropy_loss_back_f32", "cross_entropy_loss_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -633,6 +640,11 @@ void process_shaders() { string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("out_prod_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("out_prod_f16_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("out_prod_q4_0", "out_prod_q4_0.comp", merge_maps(base_dict, {{"DATA_A_Q4_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("out_prod_q8_0", "out_prod_q8_0.comp", merge_maps(base_dict, {{"DATA_A_Q8_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5ae1c527df639..a26645c213be1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "REPEAT_BACK", "CONCAT", "SILU_BACK", + "GEGLU_BACK", "NORM", "RMS_NORM", "RMS_NORM_BACK", @@ -1010,7 +1011,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); +static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1036,6 +1037,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "repeat_back(x)", "concat(x, y)", "silu_back(x)", + "geglu_back(x)", "norm(x)", "rms_norm(x)", "rms_norm_back(x)", @@ -1110,7 +1112,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); +static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2632,6 +2634,22 @@ struct ggml_tensor * ggml_silu_back( return result; } +// ggml_geglu_back +struct ggml_tensor * ggml_geglu_back( + struct ggml_context * ctx, + struct ggml_tensor * grad, + struct ggml_tensor * x, + struct ggml_tensor * g) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, x); + + result->op = GGML_OP_GEGLU_BACK; + result->src[0] = grad; + result->src[1] = x; + result->src[2] = g; + + return result; +} + // ggml hardswish struct ggml_tensor * ggml_hardswish( @@ -6123,6 +6141,16 @@ static void ggml_compute_backward( ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad)); } } break; + case GGML_GLU_OP_GEGLU: { + if (src0_needs_grads) { + GGML_ASSERT(src1 && "backward pass only implemented for split geglu"); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_gelu(ctx, src1))); + } + if (src1_needs_grads) { + struct ggml_tensor * grad_mul_src0 = ggml_mul(ctx, grad, src0); + ggml_add_or_set(ctx, cgraph, isrc1, ggml_geglu_back(ctx, grad_mul_src0, src1, src1)); + } + } break; default: { GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor))); } //break; diff --git a/include/llama.h b/include/llama.h index 1c3a1cd1b4e7d..5accb65e5a0e3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1442,6 +1442,44 @@ extern "C" { int64_t idata_split, ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + + // LoRA training parameters + enum llama_lora_target_module { + LLAMA_LORA_TARGET_ATTN_Q = 1 << 0, + LLAMA_LORA_TARGET_ATTN_K = 1 << 1, + LLAMA_LORA_TARGET_ATTN_V = 1 << 2, + LLAMA_LORA_TARGET_ATTN_O = 1 << 3, + LLAMA_LORA_TARGET_FFN_GATE = 1 << 4, + LLAMA_LORA_TARGET_FFN_UP = 1 << 5, + LLAMA_LORA_TARGET_FFN_DOWN = 1 << 6, + LLAMA_LORA_TARGET_OUTPUT = 1 << 7, + LLAMA_LORA_TARGET_ALL = 0x1FF, + }; + + struct llama_lora_training_params { + uint32_t target_modules; + int32_t rank; + float alpha; + float dropout; + float init_std; + }; + + // Initialize LoRA training with the given parameters + // Creates LoRA tensors and adds them to the model context + LLAMA_API struct llama_adapter_lora * llama_lora_training_init( + struct llama_context * ctx, + struct llama_model * model, + const struct llama_lora_training_params * params + ); + + // LoRA parameter filter (returns true for LoRA tensors only) + LLAMA_API bool llama_opt_param_filter_lora(const struct ggml_tensor * tensor, void * userdata); + + LLAMA_API bool llama_lora_save_adapter( + const struct llama_adapter_lora * adapter, + const char * filename, + const struct llama_model * model + ); #ifdef __cplusplus } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8f9cd652447ab..6aaac7875203d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(llama llama-io.cpp llama-kv-cache-unified.cpp llama-kv-cache-unified-iswa.cpp + llama-lora-training.cpp llama-memory.cpp llama-memory-hybrid.cpp llama-memory-recurrent.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1af19caa39dab..3eeb7dae5fe03 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2030,6 +2030,23 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(reinterpret_cast(&layer)[i], param_filter, param_filter_ud); } } + + // Set LoRA params as trainable if any? + for (const auto & adapter_pair : loras) { + llama_adapter_lora * adapter = adapter_pair.first; + if (adapter) { + // Register lora tensors as params for training + for (const auto & tensor_pair : adapter->ab_map) { + const llama_adapter_lora_weight & weight = tensor_pair.second; + if (weight.a) { + llama_set_param(weight.a, param_filter, param_filter_ud); + } + if (weight.b) { + llama_set_param(weight.b, param_filter, param_filter_ud); + } + } + } + } } void llama_context::opt_epoch_iter( diff --git a/src/llama-lora-training.cpp b/src/llama-lora-training.cpp new file mode 100644 index 0000000000000..e7db81d591f1f --- /dev/null +++ b/src/llama-lora-training.cpp @@ -0,0 +1,359 @@ +#include "llama-lora-training.h" + +#include +#include +#include +#include +#include +#include + + +ggml_context * llama_lora_create_context(size_t mem_size) { + struct ggml_init_params init_params = { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + return ggml_init(init_params); +} + +bool llama_lora_validate_training_params(const struct llama_lora_training_params * params) { + if (!params) { + LLAMA_LOG_ERROR("LoRA training validation: params is null\n"); + return false; + } + + if (params->rank <= 0 || params->rank > 1024) { + LLAMA_LOG_ERROR("LoRA training validation: invalid rank %d (must be 1-1024)\n", params->rank); + return false; + } + + if (params->alpha <= 0.0f) { + LLAMA_LOG_ERROR("LoRA training validation: invalid alpha %f (must be > 0)\n", params->alpha); + return false; + } + + if (params->dropout < 0.0f || params->dropout > 1.0f) { + LLAMA_LOG_ERROR("LoRA training validation: invalid dropout %f (must be [0, 1])\n", params->dropout); + return false; + } + + if (params->init_std <= 0.0f || params->init_std > 1.0f) { + LLAMA_LOG_ERROR("LoRA training validation: invalid init_std %f (must be (0, 1])\n", params->init_std); + return false; + } + + if (params->target_modules == 0) { + LLAMA_LOG_ERROR("LoRA training validation: no target modules specified\n"); + return false; + } + + return true; +} + +bool llama_lora_create_tensor_pair( + struct ggml_context * lora_ctx, + const char * base_name, + const struct ggml_tensor * base_tensor, + int32_t rank, + struct ggml_tensor ** lora_a, + struct ggml_tensor ** lora_b) { + + if (!lora_ctx || !base_name || !base_tensor || !lora_a || !lora_b) { + return false; + } + + // Get base tensor dim + const int64_t d0 = base_tensor->ne[0]; // input dim + const int64_t d1 = base_tensor->ne[1]; // output dim + + char lora_a_name[256], lora_b_name[256]; + snprintf(lora_a_name, sizeof(lora_a_name), "%s.lora_a", base_name); + snprintf(lora_b_name, sizeof(lora_b_name), "%s.lora_b", base_name); + + // LoRA A: [d0, rank] - projects input to low rank + *lora_a = ggml_new_tensor_2d(lora_ctx, GGML_TYPE_F32, d0, rank); + ggml_set_name(*lora_a, lora_a_name); + + // LoRA B: [rank, d1] - projects from low rank to output + *lora_b = ggml_new_tensor_2d(lora_ctx, GGML_TYPE_F32, rank, d1); + ggml_set_name(*lora_b, lora_b_name); + + return true; +} + +static bool is_tensor_on_device(const struct ggml_tensor * tensor) { + return tensor->buffer && !ggml_backend_buffer_is_host(tensor->buffer); +} + +static void init_tensor_guassian(struct ggml_tensor * tensor, float std_dev) { + const size_t n_elements = ggml_nelements(tensor); + std::vector data(n_elements); + + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution dist(0.0f, std_dev); + + for (size_t i = 0; i < n_elements; i++) { + data[i] = dist(gen); + } + + if (is_tensor_on_device(tensor)) { + ggml_backend_tensor_set(tensor, data.data(), 0, n_elements * sizeof(float)); + } else { + std::copy(data.begin(), data.end(), (float *)tensor->data); + } +} + +static void init_tensor_zeros(struct ggml_tensor * tensor) { + const size_t n_elements = ggml_nelements(tensor); + + if (is_tensor_on_device(tensor)) { + std::vector zeros(n_elements, 0.0f); + ggml_backend_tensor_set(tensor, zeros.data(), 0, n_elements * sizeof(float)); + } else { + std::fill_n((float *)tensor->data, n_elements, 0.0f); + } +} + +void llama_lora_init_tensor_weights(struct ggml_tensor * lora_a, struct ggml_tensor * lora_b, float init_std) { + if (!lora_a || !lora_b) return; + + // LoRA initialization: A ~ N(0, init_std), B = 0 + init_tensor_guassian(lora_a, init_std); + init_tensor_zeros(lora_b); +} + +bool llama_lora_allocate_buffers( + struct llama_adapter_lora * adapter, + struct llama_model * model) { + + if (!adapter || !model) { + return false; + } + + std::map ctx_map; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); // fallback to CPU + + // Find any layer tensor to determine the correct backend + for (const auto & tensor_pair : model->tensors_by_name) { + const std::string & name = tensor_pair.first; + struct ggml_tensor * tensor = tensor_pair.second; + + if (name.find("blk.") != std::string::npos && tensor && tensor->buffer) { + buft = ggml_backend_buffer_get_type(tensor->buffer); + break; + } + } + + if (adapter->ctxs.empty()) { + LLAMA_LOG_ERROR("No contexts found in adapter\n"); + return false; + } + ggml_context * lora_ctx = adapter->ctxs[0].get(); + + ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, buft) }; + if (!buf) { + LLAMA_LOG_ERROR("Failed to allocate buffer for LoRA adapter\n"); + return false; + } + LLAMA_LOG_INFO("LoRA buffer size = %.2f MiB\n", ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0); + adapter->bufs.emplace_back(std::move(buf)); + + return true; +} + +struct llama_adapter_lora * llama_lora_create_adapter( + struct llama_model * model, + const struct llama_lora_training_params * params) { + + // Create a new LoRA adapter instance + llama_adapter_lora * adapter = new llama_adapter_lora(); + try { + adapter->alpha = params->alpha; + + // Create LoRA tensors and populate ab_map + // Create GGML context for LoRA tensors + const size_t estimated_lora_mem = 256 * 1024 * 1024; // 256MB should be enough for most LoRA configs + ggml_context * lora_ctx = llama_lora_create_context(estimated_lora_mem); + if (!lora_ctx) { + throw std::runtime_error("Failed to create LoRA context"); + } + + adapter->ctxs.emplace_back(lora_ctx); + int created_count = 0; + + for (const auto & tensor_pair : model->tensors_by_name) { + const std::string & tensor_name = tensor_pair.first; + struct ggml_tensor * base_tensor = tensor_pair.second; + + if (!base_tensor) { + continue; + } + + bool should_create_lora = false; + if (tensor_name.find("blk.") != std::string::npos) { + if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_Q) && tensor_name.find("attn_q") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_K) && tensor_name.find("attn_k") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_V) && tensor_name.find("attn_v") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_O) && tensor_name.find("attn_output") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_FFN_GATE) && tensor_name.find("ffn_gate") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_FFN_UP) && tensor_name.find("ffn_up") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_FFN_DOWN) && tensor_name.find("ffn_down") != std::string::npos) { + should_create_lora = true; + } + } else if ((params->target_modules & LLAMA_LORA_TARGET_OUTPUT) && tensor_name.find("output") != std::string::npos) { + should_create_lora = true; + } + + if (should_create_lora && base_tensor->ne[1] > 0) { + struct ggml_tensor * lora_a = nullptr; + struct ggml_tensor * lora_b = nullptr; + + if (llama_lora_create_tensor_pair(lora_ctx, tensor_name.c_str(), base_tensor, params->rank, &lora_a, &lora_b)) { + if (!lora_a || !lora_b) { + throw std::runtime_error("Created null LoRA tensors for " + tensor_name); + } + created_count++; + adapter->ab_map[tensor_name] = llama_adapter_lora_weight(lora_a, lora_b); + } else { + throw std::runtime_error("Failed to create LoRA tensor pair for " + tensor_name); + } + } + } + + if (created_count == 0) { + throw std::runtime_error("No suitable tensors found for LoRA adaptation"); + } + + if (!llama_lora_allocate_buffers(adapter, model)) { + throw std::runtime_error("Failed to allocate LoRA buffers"); + } + + for (const auto & ab_pair : adapter->ab_map) { + const std::string & tensor_name = ab_pair.first; + const llama_adapter_lora_weight & weight = ab_pair.second; + + if (weight.a && weight.b && weight.a->data && weight.b->data) { + llama_lora_init_tensor_weights(weight.a, weight.b, params->init_std); + } else { + throw std::runtime_error("LoRA tensor initialization failed for " + tensor_name); + } + } + return adapter; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("Failed to create LoRA adapter: %s\n", err.what()); + delete adapter; + return nullptr; + } +} + +struct llama_adapter_lora * llama_lora_training_init( + struct llama_context * ctx, + struct llama_model * model, + const struct llama_lora_training_params * params) { + + if (!ctx || !model || !params) { + LLAMA_LOG_ERROR("LoRA training init: invalid parameters\n"); + return nullptr; + } + + if (!llama_lora_validate_training_params(params)) { + return nullptr; + } + + struct llama_adapter_lora * adapter = llama_lora_create_adapter(model, params); + if (!adapter) { + return nullptr; + } + + llama_clear_adapter_lora(ctx); + + if (llama_set_adapter_lora(ctx, adapter, 1.0f) < 0) { + LLAMA_LOG_ERROR("Failed to apply LoRA adapter to context\n"); + delete adapter; + return nullptr; + } + + LLAMA_LOG_INFO("LoRA adapter contains %zu tensor pairs and is now registered with context\n", adapter->ab_map.size()); + + return adapter; +} + +bool llama_opt_param_filter_lora(const struct ggml_tensor * tensor, void * userdata) { + (void) userdata; // Unused param + + if (!tensor) { + return false; + } + + const char * name = tensor->name; + + // Check if tensor is LoRA A or B + // LoRA tensor naming convention: blk.{layer}.{module}.lora_a or .lora_b + if (strstr(name, ".lora_a") || strstr(name, ".lora_b")) { + LLAMA_LOG_DEBUG("LoRA filter: including trainable params '%s'\n", name); + return true; + } + + return false; +} + +bool llama_lora_save_adapter( + const struct llama_adapter_lora * adapter, + const char * filename, + const struct llama_model * model) { + + if (!adapter || !filename || !model) { + LLAMA_LOG_ERROR("llama_lora_save_adapter: invalid parameters\n"); + return false; + } + + struct gguf_context * gguf_ctx = gguf_init_empty(); + if (!gguf_ctx) { + LLAMA_LOG_ERROR("llama_lora_save_adapter: failed to create GGUF context\n"); + return false; + } + + std::string arch_name = model->arch_name(); + if (arch_name.empty()) { + LLAMA_LOG_ERROR("llama_lora_save_adapter: failed to get model architecture\n"); + gguf_free(gguf_ctx); + return false; + } + + gguf_set_val_str(gguf_ctx, "general.architecture", arch_name.c_str()); + gguf_set_val_str(gguf_ctx, "general.type", "adapter"); + gguf_set_val_str(gguf_ctx, "general.name", "LoRA Adapter"); + gguf_set_val_str(gguf_ctx, "adapter.type", "lora"); + gguf_set_val_f32(gguf_ctx, "adapter.lora.alpha", adapter->alpha); + + int tensor_count = 0; + for (const auto & kv : adapter->ab_map) { + const auto & lora_weight = kv.second; + + if (lora_weight.a && lora_weight.b) { + gguf_add_tensor(gguf_ctx, lora_weight.a); + gguf_add_tensor(gguf_ctx, lora_weight.b); + tensor_count += 2; + } + } + + bool success = gguf_write_to_file(gguf_ctx, filename, false); + if (success) { + LLAMA_LOG_INFO("Successfully saved LoRA adapter with %d tensors to: %s\n", + tensor_count, filename); + } else { + LLAMA_LOG_ERROR("Failed to write LoRA adapter to: %s\n", filename); + } + + gguf_free(gguf_ctx); + return success; +} diff --git a/src/llama-lora-training.h b/src/llama-lora-training.h new file mode 100644 index 0000000000000..ed777be7b36f7 --- /dev/null +++ b/src/llama-lora-training.h @@ -0,0 +1,34 @@ +#pragma once + +#include "llama.h" +#include "llama-model.h" +#include "llama-adapter.h" +#include "llama-impl.h" +#include "ggml.h" + + +bool llama_lora_validate_training_params(const struct llama_lora_training_params * params); + +ggml_context * llama_lora_create_context(size_t mem_size); + +bool llama_lora_create_tensor_pair( + struct ggml_context * lora_ctx, + const char * base_name, + const struct ggml_tensor * base_tensor, + int32_t rank, + struct ggml_tensor ** lora_a, + struct ggml_tensor ** lora_b); + +void llama_lora_init_tensor_weights( + struct ggml_tensor * lora_a, + struct ggml_tensor * lora_b, + float init_std); + +struct llama_adapter_lora * llama_lora_create_adapter( + struct llama_model * model, + const struct llama_lora_training_params * params); + +bool llama_lora_allocate_buffers( + struct llama_adapter_lora * adapter, + struct llama_model * model); + diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a3d68fba046cf..53c82ff8d0b23 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1154,10 +1154,10 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); - //for (int i = 0; i < (int) f1.size(); i++) { - // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - //} - //printf("\n"); + for (int i = 0; i < (int) f1.size(); i++) { + printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + } + printf("\n"); //exit(1); ud->ok = false; } @@ -2952,6 +2952,55 @@ struct test_mul_mat : public test_case { return out; } + +#if 0 + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_F32) { + if (t->name[0] == 'o') { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = -13.0f; + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } else if (t->name[0] == 'a') { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = (3.50f * i + r) / (t->ne[0] + t->ne[1] * 0.5); + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } else { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + //data[i] = (1.33f * i + r) / (t->ne[0] + t->ne[1] * 1.2); + data[i] = 1.0f; + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } + } /* else if (t->type == GGML_TYPE_Q8_0) { + if (t->name[0] == 'a') { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + //data[i] = (4 * i + r) / (t->ne[0] + t->ne[1] * 0.5); + //data[i] = std::min(i, 32); + data[i] = 128; + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(uint8_t)); + } + } + } */ else { + init_tensor_uniform(t); + } + } + } +#endif }; // GGML_OP_MUL_MAT_ID @@ -3073,6 +3122,56 @@ struct test_out_prod : public test_case { return out; } + +#if 0 + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_F32) { + if (t->name[0] == 'o') { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = -13.0f; + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } else if (t->name[0] == 'a') { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = (3.50f * i + r) / (t->ne[0] + t->ne[1] * 0.5); + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } else { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + //data[i] = (1.33f * i + r) / (t->ne[0] + t->ne[1] * 1.2); + data[i] = i+1; + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } + } else if (t->type == GGML_TYPE_Q8_0) { + if (t->name[0] == 'a') { + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + //data[i] = (4 * i + r) / (t->ne[0] + t->ne[1] * 0.5); + //data[i] = std::min(i, 32); + data[i] = 100 + i; + } + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(uint8_t)); + } + } + } else { + init_tensor_uniform(t); + } + } + } +#endif + }; // GGML_OP_SQR @@ -4821,17 +4920,18 @@ struct test_falcon : public test_llm { // ## Section 3: GGML Op Test Instantiation ## // ########################################### static const ggml_type all_types[] = { - GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, - GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, - GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_F32, GGML_TYPE_F16, // GGML_TYPE_BF16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, +// GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, - GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, +// GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, +// GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends - GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, - GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, - GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, +// GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, +// GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, +// GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, }; static const ggml_type base_types[] = { @@ -4839,22 +4939,23 @@ static const ggml_type base_types[] = { GGML_TYPE_Q8_0, // for I8MM tests GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, // for I8MM tests - GGML_TYPE_Q4_K, - GGML_TYPE_IQ2_XXS +// GGML_TYPE_Q4_K, +// GGML_TYPE_IQ2_XXS }; static const ggml_type other_types[] = { + GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, - GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, +// GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, - GGML_TYPE_Q5_K, +// GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, +// GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends - GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, - GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, - GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, - GGML_TYPE_BF16, +// GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, +// GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, +// GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, +// GGML_TYPE_BF16, }; // Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low @@ -5197,6 +5298,60 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32, 4096*20, 256, 1024, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q8_0, GGML_TYPE_F32, 4096*20, 256, 1024, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32, 1024, 256, 4096*40, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_out_prod(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, 256, 4096*40, {1, 1}, {1, 1})); + + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 151936, 1, 1024, {1, 1}, {1, 1})); + +#if 0 + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 2, 2, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 4, 4, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 2, 2, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 4, 4, 32, {1, 1}, {1, 1})); + + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 4096*20, 256, 1024, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 4096*20, 256, 1024, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, 256, 4096*20, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 1024, 256, 4096*20, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, 256, 4096*40, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 1024, 256, 4096*40, {1, 1}, {1, 1})); +#endif + +#if 0 + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 2, 2, 1024, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 1024, 2, 2, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 2, 1024, 2, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 2, 1, 4096, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 1, 2, 4096, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 2, 2, 4096, {1, 1}, {1, 1})); + + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 2, 2, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 4, 4, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 4, 4, 64, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8, 8, 128, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 32, 2, 2, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 16, 2, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 64, 2, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 16, 4, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 64, 4, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 4, 16, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 4, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 4, 16, 64, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 4, 64, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 32, 32, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 32, 64, 64, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 32, 64, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 64, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 64, 64, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 128, 64, 2, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 256, 128, 32, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 256, 128, 64, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 512, 256, 1024, {1, 1}, {1, 1})); +#endif + +#if 1 for (ggml_type type_a : all_types) { for (int i = 1; i < 10; ++i) { test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); @@ -5306,6 +5461,7 @@ static std::vector> make_test_cases_eval() { // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.) // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend) // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1})); +#endif // test large experts*tokens for (bool b : {false, true}) { @@ -5604,6 +5760,7 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1})); +#if 0 test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true)); @@ -5614,6 +5771,7 @@ static std::vector> make_test_cases_perf() { } } } +#endif for (int K : {3, 5}) { for (int IC : {256, 2560}) {