diff --git a/ompi/mca/coll/base/coll_base_allreduce.c b/ompi/mca/coll/base/coll_base_allreduce.c index 05a2ca0d561..c6380b23866 100644 --- a/ompi/mca/coll/base/coll_base_allreduce.c +++ b/ompi/mca/coll/base/coll_base_allreduce.c @@ -141,6 +141,7 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, int ret, line, rank, size, adjsize, remote, distance; int newrank, newremote, extra_ranks; char *tmpsend = NULL, *tmprecv = NULL, *tmpswap = NULL, *inplacebuf_free = NULL, *inplacebuf; + char *recvbuf = NULL; ptrdiff_t span, gap = 0; size = ompi_comm_size(comm); @@ -158,22 +159,64 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, return MPI_SUCCESS; } - /* Allocate and initialize temporary send buffer */ + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); span = opal_datatype_span(&dtype->super, count, &gap); - inplacebuf_free = (char*) malloc(span); + inplacebuf_free = ompi_coll_base_allocate_on_device(op_dev, span, module); if (NULL == inplacebuf_free) { ret = -1; line = __LINE__; goto error_hndl; } inplacebuf = inplacebuf_free - gap; + //printf("allreduce ring count %d sbuf_dev %d rbuf_dev %d op_dev %d\n", count, sendbuf_dev, recvbuf_dev, op_dev); - if (MPI_IN_PLACE == sbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, inplacebuf, (char*)rbuf); - if (ret < 0) { line = __LINE__; goto error_hndl; } - } else { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, inplacebuf, (char*)sbuf); + opal_accelerator_stream_t *stream = NULL; + if (op_dev >= 0) { + stream = MCA_ACCELERATOR_STREAM_DEFAULT; + } + + tmpsend = (char*) sbuf; + if (op_dev != recvbuf_dev) { + /* copy data to where the op wants it to be */ + if (MPI_IN_PLACE == sbuf) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)rbuf, stream); + if (ret < 0) { line = __LINE__; goto error_hndl; } + } + /* only copy if op is on the device or we cannot access the sendbuf on the host */ + else if (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || + 0 == (sendbuf_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)sbuf, stream); + if (ret < 0) { line = __LINE__; goto error_hndl; } + } + tmpsend = (char*) inplacebuf; + } else if (MPI_IN_PLACE == sbuf) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)rbuf, stream); if (ret < 0) { line = __LINE__; goto error_hndl; } + tmpsend = (char*) inplacebuf; } - tmpsend = (char*) inplacebuf; - tmprecv = (char*) rbuf; + /* Handle MPI_IN_PLACE */ + bool use_sbuf = (MPI_IN_PLACE != sbuf); + /* allocate temporary recv buffer if the tmpbuf above is on a different device than the rbuf + * and the op is on the device or we cannot access the recv buffer on the host */ + recvbuf = rbuf; + bool free_recvbuf = false; + if (op_dev != recvbuf_dev && + (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || + 0 == (recvbuf_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, span, module); + free_recvbuf = true; + if (use_sbuf) { + /* copy from rbuf */ + ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)recvbuf, (char*)sbuf, stream); + } else { + /* copy from sbuf */ + ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)recvbuf, (char*)rbuf, stream); + } + use_sbuf = false; + } + + tmprecv = (char*) recvbuf; /* Determine nearest power of two less than or equal to size */ adjsize = opal_next_poweroftwo (size); @@ -189,6 +232,11 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, extra_ranks = size - adjsize; if (rank < (2 * extra_ranks)) { if (0 == (rank % 2)) { + /* wait for above copies to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + /* wait for tmpsend to be copied */ ret = MCA_PML_CALL(send(tmpsend, count, dtype, (rank + 1), MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -199,8 +247,14 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } - /* tmpsend = tmprecv (op) tmpsend */ - ompi_op_reduce(op, tmprecv, tmpsend, count, dtype); + if (tmpsend == sbuf) { + tmpsend = inplacebuf; + /* tmpsend = tmprecv (op) sbuf */ + ompi_3buff_op_reduce_stream(op, sbuf, tmprecv, tmpsend, count, dtype, op_dev, stream); + } else { + /* tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } newrank = rank >> 1; } } else { @@ -219,6 +273,12 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, remote = (newremote < extra_ranks)? (newremote * 2 + 1):(newremote + extra_ranks); + bool have_next_iter = ((distance << 1) < adjsize); + + /* wait for previous ops to complete to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } /* Exchange the data */ ret = ompi_coll_base_sendrecv_actual(tmpsend, count, dtype, remote, MCA_COLL_BASE_TAG_ALLREDUCE, @@ -229,14 +289,47 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, /* Apply operation */ if (rank < remote) { - /* tmprecv = tmpsend (op) tmprecv */ - ompi_op_reduce(op, tmpsend, tmprecv, count, dtype); - tmpswap = tmprecv; - tmprecv = tmpsend; - tmpsend = tmpswap; + if (tmpsend == sbuf) { + /* special case: 1st iteration takes one input from the sbuf */ + /* tmprecv = sbuf (op) tmprecv */ + ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream); + /* send the current recv buffer, and use the tmp buffer to receive */ + tmpsend = tmprecv; + tmprecv = inplacebuf; + } else if (have_next_iter || tmprecv == recvbuf) { + /* All iterations, and the last if tmprecv is the recv buffer */ + /* tmprecv = tmpsend (op) tmprecv */ + ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream); + /* swap send and receive buffers */ + tmpswap = tmprecv; + tmprecv = tmpsend; + tmpsend = tmpswap; + } else { + /* Last iteration if tmprecv is not the recv buffer, then tmpsend is */ + /* Make sure we reduce into the receive buffer + * tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } } else { - /* tmpsend = tmprecv (op) tmpsend */ - ompi_op_reduce(op, tmprecv, tmpsend, count, dtype); + if (tmpsend == sbuf) { + /* First iteration: use input from sbuf */ + /* tmpsend = tmprecv (op) sbuf */ + tmpsend = inplacebuf; + if (have_next_iter || tmpsend == recvbuf) { + ompi_3buff_op_reduce_stream(op, tmprecv, sbuf, tmpsend, count, dtype, op_dev, stream); + } else { + ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream); + tmpsend = tmprecv; + } + } else if (have_next_iter || tmpsend == recvbuf) { + /* All other iterations: reduce into tmpsend for next iteration */ + /* tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } else { + /* Last iteration: reduce into rbuf and set tmpsend to rbuf (needed at the end) */ + ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream); + tmpsend = tmprecv; + } } } @@ -253,6 +346,10 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } tmpsend = (char*)rbuf; } else { + /* wait for previous ops to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } ret = MCA_PML_CALL(send(tmpsend, count, dtype, (rank - 1), MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -262,18 +359,31 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, /* Ensure that the final result is in rbuf */ if (tmpsend != rbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, tmpsend); + /* TODO: catch this case in the 3buf selection above. Maybe already caught? */ + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)rbuf, tmpsend, stream); if (ret < 0) { line = __LINE__; goto error_hndl; } } - if (NULL != inplacebuf_free) free(inplacebuf_free); + /* wait for previous ops to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + ompi_coll_base_free_tmpbuf(inplacebuf_free, op_dev, module); + + if (free_recvbuf) { + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return MPI_SUCCESS; error_hndl: OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tRank %d Error occurred %d\n", __FILE__, line, rank, ret)); (void)line; // silence compiler warning - if (NULL != inplacebuf_free) free(inplacebuf_free); + ompi_coll_base_free_tmpbuf(inplacebuf_free, op_dev, module); + + if (op_dev != recvbuf_dev) { + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return ret; } @@ -352,6 +462,7 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, int early_segcount, late_segcount, split_rank, max_segcount; size_t typelng; char *tmpsend = NULL, *tmprecv = NULL, *inbuf[2] = {NULL, NULL}; + void *recvbuf = NULL; ptrdiff_t true_lb, true_extent, lb, extent; ptrdiff_t block_offset, max_real_segsize; ompi_request_t *reqs[2] = {MPI_REQUEST_NULL, MPI_REQUEST_NULL}; @@ -400,18 +511,36 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, max_segcount = early_segcount; max_real_segsize = true_extent + (max_segcount - 1) * extent; - - inbuf[0] = (char*)malloc(max_real_segsize); - if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); if (size > 2) { - inbuf[1] = (char*)malloc(max_real_segsize); - if (NULL == inbuf[1]) { ret = -1; line = __LINE__; goto error_hndl; } + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, 2*max_real_segsize, module); + if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } + inbuf[1] = inbuf[0] + max_real_segsize; + } else { + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); + if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } } /* Handle MPI_IN_PLACE */ - if (MPI_IN_PLACE != sbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)sbuf); - if (ret < 0) { line = __LINE__; goto error_hndl; } + bool use_sbuf = (MPI_IN_PLACE != sbuf); + /* allocate temporary recv buffer if the tmpbuf above is on a different device than the rbuf */ + recvbuf = rbuf; + if (op_dev != recvbuf_dev && + /* only copy if op is on the device or the recvbuffer cannot be accessed on the host */ + (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & recvbuf_flags))) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, typelng*count, module); + if (use_sbuf) { + /* copy from rbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)sbuf); + } else { + /* copy from sbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)rbuf); + } + use_sbuf = false; } /* Computation loop */ @@ -444,7 +573,7 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, ((ptrdiff_t)rank * (ptrdiff_t)early_segcount) : ((ptrdiff_t)rank * (ptrdiff_t)late_segcount + split_rank)); block_count = ((rank < split_rank)? early_segcount : late_segcount); - tmpsend = ((char*)rbuf) + block_offset * extent; + tmpsend = ((use_sbuf) ? ((char*)sbuf) : ((char*)recvbuf)) + block_offset * extent; ret = MCA_PML_CALL(send(tmpsend, block_count, dtype, send_to, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -471,8 +600,17 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, ((ptrdiff_t)prevblock * early_segcount) : ((ptrdiff_t)prevblock * late_segcount + split_rank)); block_count = ((prevblock < split_rank)? early_segcount : late_segcount); - tmprecv = ((char*)rbuf) + (ptrdiff_t)block_offset * extent; - ompi_op_reduce(op, inbuf[inbi ^ 0x1], tmprecv, block_count, dtype); + tmprecv = ((char*)recvbuf) + (ptrdiff_t)block_offset * extent; + if (use_sbuf) { + void *tmpsbuf = ((char*)sbuf) + (ptrdiff_t)block_offset * extent; + /* tmprecv = inbuf[inbi ^ 0x1] (op) sbuf */ + ompi_3buff_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmpsbuf, tmprecv, block_count, + dtype, op_dev, NULL); + } else { + /* tmprecv = inbuf[inbi ^ 0x1] (op) tmprecv */ + ompi_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmprecv, block_count, + dtype, op_dev, NULL); + } /* send previous block to send_to */ ret = MCA_PML_CALL(send(tmprecv, block_count, dtype, send_to, @@ -492,8 +630,8 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, ((ptrdiff_t)recv_from * early_segcount) : ((ptrdiff_t)recv_from * late_segcount + split_rank)); block_count = ((recv_from < split_rank)? early_segcount : late_segcount); - tmprecv = ((char*)rbuf) + (ptrdiff_t)block_offset * extent; - ompi_op_reduce(op, inbuf[inbi], tmprecv, block_count, dtype); + tmprecv = ((char*)recvbuf) + (ptrdiff_t)block_offset * extent; + ompi_op_reduce_stream(op, inbuf[inbi], tmprecv, block_count, dtype, op_dev, NULL); /* Distribution loop - variation of ring allgather */ send_to = (rank + 1) % size; @@ -512,8 +650,8 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, block_count = ((send_data_from < split_rank)? early_segcount : late_segcount); - tmprecv = (char*)rbuf + (ptrdiff_t)recv_block_offset * extent; - tmpsend = (char*)rbuf + (ptrdiff_t)send_block_offset * extent; + tmprecv = (char*)recvbuf + (ptrdiff_t)recv_block_offset * extent; + tmpsend = (char*)recvbuf + (ptrdiff_t)send_block_offset * extent; ret = ompi_coll_base_sendrecv(tmpsend, block_count, dtype, send_to, MCA_COLL_BASE_TAG_ALLREDUCE, @@ -521,11 +659,14 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE, rank); if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl;} - } - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + if (recvbuf != rbuf) { + /* copy to final rbuf and release temporary recvbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return MPI_SUCCESS; @@ -534,8 +675,12 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, __FILE__, line, rank, ret)); ompi_coll_base_free_reqs(reqs, 2); (void)line; // silence compiler warning - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + if (NULL != recvbuf && recvbuf != rbuf) { + /* copy to final rbuf and release temporary recvbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return ret; } @@ -688,16 +833,21 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } max_real_segsize = opal_datatype_span(&dtype->super, max_segcount, &gap); + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); /* Allocate and initialize temporary buffers */ - inbuf[0] = (char*)malloc(max_real_segsize); + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } if (size > 2) { - inbuf[1] = (char*)malloc(max_real_segsize); + inbuf[1] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); if (NULL == inbuf[1]) { ret = -1; line = __LINE__; goto error_hndl; } } /* Handle MPI_IN_PLACE */ if (MPI_IN_PLACE != sbuf) { + /* TODO: can we avoid this copy? */ ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)sbuf); if (ret < 0) { line = __LINE__; goto error_hndl; } } @@ -783,7 +933,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size ((ptrdiff_t)phase * (ptrdiff_t)early_phase_segcount) : ((ptrdiff_t)phase * (ptrdiff_t)late_phase_segcount + split_phase)); tmprecv = ((char*)rbuf) + (ptrdiff_t)(block_offset + phase_offset) * extent; - ompi_op_reduce(op, inbuf[inbi ^ 0x1], tmprecv, phase_count, dtype); + ompi_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmprecv, phase_count, + dtype, op_dev, NULL); /* send previous block to send_to */ ret = MCA_PML_CALL(send(tmprecv, phase_count, dtype, send_to, @@ -812,7 +963,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size ((ptrdiff_t)phase * (ptrdiff_t)early_phase_segcount) : ((ptrdiff_t)phase * (ptrdiff_t)late_phase_segcount + split_phase)); tmprecv = ((char*)rbuf) + (ptrdiff_t)(block_offset + phase_offset) * extent; - ompi_op_reduce(op, inbuf[inbi], tmprecv, phase_count, dtype); + ompi_op_reduce_stream(op, inbuf[inbi], tmprecv, phase_count, + dtype, op_dev, NULL); } /* Distribution loop - variation of ring allgather */ @@ -844,8 +996,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size } - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + ompi_coll_base_free_tmpbuf(inbuf[1], op_dev, module); return MPI_SUCCESS; @@ -854,8 +1006,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size __FILE__, line, rank, ret)); ompi_coll_base_free_reqs(reqs, 2); (void)line; // silence compiler warning - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + ompi_coll_base_free_tmpbuf(inbuf[1], op_dev, module); return ret; } @@ -984,7 +1136,14 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( "coll:base:allreduce_intra_redscat_allgather: rank %d/%d", rank, comm_size)); - if (!ompi_op_is_commute(op)) { + /* Find nearest power-of-two less than or equal to comm_size */ + int nsteps = opal_hibit(comm_size, comm->c_cube_dim + 1); /* ilog2(comm_size) */ + if (-1 == nsteps) { + return MPI_ERR_ARG; + } + int nprocs_pof2 = 1 << nsteps; /* flp2(comm_size) */ + + if (count < nprocs_pof2 || !ompi_op_is_commute(op)) { OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "coll:base:allreduce_intra_redscat_allgather: rank %d/%d " "count %zu switching to basic linear allreduce", @@ -993,28 +1152,32 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( op, comm, module); } - /* Find nearest power-of-two less than or equal to comm_size */ - int nsteps = opal_hibit(comm_size, comm->c_cube_dim + 1); /* ilog2(comm_size) */ - if (-1 == nsteps) { - return MPI_ERR_ARG; - } - int nprocs_pof2 = 1 << nsteps; /* flp2(comm_size) */ int err = MPI_SUCCESS; ptrdiff_t lb, extent, dsize, gap = 0; ompi_datatype_get_extent(dtype, &lb, &extent); dsize = opal_datatype_span(&dtype->super, count, &gap); + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); + /* Temporary buffer for receiving messages */ char *tmp_buf = NULL; - char *tmp_buf_raw = (char *)malloc(dsize); + char *tmp_buf_raw = ompi_coll_base_allocate_on_device(op_dev, dsize, module); if (NULL == tmp_buf_raw) return OMPI_ERR_OUT_OF_RESOURCE; tmp_buf = tmp_buf_raw - gap; - if (sbuf != MPI_IN_PLACE) { - err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, - (char *)sbuf); - if (MPI_SUCCESS != err) { goto cleanup_and_return; } + char *recvbuf = rbuf; + if (op_dev != recvbuf_dev && 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & recvbuf_flags)) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, dsize, module); + } + if (op_dev != sendbuf_dev && 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & sendbuf_flags) && sbuf != MPI_IN_PLACE) { + /* move the data into the recvbuf and set sbuf to MPI_IN_PLACE */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)sbuf); + sbuf = MPI_IN_PLACE; } /* @@ -1037,9 +1200,18 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( int vrank, step, wsize; int nprocs_rem = comm_size - nprocs_pof2; + opal_accelerator_stream_t *stream = NULL; + if (op_dev >= 0) { + stream = MCA_ACCELERATOR_STREAM_DEFAULT; + } + if (rank < 2 * nprocs_rem) { int count_lhalf = count / 2; int count_rhalf = count - count_lhalf; + const void *send_buf = sbuf; + if (MPI_IN_PLACE == sbuf) { + send_buf = recvbuf; + } if (rank % 2 != 0) { /* @@ -1047,7 +1219,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send the left half of the input vector to the left neighbor, * Recv the right half of the input vector from the left neighbor */ - err = ompi_coll_base_sendrecv(rbuf, count_lhalf, dtype, rank - 1, + err = ompi_coll_base_sendrecv((void*)send_buf, count_lhalf, dtype, rank - 1, MCA_COLL_BASE_TAG_ALLREDUCE, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank - 1, @@ -1055,12 +1227,24 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( MPI_STATUS_IGNORE, rank); if (MPI_SUCCESS != err) { goto cleanup_and_return; } - /* Reduce on the right half of the buffers (result in rbuf) */ - ompi_op_reduce(op, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, - (char *)rbuf + count_lhalf * extent, count_rhalf, dtype); + /* Reduce on the right half of the buffers (result in rbuf) + * We're not using a stream here, the reduction will make sure that the result is available upon return */ + if (MPI_IN_PLACE != sbuf) { + /* rbuf = sbuf (op) tmp_buf */ + ompi_3buff_op_reduce_stream(op, + (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, + (char *)sbuf + (ptrdiff_t)count_lhalf * extent, + (char *)recvbuf + count_lhalf * extent, + count_rhalf, dtype, op_dev, NULL); + } else { + /* rbuf = rbuf (op) tmp_buf */ + ompi_op_reduce_stream(op, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, + (char *)recvbuf + count_lhalf * extent, count_rhalf, + dtype, op_dev, NULL); + } /* Send the right half to the left neighbor */ - err = MCA_PML_CALL(send((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = MCA_PML_CALL(send((char *)recvbuf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank - 1, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -1075,7 +1259,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send the right half of the input vector to the right neighbor, * Recv the left half of the input vector from the right neighbor */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = ompi_coll_base_sendrecv((char *)send_buf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, tmp_buf, count_lhalf, dtype, rank + 1, @@ -1084,21 +1268,35 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( if (MPI_SUCCESS != err) { goto cleanup_and_return; } /* Reduce on the right half of the buffers (result in rbuf) */ - ompi_op_reduce(op, tmp_buf, rbuf, count_lhalf, dtype); + if (MPI_IN_PLACE != sbuf) { + /* rbuf = sbuf (op) tmp_buf */ + ompi_3buff_op_reduce_stream(op, sbuf, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream); + } else { + /* rbuf = rbuf (op) tmp_buf */ + ompi_op_reduce_stream(op, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream); + } + /* Recv the right half from the right neighbor */ - err = MCA_PML_CALL(recv((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = MCA_PML_CALL(recv((char *)recvbuf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (MPI_SUCCESS != err) { goto cleanup_and_return; } + /* wait for the op to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + vrank = rank / 2; } } else { /* rank >= 2 * nprocs_rem */ vrank = rank - nprocs_rem; } + /* At this point the input data has been accumulated into the rbuf */ + /* * Step 2. Reduce-scatter implemented with recursive vector halving and * recursive distance doubling. We have p' = 2^{\floor{\log_2 p}} @@ -1155,7 +1353,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( } /* Send part of data from the rbuf, recv into the tmp_buf */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)sindex[step] * extent, + err = ompi_coll_base_sendrecv((char *)recvbuf + (ptrdiff_t)sindex[step] * extent, scount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, @@ -1165,9 +1363,9 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( if (MPI_SUCCESS != err) { goto cleanup_and_return; } /* Local reduce: rbuf[] = tmp_buf[] rbuf[] */ - ompi_op_reduce(op, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, - (char *)rbuf + (ptrdiff_t)rindex[step] * extent, - rcount[step], dtype); + ompi_op_reduce_stream(op, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, + (char *)recvbuf + (ptrdiff_t)rindex[step] * extent, + rcount[step], dtype, op_dev, NULL); /* Move the current window to the received message */ if (step + 1 < nsteps) { @@ -1201,10 +1399,10 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send rcount[step] elements from rbuf[rindex[step]...] * Recv scount[step] elements to rbuf[sindex[step]...] */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)rindex[step] * extent, + err = ompi_coll_base_sendrecv((char *)recvbuf + (ptrdiff_t)rindex[step] * extent, rcount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, - (char *)rbuf + (ptrdiff_t)sindex[step] * extent, + (char *)recvbuf + (ptrdiff_t)sindex[step] * extent, scount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE, rank); @@ -1216,6 +1414,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( /* * Step 4. Send total result to excluded odd ranks. */ + bool recvbuf_need_copy = true; if (rank < 2 * nprocs_rem) { if (rank % 2 != 0) { /* Odd process -- recv result from rank - 1 */ @@ -1223,19 +1422,28 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (OMPI_SUCCESS != err) { goto cleanup_and_return; } + recvbuf_need_copy = false; } else { /* Even process -- send result to rank + 1 */ - err = MCA_PML_CALL(send(rbuf, count, dtype, rank + 1, + err = MCA_PML_CALL(send(recvbuf, count, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); if (MPI_SUCCESS != err) { goto cleanup_and_return; } } } + if (recvbuf != rbuf) { + /* copy into final rbuf */ + if (recvbuf_need_copy) { + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + } + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } + cleanup_and_return: - if (NULL != tmp_buf_raw) - free(tmp_buf_raw); + + ompi_coll_base_free_tmpbuf(tmp_buf_raw, op_dev, module); if (NULL != rindex) free(rindex); if (NULL != sindex) diff --git a/ompi/mca/coll/base/coll_base_frame.c b/ompi/mca/coll/base/coll_base_frame.c index 07b7f85cf92..36e1c428ab3 100644 --- a/ompi/mca/coll/base/coll_base_frame.c +++ b/ompi/mca/coll/base/coll_base_frame.c @@ -32,6 +32,9 @@ #include "opal/util/output.h" #include "opal/mca/base/base.h" #include "opal/mca/base/mca_base_alias.h" +#include "opal/mca/accelerator/accelerator.h" + + #include "ompi/mca/coll/coll.h" #include "ompi/mca/coll/base/base.h" #include "ompi/mca/coll/base/coll_base_functions.h" @@ -70,6 +73,8 @@ static void coll_base_comm_construct(mca_coll_base_comm_t *data) { memset ((char *) data + sizeof (data->super), 0, sizeof (*data) - sizeof (data->super)); + data->device_allocators = NULL; + data->num_device_allocators = 0; } static void @@ -108,6 +113,16 @@ coll_base_comm_destruct(mca_coll_base_comm_t *data) if (data->cached_in_order_bintree) { /* destroy in order bintree if defined */ ompi_coll_base_topo_destroy_tree (&data->cached_in_order_bintree); } + + if (NULL != data->device_allocators) { + for (int i = 0; i < data->num_device_allocators; ++i) { + if (NULL != data->device_allocators[i]) { + data->device_allocators[i]->alc_finalize(data->device_allocators[i]); + } + } + free(data->device_allocators); + data->device_allocators = NULL; + } } OBJ_CLASS_INSTANCE(mca_coll_base_comm_t, opal_object_t, diff --git a/ompi/mca/coll/base/coll_base_functions.h b/ompi/mca/coll/base/coll_base_functions.h index ae924de5d31..2303aa8665d 100644 --- a/ompi/mca/coll/base/coll_base_functions.h +++ b/ompi/mca/coll/base/coll_base_functions.h @@ -40,6 +40,8 @@ /* need to include our own topo prototypes so we can malloc data on the comm correctly */ #include "coll_base_topo.h" +#include "opal/mca/allocator/allocator.h" + /* some fixed value index vars to simplify certain operations */ typedef enum COLLTYPE { ALLGATHER = 0, /* 0 */ @@ -516,6 +518,10 @@ struct mca_coll_base_comm_t { /* in-order binary tree (root of the in-order binary tree is rank 0) */ ompi_coll_tree_t *cached_in_order_bintree; + + /* pointer to per-device memory cache */ + mca_allocator_base_module_t **device_allocators; + int num_device_allocators; }; typedef struct mca_coll_base_comm_t mca_coll_base_comm_t; OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_base_comm_t); diff --git a/ompi/mca/coll/base/coll_base_util.c b/ompi/mca/coll/base/coll_base_util.c index ae9010497d7..9dbad4e9f0f 100644 --- a/ompi/mca/coll/base/coll_base_util.c +++ b/ompi/mca/coll/base/coll_base_util.c @@ -31,6 +31,7 @@ #include "ompi/mca/pml/pml.h" #include "coll_base_util.h" #include "coll_base_functions.h" +#include "opal/mca/allocator/base/base.h" #include int ompi_coll_base_sendrecv_actual( const void* sendbuf, size_t scount, @@ -603,3 +604,58 @@ const char* mca_coll_base_colltype_to_str(int collid) } return colltype_translation_table[collid]; } + +static void* ompi_coll_base_device_allocate_cb(void *ctx, size_t *size) { + int dev_id = (intptr_t)ctx; + void *ptr = NULL; + opal_accelerator.mem_alloc(dev_id, &ptr, *size); + return ptr; +} + +static void ompi_coll_base_device_release_cb(void *ctx, void* ptr) { + int dev_id = (intptr_t)ctx; + opal_accelerator.mem_release(dev_id, ptr); +} + +void *ompi_coll_base_allocate_on_device(int device, size_t size, + mca_coll_base_module_t *module) +{ + mca_allocator_base_module_t *allocator_module; + if (device < 0) { + return malloc(size); + } + + if (module->base_data->num_device_allocators <= device) { + int num_dev; + opal_accelerator.num_devices(&num_dev); + if (num_dev < device+1) num_dev = device+1; + module->base_data->device_allocators = realloc(module->base_data->device_allocators, num_dev * sizeof(mca_allocator_base_module_t *)); + for (int i = module->base_data->num_device_allocators; i < num_dev; ++i) { + module->base_data->device_allocators[i] = NULL; + } + module->base_data->num_device_allocators = num_dev; + } + if (NULL == (allocator_module = module->base_data->device_allocators[device])) { + mca_allocator_base_component_t *allocator_component; + allocator_component = mca_allocator_component_lookup("devicebucket"); + assert(allocator_component != NULL); + allocator_module = allocator_component->allocator_init(false, ompi_coll_base_device_allocate_cb, + ompi_coll_base_device_release_cb, + (void*)(intptr_t)device); + assert(allocator_module != NULL); + module->base_data->device_allocators[device] = allocator_module; + } + return allocator_module->alc_alloc(allocator_module, size, 0); +} + +void ompi_coll_base_free_on_device(int device, void *ptr, mca_coll_base_module_t *module) +{ + mca_allocator_base_module_t *allocator_module; + if (device < 0) { + free(ptr); + } else { + assert(NULL != module->base_data->device_allocators); + allocator_module = module->base_data->device_allocators[device]; + allocator_module->alc_free(allocator_module, ptr); + } +} diff --git a/ompi/mca/coll/base/coll_base_util.h b/ompi/mca/coll/base/coll_base_util.h index 852abcedefa..dd2ecdee1c7 100644 --- a/ompi/mca/coll/base/coll_base_util.h +++ b/ompi/mca/coll/base/coll_base_util.h @@ -31,6 +31,7 @@ #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/op/op.h" #include "ompi/mca/pml/pml.h" +#include "opal/mca/accelerator/accelerator.h" BEGIN_C_DECLS @@ -200,5 +201,48 @@ int ompi_coll_base_file_peek_next_char_is(FILE *fptr, int *fileline, int expecte const char* mca_coll_base_colltype_to_str(int collid); int mca_coll_base_name_to_colltype(const char* name); +/* device/host memory allocation functions */ + + +void *ompi_coll_base_allocate_on_device(int device, size_t size, + mca_coll_base_module_t *module); + +void ompi_coll_base_free_on_device(int device, void *ptr, mca_coll_base_module_t *module); + + +static inline +void ompi_coll_base_select_device( + struct ompi_op_t *op, + const void *sendbuf, + const void *recvbuf, + size_t count, + struct ompi_datatype_t *dtype, + int *sendbuf_device, + int *recvbuf_device, + uint64_t *sendbuf_flags, + uint64_t *recvbuf_flags, + int *op_device) +{ + *recvbuf_device = MCA_ACCELERATOR_NO_DEVICE_ID; + *sendbuf_device = MCA_ACCELERATOR_NO_DEVICE_ID; + if (sendbuf != NULL && sendbuf != MPI_IN_PLACE) opal_accelerator.check_addr(sendbuf, sendbuf_device, sendbuf_flags); + if (recvbuf != NULL) opal_accelerator.check_addr(recvbuf, recvbuf_device, recvbuf_flags); + ompi_op_preferred_device(op, *recvbuf_device, *sendbuf_device, count, dtype, op_device); +} + +/** + * Frees memory allocated through ompi_coll_base_allocate_op_tmpbuf + * or ompi_coll_base_allocate_tmpbuf. + */ +static inline +void ompi_coll_base_free_tmpbuf(void *tmpbuf, int device, mca_coll_base_module_t *module) { + if (-1 == device) { + free(tmpbuf); + } else if (NULL != tmpbuf) { + ompi_coll_base_free_on_device(device, tmpbuf, module); + } +} + + END_C_DECLS #endif /* MCA_COLL_BASE_UTIL_EXPORT_H */