Skip to content

Commit

Permalink
MPIX_Continue: the basic implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
JiakunYan committed Oct 16, 2024
1 parent 5517088 commit 1f09a56
Show file tree
Hide file tree
Showing 20 changed files with 712 additions and 15 deletions.
9 changes: 5 additions & 4 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def process_func_parameters(func):
if p['length']:
length = p['length']
if length == '*':
if RE.match(r'MPI_(Test|Wait|Request_get_status_)all', func_name, re.IGNORECASE):
if RE.match(r'MPI_(Test|Wait|Request_get_status_|Continue)all', func_name, re.IGNORECASE):
length = "count"
elif RE.match(r'MPI_(Test|Wait|Request_get_status_)some', func_name, re.IGNORECASE):
length = "incount"
Expand All @@ -595,7 +595,7 @@ def process_func_parameters(func):
if kind == "REQUEST":
if RE.match(r'mpi_startall', func_name, re.IGNORECASE):
do_handle_ptr = 3
elif RE.match(r'mpix?_(wait|test|request_get_status)', func_name, re.IGNORECASE):
elif RE.match(r'mpix?_(wait|test|request_get_status|continue)', func_name, re.IGNORECASE):
do_handle_ptr = 3
elif kind == "RANK":
validation_list.append({'kind': "RANK-ARRAY", 'name': name})
Expand Down Expand Up @@ -652,6 +652,8 @@ def process_func_parameters(func):
p['can_be_null'] = "MPI_INFO_NULL"
elif kind == "REQUEST" and RE.match(r'mpix?_(wait|test|request_get_status|parrived)', func_name, re.IGNORECASE):
p['can_be_null'] = "MPI_REQUEST_NULL"
elif kind == "REQUEST" and RE.match(r'mpix_(continue|continueall)', func_name, re.IGNORECASE) and name == "cont_request":
p['can_be_null'] = "MPI_REQUEST_NULL"
elif kind == "STREAM" and RE.match(r'mpix?_(stream_(comm_create|progress)|async_(start|spawn))', func_name, re.IGNORECASE):
p['can_be_null'] = "MPIX_STREAM_NULL"
elif kind == "COMMUNICATOR" and RE.match(r'mpi_comm_get_name', func_name, re.IGNORECASE):
Expand Down Expand Up @@ -740,7 +742,6 @@ def process_func_parameters(func):
validation_list.append({'kind': "ARGNULL", 'name': name})
else:
print("Missing error checking: func=%s, name=%s, kind=%s" % (func_name, name, kind), file=sys.stderr)

if do_handle_ptr == 1:
if p['param_direction'] == 'inout':
# assume only one such parameter
Expand All @@ -762,7 +763,7 @@ def process_func_parameters(func):
if kind == "REQUEST":
ptrs_name = "request_ptrs"
p['_ptrs_name'] = ptrs_name
if RE.match(r'mpi_startall', func['name'], re.IGNORECASE):
if RE.match(r'mpix?_(start|continue)all', func['name'], re.IGNORECASE):
impl_arg_list.append(ptrs_name)
impl_param_list.append("MPIR_Request **%s" % ptrs_name)
else:
Expand Down
46 changes: 46 additions & 0 deletions src/binding/c/continue_api.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# vim: set ft=c:

MPIX_Continue_cb_function:
.return: NOTHING
error_code: ERROR_CODE
user_data: BUFFER

MPIX_Continue_init:
.desc: Creates a new continuation request
flags: ARRAY_LENGTH [flags]
max_poll: ARRAY_LENGTH_NNI [maximum number of continuations to execute when
testing, or 0 for no limit]
info: INFO, [info argument]
cont_req: REQUEST, direction=out, [continuation request created]

MPIX_Continue:
.desc: Attach a continuation to the operation represented by the request
op_request and register it with the continuation request cont_request
op_request: REQUEST, direction=inout, [the request associated with the active operation]
cb: FUNCTION, func_type=MPIX_Continue_cb_function, [callback to be invoked once the operation is complete]
cb_data: BUFFER, [pointer to a user-controlled buffer]
flags: ARRAY_LENGTH, [flags controlling aspects of the continuation]
status: STATUS, direction=inout, [status object]
cont_request: REQUEST, [continuation request]

