Skip to content

Commit

Permalink
fix the progress semantics for continuation request polling
Browse files Browse the repository at this point in the history
  • Loading branch information
JiakunYan committed Oct 16, 2024
1 parent 1f09a56 commit ec921dc
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 23 deletions.
3 changes: 2 additions & 1 deletion src/include/mpir_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ struct MPIR_Request {
struct MPIR_Continue *head, *tail;
MPID_Thread_mutex_t lock;
} ready_poll_only_cont_list;
MPID_Progress_state_cnt *state;
} cont;
/* Reserve space for local usages. For example, threadcomm, the actual struct
* is defined locally and is used via casting */
Expand Down Expand Up @@ -731,7 +732,7 @@ MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_init(MPIR_Request * req)
MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_free(MPIR_Request * req)
{
int err;
MPID_Thread_mutex_destroy(&req->cbs_lock, &err)
MPID_Thread_mutex_destroy(&req->cbs_lock, &err);
MPIR_Assert(!err);
// free all the persistent callbacks
while (req->cbs.head) {
Expand Down
33 changes: 24 additions & 9 deletions src/mpi/continue/continue_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ __thread struct {
struct MPIR_Continue *head, *tail;
} tls_deferred_cont_list = {NULL, NULL};

void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete);
void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete, bool in_request_callback);
void MPIR_Continue_callback(MPIR_Request *op_request, bool in_cs, void *cb_context);
void attach_continue_context(MPIR_Continue_context *context_ptr, bool defer_complete);

Expand Down Expand Up @@ -80,12 +80,17 @@ int MPIR_Continue_init_impl(int flags, int max_poll,
cont_req->u.cont.ready_poll_only_cont_list.tail = NULL;
cont_req->u.cont.is_pool_only = flags & MPIX_CONT_POLL_ONLY;
cont_req->u.cont.max_poll = max_poll;
cont_req->u.cont.state = (MPID_Progress_state_cnt *) MPL_malloc(sizeof(MPID_Progress_state_cnt), MPL_MEM_OTHER);
for (int i = 0; i < MPIDI_CH4_MAX_VCIS; ++i) {
MPL_atomic_release_store_uint64(&cont_req->u.cont.state->vci_refcount[i].val, 0);
}
*cont_req_ptr = cont_req;
return MPI_SUCCESS;
}

void MPIR_Continue_destroy_impl(MPIR_Request *cont_req)
{
MPL_free(cont_req->u.cont.state);
MPIR_Assert(cont_req->kind == MPIR_REQUEST_KIND__CONTINUE);
{
int err;
Expand Down Expand Up @@ -122,10 +127,15 @@ int MPIR_Continue_start(MPIR_Request * cont_request_ptr)
}

void attach_continue_context(MPIR_Continue_context *context_ptr, bool defer_complete) {
/* record the corresponding VCI for the continuation request to progress */
if (context_ptr->continue_ptr->cont_req) {
int vci = MPIDI_Request_get_vci(context_ptr->op_request);
MPL_atomic_fetch_add_int(&context_ptr->continue_ptr->cont_req->u.cont.state->vci_refcount[vci].val, 1);
}
/* Attach the continue context to the op request */
if (!MPIR_Register_callback(context_ptr->op_request, MPIR_Continue_callback, context_ptr, false)) {
/* the request has already been completed. */
complete_op_request(context_ptr->op_request, false, context_ptr, defer_complete);
complete_op_request(context_ptr->op_request, false, context_ptr, defer_complete, false);
}
}

Expand Down Expand Up @@ -207,23 +217,28 @@ void execute_continue(MPIR_Continue *continue_ptr, bool in_cs, int which_cs)
continue_ptr->cb(MPI_SUCCESS, continue_ptr->cb_data);
MPL_free(continue_ptr);
/* Signal the continuation request */
/* TODO: Find a suitable request complete function */
/* TODO: Find a suitable request complete function for continuation requests */
if (cont_req_ptr) {
int incomplete;
MPIR_cc_decr(cont_req_ptr->cc_ptr, &incomplete);
if (!incomplete) {
/* All the continue callbacks associated with this continuation request have completed */
/* TODO: reason about how to invoke the callback for continuation request */
/* TODO: reason about the safety of invoking the callback for continuation request here*/
// MPIR_Invoke_callback(cont_req_ptr, false);
MPIR_Request_free_with_safety(cont_req_ptr, !(in_cs && MPIR_REQUEST_POOL(cont_req_ptr) == which_cs));
MPIR_Request_free_with_safety(cont_req_ptr, !(in_cs && MPIR_REQUEST_POOL(cont_req_ptr) == which_cs), NULL);
}
}
}

