diff --git a/src/include/mpir_request.h b/src/include/mpir_request.h index da6713f1d3f..3af36b560e1 100644 --- a/src/include/mpir_request.h +++ b/src/include/mpir_request.h @@ -244,6 +244,7 @@ struct MPIR_Request { struct MPIR_Continue *head, *tail; MPID_Thread_mutex_t lock; } ready_poll_only_cont_list; + MPID_Progress_state_cnt *state; } cont; /* Reserve space for local usages. For example, threadcomm, the actual struct * is defined locally and is used via casting */ @@ -731,7 +732,7 @@ MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_init(MPIR_Request * req) MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_free(MPIR_Request * req) { int err; - MPID_Thread_mutex_destroy(&req->cbs_lock, &err) + MPID_Thread_mutex_destroy(&req->cbs_lock, &err); MPIR_Assert(!err); // free all the persistent callbacks while (req->cbs.head) { diff --git a/src/mpi/continue/continue_impl.c b/src/mpi/continue/continue_impl.c index 7954ae5689c..b580162af32 100644 --- a/src/mpi/continue/continue_impl.c +++ b/src/mpi/continue/continue_impl.c @@ -35,7 +35,7 @@ __thread struct { struct MPIR_Continue *head, *tail; } tls_deferred_cont_list = {NULL, NULL}; -void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete); +void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete, bool in_request_callback); void MPIR_Continue_callback(MPIR_Request *op_request, bool in_cs, void *cb_context); void attach_continue_context(MPIR_Continue_context *context_ptr, bool defer_complete); @@ -80,12 +80,17 @@ int MPIR_Continue_init_impl(int flags, int max_poll, cont_req->u.cont.ready_poll_only_cont_list.tail = NULL; cont_req->u.cont.is_pool_only = flags & MPIX_CONT_POLL_ONLY; cont_req->u.cont.max_poll = max_poll; + cont_req->u.cont.state = (MPID_Progress_state_cnt *) MPL_malloc(sizeof(MPID_Progress_state_cnt), MPL_MEM_OTHER); + for (int i = 0; i < MPIDI_CH4_MAX_VCIS; ++i) { + MPL_atomic_release_store_uint64(&cont_req->u.cont.state->vci_refcount[i].val, 0); + } *cont_req_ptr = cont_req; return MPI_SUCCESS; } void MPIR_Continue_destroy_impl(MPIR_Request *cont_req) { + MPL_free(cont_req->u.cont.state); MPIR_Assert(cont_req->kind == MPIR_REQUEST_KIND__CONTINUE); { int err; @@ -122,10 +127,15 @@ int MPIR_Continue_start(MPIR_Request * cont_request_ptr) } void attach_continue_context(MPIR_Continue_context *context_ptr, bool defer_complete) { + /* record the corresponding VCI for the continuation request to progress */ + if (context_ptr->continue_ptr->cont_req) { + int vci = MPIDI_Request_get_vci(context_ptr->op_request); + MPL_atomic_fetch_add_int(&context_ptr->continue_ptr->cont_req->u.cont.state->vci_refcount[vci].val, 1); + } /* Attach the continue context to the op request */ if (!MPIR_Register_callback(context_ptr->op_request, MPIR_Continue_callback, context_ptr, false)) { /* the request has already been completed. */ - complete_op_request(context_ptr->op_request, false, context_ptr, defer_complete); + complete_op_request(context_ptr->op_request, false, context_ptr, defer_complete, false); } } @@ -207,23 +217,28 @@ void execute_continue(MPIR_Continue *continue_ptr, bool in_cs, int which_cs) continue_ptr->cb(MPI_SUCCESS, continue_ptr->cb_data); MPL_free(continue_ptr); /* Signal the continuation request */ - /* TODO: Find a suitable request complete function */ + /* TODO: Find a suitable request complete function for continuation requests */ if (cont_req_ptr) { int incomplete; MPIR_cc_decr(cont_req_ptr->cc_ptr, &incomplete); if (!incomplete) { - /* All the continue callbacks associated with this continuation request have completed */ - /* TODO: reason about how to invoke the callback for continuation request */ + /* TODO: reason about the safety of invoking the callback for continuation request here*/ // MPIR_Invoke_callback(cont_req_ptr, false); - MPIR_Request_free_with_safety(cont_req_ptr, !(in_cs && MPIR_REQUEST_POOL(cont_req_ptr) == which_cs)); + MPIR_Request_free_with_safety(cont_req_ptr, !(in_cs && MPIR_REQUEST_POOL(cont_req_ptr) == which_cs), NULL); } } } -void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete) +void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete, bool in_request_callback) { MPIR_Continue_context *context_ptr = (MPIR_Continue_context *) cb_context; MPIR_Continue *continue_ptr = context_ptr->continue_ptr; + /* Decrease the continuation request VCI counter */ + MPIR_Request *cont_req_ptr = continue_ptr->cont_req; + if (cont_req_ptr) { + int vci = MPIDI_Request_get_vci(op_request); + MPL_atomic_fetch_sub_int(&cont_req_ptr->u.cont.state->vci_refcount[vci].val, 1); + } /* Complete this operation request */ /* FIXME: MPIR_Request_completion_processing can call MPIR_Request_free, * which might lead to deadlock */ @@ -232,7 +247,7 @@ void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, if (context_ptr->status_ptr != MPI_STATUS_IGNORE) context_ptr->status_ptr->MPI_ERROR = rc; if (!MPIR_Request_is_persistent(op_request)) { - MPIR_Request_free_with_safety(op_request, !in_cs); + MPIR_Request_free_with_safety(op_request, !in_cs, NULL); } MPL_free(context_ptr); /* Signal the continue callback */ @@ -271,7 +286,7 @@ void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, void MPIR_Continue_callback(MPIR_Request *op_request, bool in_cs, void *cb_context) { - complete_op_request(op_request, in_cs, cb_context, false); + complete_op_request(op_request, in_cs, cb_context, false, true); } int MPIR_Continue_progress_tls() diff --git a/src/mpid/ch4/include/mpidpre.h b/src/mpid/ch4/include/mpidpre.h index 0b9ec6e1df9..dca858f7cde 100644 --- a/src/mpid/ch4/include/mpidpre.h +++ b/src/mpid/ch4/include/mpidpre.h @@ -57,6 +57,15 @@ typedef struct { uint8_t vci[MPIDI_CH4_MAX_VCIS]; /* list of vcis that need progress */ } MPID_Progress_state; +typedef struct { + MPL_atomic_int64_t val; + char padding[56]; +} MPL_padded_atomic_int64_t; + +typedef struct { + MPL_padded_atomic_int64_t vci_refcount[MPIDI_CH4_MAX_VCIS]; /* list of vcis that need progress */ +} MPID_Progress_state_cnt; + typedef enum { MPIDI_PTYPE_RECV, MPIDI_PTYPE_SEND, diff --git a/src/mpid/ch4/src/ch4_wait.h b/src/mpid/ch4/src/ch4_wait.h index d6456a26c2b..36cc450d286 100644 --- a/src/mpid/ch4/src/ch4_wait.h +++ b/src/mpid/ch4/src/ch4_wait.h @@ -8,15 +8,44 @@ #include "ch4_impl.h" +MPL_STATIC_INLINE_PREFIX void MPIDI_add_vci_to_state(int vci, + MPID_Progress_state * state) +{ + MPIR_Assert(vci < MPIDI_CH4_MAX_VCIS); + for (int i = 0; i < state->vci_count; ++i) { + if (state->vci[i] == vci) { + return; + } + } + MPIR_Assert(state->vci_count < MPIDI_CH4_MAX_VCIS); + state->vci[state->vci_count++] = vci; +} + +MPL_STATIC_INLINE_PREFIX void MPIDI_add_progress_vci_cont(MPIR_Request * req, + MPID_Progress_state * state) +{ + MPIR_Assert(req->kind == MPIR_REQUEST_KIND__CONTINUE); + for (int i = 0; i < MPIDI_CH4_MAX_VCIS; ++i) { + if (MPL_atomic_relaxed_load_int64(&req->u.cont.state->vci_refcount[i].val) > 0) { + MPIDI_add_vci_to_state(i, state); + } + } +} + MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci(MPIR_Request * req, MPID_Progress_state * state) { state->flag = MPIDI_PROGRESS_ALL; /* TODO: check request is_local/anysource */ - int vci = MPIDI_Request_get_vci(req); + state->vci_count = 0; + if (req->kind == MPIR_REQUEST_KIND__CONTINUE) { + MPIDI_add_progress_vci_cont(req, state); + } else { + int vci = MPIDI_Request_get_vci(req); - state->vci_count = 1; - state->vci[0] = vci; + state->vci_count = 1; + state->vci[0] = vci; + } } MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** reqs, @@ -24,6 +53,7 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** re { state->flag = MPIDI_PROGRESS_ALL; /* TODO: check request is_local/anysource */ + state->vci_count = 0; int idx = 0; for (int i = 0; i < n; i++) { if (!MPIR_Request_is_active(reqs[i])) { @@ -34,16 +64,11 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** re continue; } - int vci = MPIDI_Request_get_vci(reqs[i]); - int found = 0; - for (int j = 0; j < idx; j++) { - if (state->vci[j] == vci) { - found = 1; - break; - } - } - if (!found) { - state->vci[idx++] = vci; + if (reqs[i]->kind == MPIR_REQUEST_KIND__CONTINUE) { + MPIDI_add_progress_vci_cont(reqs[i], state); + } else { + int vci = MPIDI_Request_get_vci(reqs[i]); + MPIDI_add_vci_to_state(vci, state); } } state->vci_count = idx;