Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPIX_Continue: the basic implementation #7164

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def process_func_parameters(func):
if p['length']:
length = p['length']
if length == '*':
if RE.match(r'MPI_(Test|Wait|Request_get_status_)all', func_name, re.IGNORECASE):
if RE.match(r'MPIX?_(Test|Wait|Request_get_status_|Continue)all', func_name, re.IGNORECASE):
length = "count"
elif RE.match(r'MPI_(Test|Wait|Request_get_status_)some', func_name, re.IGNORECASE):
length = "incount"
Expand All @@ -595,7 +595,7 @@ def process_func_parameters(func):
if kind == "REQUEST":
if RE.match(r'mpi_startall', func_name, re.IGNORECASE):
do_handle_ptr = 3
elif RE.match(r'mpix?_(wait|test|request_get_status)', func_name, re.IGNORECASE):
elif RE.match(r'mpix?_(wait|test|request_get_status|continue)', func_name, re.IGNORECASE):
do_handle_ptr = 3
elif kind == "RANK":
validation_list.append({'kind': "RANK-ARRAY", 'name': name})
Expand Down Expand Up @@ -652,6 +652,8 @@ def process_func_parameters(func):
p['can_be_null'] = "MPI_INFO_NULL"
elif kind == "REQUEST" and RE.match(r'mpix?_(wait|test|request_get_status|parrived)', func_name, re.IGNORECASE):
p['can_be_null'] = "MPI_REQUEST_NULL"
elif kind == "REQUEST" and RE.match(r'mpix_(continue|continueall)', func_name, re.IGNORECASE) and name == "cont_request":
p['can_be_null'] = "MPI_REQUEST_NULL"
elif kind == "STREAM" and RE.match(r'mpix?_(stream_(comm_create|progress)|async_(start|spawn))', func_name, re.IGNORECASE):
p['can_be_null'] = "MPIX_STREAM_NULL"
elif kind == "COMMUNICATOR" and RE.match(r'mpi_comm_get_name', func_name, re.IGNORECASE):
Expand Down Expand Up @@ -740,7 +742,6 @@ def process_func_parameters(func):
validation_list.append({'kind': "ARGNULL", 'name': name})
else:
print("Missing error checking: func=%s, name=%s, kind=%s" % (func_name, name, kind), file=sys.stderr)

if do_handle_ptr == 1:
if p['param_direction'] == 'inout':
# assume only one such parameter
Expand All @@ -762,7 +763,7 @@ def process_func_parameters(func):
if kind == "REQUEST":
ptrs_name = "request_ptrs"
p['_ptrs_name'] = ptrs_name
if RE.match(r'mpi_startall', func['name'], re.IGNORECASE):
if RE.match(r'mpix?_(start|continue)all', func['name'], re.IGNORECASE):
impl_arg_list.append(ptrs_name)
impl_param_list.append("MPIR_Request **%s" % ptrs_name)
else:
Expand Down
2 changes: 1 addition & 1 deletion maint/local_python/binding_f08.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def process_func_parameters(func):
def check_func_directives(func):
if 'dir' in func and func['dir'] == "mpit":
func['_skip_fortran'] = 1
elif RE.match(r'mpix_(grequest_|type_iov|async_)', func['name'], re.IGNORECASE):
elif RE.match(r'mpix_(grequest_|type_iov|async_|continue)', func['name'], re.IGNORECASE):
func['_skip_fortran'] = 1
elif RE.match(r'mpi_attr_', func['name'], re.IGNORECASE):
func['_skip_fortran'] = 1
Expand Down
2 changes: 1 addition & 1 deletion maint/local_python/binding_f77.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def dump_fortran_line(s):
def check_func_directives(func):
if 'dir' in func and func['dir'] == "mpit":
func['_skip_fortran'] = 1
elif RE.match(r'mpix_(grequest_|type_iov|async_|(comm|file|win|session)_create_errhandler_x|op_create_x)', func['name'], re.IGNORECASE):
elif RE.match(r'mpix_(grequest_|type_iov|async_|(comm|file|win|session)_create_errhandler_x|op_create_x|continue)', func['name'], re.IGNORECASE):
func['_skip_fortran'] = 1
elif RE.match(r'mpi_\w+_(f|f08|c)2(f|f08|c)$', func['name'], re.IGNORECASE):
# implemented in mpi_f08_types.f90
Expand Down
2 changes: 1 addition & 1 deletion maint/local_python/binding_f90.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def dump_f90_sizeofs():
def check_func_directives(func):
if 'dir' in func and func['dir'] == "mpit":
func['_skip_fortran'] = 1
elif RE.match(r'mpix_(grequest_|type_iov|async_)', func['name'], re.IGNORECASE):
elif RE.match(r'mpix_(grequest_|type_iov|async_|continue)', func['name'], re.IGNORECASE):
func['_skip_fortran'] = 1
elif RE.match(r'mpi_attr_', func['name'], re.IGNORECASE):
func['_skip_fortran'] = 1
Expand Down
46 changes: 46 additions & 0 deletions src/binding/c/continue_api.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# vim: set ft=c:

