diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 90607405a8ab59..1551517b54a60b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -409,6 +409,16 @@ struct ggml_tensor_extra_gpu { cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs }; +cudaError_t ggml_cuda_set_device(int device) { + static int current_device = -1; + + if (device == current_device) { + return cudaSuccess; + } + + return cudaSetDevice(device); +} + static int g_device_count = -1; static int g_main_device = 0; static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; @@ -5151,7 +5161,7 @@ void ggml_init_cublas() { } for (int64_t id = 0; id < g_device_count; ++id) { - CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(ggml_cuda_set_device(id)); // create cuda streams for (int64_t is = 0; is < MAX_STREAMS; ++is) { @@ -5795,7 +5805,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s size_t src1_asf = 0; size_t dst_asf = 0; - cudaSetDevice(g_main_device); + ggml_cuda_set_device(g_main_device); const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; if (src0_on_device) { @@ -5940,7 +5950,7 @@ static void ggml_cuda_op_mul_mat( const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; - cudaSetDevice(id); + ggml_cuda_set_device(id); const cudaStream_t stream = g_cudaStreams[id][0]; if (src0_on_device && src0_is_contiguous) { @@ -5976,7 +5986,7 @@ static void ggml_cuda_op_mul_mat( // if multiple devices are used they need to wait for the main device // here an event is recorded that signals that the main device has finished calculating the input data if (split && g_device_count > 1) { - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); } @@ -5994,7 +6004,7 @@ static void ggml_cuda_op_mul_mat( const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; const int64_t row_diff = row_high[id] - row_low[id]; - cudaSetDevice(id); + ggml_cuda_set_device(id); const cudaStream_t stream = g_cudaStreams[id][is]; // wait for main GPU data if necessary @@ -6096,7 +6106,7 @@ static void ggml_cuda_op_mul_mat( } for (int64_t id = 0; id < g_device_count; ++id) { - CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(ggml_cuda_set_device(id)); // free buffers again when done if (src0_as[id] > 0) { @@ -6118,7 +6128,7 @@ static void ggml_cuda_op_mul_mat( int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS; - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t is = 0; is < is_max; ++is) { CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is])); @@ -6127,7 +6137,7 @@ static void ggml_cuda_op_mul_mat( } if (dst->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaDeviceSynchronize()); } } @@ -6187,7 +6197,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr const int64_t ne12 = src1->ne[2]; - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; @@ -6218,7 +6228,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; @@ -6310,7 +6320,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens const int64_t nb11 = src1->nb[1]; const int64_t nb12 = src1->nb[2]; - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; @@ -6376,7 +6386,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { continue; } - cudaSetDevice(id); + ggml_cuda_set_device(id); int64_t row_low, row_high; if (backend == GGML_BACKEND_GPU) { @@ -6446,13 +6456,13 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { for (int64_t id = 0; id < g_device_count; ++id) { if (extra->data_device[id] != nullptr) { - CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(ggml_cuda_set_device(id)); CUDA_CHECK(cudaFree(extra->data_device[id])); } for (int64_t is = 0; is < MAX_STREAMS; ++is) { if (extra->events[id][is] != nullptr) { - CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(ggml_cuda_set_device(id)); CUDA_CHECK(cudaEventDestroy(extra->events[id][is])); } } @@ -6506,7 +6516,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo force_inplace; const size_t size = ggml_nbytes(tensor); - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device];