Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml: add map_ternary_f32() #1482

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 115 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3465,9 +3465,10 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {

"MAP_UNARY",
"MAP_BINARY",
"MAP_TERNARY",
};

static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -3527,7 +3528,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)",
};

static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");

static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
Expand Down Expand Up @@ -4034,6 +4035,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
/*.grad =*/ NULL,
/*.src0 =*/ NULL,
/*.src1 =*/ NULL,
/*.src2 =*/ NULL,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make it use opt instead of src2?

The problem with src2 is that there is logic for constructing computation graphs that currently considers src0, src1 and opt. So it would need to be updated if we introduce src2, but I want to do that at a later stage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I will change that

/*.opt =*/ { NULL },
/*.n_tasks =*/ 0,
/*.perf_runs =*/ 0,
Expand Down Expand Up @@ -6421,6 +6423,56 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
}

// ggml_map_ternary

struct ggml_tensor * ggml_map_ternary_impl_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun,
bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(ggml_are_same_shape(b, c));

bool is_node = false;

if (!inplace && (a->grad || b->grad || c->grad)) {
is_node = true;
}

struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

result->op = GGML_OP_MAP_TERNARY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
result->src2 = c;
result->opt[0] = addr_tensor;

return result;
}

struct ggml_tensor * ggml_map_ternary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun) {
return ggml_map_ternary_impl_f32(ctx, a, b, c, fun, false);
}

struct ggml_tensor * ggml_map_ternary_inplace_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun) {
return ggml_map_ternary_impl_f32(ctx, a, b, c, fun, true);
}

////////////////////////////////////////////////////////////////////////////////

void ggml_set_param(
Expand Down Expand Up @@ -12628,6 +12680,59 @@ static void ggml_compute_forward_map_binary(
}
}

// ggml_compute_forward_map_ternary

static void ggml_compute_forward_map_ternary_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst,
const ggml_ternary_op_f32_t fun) {
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src1, src2) && ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

const int n = ggml_nrows(src0);
const int nc = src0->ne[0];

assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
assert(src1->nb[0] == sizeof(float));
assert(src2->nb[0] == sizeof(float));

for (int i = 0; i < n; i++) {
fun(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])),
(float *) ((char *) src1->data + i*(src1->nb[1])),
(float *) ((char *) src2->data + i*(src2->nb[1])));
}
}


static void ggml_compute_forward_map_ternary(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst,
const ggml_ternary_op_f32_t fun) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_map_ternary_f32(params, src0, src1, src2, dst, fun);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}

/////////////////////////////////

static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
Expand Down Expand Up @@ -12837,6 +12942,12 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
}
break;
case GGML_OP_MAP_TERNARY:
{
const ggml_ternary_op_f32_t fun = *((ggml_ternary_op_f32_t *)tensor->opt[0]->data);
ggml_compute_forward_map_ternary(params, tensor->src0, tensor->src1, tensor->src2, tensor, fun);
}
break;
case GGML_OP_NONE:
{
// nop
Expand Down Expand Up @@ -13517,6 +13628,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break;
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_TERNARY:
{
GGML_ASSERT(false); // not supported
} break;
Expand Down Expand Up @@ -14062,6 +14174,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_TERNARY:
{
node->n_tasks = 1;
} break;
Expand Down
12 changes: 11 additions & 1 deletion ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ extern "C" {

GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY,
GGML_OP_MAP_TERNARY,

GGML_OP_COUNT,
};
Expand Down Expand Up @@ -358,6 +359,7 @@ extern "C" {
struct ggml_tensor * grad;
struct ggml_tensor * src0;
struct ggml_tensor * src1;
struct ggml_tensor * src2;
struct ggml_tensor * opt[GGML_MAX_OPT];

// thread scheduling
Expand All @@ -372,7 +374,7 @@ extern "C" {

char name[32];

char padding[16];
char padding[8];
};

// computation graph
Expand Down Expand Up @@ -931,6 +933,7 @@ extern "C" {
// Mapping operations
typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
typedef void (*ggml_ternary_op_f32_t)(const int, float *, const float *, const float *, const float *);

GGML_API struct ggml_tensor * ggml_map_unary_f32(
struct ggml_context * ctx,
Expand All @@ -943,6 +946,13 @@ extern "C" {
struct ggml_tensor * b,
ggml_binary_op_f32_t fun);

GGML_API struct ggml_tensor * ggml_map_ternary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_ternary_op_f32_t fun);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: parameter 'fun' is const-qualified in the function declaration; const-qualification of parameters only has an effect in function definitions [readability-avoid-const-params-in-decls]

Suggested change
const ggml_ternary_op_f32_t fun);
ggml_ternary_op_f32_t fun);


//
// automatic differentiation
//
Expand Down