Skip to content

Commit

Permalink
Support CUDA without cuBLAS (#82)
Browse files Browse the repository at this point in the history
This change introduces tinyBLAS which lets us use GPU
modestly when cuBLAS isn't available.
  • Loading branch information
mrdomino authored Dec 12, 2023
1 parent 1f17930 commit 72e1c72
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 19 deletions.
50 changes: 39 additions & 11 deletions llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <assert.h>

#if defined(GGML_USE_HIPBLAS)

#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
Expand Down Expand Up @@ -75,11 +76,32 @@
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess

#elif defined(GGML_USE_TINYBLAS)

#include "tinyblas.cu"
#define cublasSgemm tinyblasSgemm
#define cublasGemmEx tinyblasGemmEx
#define cublasGemmBatchedEx tinyblasGemmBatchedEx
#define cublasGemmStridedBatchedEx tinyblasGemmStridedBatchedEx

#else

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#endif // defined(GGML_USE_HIPBLAS)

#endif // defined(GGML_USE_HIPBLAS) || defined(GGML_USE_TINYBLAS)

#if defined(GGML_USE_TINYBLAS)
#define CUBLAS_ENTRY() cudaStream_t _ggml_stream = nullptr
#define CUBLAS_SET_STREAM(_, stream) do _ggml_stream = (stream); while (0)
#define CUBLAS_HANDLE(_) _ggml_stream
#else
#define CUBLAS_ENTRY()
#define CUBLAS_SET_STREAM(id, stream) CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream))
#define CUBLAS_HANDLE(id) g_cublas_handles[id]
#endif

#include "ggml-cuda.h"
#include "ggml.h"
Expand Down Expand Up @@ -421,7 +443,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} \
} while (0)

#if CUDART_VERSION >= 12000
#if CUDART_VERSION >= 12000 && !defined(GGML_USE_TINYBLAS)
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
Expand All @@ -446,7 +468,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
exit(1); \
} \
} while (0)
#endif // CUDART_VERSION >= 11
#endif // CUDART_VERSION >= 11 && !defined(GGML_USE_TINYBLAS)

#if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
Expand Down Expand Up @@ -724,7 +746,9 @@ static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0;

#if !defined(GGML_USE_TINYBLAS)
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
#endif // !defined(GGML_USE_TINYBLAS)

static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -6091,9 +6115,11 @@ void ggml_init_cublas() {
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking));
}

#if !defined(GGML_USE_TINYBLAS)
// create cublas handle
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
#endif // !defined(GGML_USE_TINYBLAS)
}

// configure logging to stdout
Expand Down Expand Up @@ -6662,6 +6688,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream) {

CUBLAS_ENTRY();
GGML_ASSERT(src0_dd_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
GGML_ASSERT(dst_dd_i != nullptr);
Expand Down Expand Up @@ -6712,9 +6739,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
CUBLAS_SET_STREAM(id, stream);
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmEx(CUBLAS_HANDLE(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
Expand Down Expand Up @@ -6750,9 +6777,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
const float alpha = 1.0f;
const float beta = 0.0f;

CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
CUBLAS_SET_STREAM(id, stream);
CUBLAS_CHECK(
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasSgemm(CUBLAS_HANDLE(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10,
Expand Down Expand Up @@ -7512,6 +7539,7 @@ __global__ void k_compute_batched_ptrs(
}

static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
CUBLAS_ENTRY();
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));

Expand Down Expand Up @@ -7545,7 +7573,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const

int id;
CUDA_CHECK(cudaGetDevice(&id));
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
CUBLAS_SET_STREAM(id, main_stream);

ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
Expand Down Expand Up @@ -7587,7 +7615,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
int i02 = i12 / r2;

CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmEx(CUBLAS_HANDLE(id), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
Expand All @@ -7602,7 +7630,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmStridedBatchedEx(CUBLAS_HANDLE(id), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
Expand Down Expand Up @@ -7636,7 +7664,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmBatchedEx(CUBLAS_HANDLE(id), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
Expand Down
4 changes: 2 additions & 2 deletions llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16183,7 +16183,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
size_t cur = 0;
const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;

#if defined(GGML_USE_CUBLAS)
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_TINYBLAS)
if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
Expand Down Expand Up @@ -19607,7 +19607,7 @@ int ggml_cpu_has_wasm_simd(void) {
}

int ggml_cpu_has_blas(void) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_TINYBLAS)
return 1;
#else
return 0;
Expand Down
Loading

0 comments on commit 72e1c72

Please sign in to comment.