Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coll: add collattr to collective interface #6350

Closed
wants to merge 11 commits into from
25 changes: 17 additions & 8 deletions maint/gen_coll.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def dump_mpir_impl_persistent(name):
dump_split(0, "int MPIR_%s_impl(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
if not re.match(r'Neighbor_', Name):
G.out.append("int collattr = 0;")
G.out.append("")
G.out.append("MPIR_Request *req = MPIR_Request_create(MPIR_REQUEST_KIND__PREQUEST_COLL);")
G.out.append("MPIR_ERR_CHKANDJUMP(!req, mpi_errno, MPI_ERR_OTHER, \"**nomem\");")
Expand Down Expand Up @@ -646,12 +648,13 @@ def get_algo_args(args, algo, kind):
if 'extra_params' in algo:
algo_args += ", " + get_algo_extra_args(algo, kind)

if not re.match(r'i?neighbor', algo['func-commkind']):
algo_args += ', collattr'

if algo['name'].startswith('tsp_'):
algo_args += ", *sched_p"
elif algo['func-commkind'].startswith('i'):
algo_args += ", *sched_p"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_args += ", errflag"

return algo_args

Expand All @@ -660,12 +663,13 @@ def get_algo_params(params, algo):
if 'extra_params' in algo:
algo_params += ", " + get_algo_extra_params(algo)

if not re.match(r'i?neighbor_', algo['func-commkind']):
algo_params += ', int collattr'

if algo['name'].startswith('tsp_'):
algo_params += ", MPIR_TSP_sched_t sched"
elif algo['func-commkind'].startswith('i'):
algo_params += ", MPIR_Sched_t s"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_params += ", MPIR_Errflag_t errflag"

return algo_params

Expand All @@ -680,9 +684,12 @@ def get_algo_name(algo):

def get_func_params(params, name, kind):
func_params = params

if not (name.startswith('neighbor_') or kind == "persistent"):
func_params += ', int collattr'

if kind == "blocking":
if not name.startswith('neighbor_'):
func_params += ", MPIR_Errflag_t errflag"
pass
elif kind == "nonblocking":
func_params += ", MPIR_Request ** request"
elif kind == "persistent":
Expand All @@ -700,9 +707,11 @@ def get_func_params(params, name, kind):

def get_func_args(args, name, kind):
func_args = args
if not (name.startswith('neighbor_') or kind == "persistent"):
func_args += ", collattr"

if kind == "blocking":
if not name.startswith('neighbor_'):
func_args += ", errflag"
pass
elif kind == "nonblocking":
func_args += ", request"
elif kind == "persistent":
Expand Down
28 changes: 18 additions & 10 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,9 +1366,14 @@ def push_impl_decl(func, impl_name=None):
if func['_impl_param_list']:
params = ', '.join(func['_impl_param_list'])
if func['dir'] == 'coll':
# block collective use an extra errflag
if not RE.match(r'MPI_(I.*|Neighbor.*|.*_init)$', func['name']):
params = params + ", MPIR_Errflag_t errflag"
if RE.match(r'MPI_(Neighbor.*|Ineighbor.*|.*_init)$', func['name']):
pass
elif RE.match(r'MPI_I\w+', func['name']):
# non-blocking
params = re.sub(r'MPIR_Request', 'int collattr, MPIR_Request', params)
else:
# blocking
params = params + ", int collattr"
else:
params="void"

Expand All @@ -1378,9 +1383,6 @@ def push_impl_decl(func, impl_name=None):
G.impl_declares.append("int %s(%s);" % (mpir_name, params))
# dump MPIR_Xxx_impl(...)
G.impl_declares.append("int %s(%s);" % (impl_name, params))
if func['dir'] == 'coll':
mpir_name = re.sub(r'^MPIX?_', 'MPIR_', func['name'])
G.impl_declares.append("int %s(%s);" % (mpir_name, params))

def dump_CHECKENUM(var, errname, t, type="ENUM"):
val_list = t.split()
Expand All @@ -1401,6 +1403,15 @@ def dump_body_coll(func):
mpir_name = re.sub(r'^MPIX?_', 'MPIR_', func['name'])

args = ", ".join(func['_impl_arg_list'])
if RE.match(r'mpi_(i?neighbor_.*|.*_init)$', func['name'], re.IGNORECASE):
pass
elif RE.match(r'mpi_i', func['name'], re.IGNORECASE):
# non-blocking
args = re.sub(r'&request_ptr', 'MPIR_COLL_ATTR_NONE, &request_ptr', args)
else:
# blocking
args += ", MPIR_COLL_ATTR_NONE"


if RE.match(r'MPI_(I.*|.*_init)$', func['name'], re.IGNORECASE):
# non-blocking collectives
Expand All @@ -1411,12 +1422,9 @@ def dump_body_coll(func):
G.out.append(" request_ptr = MPIR_Request_create_complete(MPIR_REQUEST_KIND__COLL);")
G.out.append("}")
G.out.append("*request = request_ptr->handle;")
elif RE.match(r'mpi_neighbor_', func['name'], re.IGNORECASE):
dump_line_with_break("mpi_errno = %s(%s);" % (mpir_name, args))
dump_error_check("")
else:
# blocking collectives
dump_line_with_break("mpi_errno = %s(%s, MPIR_ERR_NONE);" % (mpir_name, args))
dump_line_with_break("mpi_errno = %s(%s);" % (mpir_name, args))
dump_error_check("")

def dump_coll_v_swap(func):
Expand Down
70 changes: 57 additions & 13 deletions src/include/mpir_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,47 @@
#include "coll_impl.h"
#include "coll_algos.h"

/* collective attr bits allocation:
* 0-7: errflag
* 8-15: subcomm type
* 16-23: subcomm index
*/

#define MPIR_COLL_ATTR_NONE 0
#define MPIR_COLL_ATTR_GET_ERRFLAG(attr) ((attr) & 0xff)
#define MPIR_COLL_ATTR_GET_SUBCOMM_TYPE(attr) (((attr) >> 8) & 0xff)
#define MPIR_COLL_ATTR_GET_SUBCOMM_INDEX(attr) (((attr) >> 16) & 0xff)

#define MPIR_COLL_SUBCOMM_TYPE_NONE 0
#define MPIR_COLL_SUBCOMM_TYPE_CHILD 1
#define MPIR_COLL_SUBCOMM_TYPE_ROOTS 2

#define MPIR_COLL_GET_RANK_SIZE(comm_ptr, collattr, rank_, size_) \
do { \
int subcomm_type = MPIR_COLL_ATTR_GET_SUBCOMM_TYPE(collattr); \
if (subcomm_type) { \
int subcomm_index = MPIR_COLL_ATTR_GET_SUBCOMM_INDEX(collattr); \
switch(subcomm_type) { \
case MPIR_COLL_SUBCOMM_TYPE_CHILD: \
rank_ = comm_ptr->child_subcomm[subcomm_index].rank; \
size_ = comm_ptr->child_subcomm[subcomm_index].size; \
break; \
case MPIR_COLL_SUBCOMM_TYPE_ROOTS: \
rank_ = comm_ptr->roots_subcomm[subcomm_index].rank; \
size_ = comm_ptr->roots_subcomm[subcomm_index].size; \
break; \
default: \
MPIR_Assert(0); \
rank_ = -1; \
size_ = 0; \
} \
} else { \
rank_ = (comm_ptr)->rank; \
size_ = (comm_ptr)->local_size; \
} \
} while (0)


/* During init, not all algorithms are safe to use. For example, the csel
* may not have been initialized. We define a set of fallback routines that
* are safe to use during init. They are all intra algorithms.
Expand All @@ -28,40 +69,43 @@ int MPIC_Wait(MPIR_Request * request_ptr);
int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status);

int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, int collattr);
int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag,
MPIR_Comm * comm_ptr, MPI_Status * status);
MPIR_Comm * comm_ptr, int collattr, MPI_Status * status);
int MPIC_Ssend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, int collattr);
int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
int dest, int sendtag, void *recvbuf, MPI_Aint recvcount,
MPI_Datatype recvtype, int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPI_Status * status, int collattr);
int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int sendtag,
int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPI_Status * status, int collattr);
int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPIR_Request ** request, int collattr);
int MPIC_Issend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPIR_Request ** request, int collattr);
int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source,
int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request);
int tag, MPIR_Comm * comm_ptr, int collattr, MPIR_Request ** request);
int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status statuses[]);

int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op);

int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int collattr);

/* TSP auto */
int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
MPI_Datatype datatype, MPI_Op op,
MPIR_Comm * comm, MPIR_TSP_sched_t sched);
MPIR_Comm * comm, int collattr,
MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype,
int root, MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sched);
int root, MPIR_Comm * comm_ptr, int collattr,
MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int collattr, MPIR_TSP_sched_t sched);
int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
MPI_Datatype datatype, MPI_Op op, int root,
MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched);
MPIR_Comm * comm_ptr, int collattr,
MPIR_TSP_sched_t sched);
#endif /* MPIR_COLL_H_INCLUDED */
14 changes: 14 additions & 0 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ enum MPIR_COMM_HINT_PREDEFINED_t {
MPIR_COMM_HINT_PREDEFINED_COUNT
};

