Skip to content

Commit b1ee8f5

Browse files
authored
cuBLAS: non-contiguous tensor support (#1215)
* Cuda: non-contiguous tensor support * remove extra stuff * rename * fix error * more fixes, now OpenBLAS and CLBlast build too * now then?
1 parent 36d19a6 commit b1ee8f5

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

Diff for: ggml-cuda.cu

+28
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,31 @@ void ggml_init_cublas(void) {
302302
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
303303
}
304304
}
305+
306+
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
307+
const uint64_t ne0 = src->ne[0];
308+
const uint64_t ne1 = src->ne[1];
309+
const uint64_t nb0 = src->nb[0];
310+
const uint64_t nb1 = src->nb[1];
311+
const uint64_t nb2 = src->nb[2];
312+
const uint64_t nb3 = src->nb[3];
313+
const enum ggml_type type = src->type;
314+
const size_t ts = ggml_type_size(type);
315+
const size_t bs = ggml_blck_size(type);
316+
317+
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
318+
if (nb0 == ts && nb1 == ts*ne0/bs) {
319+
return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
320+
} else if (nb0 == ts) {
321+
return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
322+
} else {
323+
for (uint64_t i1 = 0; i1 < ne1; i1++) {
324+
const void * rx = (const void *) ((const char *) x + i1*nb1);
325+
void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
326+
// pretend the row is a matrix with cols=1
327+
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
328+
if (r != cudaSuccess) return r;
329+
}
330+
return cudaSuccess;
331+
}
332+
}

Diff for: ggml-cuda.h

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cublas_v2.h>
22
#include <cuda_runtime.h>
3+
#include "ggml.h"
34

45
#ifdef __cplusplus
56
extern "C" {
@@ -38,6 +39,8 @@ void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t st
3839
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
3940
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
4041

42+
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
43+
4144
#ifdef __cplusplus
4245
}
4346
#endif

Diff for: ggml.c

+13-11
Original file line numberDiff line numberDiff line change
@@ -7930,8 +7930,12 @@ static bool ggml_compute_forward_mul_mat_use_blas(
79307930
const int64_t ne1 = dst->ne[1];
79317931

79327932
// TODO: find the optimal values for these
7933-
if (ggml_is_contiguous(src0) &&
7934-
ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
7933+
if (
7934+
#if !defined(GGML_USE_CUBLAS)
7935+
ggml_is_contiguous(src0) &&
7936+
ggml_is_contiguous(src1) &&
7937+
#endif
7938+
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
79357939

79367940
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
79377941
return true;
@@ -8041,15 +8045,16 @@ static void ggml_compute_forward_mul_mat_f32(
80418045

80428046
for (int64_t i03 = 0; i03 < ne03; i03++) {
80438047
for (int64_t i02 = 0; i02 < ne02; i02++) {
8048+
#if !defined(GGML_USE_CUBLAS)
80448049
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
80458050
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8046-
8051+
#endif
80478052
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
80488053

80498054
#if defined(GGML_USE_CUBLAS)
80508055
// copy data to device
8051-
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8052-
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8056+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
8057+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
80538058

80548059
// compute
80558060
CUBLAS_CHECK(
@@ -8269,13 +8274,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
82698274
#endif
82708275

82718276
#if defined(GGML_USE_CUBLAS)
8272-
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
82738277
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
82748278

82758279
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
82768280

82778281
// copy data to device
8278-
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8282+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
82798283
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
82808284

82818285
// compute
@@ -8539,9 +8543,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
85398543

85408544
#if defined(GGML_USE_CUBLAS)
85418545
// copy and dequantize on device
8542-
CUDA_CHECK(
8543-
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8544-
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
8546+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream));
85458547

85468548
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
85478549
CUDA_CHECK(cudaGetLastError());
@@ -8561,7 +8563,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
85618563

85628564
#if defined(GGML_USE_CUBLAS)
85638565
// copy data to device
8564-
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8566+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
85658567

85668568
// compute
85678569
CUBLAS_CHECK(

0 commit comments

Comments
 (0)