Skip to content

Commit

Permalink
ggml_cuda_set_device
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 11, 2023
1 parent bd79c94 commit 866b502
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,16 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
};

cudaError_t ggml_cuda_set_device(int device) {
static int current_device = -1;

if (device == current_device) {
return cudaSuccess;
}

return cudaSetDevice(device);
}

static int g_device_count = -1;
static int g_main_device = 0;
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
Expand Down Expand Up @@ -5151,7 +5161,7 @@ void ggml_init_cublas() {
}

for (int64_t id = 0; id < g_device_count; ++id) {
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(ggml_cuda_set_device(id));

// create cuda streams
for (int64_t is = 0; is < MAX_STREAMS; ++is) {
Expand Down Expand Up @@ -5795,7 +5805,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
size_t src1_asf = 0;
size_t dst_asf = 0;

cudaSetDevice(g_main_device);
ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

if (src0_on_device) {
Expand Down Expand Up @@ -5940,7 +5950,7 @@ static void ggml_cuda_op_mul_mat(
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;

cudaSetDevice(id);
ggml_cuda_set_device(id);
const cudaStream_t stream = g_cudaStreams[id][0];

if (src0_on_device && src0_is_contiguous) {
Expand Down Expand Up @@ -5976,7 +5986,7 @@ static void ggml_cuda_op_mul_mat(
// if multiple devices are used they need to wait for the main device
// here an event is recorded that signals that the main device has finished calculating the input data
if (split && g_device_count > 1) {
CUDA_CHECK(cudaSetDevice(g_main_device));
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
}

Expand All @@ -5994,7 +6004,7 @@ static void ggml_cuda_op_mul_mat(
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
const int64_t row_diff = row_high[id] - row_low[id];

cudaSetDevice(id);
ggml_cuda_set_device(id);
const cudaStream_t stream = g_cudaStreams[id][is];

// wait for main GPU data if necessary
Expand Down Expand Up @@ -6096,7 +6106,7 @@ static void ggml_cuda_op_mul_mat(
}

for (int64_t id = 0; id < g_device_count; ++id) {
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(ggml_cuda_set_device(id));

// free buffers again when done
if (src0_as[id] > 0) {
Expand All @@ -6118,7 +6128,7 @@ static void ggml_cuda_op_mul_mat(
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;

CUDA_CHECK(cudaSetDevice(g_main_device));
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
for (int64_t id = 0; id < g_device_count; ++id) {
for (int64_t is = 0; is < is_max; ++is) {
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is]));
Expand All @@ -6127,7 +6137,7 @@ static void ggml_cuda_op_mul_mat(
}

if (dst->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaSetDevice(g_main_device));
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
CUDA_CHECK(cudaDeviceSynchronize());
}
}
Expand Down Expand Up @@ -6187,7 +6197,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr

const int64_t ne12 = src1->ne[2];

CUDA_CHECK(cudaSetDevice(g_main_device));
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
Expand Down Expand Up @@ -6218,7 +6228,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];

CUDA_CHECK(cudaSetDevice(g_main_device));
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
Expand Down Expand Up @@ -6310,7 +6320,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];

CUDA_CHECK(cudaSetDevice(g_main_device));
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
Expand Down Expand Up @@ -6376,7 +6386,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
continue;
}

cudaSetDevice(id);
ggml_cuda_set_device(id);

int64_t row_low, row_high;
if (backend == GGML_BACKEND_GPU) {
Expand Down Expand Up @@ -6446,13 +6456,13 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {

for (int64_t id = 0; id < g_device_count; ++id) {
if (extra->data_device[id] != nullptr) {
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(ggml_cuda_set_device(id));
CUDA_CHECK(cudaFree(extra->data_device[id]));
}

for (int64_t is = 0; is < MAX_STREAMS; ++is) {
if (extra->events[id][is] != nullptr) {
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(ggml_cuda_set_device(id));
CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
}
}
Expand Down Expand Up @@ -6506,7 +6516,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
force_inplace;
const size_t size = ggml_nbytes(tensor);

CUDA_CHECK(cudaSetDevice(g_main_device));
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
Expand Down

0 comments on commit 866b502

Please sign in to comment.