Skip to content

Commit

Permalink
ggml : add ggml_conv_2d() (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Apr 10, 2023
1 parent 8fbc4ce commit d7a0b07
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 8 deletions.
16 changes: 10 additions & 6 deletions examples/sam/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,12 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
break;
}

int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
int64_t nelements = 1;
int64_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
int32_t ne_cur;
fin.read(reinterpret_cast<char *>(&ne_cur), sizeof(ne_cur));
ne[i] = ne_cur;
nelements *= ne[i];
}

Expand All @@ -293,13 +295,13 @@ bool sam_model_load(const std::string & fname, sam_model & model) {

if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) ne[0], (int) ne[1]);
return false;
}

if (0) {
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
fprintf(stderr, "%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
fprintf(stderr, "%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), (int) ne[0], (int) ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}

size_t bpe = 0;
Expand All @@ -318,7 +320,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) {

if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
__func__, name.data(), ggml_nbytes(tensor), (size_t) nelements*bpe);
return false;
}

Expand Down Expand Up @@ -376,6 +378,7 @@ int main(int argc, char ** argv) {

fprintf(stderr, "%s: preprocessed image (%d x %d)\n", __func__, img1.nx, img1.ny);

#if 0
{
const int n = 128;
fprintf(stderr, "%s: first %d diagonal pixels:\n", __func__, n);
Expand All @@ -384,6 +387,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: %d: %f %f %f\n", __func__, i, img1.data[3*ii + 0], img1.data[3*ii + 1], img1.data[3*ii + 2]);
}
}
#endif

int64_t t_load_us = 0;

Expand Down
7 changes: 7 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ enum ggml_op {
GGML_OP_ROPE,
GGML_OP_CONV_1D_1S,
GGML_OP_CONV_1D_2S,
GGML_OP_CONV_2D,

GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
Expand Down Expand Up @@ -635,6 +636,12 @@ struct ggml_tensor * ggml_conv_1d_2s(
struct ggml_tensor * a,
struct ggml_tensor * b);

// stride is equal to kernel size
struct ggml_tensor * ggml_conv_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);

struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
Loading

0 comments on commit d7a0b07

Please sign in to comment.