Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU-accelerated token generation (new quantization format) #1412

Merged
merged 9 commits into from
May 13, 2023
8 changes: 8 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "--mlock") {
params.use_mlock = true;
} else if (arg == "--gpu-layers") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.gpu_layers = std::stoi(argv[i]);
} else if (arg == "--no-mmap") {
params.use_mmap = false;
} else if (arg == "--mtest") {
Expand Down Expand Up @@ -421,6 +427,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
if (llama_mmap_supported()) {
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
fprintf(stderr, " --gpu-layers number of layers to store in VRAM\n");
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
Expand Down Expand Up @@ -469,6 +476,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.gpu_layers = params.gpu_layers;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;

Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct gpt_params {
bool perplexity = false; // compute perplexity over the prompt
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
int gpu_layers = 0; // number of layers to store in VRAM
bool mem_test = false; // compute maximum memory usage
bool verbose_prompt = false; // print prompt tokens before generation
};
Expand Down
185 changes: 162 additions & 23 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} \
} while (0)

typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);

#define QK4_0 32
typedef struct {
Expand Down Expand Up @@ -73,6 +75,37 @@ typedef struct {
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");

#define CUDA_DMMV_BLOCK_SIZE 32

static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
const block_q4_0 * x = (const block_q4_0 *) vx;

const float d = x[ib].d;

const uint8_t vui = x[ib].qs[iqs];

const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;

v0 = (vi0 - 8)*d;
v1 = (vi1 - 8)*d;
}

static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
const block_q4_1 * x = (const block_q4_1 *) vx;

const float d = x[ib].d;
const float m = x[ib].m;

const uint8_t vui = x[ib].qs[iqs];

const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;

v0 = vi0*d + m;
v1 = vi1*d + m;
}

static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static const int qk = QK4_0;

Expand Down Expand Up @@ -173,6 +206,41 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}

template <int block_size, int qk, dequantize_kernel_t dequantize_kernel> static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the template<...> be on a separate line? It's getting very long otherwise.

const int row = blockIdx.x;
const int tid = threadIdx.x;

__shared__ float tmp[block_size]; // separate sum for each thread
tmp[tid] = 0;

for (int i = 0; i < ncols/block_size; i += 2) {
const int col = i*block_size + 2*tid;
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
const int ib = (row*ncols + col)/qk; // block index
const int iqs = (col%qk)/2; // quant index
const int iybs = col - col%qk; // y block start index

// dequantize
float v0, v1;
dequantize_kernel(vx, ib, iqs, v0, v1);

// matrix multiplication
tmp[tid] += v0 * y[iybs + iqs + 0];
tmp[tid] += v1 * y[iybs + iqs + qk/2];
}

// sum up partial sums and write back result
__syncthreads();
for (int s=block_size/2; s>0; s>>=1) {
JohannesGaessler marked this conversation as resolved.
Show resolved Hide resolved
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
__syncthreads();
}
if (tid == 0) {
dst[row] = tmp[0];
}
}

static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
Expand All @@ -198,6 +266,16 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}

static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, dequantize_q4_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
}

static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, dequantize_q4_1><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
}

// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
const half * x = (const half *) vx;
Expand Down Expand Up @@ -230,8 +308,19 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
}

static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_mul_mat_vec_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_mul_mat_vec_q4_1_cuda;
default:
return nullptr;
}
}

// buffer pool for cuda
#define MAX_CUDA_BUFFERS 16
#define MAX_CUDA_BUFFERS 256

struct scoped_spin_lock {
std::atomic_flag& lock;
Expand Down Expand Up @@ -538,12 +627,16 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);

size_t x_size, y_size, d_size, q_size;
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
float * d_X;
if (ne11 > 1) {
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
}
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
GGML_ASSERT(to_fp32_cuda != nullptr);

for (int64_t i03 = 0; i03 < ne03; i03++) {
Expand All @@ -553,31 +646,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];

float * c_X = d_X + i * x_ne;
float * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
char * c_Q = d_Q + i * q_sz;

// copy src0 and convert to fp32 on device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
// copy src0 to device if necessary
if (src0->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
} else if (src0->backend == GGML_BACKEND_CUDA) {
c_Q = ((char *) src0->data) + i * q_sz;
} else {
GGML_ASSERT(false);
}
if (ne11 == 1) {
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));

// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));

// wait for conversion
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
// wait for data
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));

// compute
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, ne00,
c_Y, ne10,
&beta, c_D, ne01));
// compute
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
CUDA_CHECK(cudaGetLastError());

} else {
float * c_X = d_X + i * x_ne;

// convert src0 to fp32 on device
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));

// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));

// wait for conversion
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));

// compute
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, c_X, ne00,
c_Y, ne10,
&beta, c_D, ne01));
}

// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
Expand All @@ -586,7 +702,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
}

CUDA_CHECK(cudaDeviceSynchronize());
ggml_cuda_pool_free(d_X, x_size);
if (ne11 > 1) {
ggml_cuda_pool_free(d_X, x_size);
}
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
Expand All @@ -602,8 +720,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {

((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
return true;
}

Expand Down Expand Up @@ -655,3 +772,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
return 0;
}
}

void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
const int64_t ne0 = tensor->ne[0];
const int64_t ne1 = tensor->ne[1];
const int64_t ne2 = tensor->ne[2];
const int64_t ne3 = tensor->ne[3];

const ggml_type type = tensor->type;
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);

size_t q_size;
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);

cudaStream_t cudaStream2 = g_cudaStreams2[0];

// copy tensor to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
CUDA_CHECK(cudaDeviceSynchronize());

tensor->data = d_Q;
tensor->backend = GGML_BACKEND_CUDA;
}
2 changes: 2 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);

void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);

#ifdef __cplusplus
}
#endif
1 change: 1 addition & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3702,6 +3702,7 @@ struct ggml_tensor * ggml_new_tensor_impl(

*result = (struct ggml_tensor) {
/*.type =*/ type,
/*.backend =*/ GGML_BACKEND_CPU,
/*.n_dims =*/ n_dims,
/*.ne =*/ { 1, 1, 1, 1 },
/*.nb =*/ { 0, 0, 0, 0 },
Expand Down
8 changes: 7 additions & 1 deletion ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ extern "C" {
GGML_TYPE_COUNT,
};

enum ggml_backend {
GGML_BACKEND_CPU = 0,
GGML_BACKEND_CUDA = 1,
};

// model file types
enum ggml_ftype {
GGML_FTYPE_UNKNOWN = -1,
Expand Down Expand Up @@ -322,6 +327,7 @@ extern "C" {
// n-dimensional tensor
struct ggml_tensor {
enum ggml_type type;
enum ggml_backend backend;

int n_dims;
int64_t ne[GGML_MAX_DIMS]; // number of elements
Expand Down Expand Up @@ -352,7 +358,7 @@ extern "C" {

char name[32];

char padding[8]; // TODO: remove and add padding to name?
char padding[9]; // TODO: remove and add padding to name?
};

// computation graph
Expand Down
Loading