MPIX_Continueall:
.desc: Attach a continuation callback to a set of operation requests
count: ARRAY_LENGTH_NNI, [lists length]
array_of_op_requests: REQUEST, direction=inout, length=count, [array of requests]
cb: FUNCTION, func_type=MPIX_Continue_cb_function, [the continuation callback function]
cb_data: BUFFER, [the argument passed to the callback]
flags: ARRAY_LENGTH, [flags controlling aspects of the continuation]
array_of_statuses: STATUS, direction=inout, length=*, pointer=False, [array of status objects]
cont_request: REQUEST, [the continuation request]
{
mpi_errno = MPIR_Continueall_impl(count, request_ptrs, cb, cb_data, flags, array_of_statuses,
cont_request_ptr);
if (mpi_errno) {
goto fn_fail;
}
if (!(flags & MPIX_CONT_REQBUF_VOLATILE)) {
for (int i = 0; i < count; ++i) {
array_of_op_requests[i] = MPI_REQUEST_NULL;
}
}
}
2 changes: 1 addition & 1 deletion src/include/Makefile.mk
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ noinst_HEADERS += \
src/include/mpir_refcount_global.h \
src/include/mpir_refcount_vci.h \
src/include/mpir_refcount_single.h \
src/include/mpir_refcount.h \
src/include/mpir_atomic_flag.h \
src/include/mpir_assert.h \
src/include/mpir_misc_post.h \
src/include/mpir_type_defs.h \
Expand Down
12 changes: 12 additions & 0 deletions src/include/mpi.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,15 @@ enum MPIR_Combiner_enum {
#define MPIX_GPU_SUPPORT_ZE (1)
#define MPIX_GPU_SUPPORT_HIP (2)

/* Continue flags */
#define MPIX_CONT_REQBUF_VOLATILE 1<<0
#define MPIX_CONT_PERSISTENT 1<<1
#define MPIX_CONT_POLL_ONLY 1<<2
#define MPIX_CONT_DEFER_COMPLETE 1<<3
#define MPIX_CONT_INVOKE_FAILED 1<<4
#define MPIX_CONT_IMMEDIATE 1<<5
#define MPIX_CONT_FORGET 1<<6

/* feature advertisement */
#define MPIIMPL_ADVERTISES_FEATURES 1
#define MPIIMPL_HAVE_MPI_INFO 1
Expand Down Expand Up @@ -843,6 +852,9 @@ typedef int (MPI_Datarep_extent_function)(MPI_Datatype datatype, MPI_Aint *,
typedef int (MPI_Datarep_conversion_function_c)(void *, MPI_Datatype, MPI_Count,
void *, MPI_Offset, void *);

/* Typedefs for continuation callback */
typedef int (MPIX_Continue_cb_function)(int error_code, void *user_data);

/* Make the C names for the dup function mixed case.
This is required for systems that use all uppercase names for Fortran
externals. */
Expand Down
1 change: 1 addition & 0 deletions src/include/mpiimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ typedef struct MPIR_Stream MPIR_Stream;
#include "mpir_assert.h"
#include "mpir_pointers.h"
#include "mpir_refcount.h"
#include "mpir_atomic_flag.h"
#include "mpir_mem.h"
#include "mpir_info.h"
#include "mpir_errcodes.h"
Expand Down
69 changes: 69 additions & 0 deletions src/include/mpir_atomic_flag.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (C) by Argonne National Laboratory
* See COPYRIGHT in top-level directory
*/

#ifndef MPIR_ATOMIC_FLAG_H_INCLUDED
#define MPIR_ATOMIC_FLAG_H_INCLUDED

#include "mpi.h"
#include "mpichconf.h"

#if MPICH_THREAD_LEVEL == MPI_THREAD_MULTIPLE && \
MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI

typedef MPL_atomic_int_t MPIR_atomic_flag_t;

static inline void MPIR_atomic_flag_set(MPIR_atomic_flag_t * flag_ptr, int val)
{
MPL_atomic_relaxed_store_int(flag_ptr, val);
}

static inline int MPIR_atomic_flag_get(MPIR_atomic_flag_t * flag_ptr)
{
return MPL_atomic_relaxed_load_int(flag_ptr);
}

static inline int MPIR_atomic_flag_swap(MPIR_atomic_flag_t * flag_ptr, int val)
{
return MPL_atomic_swap_int(flag_ptr, val);
}

static inline int MPIR_atomic_flag_cas(MPIR_atomic_flag_t * flag_ptr, int old_val, int new_val)
{
return MPL_atomic_cas_int(flag_ptr, old_val, new_val);
}

#else

typedef int MPIR_atomic_flag_t;

static inline void MPIR_atomic_flag_set(MPIR_atomic_flag_t * flag_ptr, int val)
{
*flag_ptr = val;
}

static inline int MPIR_atomic_flag_get(MPIR_atomic_flag_t * flag_ptr)
{
return *flag_ptr;
}

static inline int MPIR_atomic_flag_swap(MPIR_atomic_flag_t * flag_ptr, int val)
{
int ret = *flag_ptr;
*flag_ptr = val;
return ret;
}

static inline int MPIR_atomic_flag_cas(MPIR_atomic_flag_t * flag_ptr, int old_val, int new_val)
{
int ret = *flag_ptr;
if (*flag_ptr == old_val) {
*flag_ptr = new_val;
}
return ret;
}

#endif

#endif /* MPIR_ATOMIC_FLAG_H_INCLUDED */
8 changes: 5 additions & 3 deletions src/include/mpir_err.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,7 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in
}

#define MPIR_ERRTEST_STARTREQ(reqp,err) \
if ((reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_SEND && (reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_RECV \
&& (reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_COLL \
&& (reqp)->kind != MPIR_REQUEST_KIND__PART_SEND && (reqp)->kind != MPIR_REQUEST_KIND__PART_RECV) { \
if (!MPIR_Request_is_persistent(reqp)) { \
err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \
MPI_ERR_REQUEST, "**requestinvalidstart", 0); \
goto fn_fail; \
Expand All @@ -394,6 +392,10 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in
err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \
MPI_ERR_REQUEST, "**requestpartactive", 0); \
goto fn_fail; \
} else if (((reqp)->kind == MPIR_REQUEST_KIND__CONTINUE) && MPIR_Cont_request_is_active(reqp)) { \
err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \
MPI_ERR_REQUEST, "**requestpartactive", 0); \
goto fn_fail; \
}

#define MPIR_ERRTEST_PREADYREQ(reqp,err) \
Expand Down
Loading

0 comments on commit 1f09a56

Please sign in to comment.