From 8fcf31b4f1b98a602ae323c23b82121c192ce9e3 Mon Sep 17 00:00:00 2001 From: Kamil Tomsik Date: Tue, 16 May 2023 17:43:42 +0200 Subject: [PATCH] ggml: add map_ternary_f32() --- ggml.c | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- ggml.h | 12 +++++- 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index 4311ce7cf9dbe..d458a6e559ecd 100644 --- a/ggml.c +++ b/ggml.c @@ -3465,9 +3465,10 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_UNARY", "MAP_BINARY", + "MAP_TERNARY", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3527,7 +3528,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -4034,6 +4035,7 @@ struct ggml_tensor * ggml_new_tensor_impl( /*.grad =*/ NULL, /*.src0 =*/ NULL, /*.src1 =*/ NULL, + /*.src2 =*/ NULL, /*.opt =*/ { NULL }, /*.n_tasks =*/ 0, /*.perf_runs =*/ 0, @@ -6421,6 +6423,56 @@ struct ggml_tensor * ggml_map_binary_inplace_f32( return ggml_map_binary_impl_f32(ctx, a, b, fun, true); } +// ggml_map_ternary + +struct ggml_tensor * ggml_map_ternary_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_ternary_op_f32_t fun, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_are_same_shape(b, c)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad || c->grad)) { + is_node = true; + } + + struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); + *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MAP_TERNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + result->src2 = c; + result->opt[0] = addr_tensor; + + return result; +} + +struct ggml_tensor * ggml_map_ternary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_ternary_op_f32_t fun) { + return ggml_map_ternary_impl_f32(ctx, a, b, c, fun, false); +} + +struct ggml_tensor * ggml_map_ternary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_ternary_op_f32_t fun) { + return ggml_map_ternary_impl_f32(ctx, a, b, c, fun, true); +} + //////////////////////////////////////////////////////////////////////////////// void ggml_set_param( @@ -12628,6 +12680,59 @@ static void ggml_compute_forward_map_binary( } } +// ggml_compute_forward_map_ternary + +static void ggml_compute_forward_map_ternary_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst, + const ggml_ternary_op_f32_t fun) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src1, src2) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + assert(src2->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1])), + (float *) ((char *) src2->data + i*(src2->nb[1]))); + } +} + + +static void ggml_compute_forward_map_ternary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst, + const ggml_ternary_op_f32_t fun) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_ternary_f32(params, src0, src1, src2, dst, fun); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + ///////////////////////////////// static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { @@ -12837,6 +12942,12 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun); } break; + case GGML_OP_MAP_TERNARY: + { + const ggml_ternary_op_f32_t fun = *((ggml_ternary_op_f32_t *)tensor->opt[0]->data); + ggml_compute_forward_map_ternary(params, tensor->src0, tensor->src1, tensor->src2, tensor, fun); + } + break; case GGML_OP_NONE: { // nop @@ -13517,6 +13628,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_TERNARY: { GGML_ASSERT(false); // not supported } break; @@ -14062,6 +14174,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_TERNARY: { node->n_tasks = 1; } break; diff --git a/ggml.h b/ggml.h index 255541d0257e3..4c51290f0e13c 100644 --- a/ggml.h +++ b/ggml.h @@ -321,6 +321,7 @@ extern "C" { GGML_OP_MAP_UNARY, GGML_OP_MAP_BINARY, + GGML_OP_MAP_TERNARY, GGML_OP_COUNT, }; @@ -358,6 +359,7 @@ extern "C" { struct ggml_tensor * grad; struct ggml_tensor * src0; struct ggml_tensor * src1; + struct ggml_tensor * src2; struct ggml_tensor * opt[GGML_MAX_OPT]; // thread scheduling @@ -372,7 +374,7 @@ extern "C" { char name[32]; - char padding[16]; + char padding[8]; }; // computation graph @@ -931,6 +933,7 @@ extern "C" { // Mapping operations typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *); typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); + typedef void (*ggml_ternary_op_f32_t)(const int, float *, const float *, const float *, const float *); GGML_API struct ggml_tensor * ggml_map_unary_f32( struct ggml_context * ctx, @@ -943,6 +946,13 @@ extern "C" { struct ggml_tensor * b, ggml_binary_op_f32_t fun); + GGML_API struct ggml_tensor * ggml_map_ternary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_ternary_op_f32_t fun); + // // automatic differentiation //