Skip to content

Commit

Permalink
ch4/mpidig: cleanup and fix anysrc receive
Browse files Browse the repository at this point in the history
We should call MPIDI_anysrc_try_cancel_partner upon matching, and call
MPIDI_anysrc_free_partner upon completion.

* In MPIDIG_do_irecv, we only called MPIDI_anysrc_try_cancel_partner
upon matching unexpected request but only for has_request branch. We
should do that for both branches.

* MPIDI_anysrc_free_partner need copy the status when the partner
request is the user visible request. Thus we only should free partner
when rreq completes. Unfortunately, we have multiple completion
branches.
  1. Eager unexpected path
  2. recv_target_cmpl_cb
  3. MPIDIG_tag_recv_complete
  • Loading branch information
hzhou committed Oct 27, 2024
1 parent 99b49e4 commit fa1728b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 30 deletions.
15 changes: 3 additions & 12 deletions src/mpid/ch4/src/ch4_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,9 @@ MPL_STATIC_INLINE_PREFIX int anysource_irecv(void *buf, MPI_Aint count, MPI_Data
mpi_errno = MPIDI_NM_mpi_irecv(buf, count, datatype, rank, tag, comm, attr,
av, &nm_rreq, *request);
MPIR_ERR_CHECK(mpi_errno);
(*request)->dev.anysrc_partner = nm_rreq;

/* cancel the shm request if netmod/am handles the request from unexpected queue. */
if (MPIR_Request_is_complete(nm_rreq)) {
mpi_errno = MPIDI_SHM_mpi_cancel_recv(*request);
if (MPIR_STATUS_GET_CANCEL_BIT((*request)->status)) {
(*request)->status = nm_rreq->status;
}
/* nm_rreq will be freed here. User-layer will have a completed (cancelled)
* request with correct status. */
MPIDI_CH4_REQUEST_FREE(nm_rreq);
goto fn_exit;
/* if netmod recv is not matched yet, attach it to the shm request's partner */
if (nm_rreq->dev.anysrc_partner) {
(*request)->dev.anysrc_partner = nm_rreq;
}
}
fn_exit:
Expand Down
9 changes: 6 additions & 3 deletions src/mpid/ch4/src/mpidig_pt2pt_callbacks.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ int MPIDIG_tag_recv_complete(MPIR_Request * rreq, MPI_Status * status)
MPIR_STATUS_COPY_COUNT(rreq->status, *status);

MPIR_Datatype_release_if_not_builtin(MPIDIG_REQUEST(rreq, datatype));
#ifndef MPIDI_CH4_DIRECT_NETMOD
MPIDI_anysrc_free_partner(rreq);
#endif
MPID_Request_complete(rreq);

MPIR_FUNC_EXIT;
Expand Down Expand Up @@ -161,6 +164,9 @@ static int recv_target_cmpl_cb(MPIR_Request * rreq)
/* Free the unexpected request on behalf of the user */
MPIDI_CH4_REQUEST_FREE(rreq);
}
#ifndef MPIDI_CH4_DIRECT_NETMOD
MPIDI_anysrc_free_partner(rreq);
#endif
MPID_Request_complete(rreq);
fn_exit:
MPIR_FUNC_EXIT;
Expand Down Expand Up @@ -422,9 +428,6 @@ int MPIDIG_send_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
mpi_errno = MPIDIG_reply_ssend(rreq);
MPIR_ERR_CHECK(mpi_errno);
}
#ifndef MPIDI_CH4_DIRECT_NETMOD
MPIDI_anysrc_free_partner(rreq);
#endif

MPIDIG_recv_type_init(hdr->data_sz, rreq);

Expand Down
34 changes: 19 additions & 15 deletions src/mpid/ch4/src/mpidig_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_handle_unexpected(void *buf, MPI_Aint count,
MPIR_ERR_CHECK(mpi_errno);
MPIDIG_REQUEST(rreq, req->status) &= ~MPIDIG_REQ_UNEXPECTED;

#ifndef MPIDI_CH4_DIRECT_NETMOD
MPIDI_anysrc_free_partner(rreq);
#endif
MPID_Request_complete(rreq);
} else {
/* This is the path for async data copy still need to happen. The request will be completed
Expand Down Expand Up @@ -182,7 +185,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_handle_unexp_mrecv(MPIR_Request * rreq)
MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Datatype datatype,
int rank, int tag, MPIR_Comm * comm,
int context_offset, int vci, MPIR_Request ** request,
bool is_local, uint64_t flags)
bool is_local, MPIR_Request * partner, uint64_t flags)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Request *rreq = NULL, *unexp_req = NULL;
Expand Down Expand Up @@ -225,17 +228,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Data
* Record the passed `*request` to `match_req` so that we can complete it
* later when `unexp_req` completes.
* See MPIDI_recv_target_cmpl_cb for actual completion handler. */
#ifndef MPIDI_CH4_DIRECT_NETMOD
MPIR_Request *match_req = *request;
int is_cancelled;
mpi_errno = MPIDI_anysrc_try_cancel_partner(match_req, &is_cancelled);
MPIR_ERR_CHECK(mpi_errno);
/* since we will always progress shm first, when unexpected
* message match, the NM partner wouldn't have progressed yet, so the cancel
* should always succeed. */
MPIR_Assert(is_cancelled);
MPIDI_anysrc_free_partner(match_req);
#endif
MPIDIG_REQUEST(unexp_req, req->rreq.match_req) = *request;
MPIDIG_REQUEST(*request, req->remote_vci) = MPIDIG_REQUEST(unexp_req, req->remote_vci);
/* the tag and source in status are set at the time of receiving RTS, copy it from unexp_req */
Expand All @@ -244,6 +236,19 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Data
MPIDIG_REQUEST(unexp_req, req->status) |= MPIDIG_REQ_MATCHED;
MPIDIG_REQUEST(*request, req->status) |= MPIDIG_REQ_IN_PROGRESS;

#ifndef MPIDI_CH4_DIRECT_NETMOD
rreq = *request;
MPIDI_REQUEST_SET_LOCAL(rreq, is_local, partner);

int is_cancelled;
mpi_errno = MPIDI_anysrc_try_cancel_partner(rreq, &is_cancelled);
MPIR_ERR_CHECK(mpi_errno);
/* since we will always progress shm first, when unexpected
* message match, the NM partner wouldn't have progressed yet, so the cancel
* should always succeed. */
MPIR_Assert(is_cancelled);
#endif

if (MPIDIG_REQUEST(unexp_req, req->status) & MPIDIG_REQ_BUSY) {
/* Nothing to do here. MPIDIG_handle_unexpected etc. in mpidig_pt2pt_callbacks.c */
} else if (MPIDIG_REQUEST(unexp_req, req->status) & MPIDIG_REQ_RTS) {
Expand Down Expand Up @@ -279,6 +284,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Data
}

*request = rreq;
MPIDI_REQUEST_SET_LOCAL(rreq, is_local, partner);

MPIR_Datatype_add_ref_if_not_builtin(datatype);
MPIDIG_prepare_recv_req(rank, tag, context_id, buf, count, datatype, rreq);
Expand Down Expand Up @@ -351,11 +357,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_irecv(void *buf,
MPIR_FUNC_ENTER;

mpi_errno = MPIDIG_do_irecv(buf, count, datatype, rank, tag, comm, context_offset, vci,
request, is_local, 0ULL);
request, is_local, partner, 0ULL);
MPIR_ERR_CHECK(mpi_errno);

MPIDI_REQUEST_SET_LOCAL(*request, is_local, partner);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
Expand Down

0 comments on commit fa1728b

Please sign in to comment.