Skip to content

ggml-cuda : use graph allocator #2684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch);
} else if (arg == "--keep") {
if (++i >= argc) {
invalid_param = true;
Expand Down
75 changes: 56 additions & 19 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3886,13 +3886,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale) {
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (col >= ncols) {
return;
}

const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col;

const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
Expand Down Expand Up @@ -3941,8 +3941,8 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
}

static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int col = blockDim.y*blockIdx.y + threadIdx.y;
const int row = blockDim.x*blockIdx.x + threadIdx.x;

if (col >= ncols) {
return;
Expand All @@ -3958,9 +3958,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
// values are also not normalized to the maximum value by subtracting it in the exponential function
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int block_size = blockDim.x;
const int tid = threadIdx.x;
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int block_size = blockDim.y;
const int tid = threadIdx.y;

float tmp = 0.0;

Expand Down Expand Up @@ -4752,9 +4752,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 2 == 0);
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
const dim3 block_nums(nrows, num_blocks_x, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
}

Expand All @@ -4767,15 +4767,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
}

static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
const dim3 block_nums(block_num_x, nrows_x, 1);
const dim3 block_nums(nrows_x, block_num_x, 1);
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}

static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(1, nrows_x, 1);
const dim3 block_dims(1, WARP_SIZE, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
}

Expand Down Expand Up @@ -6240,7 +6240,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
return extra;
}

void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
if (scratch && g_scratch_size == 0) {
return;
}
Expand All @@ -6249,14 +6249,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src[0]->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
}
}
if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
}

tensor->backend = GGML_BACKEND_GPU;

if (scratch && no_alloc) {
return;
}

struct ggml_tensor_extra_gpu * extra;

const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
Expand Down Expand Up @@ -6308,16 +6313,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
tensor->extra = extra;
}

void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
if (g_scratch_size == 0) {
return;
}
if (g_scratch_buffer == nullptr) {
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
}

struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();

const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->op == GGML_OP_VIEW;

if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
size_t view_offset = 0;
if (tensor->op == GGML_OP_VIEW) {
memcpy(&view_offset, tensor->op_params, sizeof(size_t));
}
extra->data_device[g_main_device] = src0_ddc + view_offset;
} else {
extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
}

tensor->extra = extra;
}

void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
ggml_cuda_assign_buffers_impl(tensor, true, false);
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
}

void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
ggml_cuda_assign_buffers_impl(tensor, true, false, true);
}

void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
ggml_cuda_assign_buffers_impl(tensor, false, false);
ggml_cuda_assign_buffers_impl(tensor, false, false, false);
}

void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
ggml_cuda_assign_buffers_impl(tensor, false, true);
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
}

void ggml_cuda_set_main_device(int main_device) {
Expand Down
5 changes: 5 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);

GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);

GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);

GGML_API void ggml_cuda_set_main_device(int main_device);
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
Expand Down
Loading