Skip to content

comm/cid: use ibcast to distribute result in intercomm case #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 7, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 17 additions & 35 deletions ompi/communicator/comm_cid.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t {
ompi_comm_cid_context_t *cid_context;
int *tmpbuf;

/* for intercomm allreduce */
int *rcounts;
int *rdisps;

/* for group allreduce */
int peers_comm[3];
};
Expand All @@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t
static void ompi_comm_allreduce_context_destruct (ompi_comm_allreduce_context_t *context)
{
free (context->tmpbuf);
free (context->rcounts);
free (context->rdisps);
}

OBJ_CLASS_INSTANCE (ompi_comm_allreduce_context_t, opal_object_t,
Expand Down Expand Up @@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str
/* Non-blocking version of ompi_comm_allreduce_inter */
static int ompi_comm_allreduce_inter_leader_exchange (ompi_comm_request_t *request);
static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request);
static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request);
static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request);

static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
int count, struct ompi_op_t *op,
Expand Down Expand Up @@ -636,18 +630,19 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
rsize = ompi_comm_remote_size (intercomm);
local_rank = ompi_comm_rank (intercomm);

context->tmpbuf = (int *) calloc (count, sizeof(int));
context->rdisps = (int *) calloc (rsize, sizeof(int));
context->rcounts = (int *) calloc (rsize, sizeof(int));
if (OPAL_UNLIKELY (NULL == context->tmpbuf || NULL == context->rdisps || NULL == context->rcounts)) {
ompi_comm_request_return (request);
return OMPI_ERR_OUT_OF_RESOURCE;
if (0 == local_rank) {
context->tmpbuf = (int *) calloc (count, sizeof(int));
if (OPAL_UNLIKELY (NULL == context->tmpbuf)) {
ompi_comm_request_return (request);
return OMPI_ERR_OUT_OF_RESOURCE;
}
}

/* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group
* and vise-versa. */
rc = intercomm->c_coll.coll_iallreduce (inbuf, context->tmpbuf, count, MPI_INT, op, intercomm,
&subreq, intercomm->c_coll.coll_iallreduce_module);
rc = intercomm->c_local_comm->c_coll.coll_ireduce (inbuf, context->tmpbuf, count, MPI_INT, op, 0,
intercomm->c_local_comm, &subreq,
intercomm->c_local_comm->c_coll.coll_ireduce_module);
if (OPAL_UNLIKELY(OMPI_SUCCESS != rc)) {
ompi_comm_request_return (request);
return rc;
Expand All @@ -656,7 +651,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
if (0 == local_rank) {
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_leader_exchange, &subreq, 1);
} else {
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_allgather, &subreq, 1);
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_bcast, &subreq, 1);
}

ompi_comm_request_start (request);
Expand Down Expand Up @@ -696,33 +691,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request

ompi_op_reduce (context->op, context->tmpbuf, context->outbuf, context->count, MPI_INT);

return ompi_comm_allreduce_inter_allgather (request);
return ompi_comm_allreduce_inter_bcast (request);
}


static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request)
static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request)
{
ompi_comm_allreduce_context_t *context = (ompi_comm_allreduce_context_t *) request->context;
ompi_communicator_t *intercomm = context->cid_context->comm;
ompi_communicator_t *comm = context->cid_context->comm->c_local_comm;
ompi_request_t *subreq;
int scount = 0, rc;

/* distribute the overall result to all processes in the other group.
Instead of using bcast, we are using here allgatherv, to avoid the
possible deadlock. Else, we need an algorithm to determine,
which group sends first in the inter-bcast and which receives
the result first.
*/

if (0 != ompi_comm_rank (intercomm)) {
context->rcounts[0] = context->count;
} else {
scount = context->count;
}

rc = intercomm->c_coll.coll_iallgatherv (context->outbuf, scount, MPI_INT, context->outbuf,
context->rcounts, context->rdisps, MPI_INT, intercomm,
&subreq, intercomm->c_coll.coll_iallgatherv_module);
/* both roots have the same result. broadcast to the local group */
rc = comm->c_coll.coll_ibcast (context->outbuf, context->count, MPI_INT, 0, comm,
&subreq, comm->c_coll.coll_ibcast_module);
if (OMPI_SUCCESS != rc) {
return rc;
}
Expand Down