Skip to content

Commit

Permalink
Device op: pass device to lower-level op to avoid recurring queries
Browse files Browse the repository at this point in the history
  • Loading branch information
devreal committed Jun 28, 2023
1 parent bf09316 commit 955849b
Show file tree
Hide file tree
Showing 17 changed files with 683 additions and 412 deletions.
22 changes: 11 additions & 11 deletions ompi/mca/coll/base/coll_base_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf,
if (tmpsend == sbuf) {
tmpsend = inplacebuf;
/* tmpsend = tmprecv (op) sbuf */
ompi_3buff_op_reduce_stream(op, sbuf, tmprecv, tmpsend, count, dtype, stream);
ompi_3buff_op_reduce_stream(op, sbuf, tmprecv, tmpsend, count, dtype, op_dev, stream);
} else {
/* tmpsend = tmprecv (op) tmpsend */
ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, stream);
ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream);
}
newrank = rank >> 1;
}
Expand Down Expand Up @@ -281,14 +281,14 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf,
if (tmpsend == sbuf) {
/* special case: 1st iteration takes one input from the sbuf */
/* tmprecv = sbuf (op) tmprecv */
ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, stream);
ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream);
/* send the current recv buffer, and use the tmp buffer to receive */
tmpsend = tmprecv;
tmprecv = inplacebuf;
} else if (have_next_iter || tmprecv == recvbuf) {
/* All iterations, and the last if tmprecv is the recv buffer */
/* tmprecv = tmpsend (op) tmprecv */
ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, stream);
ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream);
/* swap send and receive buffers */
tmpswap = tmprecv;
tmprecv = tmpsend;
Expand All @@ -297,26 +297,26 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf,
/* Last iteration if tmprecv is not the recv buffer, then tmpsend is */
/* Make sure we reduce into the receive buffer
* tmpsend = tmprecv (op) tmpsend */
ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, stream);
ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream);
}
} else {
if (tmpsend == sbuf) {
/* First iteration: use input from sbuf */
/* tmpsend = tmprecv (op) sbuf */
tmpsend = inplacebuf;
if (have_next_iter || tmpsend == recvbuf) {
ompi_3buff_op_reduce_stream(op, tmprecv, sbuf, tmpsend, count, dtype, stream);
ompi_3buff_op_reduce_stream(op, tmprecv, sbuf, tmpsend, count, dtype, op_dev, stream);
} else {
ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, stream);
ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream);
tmpsend = tmprecv;
}
} else if (have_next_iter || tmpsend == recvbuf) {
/* All other iterations: reduce into tmpsend for next iteration */
/* tmpsend = tmprecv (op) tmpsend */
ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, stream);
ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream);
} else {
/* Last iteration: reduce into rbuf and set tmpsend to rbuf (needed at the end) */
ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, stream);
ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream);
tmpsend = tmprecv;
}
}
Expand Down Expand Up @@ -1253,11 +1253,11 @@ int ompi_coll_base_allreduce_intra_redscat_allgather(
/* Reduce on the right half of the buffers (result in rbuf) */
if (MPI_IN_PLACE != sbuf) {
/* rbuf = sbuf (op) tmp_buf */
ompi_3buff_op_reduce_stream(op, sbuf, tmp_buf, recvbuf, count_lhalf, dtype, stream);
ompi_3buff_op_reduce_stream(op, sbuf, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream);

} else {
/* rbuf = rbuf (op) tmp_buf */
ompi_op_reduce_stream(op, tmp_buf, recvbuf, count_lhalf, dtype, stream);
ompi_op_reduce_stream(op, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream);
}


Expand Down
2 changes: 1 addition & 1 deletion ompi/mca/coll/base/coll_base_reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ int ompi_coll_base_reduce_generic( const void* sendbuf, void* recvbuf, int origi

/* If this is a non-commutative operation we must copy
sendbuf to the accumbuf, in order to simplify the loops */

if (!ompi_op_is_commute(op) && MPI_IN_PLACE != sendbuf) {
ompi_datatype_copy_content_same_ddt(datatype, original_count,
(char*)accumbuf,
Expand Down
1 change: 1 addition & 0 deletions ompi/mca/op/cuda/op_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ typedef struct {
CUcontext *cu_ctx;
#endif // 0
int *cu_max_threads_per_block;
int *cu_max_blocks;
CUdevice *cu_devices;
int cu_num_devices;
} ompi_op_cuda_component_t;
Expand Down
24 changes: 12 additions & 12 deletions ompi/mca/op/cuda/op_cuda_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ ompi_op_cuda_component_t mca_op_cuda_component = {
.opc_op_query = cuda_component_op_query,
},
.cu_max_threads_per_block = NULL,
.cu_max_blocks = NULL,
.cu_devices = NULL,
.cu_num_devices = 0,
};
Expand Down Expand Up @@ -92,6 +93,8 @@ static int cuda_component_close(void)
//cuStreamDestroy(mca_op_cuda_component.cu_stream);
free(mca_op_cuda_component.cu_max_threads_per_block);
mca_op_cuda_component.cu_max_threads_per_block = NULL;
free(mca_op_cuda_component.cu_max_blocks);
mca_op_cuda_component.cu_max_blocks = NULL;
free(mca_op_cuda_component.cu_devices);
mca_op_cuda_component.cu_devices = NULL;
mca_op_cuda_component.cu_num_devices = 0;
Expand Down Expand Up @@ -127,27 +130,24 @@ cuda_component_init_query(bool enable_progress_threads,
CHECK(cuDeviceGetCount, (&num_devices));
mca_op_cuda_component.cu_num_devices = num_devices;
mca_op_cuda_component.cu_devices = (CUdevice*)malloc(num_devices*sizeof(CUdevice));
#if 0
mca_op_cuda_component.cu_ctx = (CUcontext*)malloc(num_devices*sizeof(CUcontext));
#endif // 0
mca_op_cuda_component.cu_max_threads_per_block = (int*)malloc(num_devices*sizeof(int));
mca_op_cuda_component.cu_max_blocks = (int*)malloc(num_devices*sizeof(int));
for (int i = 0; i < num_devices; ++i) {
CHECK(cuDeviceGet, (&mca_op_cuda_component.cu_devices[i], i));
#if 0
rc = cuCtxCreate(&mca_op_cuda_component.cu_ctx[i],
0, mca_op_cuda_component.cu_devices[i]);
if (CUDA_SUCCESS != rc) {
CHECK(cuDevicePrimaryCtxRetain,
(&mca_op_cuda_component.cu_ctx[i], mca_op_cuda_component.cu_devices[i]));
}
#endif // 0
rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_threads_per_block[i],
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X,
mca_op_cuda_component.cu_devices[i]);
if (CUDA_SUCCESS != rc) {
/* fall-back to value that should work on every device */
mca_op_cuda_component.cu_max_threads_per_block[i] = 512;
}
rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_blocks[i],
CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X,
mca_op_cuda_component.cu_devices[i]);
if (CUDA_SUCCESS != rc) {
/* fall-back to value that should work on every device */
mca_op_cuda_component.cu_max_blocks[i] = 512;
}
}

#if 0
Expand Down
Loading

0 comments on commit 955849b

Please sign in to comment.