Skip to content

Conv2D: Add CPU version #14320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_CONV_2D,
GGML_OP_CONV_2D_DW,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
Expand Down Expand Up @@ -1723,6 +1724,17 @@ extern "C" {
struct ggml_tensor * b,
int stride);

GGML_API struct ggml_tensor * ggml_conv_2d_direct(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
struct ggml_tensor * b, // input data [W, H, C, N]
int s0, // stride dimension 0
int s1, // stride dimension 1
int p0, // padding dimension 0
int p1, // padding dimension 1
int d0, // dilation dimension 0
int d1); // dilation dimension 1

enum ggml_op_pool {
GGML_OP_POOL_MAX,
GGML_OP_POOL_AVG,
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_im2col_back_f32(params, tensor);
} break;
case GGML_OP_CONV_2D:
{
ggml_compute_forward_conv_2d(params, tensor);
} break;
case GGML_OP_CONV_2D_DW:
{
ggml_compute_forward_conv_2d_dw(params, tensor);
Expand Down Expand Up @@ -2203,6 +2207,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_BACK:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_CONV_TRANSPOSE_2D:
Expand Down
157 changes: 157 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6058,6 +6058,163 @@ void ggml_compute_forward_im2col_back_f32(
}
}

// ggml_compute_forward_conv_2d

static void ggml_compute_forward_conv_2d_f32(
const ggml_compute_params * params,
const ggml_tensor * kernel, // [KW, KH, IC, OC]
const ggml_tensor * src, // [W, H, C, N]
ggml_tensor * dst) { // [OW, OH, OC, N]

const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
const int32_t d1 = ggml_get_op_params_i32(dst, 5);

const int64_t OW = dst->ne[0];
const int64_t OH = dst->ne[1];
const int64_t OC = dst->ne[2];
const int64_t N = dst->ne[3];

const int64_t IW = src->ne[0];
const int64_t IH = src->ne[1];
const int64_t IC = src->ne[2];

const int64_t KW = kernel->ne[0];
const int64_t KH = kernel->ne[1];

const float * kernel_data = (const float *)kernel->data;
const float * src_data = (const float *)src->data;
float * dst_data = (float *)dst->data;

const int64_t rows_total = OH * N;
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
const int64_t row_start = params->ith * rows_per_thread;
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);

for (int64_t row = row_start; row < row_end; ++row) {
const int64_t oh = row % OH;
const int64_t n = row / OH;
const float * src_batch = src_data + n * IW * IH * IC;

for (int64_t ow = 0; ow < OW; ++ow) {
for (int64_t oc = 0; oc < OC; ++oc) {
float sum = 0.0f;
const float * kernel_channel = kernel_data + oc * KW * KH * IC;

for (int64_t kh = 0; kh < KH; ++kh) {
const int64_t ih = oh * s1 - p1 + kh * d1;
if (ih < 0 || ih >= IH) continue;

for (int64_t kw = 0; kw < KW; ++kw) {
const int64_t iw = ow * s0 - p0 + kw * d0;
if (iw < 0 || iw >= IW) continue;

#pragma omp simd
for (int64_t ic = 0; ic < IC; ++ic) {
const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
sum += (*kernel_ptr) * (*src_ptr);
}
}
}

dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
}
}
}
}

static void ggml_compute_forward_conv_2d_f16(
const ggml_compute_params * params,
const ggml_tensor * kernel, // [KW, KH, IC, OC]
const ggml_tensor * src, // [W, H, C, N]
ggml_tensor * dst) { // [OW, OH, OC, N]

const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
const int32_t d1 = ggml_get_op_params_i32(dst, 5);

const int64_t OW = dst->ne[0];
const int64_t OH = dst->ne[1];
const int64_t OC = dst->ne[2];
const int64_t N = dst->ne[3];

const int64_t IW = src->ne[0];
const int64_t IH = src->ne[1];
const int64_t IC = src->ne[2];

const int64_t KW = kernel->ne[0];
const int64_t KH = kernel->ne[1];

const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data;
const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data;
ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data;

const int64_t rows_total = OH * N;
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
const int64_t row_start = params->ith * rows_per_thread;
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);

for (int64_t row = row_start; row < row_end; ++row) {
const int64_t oh = row % OH;
const int64_t n = row / OH;
const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;

for (int64_t ow = 0; ow < OW; ++ow) {
for (int64_t oc = 0; oc < OC; ++oc) {
float sum = 0.0f;
const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
for (int64_t kh = 0; kh < KH; ++kh) {
const int64_t ih = oh * s1 - p1 + kh * d1;
if (ih < 0 || ih >= IH) continue;

for (int64_t kw = 0; kw < KW; ++kw) {
const int64_t iw = ow * s0 - p0 + kw * d0;
if (iw < 0 || iw >= IW) continue;

for (int64_t ic = 0; ic < IC; ++ic) {
const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
sum += GGML_FP16_TO_FP32(*kernel_ptr) * GGML_FP16_TO_FP32(*src_ptr);
}
}
}

dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16(sum);
}
}
}
}

void ggml_compute_forward_conv_2d(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_conv_2d_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_conv_transpose_2d

void ggml_compute_forward_conv_transpose_2d(
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
43 changes: 41 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};

static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1043,6 +1043,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"conv_transpose_1d(x)",
"im2col(x)",
"im2col_back(x)",
"conv_2d(x)",
"conv_2d_dw(x)",
"conv_transpose_2d(x)",
"pool_1d(x)",
Expand Down Expand Up @@ -1082,7 +1083,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};

static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -4131,6 +4132,44 @@ struct ggml_tensor * ggml_conv_2d_dw_direct(
return result;
}

// ggml_conv_2d_direct

struct ggml_tensor * ggml_conv_2d_direct(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
struct ggml_tensor * b, // input data [W, H, C, N]
int s0, // stride dimension 0
int s1, // stride dimension 1
int p0, // padding dimension 0
int p1, // padding dimension 1
int d0, // dilation dimension 0
int d1) {// dilation dimension 1

GGML_ASSERT(a->ne[2] == b->ne[2]);
GGML_ASSERT(a->type == b->type);

int64_t ne[4];
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
ne[2] = a->ne[3];
ne[3] = b->ne[3];

struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);

ggml_set_op_params_i32(result, 0, s0);
ggml_set_op_params_i32(result, 1, s1);
ggml_set_op_params_i32(result, 2, p0);
ggml_set_op_params_i32(result, 3, p1);
ggml_set_op_params_i32(result, 4, d0);
ggml_set_op_params_i32(result, 5, d1);

result->op = GGML_OP_CONV_2D;
result->src[0] = a;
result->src[1] = b;

return result;
}

// ggml_conv_transpose_2d_p0

static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
Expand Down
Loading