/* lightweight comm struct defines a sub-comm for collectives */
typedef struct MPIR_sub_comm {
int rank; /* rank in this sub-comm */
int size; /* size of this sub-comm */
int *ranklist; /* A list of ranks (as in parent comm) in this sub-comm.
* NULL is interpreted as identity array */
} MPIR_sub_comm;

#define MPIR_MAX_SUBCOMM_DEPTH 1

/*S
MPIR_Comm - Description of the Communicator data structure

Expand Down Expand Up @@ -180,6 +190,9 @@ struct MPIR_Comm {
struct MPIR_Comm *node_comm; /* Comm of processes in this comm that are on
* the same node as this process. */
struct MPIR_Comm *node_roots_comm; /* Comm of root processes for other nodes. */
int subcomm_depth;
MPIR_sub_comm child_subcomm[MPIR_MAX_SUBCOMM_DEPTH];
MPIR_sub_comm roots_subcomm[MPIR_MAX_SUBCOMM_DEPTH];
int *intranode_table; /* intranode_table[i] gives the rank in
* node_comm of rank i in this comm or -1 if i
* is not in this process' node_comm.
Expand Down Expand Up @@ -264,6 +277,7 @@ struct MPIR_Comm {
MPID_DEV_COMM_DECL
#endif
};

extern MPIR_Object_alloc_t MPIR_Comm_mem;

/* this function should not be called by normal code! */
Expand Down
10 changes: 5 additions & 5 deletions src/include/mpir_nbc.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ int MPIR_Sched_start(MPIR_Sched_t s, MPIR_Comm * comm, MPIR_Request ** req);

