Skip to content

Commit

Permalink
added gather op
Browse files Browse the repository at this point in the history
  • Loading branch information
balisujohn committed Feb 1, 2024
1 parent 7801a58 commit 1282153
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 5 deletions.
26 changes: 23 additions & 3 deletions examples/tortoise/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,11 +826,17 @@ struct ggml_cgraph * autoregressive_graph(


ggml_tensor * next_token_logits = ggml_cont(ctx0,ggml_view_4d(ctx0, cur, 8194, 1, 4, 1, cur->nb[1], cur->nb[2], cur->nb[3], 17 * sizeof(float) * 8194 ));

next_token_logits = ggml_reshape_4d(ctx0, next_token_logits, 8194, 4, 1,1);

mel_transformer_inputs = ggml_reshape_4d(ctx0, mel_transformer_inputs, 18, 4, 1, 1);

ggml_tensor * score = ggml_gather(ctx0, next_token_logits, mel_transformer_inputs, 1);


std::cout << "didn't reach here" << std::endl;

ggml_build_forward_expand(gf, next_token_logits);
ggml_build_forward_expand(gf, score);

std::cout << "reached end graph build" << std::endl;

Expand Down Expand Up @@ -968,7 +974,7 @@ int main(int argc, char ** argv) {
for (int c = 0; c < elements ; c++)
{

if (c < 3 || c > elements-4 || c == 1024*18-1|| c == 1024*18-2|| c == 1024*18 || c == 1024*18+2 )
if (c < 3 || c > elements-4 || c == 1024*18-1|| c == 1024*18-2|| c == 1024*18 || c == 1024*18+2 || c == 17)
{

std::cout << (test_read.data()[c])<< std::endl;
Expand All @@ -988,13 +994,27 @@ int main(int argc, char ** argv) {
std::cout << ggml_fp16_to_fp32(test_read.data()[c])<< std::endl;
}
}
}
else if(test->type == GGML_TYPE_I32){
std::vector<int32_t> test_read( elements);
ggml_backend_tensor_get(test,test_read.data(), 0 ,sizeof(int32_t)* elements);
//
for (int c = 0; c < elements ; c++)
{
if (c < 3 || c > elements-4)
{

std::cout << test_read.data()[c]<< std::endl;
}
}
}




}

//ggml_graph_print (gf);
// ggml_graph_print (gf);


//std::cout << (float * )test->data << std::endl;
Expand Down
8 changes: 8 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ extern "C" {
GGML_OP_REPEAT,
GGML_OP_REPEAT_BACK,
GGML_OP_CONCAT,
GGML_OP_GATHER,
GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
Expand Down Expand Up @@ -1247,6 +1248,13 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * c);

GGML_API struct ggml_tensor * ggml_gather(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * index,
int dim
);

GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a);
Expand Down
70 changes: 70 additions & 0 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,21 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
dst[i] = x[i] + y[i%ky];
}


static __global__ void gather_f32(const float * x, const int32_t * y, float * dst, const int kx, const int ky,
const int ne00, const int ne01, const int ne10, const int ne11) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

const int rowIndex = i / ne10;
//const int colIndex = i & ne10;
const int colIndex = y[i];

if (i >= ky) {
return;
}
dst[i] = x[rowIndex * ne00 + colIndex];
}

static __global__ void concat_f32(const float * x, const float * y, float * dst, const int dst_size, const int src0_size, const int src0_dim0,const int src0_dim1,const int src0_dim2,const int src0_dim3, const int src1_dim2) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

Expand Down Expand Up @@ -4679,6 +4694,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
}

static void gather_cuda(const float * x, const int32_t * y, float * dst, const int kx, const int ky, const int ne00, const int ne01, const int ne10, const int ne11, cudaStream_t stream) {
const int num_blocks = (ky + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
gather_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky, ne00,ne01, ne10,ne11);
}


static void concat_f32_cuda(const float * x, const float * y, float * dst, const int src0_dim0,const int src0_dim1,const int src0_dim2,const int src0_dim3, const int combined_dim_2, cudaStream_t stream) {
const int dst_size = src0_dim0 * src0_dim1 * combined_dim_2 * src0_dim3;
Expand Down Expand Up @@ -6058,6 +6078,45 @@ inline void ggml_cuda_op_add(
(void) dst;
}


inline void ggml_cuda_op_gather(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

GGML_ASSERT(src1->ne[0] == dst->ne[0]);
GGML_ASSERT(src1->ne[1] == dst->ne[1]);
GGML_ASSERT(src1->ne[2] == dst->ne[2]);
GGML_ASSERT(src1->ne[3] == dst->ne[3]);


GGML_ASSERT(src1->ne[1] == src0->ne[1]); // need to have the same number of rows


GGML_ASSERT(src1->ne[2] == 1); // only set up for 2d tensors
GGML_ASSERT(src1->ne[3] == 1);


GGML_ASSERT(src0->ne[2] == 1);
GGML_ASSERT(src0->ne[3] == 1);



const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];

const int32_t * src1_i32_dd = (const int32_t *) src1_dd;


gather_cuda(src0_dd, src1_i32_dd, dst_dd, ggml_nelements(src0), ggml_nelements(src1),src0->ne[0], src0->ne[1], src1->ne[0], src1->ne[1], main_stream);

(void) src1;
(void) dst;
}

inline void ggml_cuda_op_mul(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
Expand Down Expand Up @@ -7471,11 +7530,19 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
(void) dst;
}





static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_cpy(src0, dst, nullptr);
(void) src1;
}

static void ggml_cuda_gather(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gather);
}

static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
}
Expand Down Expand Up @@ -7795,6 +7862,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_CONCAT:
func = ggml_cuda_concat;
break;
case GGML_OP_GATHER:
func = ggml_cuda_gather;
break;
case GGML_OP_GET_ROWS:
func = ggml_cuda_get_rows;
break;
Expand Down
36 changes: 34 additions & 2 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4027,6 +4027,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"REPEAT",
"REPEAT_BACK",
"CONCAT",
"GATHER",
"SILU_BACK",
"NORM",
"RMS_NORM",
Expand Down Expand Up @@ -4092,7 +4093,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};

static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand All @@ -4114,6 +4115,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"repeat(x)",
"repeat_back(x)",
"concat(x, y)",
"gather(x, y)",
"silu_back(x)",
"norm(x)",
"rms_norm(x)",
Expand Down Expand Up @@ -4179,7 +4181,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};

static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -7164,6 +7166,31 @@ struct ggml_tensor * ggml_get_rows_back(
return result;
}


struct ggml_tensor * ggml_gather(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * index,
int dim
){

bool is_node = false;

if (a->grad) {
is_node = true;
}

struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, index->ne[0], index->ne[1], index->ne[2], index->ne[3]);

result->op = GGML_OP_GATHER;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = index;

return result;

}

// ggml_diag

struct ggml_tensor * ggml_diag(
Expand Down Expand Up @@ -16948,6 +16975,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_GATHER:
{
GGML_ASSERT(false);// not implemented for CPU
} break;
case GGML_OP_SILU_BACK:
{
ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
Expand Down Expand Up @@ -18768,6 +18799,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
n_tasks = n_threads;
} break;
case GGML_OP_CONCAT:
case GGML_OP_GATHER:
case GGML_OP_MUL_MAT:
{
n_tasks = n_threads;
Expand Down

0 comments on commit 1282153

Please sign in to comment.