void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete)
void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete, bool in_request_callback)
{
MPIR_Continue_context *context_ptr = (MPIR_Continue_context *) cb_context;
MPIR_Continue *continue_ptr = context_ptr->continue_ptr;
/* Decrease the continuation request VCI counter */
MPIR_Request *cont_req_ptr = continue_ptr->cont_req;
if (cont_req_ptr) {
int vci = MPIDI_Request_get_vci(op_request);
MPL_atomic_fetch_sub_int(&cont_req_ptr->u.cont.state->vci_refcount[vci].val, 1);
}
/* Complete this operation request */
/* FIXME: MPIR_Request_completion_processing can call MPIR_Request_free,
* which might lead to deadlock */
Expand All @@ -232,7 +247,7 @@ void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context,
if (context_ptr->status_ptr != MPI_STATUS_IGNORE)
context_ptr->status_ptr->MPI_ERROR = rc;
if (!MPIR_Request_is_persistent(op_request)) {
MPIR_Request_free_with_safety(op_request, !in_cs);
MPIR_Request_free_with_safety(op_request, !in_cs, NULL);
}
MPL_free(context_ptr);
/* Signal the continue callback */
Expand Down Expand Up @@ -271,7 +286,7 @@ void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context,

void MPIR_Continue_callback(MPIR_Request *op_request, bool in_cs, void *cb_context)
{
complete_op_request(op_request, in_cs, cb_context, false);
complete_op_request(op_request, in_cs, cb_context, false, true);
}

int MPIR_Continue_progress_tls()
Expand Down
9 changes: 9 additions & 0 deletions src/mpid/ch4/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ typedef struct {
uint8_t vci[MPIDI_CH4_MAX_VCIS]; /* list of vcis that need progress */
} MPID_Progress_state;

typedef struct {
MPL_atomic_int64_t val;
char padding[56];
} MPL_padded_atomic_int64_t;

typedef struct {
MPL_padded_atomic_int64_t vci_refcount[MPIDI_CH4_MAX_VCIS]; /* list of vcis that need progress */
} MPID_Progress_state_cnt;

typedef enum {
MPIDI_PTYPE_RECV,
MPIDI_PTYPE_SEND,
Expand Down
51 changes: 38 additions & 13 deletions src/mpid/ch4/src/ch4_wait.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,52 @@

#include "ch4_impl.h"

MPL_STATIC_INLINE_PREFIX void MPIDI_add_vci_to_state(int vci,
MPID_Progress_state * state)
{
MPIR_Assert(vci < MPIDI_CH4_MAX_VCIS);
for (int i = 0; i < state->vci_count; ++i) {
if (state->vci[i] == vci) {
return;
}
}
MPIR_Assert(state->vci_count < MPIDI_CH4_MAX_VCIS);
state->vci[state->vci_count++] = vci;
}

MPL_STATIC_INLINE_PREFIX void MPIDI_add_progress_vci_cont(MPIR_Request * req,
MPID_Progress_state * state)
{
MPIR_Assert(req->kind == MPIR_REQUEST_KIND__CONTINUE);
for (int i = 0; i < MPIDI_CH4_MAX_VCIS; ++i) {
if (MPL_atomic_relaxed_load_int64(&req->u.cont.state->vci_refcount[i].val) > 0) {
MPIDI_add_vci_to_state(i, state);
}
}
}

MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci(MPIR_Request * req,
MPID_Progress_state * state)
{
state->flag = MPIDI_PROGRESS_ALL; /* TODO: check request is_local/anysource */

int vci = MPIDI_Request_get_vci(req);
state->vci_count = 0;
if (req->kind == MPIR_REQUEST_KIND__CONTINUE) {
MPIDI_add_progress_vci_cont(req, state);
} else {
int vci = MPIDI_Request_get_vci(req);

state->vci_count = 1;
state->vci[0] = vci;
state->vci_count = 1;
state->vci[0] = vci;
}
}

MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** reqs,
MPID_Progress_state * state)
{
state->flag = MPIDI_PROGRESS_ALL; /* TODO: check request is_local/anysource */

state->vci_count = 0;
int idx = 0;
for (int i = 0; i < n; i++) {
if (!MPIR_Request_is_active(reqs[i])) {
Expand All @@ -34,16 +64,11 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** re
continue;
}

int vci = MPIDI_Request_get_vci(reqs[i]);
int found = 0;
for (int j = 0; j < idx; j++) {
if (state->vci[j] == vci) {
found = 1;
break;
}
}
if (!found) {
state->vci[idx++] = vci;
if (reqs[i]->kind == MPIR_REQUEST_KIND__CONTINUE) {
MPIDI_add_progress_vci_cont(reqs[i], state);
} else {
int vci = MPIDI_Request_get_vci(reqs[i]);
MPIDI_add_vci_to_state(vci, state);
}
}
state->vci_count = idx;
Expand Down

0 comments on commit ec921dc

Please sign in to comment.