Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: enable peer access between devices #2470

Merged
merged 1 commit into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kern
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"llama: max. batch size for using peer access")
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
Expand Down Expand Up @@ -304,6 +306,7 @@ if (LLAMA_CUBLAS)
add_compile_definitions(GGML_CUDA_F16)
endif()
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})

if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
else
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
endif
ifdef LLAMA_CUDA_PEER_MAX_BATCH_SIZE
NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(LLAMA_CUDA_PEER_MAX_BATCH_SIZE)
else
NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128
endif # LLAMA_CUDA_PEER_MAX_BATCH_SIZE
#ifdef LLAMA_CUDA_CUBLAS
# NVCCFLAGS += -DGGML_CUDA_CUBLAS
#endif # LLAMA_CUDA_CUBLAS
Expand Down
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,14 @@ Building the program with BLAS support may lead to some performance improvements
<!---
| LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
--->
| Option | Legal values | Default | Description |
|-------------------------|------------------------|---------|-------------|
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
| Option | Legal values | Default | Description |
|--------------------------------|------------------------|---------|-------------|
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |

- #### hipBLAS

Expand Down
50 changes: 47 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
Expand Down Expand Up @@ -424,6 +427,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif

#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE

#define MUL_MAT_SRC1_COL_STRIDE 128

#define MAX_STREAMS 8
Expand Down Expand Up @@ -6258,6 +6265,41 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
}
}

void ggml_cuda_set_peer_access(const int n_tokens) {
static bool peer_access_enabled = false;

const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;

if (peer_access_enabled == enable_peer_access) {
return;
}

#ifdef NDEBUG
slaren marked this conversation as resolved.
Show resolved Hide resolved
for (int id = 0; id < g_device_count; ++id) {
CUDA_CHECK(ggml_cuda_set_device(id));

for (int id_other = 0; id_other < g_device_count; ++id_other) {
if (id == id_other) {
continue;
}
if (id != g_main_device && id_other != g_main_device) {
continue;
}

int canAccessPeer;
CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, id, id_other));
if (enable_peer_access) {
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
} else {
CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other));
}
}
}
#endif // NDEBUG

peer_access_enabled = enable_peer_access;
}

static void ggml_cuda_op_mul_mat(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
const bool convert_src1_to_q8_1) {
Expand All @@ -6282,6 +6324,8 @@ static void ggml_cuda_op_mul_mat(
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];

ggml_cuda_set_peer_access(ne11);

GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);

Expand Down Expand Up @@ -7009,7 +7053,7 @@ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
}

void ggml_cuda_set_main_device(int main_device) {
void ggml_cuda_set_main_device(const int main_device) {
if (main_device >= g_device_count) {
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
main_device, g_device_count, g_main_device);
Expand All @@ -7023,11 +7067,11 @@ void ggml_cuda_set_main_device(int main_device) {
}
}

void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
g_mul_mat_q = mul_mat_q;
}

void ggml_cuda_set_scratch_size(size_t scratch_size) {
void ggml_cuda_set_scratch_size(const size_t scratch_size) {
g_scratch_size = scratch_size;
}

Expand Down