MPIX_Continue_cb_function:
.return: NOTHING
error_code: ERROR_CODE
user_data: BUFFER

MPIX_Continue_init:
.desc: Creates a new continuation request
flags: ARRAY_LENGTH [flags]
max_poll: ARRAY_LENGTH_NNI [maximum number of continuations to execute when
testing, or 0 for no limit]
info: INFO, [info argument]
cont_req: REQUEST, direction=out, [continuation request created]

MPIX_Continue:
.desc: Attach a continuation to the operation represented by the request
op_request and register it with the continuation request cont_request
op_request: REQUEST, direction=inout, [the request associated with the active operation]
cb: FUNCTION, func_type=MPIX_Continue_cb_function, [callback to be invoked once the operation is complete]
cb_data: BUFFER, [pointer to a user-controlled buffer]
flags: ARRAY_LENGTH, [flags controlling aspects of the continuation]
status: STATUS, direction=inout, [status object]
cont_request: REQUEST, [continuation request]

MPIX_Continueall:
.desc: Attach a continuation callback to a set of operation requests
count: ARRAY_LENGTH_NNI, [lists length]
array_of_op_requests: REQUEST, direction=inout, length=count, [array of requests]
cb: FUNCTION, func_type=MPIX_Continue_cb_function, [the continuation callback function]
cb_data: BUFFER, [the argument passed to the callback]
flags: ARRAY_LENGTH, [flags controlling aspects of the continuation]
array_of_statuses: STATUS, direction=out, length=*, pointer=False, [array of status objects]
cont_request: REQUEST, [the continuation request]
{
mpi_errno = MPIR_Continueall_impl(count, request_ptrs, cb, cb_data, flags, array_of_statuses,
cont_request_ptr);
if (mpi_errno) {
goto fn_fail;
}
if (!(flags & MPIX_CONT_REQBUF_VOLATILE)) {
for (int i = 0; i < count; ++i) {
array_of_op_requests[i] = MPI_REQUEST_NULL;
}
}
}
2 changes: 1 addition & 1 deletion src/include/Makefile.mk
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ noinst_HEADERS += \
src/include/mpir_refcount_global.h \
src/include/mpir_refcount_vci.h \
src/include/mpir_refcount_single.h \
src/include/mpir_refcount.h \
src/include/mpir_atomic_flag.h \
src/include/mpir_assert.h \
src/include/mpir_misc_post.h \
src/include/mpir_type_defs.h \
Expand Down
11 changes: 11 additions & 0 deletions src/include/mpi.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,14 @@ enum MPIR_Combiner_enum {
#define MPIX_GPU_SUPPORT_ZE (1)
#define MPIX_GPU_SUPPORT_HIP (2)

/* Continue flags */
#define MPIX_CONT_REQBUF_VOLATILE 1<<0
#define MPIX_CONT_PERSISTENT 1<<1
#define MPIX_CONT_POLL_ONLY 1<<2
#define MPIX_CONT_DEFER_COMPLETE 1<<3
#define MPIX_CONT_INVOKE_FAILED 1<<4
#define MPIX_CONT_IMMEDIATE 1<<5

/* feature advertisement */
#define MPIIMPL_ADVERTISES_FEATURES 1
#define MPIIMPL_HAVE_MPI_INFO 1
Expand Down Expand Up @@ -843,6 +851,9 @@ typedef int (MPI_Datarep_extent_function)(MPI_Datatype datatype, MPI_Aint *,
typedef int (MPI_Datarep_conversion_function_c)(void *, MPI_Datatype, MPI_Count,
void *, MPI_Offset, void *);

/* Typedefs for continuation callback */
typedef int (MPIX_Continue_cb_function)(int error_code, void *user_data);

/* Make the C names for the dup function mixed case.
This is required for systems that use all uppercase names for Fortran
externals. */
Expand Down
1 change: 1 addition & 0 deletions src/include/mpiimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ typedef struct MPIR_Stream MPIR_Stream;
#include "mpir_assert.h"
#include "mpir_pointers.h"
#include "mpir_refcount.h"
#include "mpir_atomic_flag.h"
#include "mpir_mem.h"
#include "mpir_info.h"
#include "mpir_errcodes.h"
Expand Down
69 changes: 69 additions & 0 deletions src/include/mpir_atomic_flag.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (C) by Argonne National Laboratory
* See COPYRIGHT in top-level directory
*/

#ifndef MPIR_ATOMIC_FLAG_H_INCLUDED
#define MPIR_ATOMIC_FLAG_H_INCLUDED

#include "mpi.h"
#include "mpichconf.h"

#if MPICH_THREAD_LEVEL == MPI_THREAD_MULTIPLE && \
MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI

typedef MPL_atomic_int_t MPIR_atomic_flag_t;

static inline void MPIR_atomic_flag_set(MPIR_atomic_flag_t * flag_ptr, int val)
{
MPL_atomic_relaxed_store_int(flag_ptr, val);
}

static inline int MPIR_atomic_flag_get(MPIR_atomic_flag_t * flag_ptr)
{
return MPL_atomic_relaxed_load_int(flag_ptr);
}

static inline int MPIR_atomic_flag_swap(MPIR_atomic_flag_t * flag_ptr, int val)
{
return MPL_atomic_swap_int(flag_ptr, val);
}

static inline int MPIR_atomic_flag_cas(MPIR_atomic_flag_t * flag_ptr, int old_val, int new_val)
{
return MPL_atomic_cas_int(flag_ptr, old_val, new_val);
}

#else

typedef int MPIR_atomic_flag_t;

static inline void MPIR_atomic_flag_set(MPIR_atomic_flag_t * flag_ptr, int val)
{
*flag_ptr = val;
}

static inline int MPIR_atomic_flag_get(MPIR_atomic_flag_t * flag_ptr)
{
return *flag_ptr;
}

static inline int MPIR_atomic_flag_swap(MPIR_atomic_flag_t * flag_ptr, int val)
{
int ret = *flag_ptr;
*flag_ptr = val;
return ret;
}

static inline int MPIR_atomic_flag_cas(MPIR_atomic_flag_t * flag_ptr, int old_val, int new_val)
{
int ret = *flag_ptr;
if (*flag_ptr == old_val) {
*flag_ptr = new_val;
}
return ret;
}

#endif

#endif /* MPIR_ATOMIC_FLAG_H_INCLUDED */
8 changes: 5 additions & 3 deletions src/include/mpir_err.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,7 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in
}

#define MPIR_ERRTEST_STARTREQ(reqp,err) \
if ((reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_SEND && (reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_RECV \
&& (reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_COLL \
&& (reqp)->kind != MPIR_REQUEST_KIND__PART_SEND && (reqp)->kind != MPIR_REQUEST_KIND__PART_RECV) { \
if (!MPIR_Request_is_persistent(reqp)) { \
err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \
MPI_ERR_REQUEST, "**requestinvalidstart", 0); \
goto fn_fail; \
Expand All @@ -394,6 +392,10 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in
err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \
MPI_ERR_REQUEST, "**requestpartactive", 0); \
goto fn_fail; \
} else if (((reqp)->kind == MPIR_REQUEST_KIND__CONTINUE) && MPIR_Cont_request_is_active(reqp)) { \
err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \
MPI_ERR_REQUEST, "**requestpartactive", 0); \
goto fn_fail; \
}

#define MPIR_ERRTEST_PREADYREQ(reqp,err) \
Expand Down
Loading