/* send and recv take a comm ptr to enable hierarchical collectives */
int MPIR_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest,
MPIR_Comm * comm, MPIR_Sched_t s);
MPIR_Comm * comm, int collattr, MPIR_Sched_t s);
int MPIR_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, MPIR_Comm * comm,
MPIR_Sched_t s);
int collattr, MPIR_Sched_t s);

/* just like MPI_Issend, can't complete until the matching recv is posted */
int MPIR_Sched_ssend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest,
MPIR_Comm * comm, MPIR_Sched_t s);
MPIR_Comm * comm, int collattr, MPIR_Sched_t s);

int MPIR_Sched_reduce(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op, MPIR_Sched_t s);
Expand Down Expand Up @@ -104,12 +104,12 @@ int MPIR_Sched_barrier(MPIR_Sched_t s);
* is no known use case. The recv count is just an upper bound, not an exact
* amount to be received, so an oversized recv is used instead of deferral. */
int MPIR_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype datatype, int dest,
MPIR_Comm * comm, MPIR_Sched_t s);
MPIR_Comm * comm, int collattr, MPIR_Sched_t s);
/* Just like MPIR_Sched_recv except it populates the given status object with
* the received count and error information, much like a normal recv. Often
* useful in conjunction with MPIR_Sched_send_defer. */
int MPIR_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, int src,
MPIR_Comm * comm, MPI_Status * status, MPIR_Sched_t s);
MPIR_Comm * comm, MPI_Status * status, int collattr, MPIR_Sched_t s);

/* buffer management, fancy reductions, etc */
int MPIR_Sched_cb(MPIR_Sched_cb_t * cb_p, void *cb_state, MPIR_Sched_t s);
Expand Down
4 changes: 2 additions & 2 deletions src/mpi/coll/algorithms/recexchalgo/recexchalgo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n
size_t recv_extent, const MPI_Aint * recvcounts,
const MPI_Aint * displs, MPI_Datatype recvtype,
int is_dist_halving, MPIR_Comm * comm,
MPIR_TSP_sched_t sched);
int collattr, MPIR_TSP_sched_t sched);
int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void *tmp_recvbuf,
const MPI_Aint * recvcounts,
MPI_Aint * displs, MPI_Datatype datatype,
Expand All @@ -36,6 +36,6 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void *
int step2_nphases, int **step2_nbrs,
int rank, int nranks, int sink_id,
int is_out_vtcs, int *reduce_id_,
MPIR_TSP_sched_t sched);
int collattr, MPIR_TSP_sched_t sched);

#endif /* RECEXCHALGO_H_INCLUDED */
7 changes: 3 additions & 4 deletions src/mpi/coll/allgather/allgather_allcomm_nb.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int collattr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Request *req_ptr = NULL;

/* just call the nonblocking version and wait on it */
mpi_errno =
MPIR_Iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm_ptr,
&req_ptr);
mpi_errno = MPIR_Iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype,
comm_ptr, collattr, &req_ptr);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIC_Wait(req_ptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount,
MPI_Datatype sendtype, void *recvbuf,
MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int collattr)
{
int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
MPI_Aint sendtype_sz;
void *tmp_buf = NULL;
MPIR_Comm *newcomm_ptr = NULL;
Expand Down Expand Up @@ -48,7 +49,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint

if (sendcount != 0) {
mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz,
MPI_BYTE, 0, newcomm_ptr, errflag);
MPI_BYTE, 0, newcomm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

Expand All @@ -59,31 +60,31 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPI_BYTE, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

/* receive bcast from right */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}
} else {
/* receive bcast from left */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

/* bcast to left */
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPI_BYTE, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}
}
Expand Down
Loading