Skip to content

Commit

Permalink
add threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
MjieYu committed Mar 21, 2024
1 parent f90a603 commit 4df1f4b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
39 changes: 35 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0);
}

static __global__ void threshold_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
dst[i] = x[i]> 0? 1: 0;
}

static __global__ void sqr_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

Expand Down Expand Up @@ -5196,6 +5205,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}

static void threshold_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
threshold_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}

static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
Expand Down Expand Up @@ -6756,6 +6770,21 @@ inline void ggml_cuda_op_relu(
(void) src1_dd;
}

inline void ggml_cuda_op_threshold(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

threshold_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);

(void) src1;
(void) dst;
(void) src1_dd;
}


inline void ggml_cuda_op_sqr(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
Expand Down Expand Up @@ -8275,6 +8304,10 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
}
static void ggml_cuda_threshold(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
// return ;
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_threshold);
}

static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
Expand Down Expand Up @@ -8541,9 +8574,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}
void ggml_cuda_threshold(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
printf("%s \n",__func__);
}



void ggml_cuda_create_by_rdma(const ggml_tensor * src, const ggml_tensor * gpu_bucket, ggml_tensor * dst)
Expand Down Expand Up @@ -9332,7 +9363,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_

ggml_cuda_func_t func;
const bool src0_on_device = tensor->src[0] != nullptr && (tensor->src[0]->backend != GGML_BACKEND_CPU);
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU || src0_on_device
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU || src0_on_device //至少有一个在gpu上
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);

// when src0 (weights) is not on device, we compute on CPU with sparsity
Expand Down
6 changes: 4 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -15311,7 +15311,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_THRESHOLD:
{ printf("GGML_OP_TEST\n");
{ printf("GGML_OP_THRESHOLD\n");
return ;
ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ADD1:
Expand Down Expand Up @@ -15400,7 +15401,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;

case GGML_OP_MUL_MAT_SPARSE:
{
{

GGML_ASSERT(tensor->src[2] != NULL && "sparsity index is required for MUL_MAT_SPARSE");

// MUL_MAT_SPARSE is the first operation in the FFN block, and
Expand Down
5 changes: 2 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4719,12 +4719,11 @@ static struct ggml_tensor * llm_build_ffn_sparse_rdma(
idx = ggml_relu(ctx, idx);
cb(idx, "mlp_pre_relu");
idx = ggml_mul_mat_pre_w2(ctx, pre_w2, idx,up,gpu_bucket,up_gpu,il,layer);
(full_gpu ? cb : cb_outer)(idx, "mlp_pre_out");

ggml_tensor *idx_threshold = ggml_threshold(ctx, idx);
cb(idx_threshold, "idx_threshold");
// printf("idx: %d\n", idx->backend);
// If the FFN layer is not fully offloaded, we need to transfer the sparsity index
// back to the CPU to avoid synchronization issues.
(full_gpu ? cb : cb_outer)(idx, "mlp_pre_out");
// printf_dim(up_gpu, "up_gpu");
// printf("il: %d\n",il);
struct ggml_tensor * tensor = ((struct ggml_tensor **) rdma_idx_vec)[il];
Expand Down

0 comments on commit 4df1f4b

Please sign in to comment.