diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 10fd49d9086..cf77d6fea5a 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -94,6 +94,8 @@ def dump_allcomm_auto_blocking(name): dump_open("MPIR_Csel_coll_sig_s coll_sig = {") G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME) G.out.append(".comm_ptr = comm_ptr,") + if not re.match(r'i?neighbor_', func_name, re.IGNORECASE): + G.out.append(".coll_group = coll_group,") for p in func['parameters']: if not re.match(r'comm$', p['name']): G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name'])) @@ -163,12 +165,16 @@ def dump_allcomm_sched_auto(name): dump_split(0, "int MPIR_%s_allcomm_sched_auto(%s)" % (Name, func_params)) dump_open('{') G.out.append("int mpi_errno = MPI_SUCCESS;") + if re.match(r'Ineighbor_', Name): + G.out.append("int coll_group = MPIR_SUBGROUP_NONE;") G.out.append("") # -- Csel_search dump_open("MPIR_Csel_coll_sig_s coll_sig = {") G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME) G.out.append(".comm_ptr = comm_ptr,") + if not re.match(r'i?neighbor_', func_name, re.IGNORECASE): + G.out.append(".coll_group = coll_group,") for p in func['parameters']: if not re.match(r'comm$', p['name']): G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name'])) @@ -363,6 +369,8 @@ def dump_cases(commkind): dump_split(0, "int MPIR_%s_sched_impl(%s)" % (Name, func_params)) dump_open('{') G.out.append("int mpi_errno = MPI_SUCCESS;") + if re.match(r'Ineighbor_', Name): + G.out.append("int coll_group = MPIR_SUBGROUP_NONE;") G.out.append("") dump_open("if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {") @@ -552,20 +560,22 @@ def dump_fallback(algo): elif a == "noinplace": cond_list.append("sendbuf != MPI_IN_PLACE") elif a == "power-of-two": - cond_list.append("comm_ptr->local_size == comm_ptr->coll.pof2") + cond_list.append("MPL_is_pof2(MPIR_Coll_size(comm_ptr, coll_group))") elif a == "size-ge-pof2": - cond_list.append("count >= comm_ptr->coll.pof2") + cond_list.append("count >= MPL_pof2(MPIR_Coll_size(comm_ptr, coll_group))") elif a == "commutative": cond_list.append("MPIR_Op_is_commutative(op)") elif a== "builtin-op": cond_list.append("HANDLE_IS_BUILTIN(op)") elif a == "parent-comm": - cond_list.append("MPIR_Comm_is_parent_comm(comm_ptr)") + cond_list.append("MPIR_Comm_is_parent_comm(comm_ptr, coll_group)") elif a == "node-consecutive": cond_list.append("MPII_Comm_is_node_consecutive(comm_ptr)") elif a == "displs-ordered": # assume it's allgatherv cond_list.append("MPII_Iallgatherv_is_displs_ordered(comm_ptr->local_size, recvcounts, displs)") + elif a == "nogroup": + cond_list.append("coll_group == MPIR_SUBGROUP_NONE") else: raise Exception("Unsupported restrictions - %s" % a) (func_name, commkind) = algo['func-commkind'].split('-') @@ -644,6 +654,9 @@ def get_algo_extra_params(algo): # additional wrappers def get_algo_args(args, algo, kind): algo_args = args + if not re.match(r'i?neighbor_', algo['func-commkind']): + algo_args += ", coll_group" + if 'extra_params' in algo: algo_args += ", " + get_algo_extra_args(algo, kind) @@ -658,6 +671,9 @@ def get_algo_args(args, algo, kind): def get_algo_params(params, algo): algo_params = params + if not re.match(r'i?neighbor_', algo['func-commkind']): + algo_params += ", int coll_group" + if 'extra_params' in algo: algo_params += ", " + get_algo_extra_params(algo) @@ -681,6 +697,8 @@ def get_algo_name(algo): def get_func_params(params, name, kind): func_params = params + if not name.startswith('neighbor_'): + func_params += ", int coll_group" if kind == "blocking": if not name.startswith('neighbor_'): func_params += ", MPIR_Errflag_t errflag" @@ -701,6 +719,8 @@ def get_func_params(params, name, kind): def get_func_args(args, name, kind): func_args = args + if not name.startswith('neighbor_'): + func_args += ", coll_group" if kind == "blocking": if not name.startswith('neighbor_'): func_args += ", errflag" diff --git a/maint/local_python/binding_c.py b/maint/local_python/binding_c.py index f219ee4e194..f712618232a 100644 --- a/maint/local_python/binding_c.py +++ b/maint/local_python/binding_c.py @@ -1686,6 +1686,8 @@ def push_impl_decl(func, impl_name=None): if func['_impl_param_list']: params = ', '.join(func['_impl_param_list']) if func['dir'] == 'coll': + if not RE.match(r'MPI_(Ineighbor|Neighbor)', func['name']): + params = params.replace('comm_ptr', 'comm_ptr, int coll_group') # block collective use an extra errflag if not RE.match(r'MPI_(I.*|Neighbor.*|.*_init)$', func['name']): params = params + ", MPIR_Errflag_t errflag" @@ -1726,6 +1728,8 @@ def dump_body_coll(func): mpir_name = re.sub(r'^MPIX?_', 'MPIR_', func['name']) args = ", ".join(func['_impl_arg_list']) + if not RE.match(r'MPI_(Ineighbor|Neighbor)', func['name']): + args = args.replace('comm_ptr', 'comm_ptr, MPIR_SUBGROUP_NONE') if RE.match(r'MPI_(I.*|.*_init)$', func['name'], re.IGNORECASE): # non-blocking collectives @@ -1956,6 +1960,7 @@ def dump_body_reduce_equal(func): args = ", ".join(func['_impl_arg_list']) args = re.sub(r'recvbuf, ', '', args) args = re.sub(r'op, ', 'recvbuf, ', args) + args += ", MPIR_SUBGROUP_NONE" dump_line_with_break("mpi_errno = %s(%s);" % (impl, args)) dump_error_check("") diff --git a/src/binding/c/comm_api.txt b/src/binding/c/comm_api.txt index 58dbabf4169..a827c66e7ff 100644 --- a/src/binding/c/comm_api.txt +++ b/src/binding/c/comm_api.txt @@ -301,7 +301,7 @@ MPI_Intercomm_merge: * error to make */ acthigh = high ? 1 : 0; /* Clamp high into 1 or 0 */ mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &acthigh, 1, MPI_INT, - MPI_SUM, intercomm_ptr->local_comm, MPIR_ERR_NONE); + MPI_SUM, intercomm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* acthigh must either == 0 or the size of the local comm */ if (acthigh != 0 && acthigh != intercomm_ptr->local_size) { diff --git a/src/include/mpir_coll.h b/src/include/mpir_coll.h index 4038a272017..8e8cd0111cc 100644 --- a/src/include/mpir_coll.h +++ b/src/include/mpir_coll.h @@ -8,6 +8,52 @@ #include "coll_impl.h" #include "coll_algos.h" +#include "mpir_threadcomm.h" + +#ifdef ENABLE_THREADCOMM +#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ + MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \ + MPIR_Assert(threadcomm); \ + int intracomm_size = (comm)->local_size; \ + size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \ + rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \ + } while (0) +#else +#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ + MPIR_Assert(0); \ + size_ = 0; \ + rank_ = -1; \ + } while (0) +#endif + +#define MPIR_COLL_RANK_SIZE(comm, coll_group, rank_, size_) do { \ + if (coll_group == MPIR_SUBGROUP_NONE) { \ + rank_ = (comm)->rank; \ + size_ = (comm)->local_size; \ + } else if (coll_group == MPIR_SUBGROUP_THREADCOMM) { \ + MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_); \ + } else { \ + rank_ = (comm)->subgroups[coll_group].rank; \ + size_ = (comm)->subgroups[coll_group].size; \ + } \ + } while (0) + +/* sometime it is convenient to just get the rank or size */ +static inline int MPIR_Coll_size(MPIR_Comm * comm, int coll_group) +{ + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + (void) rank; + return size; +} + +static inline int MPIR_Coll_rank(MPIR_Comm * comm, int coll_group) +{ + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + (void) size; + return rank; +} /* During init, not all algorithms are safe to use. For example, the csel * may not have been initialized. We define a set of fallback routines that @@ -28,36 +74,41 @@ int MPIC_Wait(MPIR_Request * request_ptr); int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status); int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag); + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag); int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status); + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status); int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag); -int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, - int dest, int sendtag, - int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag); + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, + MPIR_Errflag_t errflag); +int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int sendtag, + int source, int recvtag, MPIR_Comm * comm_ptr, int coll_group, + MPI_Status * status, MPIR_Errflag_t errflag); int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag); -int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, - int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request); + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request, + MPIR_Errflag_t errflag); +int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request); int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status * statuses); int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op); -int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag); +int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag); /* TSP auto */ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched); + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched); int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched); -int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sched); + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched); +int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched); int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched); + MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched); #endif /* MPIR_COLL_H_INCLUDED */ diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 58cb5b2af85..1a2292199f0 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -101,6 +101,51 @@ enum MPIR_COMM_HINT_PREDEFINED_t { MPIR_COMM_HINT_PREDEFINED_COUNT }; +/* MPIR_Subgroup is similar to MPIR_Group, but only used to describe subgroups within + * an intra communicator. The proc_table refers to ranks within the communicator. + * It is only used internally for group collectives. + */ +typedef struct MPIR_Subgroup { + int size; + int rank; + int *proc_table; /* can be NULL if the group is trivial */ +} MPIR_Subgroup; + +#define MPIR_MAX_SUBGROUPS 16 + +/* reserved subgroup indexes */ +enum { + MPIR_SUBGROUP_THREADCOMM = -1, + MPIR_SUBGROUP_NONE = 0, + MPIR_SUBGROUP_NODE, /* i.e. nodecomm */ + MPIR_SUBGROUP_NODE_CROSS, /* node_roots_comm, node_rank_1_comm, ... */ + MPIR_SUBGROUP_NUMA1, /* 1-level below node in topology */ + MPIR_SUBGROUP_NUMA1_CROSS, /* cross-link group at NUMA1 within NODE */ + MPIR_SUBGROUP_NUMA2, /* and so on */ + MPIR_SUBGROUP_NUMA2_CROSS, + MPIR_SUBGROUP_NUM_RESERVED, +}; + +/* macros to create dynamic subgroups. + * It is expected to fillout the proc_table after MPIR_COMM_PUSH_SUBGROUP. + */ +#define MPIR_COMM_PUSH_SUBGROUP(comm, _size, _rank, newgrp, proc_table_out) \ + do { \ + (newgrp) = (comm)->num_subgroups++; \ + MPIR_Assert((comm)->num_subgroups < MPIR_MAX_SUBGROUPS); \ + (comm)->subgroups[newgrp].size = _size; \ + (comm)->subgroups[newgrp].rank = _rank; \ + (proc_table_out) = MPL_malloc((_size) * sizeof(int), MPL_MEM_OTHER); \ + (comm)->subgroups[newgrp].proc_table = (proc_table_out); \ + } while (0) + +#define MPIR_COMM_POP_SUBGROUP(comm) \ + do { \ + int i = --(comm)->num_subgroups; \ + MPIR_Assert(i > 0); \ + MPL_free((comm)->subgroups[i].proc_table); \ + } while (0) + /*S MPIR_Comm - Description of the Communicator data structure @@ -187,7 +232,8 @@ struct MPIR_Comm { int *internode_table; /* internode_table[i] gives the rank in * node_roots_comm of rank i in this comm. * It is of size 'local_size'. */ - int node_count; /* number of nodes this comm is spread over */ + int num_local; /* number of procs in this comm on local node */ + int num_external; /* number of nodes this comm is spread over */ int is_low_group; /* For intercomms only, this boolean is * set for all members of one of the @@ -196,6 +242,8 @@ struct MPIR_Comm { * intercommunicator collective operations * that wish to use half-duplex operations * to implement a full-duplex operation */ + MPIR_Subgroup subgroups[MPIR_MAX_SUBGROUPS]; + int num_subgroups; struct MPIR_Comm *comm_next; /* Provides a chain through all active * communicators */ @@ -222,9 +270,6 @@ struct MPIR_Comm { * use int array for fast access */ struct { - int pof2; /* Nearest (smaller than or equal to) power of 2 - * to the number of ranks in the communicator. - * To be used during collective communication */ int pofk[MAX_RADIX - 1]; int k[MAX_RADIX - 1]; int step1_sendto[MAX_RADIX - 1]; @@ -234,18 +279,9 @@ struct MPIR_Comm { int **step2_nbrs[MAX_RADIX - 1]; int nbrs_defined[MAX_RADIX - 1]; void **recexch_allreduce_nbr_buffer; - int topo_aware_tree_root; - int topo_aware_tree_k; - MPIR_Treealgo_tree_t *topo_aware_tree; - int topo_aware_k_tree_root; - int topo_aware_k_tree_k; - MPIR_Treealgo_tree_t *topo_aware_k_tree; - int topo_wave_tree_root; - int topo_wave_tree_overhead; - int topo_wave_tree_lat_diff_groups; - int topo_wave_tree_lat_diff_switches; - int topo_wave_tree_lat_same_switches; - MPIR_Treealgo_tree_t *topo_wave_tree; + + MPIR_Treealgo_tree_t *cached_tree; + MPIR_Treealgo_param_t cached_tree_param; } coll; void *csel_comm; /* collective selector handle */ @@ -374,7 +410,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co int MPIR_Comm_create_subcomms(MPIR_Comm * comm); int MPIR_Comm_commit(MPIR_Comm *); -int MPIR_Comm_is_parent_comm(MPIR_Comm *); +int MPIR_Comm_is_parent_comm(MPIR_Comm * comm, int coll_group); /* peer intercomm is an internal 1-to-1 intercomm used for connecting dynamic processes */ int MPIR_peer_intercomm_create(MPIR_Context_id_t context_id, MPIR_Context_id_t recvcontext_id, diff --git a/src/include/mpir_csel.h b/src/include/mpir_csel.h index 07f061e98ef..c25193bb2fe 100644 --- a/src/include/mpir_csel.h +++ b/src/include/mpir_csel.h @@ -60,6 +60,7 @@ typedef enum { typedef struct { MPIR_Csel_coll_type_e coll_type; MPIR_Comm *comm_ptr; + int coll_group; union { struct { diff --git a/src/include/mpir_nbc.h b/src/include/mpir_nbc.h index eb08995fe04..710521d9a5d 100644 --- a/src/include/mpir_nbc.h +++ b/src/include/mpir_nbc.h @@ -45,7 +45,7 @@ /* Open question: should tag allocation be rolled into Sched_start? Keeping it * separate potentially allows more parallelism in the future, but it also * pushes more work onto the clients of this interface. */ -int MPIR_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag); +int MPIR_Sched_next_tag(MPIR_Comm * comm_ptr, int coll_group, int *tag); void MPIR_Sched_set_tag(MPIR_Sched_t s, int tag); /* the device must provide a typedef for MPIR_Sched_t in mpidpre.h */ @@ -69,13 +69,13 @@ int MPIR_Sched_start(MPIR_Sched_t s, MPIR_Comm * comm, MPIR_Request ** req); /* send and recv take a comm ptr to enable hierarchical collectives */ int MPIR_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s); int MPIR_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, MPIR_Comm * comm, - MPIR_Sched_t s); + int coll_group, MPIR_Sched_t s); /* just like MPI_Issend, can't complete until the matching recv is posted */ int MPIR_Sched_ssend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s); int MPIR_Sched_reduce(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Sched_t s); @@ -104,12 +104,12 @@ int MPIR_Sched_barrier(MPIR_Sched_t s); * is no known use case. The recv count is just an upper bound, not an exact * amount to be received, so an oversized recv is used instead of deferral. */ int MPIR_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s); /* Just like MPIR_Sched_recv except it populates the given status object with * the received count and error information, much like a normal recv. Often * useful in conjunction with MPIR_Sched_send_defer. */ int MPIR_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, - MPIR_Comm * comm, MPI_Status * status, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPI_Status * status, MPIR_Sched_t s); /* buffer management, fancy reductions, etc */ int MPIR_Sched_cb(MPIR_Sched_cb_t * cb_p, void *cb_state, MPIR_Sched_t s); diff --git a/src/include/mpir_op.h b/src/include/mpir_op.h index ec04103b3e1..c940aacc06d 100644 --- a/src/include/mpir_op.h +++ b/src/include/mpir_op.h @@ -235,8 +235,8 @@ int MPIR_Op_is_commutative(MPI_Op); MPI_Datatype MPIR_Op_get_alt_datatype(MPI_Op op, MPI_Datatype datatype); int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, int root, MPIR_Comm * comm_ptr); + int *is_equal, int root, MPIR_Comm * comm_ptr, int coll_group); int MPIR_Allreduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, MPIR_Comm * comm_ptr); + int *is_equal, MPIR_Comm * comm_ptr, int coll_group); #endif /* MPIR_OP_H_INCLUDED */ diff --git a/src/include/mpir_threadcomm.h b/src/include/mpir_threadcomm.h index cda298f1f9e..90aefcf0a9b 100644 --- a/src/include/mpir_threadcomm.h +++ b/src/include/mpir_threadcomm.h @@ -110,28 +110,6 @@ MPL_STATIC_INLINE_PREFIX #endif /* ENABLE_THREADCOMM */ } -#ifdef ENABLE_THREADCOMM -#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ - MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \ - if (threadcomm) { \ - int intracomm_size = (comm)->local_size; \ - size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \ - rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \ - } else { \ - rank_ = (comm)->rank; \ - size_ = (comm)->local_size; \ - } \ - } while (0) - -#else -#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ - MPIR_Assert((comm)->threadcomm == NULL); \ - rank_ = (comm)->rank; \ - size_ = (comm)->local_size; \ - } while (0) - -#endif - #ifdef ENABLE_THREADCOMM typedef struct MPIR_threadcomm_tls_t { MPIR_Threadcomm *threadcomm; diff --git a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c index 61042c75671..759be0cd6e2 100644 --- a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c +++ b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c @@ -26,19 +26,7 @@ int MPII_Recexchalgo_comm_init(MPIR_Comm * comm) } comm->coll.recexch_allreduce_nbr_buffer = NULL; - comm->coll.topo_aware_tree_root = -1; - comm->coll.topo_aware_tree_k = 0; - comm->coll.topo_aware_tree = NULL; - comm->coll.topo_aware_k_tree_root = -1; - comm->coll.topo_aware_k_tree_k = 0; - comm->coll.topo_aware_k_tree = NULL; - comm->coll.topo_wave_tree_root = -1; - comm->coll.topo_wave_tree = NULL; - comm->coll.topo_wave_tree_overhead = 0; - comm->coll.topo_wave_tree_lat_diff_groups = 0; - comm->coll.topo_wave_tree_lat_diff_switches = 0; - comm->coll.topo_wave_tree_lat_same_switches = 0; - + comm->coll.cached_tree = NULL; return mpi_errno; } @@ -66,22 +54,10 @@ int MPII_Recexchalgo_comm_cleanup(MPIR_Comm * comm) MPL_free(comm->coll.recexch_allreduce_nbr_buffer); } - if (comm->coll.topo_aware_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree); - MPL_free(comm->coll.topo_aware_tree); - comm->coll.topo_aware_tree = NULL; - } - - if (comm->coll.topo_aware_k_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree); - MPL_free(comm->coll.topo_aware_k_tree); - comm->coll.topo_aware_k_tree = NULL; - } - - if (comm->coll.topo_wave_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree); - MPL_free(comm->coll.topo_wave_tree); - comm->coll.topo_wave_tree = NULL; + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); + MPL_free(comm->coll.cached_tree); + comm->coll.cached_tree = NULL; } return mpi_errno; diff --git a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h index 8b8fb8fc40d..eab92d1decb 100644 --- a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h +++ b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h @@ -27,15 +27,15 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n size_t recv_extent, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int is_dist_halving, MPIR_Comm * comm, - MPIR_TSP_sched_t sched); + int coll_group, MPIR_TSP_sched_t sched); int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void *tmp_recvbuf, const MPI_Aint * recvcounts, MPI_Aint * displs, MPI_Datatype datatype, MPI_Op op, size_t extent, int tag, - MPIR_Comm * comm, int k, int is_dist_halving, - int step2_nphases, int **step2_nbrs, - int rank, int nranks, int sink_id, - int is_out_vtcs, int *reduce_id_, - MPIR_TSP_sched_t sched); + MPIR_Comm * comm, int coll_group, int k, + int is_dist_halving, int step2_nphases, + int **step2_nbrs, int rank, int nranks, + int sink_id, int is_out_vtcs, + int *reduce_id_, MPIR_TSP_sched_t sched); #endif /* RECEXCHALGO_H_INCLUDED */ diff --git a/src/mpi/coll/algorithms/treealgo/treealgo.c b/src/mpi/coll/algorithms/treealgo/treealgo.c index 25a291c5f1c..d806e3b8f51 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo.c +++ b/src/mpi/coll/algorithms/treealgo/treealgo.c @@ -33,6 +33,62 @@ int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm) return mpi_errno; } +static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE && + param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED && + param->root == root && param->u.topo_aware.k == k); +} + +static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE; + param->coll_group = coll_group; + param->root = root; + param->u.topo_aware.k = k; +} + +static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K && + param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED && + param->root == root && param->u.topo_aware.k == k); +} + +static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE_K; + param->coll_group = coll_group; + param->root = root; + param->u.topo_aware.k = k; +} + +static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param, int coll_group, + int root, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_WAVE && + param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED && + param->root == root && + param->u.topo_wave.overhead == overhead && + param->u.topo_wave.lat_diff_groups == lat_diff_groups && + param->u.topo_wave.lat_diff_switches == lat_diff_switches && + param->u.topo_wave.lat_same_switches == lat_same_switches); +} + +static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param, int coll_group, + int root, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_WAVE; + param->coll_group = coll_group; + param->root = root; + param->u.topo_wave.overhead = overhead; + param->u.topo_wave.lat_diff_groups = lat_diff_groups; + param->u.topo_wave.lat_diff_switches = lat_diff_switches; + param->u.topo_wave.lat_same_switches = lat_same_switches; +} + int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int root, MPIR_Treealgo_tree_t * ct) @@ -75,7 +131,8 @@ int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int ro } -int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, int root, +int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int coll_group, int tree_type, + int k, int root, bool enable_reorder, MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; @@ -84,56 +141,53 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, switch (tree_type) { case MPIR_TREE_TYPE_TOPOLOGY_AWARE: - if (!comm->coll.topo_aware_tree || root != comm->coll.topo_aware_tree_root - || k != comm->coll.topo_aware_tree_k) { - if (comm->coll.topo_aware_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree); + if (!comm->coll.cached_tree || + !match_param_topo_aware(&comm->coll.cached_tree_param, coll_group, root, k)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_aware_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = - MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder, - comm->coll.topo_aware_tree); + MPII_Treeutil_tree_topology_aware_init(comm, coll_group, k, root, + enable_reorder, comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_aware_tree; - comm->coll.topo_aware_tree_root = root; - comm->coll.topo_aware_tree_k = k; + *ct = *comm->coll.cached_tree; + set_param_topo_aware(&comm->coll.cached_tree_param, coll_group, root, k); } - *ct = *comm->coll.topo_aware_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_aware_tree->children)[i], - MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } break; case MPIR_TREE_TYPE_TOPOLOGY_AWARE_K: - if (!comm->coll.topo_aware_k_tree || root != comm->coll.topo_aware_k_tree_root - || k != comm->coll.topo_aware_k_tree_k) { - if (comm->coll.topo_aware_k_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree); + if (!comm->coll.cached_tree || + !match_param_topo_aware_k(&comm->coll.cached_tree_param, coll_group, root, k)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_aware_k_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = - MPII_Treeutil_tree_topology_aware_k_init(comm, k, root, enable_reorder, - comm->coll.topo_aware_k_tree); + MPII_Treeutil_tree_topology_aware_k_init(comm, coll_group, k, root, + enable_reorder, + comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_aware_k_tree; - comm->coll.topo_aware_k_tree_root = root; - comm->coll.topo_aware_k_tree_k = k; + *ct = *comm->coll.cached_tree; + set_param_topo_aware_k(&comm->coll.cached_tree_param, coll_group, root, k); } - *ct = *comm->coll.topo_aware_k_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_aware_k_tree->children)[i], - MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } break; @@ -155,7 +209,7 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, } -int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, +int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int coll_group, int k, int root, bool enable_reorder, int overhead, int lat_diff_groups, int lat_diff_switches, int lat_same_switches, MPIR_Treealgo_tree_t * ct) @@ -164,34 +218,29 @@ int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, MPIR_FUNC_ENTER; - if (!comm->coll.topo_wave_tree || root != comm->coll.topo_wave_tree_root - || overhead != comm->coll.topo_wave_tree_overhead - || lat_diff_groups != comm->coll.topo_wave_tree_lat_diff_groups - || lat_diff_switches != comm->coll.topo_wave_tree_lat_diff_switches - || lat_same_switches != comm->coll.topo_wave_tree_lat_same_switches) { - if (comm->coll.topo_wave_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree); + if (!comm->coll.cached_tree || + !match_param_topo_wave(&comm->coll.cached_tree_param, coll_group, root, + overhead, lat_diff_groups, lat_diff_switches, lat_same_switches)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_wave_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } - mpi_errno = MPII_Treeutil_tree_topology_wave_init(comm, k, root, enable_reorder, overhead, - lat_diff_groups, lat_diff_switches, - lat_same_switches, - comm->coll.topo_wave_tree); + mpi_errno = + MPII_Treeutil_tree_topology_wave_init(comm, coll_group, k, root, enable_reorder, + overhead, lat_diff_groups, lat_diff_switches, + lat_same_switches, comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_wave_tree; - comm->coll.topo_wave_tree_root = root; - comm->coll.topo_wave_tree_overhead = overhead; - comm->coll.topo_wave_tree_lat_diff_groups = lat_diff_groups; - comm->coll.topo_wave_tree_lat_diff_switches = lat_diff_switches; - comm->coll.topo_wave_tree_lat_same_switches = lat_same_switches; + *ct = *comm->coll.cached_tree; + set_param_topo_wave(&comm->coll.cached_tree_param, coll_group, root, overhead, + lat_diff_groups, lat_diff_switches, lat_same_switches); } - *ct = *comm->coll.topo_wave_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_wave_tree->children)[i], MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } MPIR_FUNC_EXIT; diff --git a/src/mpi/coll/algorithms/treealgo/treealgo.h b/src/mpi/coll/algorithms/treealgo/treealgo.h index 60bac96806d..50e473f2b94 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo.h +++ b/src/mpi/coll/algorithms/treealgo/treealgo.h @@ -13,9 +13,10 @@ int MPII_Treealgo_comm_init(MPIR_Comm * comm); int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm); int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int root, MPIR_Treealgo_tree_t * ct); -int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, int root, +int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int coll_group, int tree_type, + int k, int root, bool enable_reorder, MPIR_Treealgo_tree_t * ct); -int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, +int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int coll_group, int k, int root, bool enable_reorder, int overhead, int lat_diff_groups, int lat_diff_switches, int lat_same_switches, MPIR_Treealgo_tree_t * ct); diff --git a/src/mpi/coll/algorithms/treealgo/treealgo_types.h b/src/mpi/coll/algorithms/treealgo/treealgo_types.h index 5db2c5ae931..646c63f866b 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo_types.h +++ b/src/mpi/coll/algorithms/treealgo/treealgo_types.h @@ -8,6 +8,16 @@ #include +/* enumerator for different tree types */ +typedef enum MPIR_Tree_type_t { + MPIR_TREE_TYPE_KARY = 0, + MPIR_TREE_TYPE_KNOMIAL_1, + MPIR_TREE_TYPE_KNOMIAL_2, + MPIR_TREE_TYPE_TOPOLOGY_AWARE, + MPIR_TREE_TYPE_TOPOLOGY_AWARE_K, + MPIR_TREE_TYPE_TOPOLOGY_WAVE, +} MPIR_Tree_type_t; + typedef struct { int rank; int nranks; @@ -16,4 +26,21 @@ typedef struct { UT_array *children; } MPIR_Treealgo_tree_t; +typedef struct { + MPIR_Tree_type_t type; + int coll_group; + int root; + union { + struct { + int k; + } topo_aware; + struct { + int overhead; + int lat_diff_groups; + int lat_diff_switches; + int lat_same_switches; + } topo_wave; + } u; +} MPIR_Treealgo_param_t; + #endif /* TREEALGO_TYPES_H_INCLUDED */ diff --git a/src/mpi/coll/algorithms/treealgo/treeutil.c b/src/mpi/coll/algorithms/treealgo/treeutil.c index ad9b003c369..1c55d713e41 100644 --- a/src/mpi/coll/algorithms/treealgo/treeutil.c +++ b/src/mpi/coll/algorithms/treealgo/treeutil.c @@ -472,8 +472,8 @@ static void MPII_Treeutil_hierarchy_reorder(UT_array * hierarchy, int rank) } /* tree init function is for building hierarchy of MPIR_Process::coords_dims */ -static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nranks, int root, - bool enable_reorder, UT_array * hierarchy) +static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int coll_group, int rank, int nranks, + int root, bool enable_reorder, UT_array * hierarchy) { int mpi_errno = MPI_SUCCESS; @@ -504,8 +504,12 @@ static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nran MPIR_Assert(upper_level != NULL); /* Get wrank from the communicator as the coords are stored with wrank */ + int comm_rank = r; + if (coll_group > 0) { + comm_rank = comm->subgroups[coll_group].proc_table[r]; + } uint64_t temp = 0; - MPID_Comm_get_lpid(comm, r, &temp, FALSE); + MPID_Comm_get_lpid(comm, comm_rank, &temp, FALSE); int wrank = (int) temp; if (wrank < 0) goto fn_fail; @@ -600,12 +604,13 @@ static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nran * build the hierarchy of the topology-aware tree. * For the mentioned cases see tags 'goto fn_fallback;'. */ -int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct) +int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; - int rank = comm->rank; - int nranks = comm->local_size; + + int rank, nranks; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); UT_array hierarchy[MAX_HIERARCHY_DEPTH]; int dim = MPIR_Process.coords_dims - 1; @@ -613,7 +618,8 @@ int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bo tree_ut_hierarchy_init(&hierarchy[dim]); if (k <= 0 || - 0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy)) + 0 != MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder, + hierarchy)) goto fn_fallback; ct->rank = rank; @@ -695,16 +701,18 @@ int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bo } /* Implementation of 'Topology aware' algorithm with the branching factor k */ -int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct) +int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; - int rank = comm->rank; - int nranks = comm->local_size; + + int rank, nranks; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); /* fall back to MPII_Treeutil_tree_topology_aware_init if k is less or equal to 2 */ if (k <= 2) { - return MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder, ct); + return MPII_Treeutil_tree_topology_aware_init(comm, coll_group, k, root, enable_reorder, + ct); } int *num_childrens = NULL; @@ -719,7 +727,9 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, for (dim = MPIR_Process.coords_dims - 1; dim >= 0; --dim) tree_ut_hierarchy_init(&hierarchy[dim]); - if (0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy)) + if (0 != + MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder, + hierarchy)) goto fn_fallback; ct->rank = rank; @@ -758,7 +768,7 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, /* Do an allgather to know the current num_children on each rank */ MPIR_Errflag_t errflag = MPIR_ERR_NONE; MPIR_Allgather_impl(&(ct->num_children), 1, MPI_INT, num_childrens, 1, MPI_INT, - comm, errflag); + comm, coll_group, errflag); if (mpi_errno) { goto fn_fail; } @@ -1111,13 +1121,12 @@ static int init_root_switch(const UT_array * hierarchy, heap_vector * minHeaps, } /* 'Topology Wave' implementation */ -int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - int overhead, int lat_diff_groups, int lat_diff_switches, - int lat_same_switches, MPIR_Treealgo_tree_t * ct) +int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches, + MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; - int rank = comm->rank; - int nranks = comm->local_size; int root_gr_sorted_idx = 0; int root_sw_sorted_idx = 0; int group_offset = 0; @@ -1126,6 +1135,9 @@ int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, boo UT_array hierarchy[MAX_HIERARCHY_DEPTH]; UT_array *unv_set = NULL; + int rank, nranks; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); + heap_vector minHeaps; heap_vector_init(&minHeaps); @@ -1135,7 +1147,8 @@ int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, boo tree_ut_hierarchy_init(&hierarchy[dim]); if (overhead <= 0 || lat_diff_groups <= 0 || lat_diff_switches <= 0 || lat_same_switches <= 0 || - 0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy)) + 0 != MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder, + hierarchy)) goto fn_fallback; UT_icd intpair_icd = { sizeof(pair), NULL, NULL, NULL }; diff --git a/src/mpi/coll/algorithms/treealgo/treeutil.h b/src/mpi/coll/algorithms/treealgo/treeutil.h index c628f162ca6..51864938f4d 100644 --- a/src/mpi/coll/algorithms/treealgo/treeutil.h +++ b/src/mpi/coll/algorithms/treealgo/treeutil.h @@ -123,15 +123,16 @@ int MPII_Treeutil_tree_knomial_2_init(int rank, int nranks, int k, int root, MPIR_Treealgo_tree_t * ct); /* Generate topology_aware tree information */ -int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct); +int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct); -int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct); +int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct); /* Generate topology_wave tree information */ -int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - int overhead, int lat_diff_groups, int lat_diff_switches, - int lat_same_switches, MPIR_Treealgo_tree_t * ct); +int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches, + MPIR_Treealgo_tree_t * ct); #endif /* TREEUTIL_H_INCLUDED */ diff --git a/src/mpi/coll/allgather/allgather_allcomm_nb.c b/src/mpi/coll/allgather/allgather_allcomm_nb.c index 37800564381..8a818710365 100644 --- a/src/mpi/coll/allgather/allgather_allcomm_nb.c +++ b/src/mpi/coll/allgather/allgather_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datat /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm_ptr, - &req_ptr); + coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c b/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c index fedb0ff358e..4b0932aa5db 100644 --- a/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c +++ b/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c @@ -15,7 +15,8 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root; MPI_Aint sendtype_sz; @@ -47,7 +48,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (sendcount != 0) { mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, - MPI_BYTE, 0, newcomm_ptr, errflag); + MPI_BYTE, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -58,7 +59,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size, - MPI_BYTE, root, comm_ptr, errflag); + MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -66,7 +67,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (recvcount != 0) { root = 0; mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, errflag); + recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { @@ -74,7 +75,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (recvcount != 0) { root = 0; mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, errflag); + recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -82,7 +83,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size, - MPI_BYTE, root, comm_ptr, errflag); + MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/allgather/allgather_intra_brucks.c b/src/mpi/coll/allgather/allgather_intra_brucks.c index 4e22866ebeb..21d7b1655db 100644 --- a/src/mpi/coll/allgather/allgather_intra_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_brucks.c @@ -19,7 +19,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; @@ -33,7 +34,7 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, if (((sendcount == 0) && (sendbuf != MPI_IN_PLACE)) || (recvcount == 0)) goto fn_exit; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); MPIR_Datatype_get_size_macro(recvtype, recvtype_sz); @@ -66,7 +67,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, MPIR_ALLGATHER_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), curr_cnt * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); curr_cnt *= 2; pof2 *= 2; @@ -83,7 +85,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, dst, MPIR_ALLGATHER_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), rem * recvcount * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allgather/allgather_intra_k_brucks.c b/src/mpi/coll/allgather/allgather_intra_k_brucks.c index 010a5a8567d..cf138bf5e96 100644 --- a/src/mpi/coll/allgather/allgather_intra_k_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_k_brucks.c @@ -22,16 +22,17 @@ int MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, int k, - MPIR_Errflag_t errflag) + MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, + int coll_group, int k, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int i, j; int nphases = 0; int src, dst, p_of_k = 0; /* Largest power of k that is smaller than 'size' */ - int rank = MPIR_Comm_rank(comm); - int size = MPIR_Comm_size(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); int max = size - 1; MPIR_Request **reqs; @@ -140,7 +141,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount, /* Receive at the exact location. */ mpi_errno = MPIC_Irecv((char *) tmp_recvbuf + j * recvcount * delta * recvtype_extent, - count, recvtype, src, MPIR_ALLGATHER_TAG, comm, + count, recvtype, src, MPIR_ALLGATHER_TAG, comm, coll_group, &reqs[num_reqs++]); MPIR_ERR_CHECK(mpi_errno); @@ -152,7 +153,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount, /* Send from the start of recv till `count` amount of data. */ mpi_errno = - MPIC_Isend(tmp_recvbuf, count, recvtype, dst, MPIR_ALLGATHER_TAG, comm, + MPIC_Isend(tmp_recvbuf, count, recvtype, dst, MPIR_ALLGATHER_TAG, comm, coll_group, &reqs[num_reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/allgather/allgather_intra_recexch.c b/src/mpi/coll/allgather/allgather_intra_recexch.c index 20a2f0501b7..ff6c566de69 100644 --- a/src/mpi/coll/allgather/allgather_intra_recexch.c +++ b/src/mpi/coll/allgather/allgather_intra_recexch.c @@ -18,7 +18,7 @@ * */ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int recexch_type, int k, int single_phase_recv, MPIR_Errflag_t errflag) { @@ -36,9 +36,11 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPIR_Request *rreqs[MAX_RADIX * 2], *sreqs[MAX_RADIX * 2]; MPIR_Request **recv_reqs = NULL, **send_reqs = NULL; + /* it caches data in comm */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); @@ -117,7 +119,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, buf_to_send = (void *) sendbuf; mpi_errno = MPIC_Send(buf_to_send, recvcount, recvtype, step1_sendto, MPIR_ALLGATHER_TAG, comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { if (step1_nrecvs) { @@ -125,7 +127,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets the data from non-participating rank */ recv_offset = step1_recvfrom[i] * recv_extent * recvcount; mpi_errno = MPIC_Irecv(((char *) recvbuf + recv_offset), recvcount, recvtype, - step1_recvfrom[i], MPIR_ALLGATHER_TAG, comm, + step1_recvfrom[i], MPIR_ALLGATHER_TAG, comm, coll_group, &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); } @@ -159,8 +161,8 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPIC_Sendrecv(((char *) recvbuf + send_offset), send_count * recvcount, recvtype, partner, MPIR_ALLGATHER_TAG, ((char *) recvbuf + recv_offset), recv_count * recvcount, - recvtype, partner, MPIR_ALLGATHER_TAG, comm, MPI_STATUS_IGNORE, - errflag); + recvtype, partner, MPIR_ALLGATHER_TAG, comm, coll_group, + MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); } } @@ -191,7 +193,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, recv_offset = offset * recv_extent * recvcount; mpi_errno = MPIC_Irecv(((char *) recvbuf + recv_offset), count * recvcount, recvtype, nbr, - MPIR_ALLGATHER_TAG, comm, &recv_reqs[num_rreq++]); + MPIR_ALLGATHER_TAG, comm, coll_group, &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); } if (recexch_type == MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_HALVING) @@ -210,7 +212,8 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPII_Recexchalgo_get_count_and_offset(rank_for_offset, j, k, nranks, &count, &offset); send_offset = offset * recv_extent * recvcount; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), count * recvcount, recvtype, - nbr, MPIR_ALLGATHER_TAG, comm, &send_reqs[num_sreq++], errflag); + nbr, MPIR_ALLGATHER_TAG, comm, coll_group, + &send_reqs[num_sreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } /* wait on prev recvs */ @@ -236,7 +239,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, send_offset = offset * recv_extent * recvcount; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), count * recvcount, - recvtype, nbr, MPIR_ALLGATHER_TAG, comm, + recvtype, nbr, MPIR_ALLGATHER_TAG, comm, coll_group, &send_reqs[num_sreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -260,13 +263,14 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, if (step1_sendto != -1) { mpi_errno = MPIC_Recv(recvbuf, recvcount * nranks, recvtype, step1_sendto, MPIR_ALLGATHER_TAG, - comm, MPI_STATUS_IGNORE); + comm, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIC_Isend(recvbuf, recvcount * nranks, recvtype, step1_recvfrom[i], - MPIR_ALLGATHER_TAG, comm, &recv_reqs[num_rreq++], errflag); + MPIR_ALLGATHER_TAG, comm, coll_group, &recv_reqs[num_rreq++], + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c index 3dd37b22ea5..be80c8774b5 100644 --- a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c +++ b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c @@ -23,7 +23,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; @@ -34,8 +35,7 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, MPI_Status status; int mask, dst_tree_root, my_tree_root, nprocs_completed, k, tmp_mask, tree_root; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. @@ -81,7 +81,7 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, ((char *) recvbuf + recv_offset), (comm_size - dst_tree_root) * recvcount, recvtype, dst, - MPIR_ALLGATHER_TAG, comm_ptr, &status, errflag); + MPIR_ALLGATHER_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Get_count_impl(&status, recvtype, &last_recv_cnt); curr_cnt += last_recv_cnt; @@ -135,7 +135,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, && (dst >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Send(((char *) recvbuf + offset), last_recv_cnt, - recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, errflag); + recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } /* recv only if this proc. doesn't have data and sender @@ -145,7 +146,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, (rank >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Recv(((char *) recvbuf + offset), (comm_size - (my_tree_root + mask)) * recvcount, - recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, &status); + recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, + &status); MPIR_ERR_CHECK(mpi_errno); /* nprocs_completed is also equal to the * no. of processes whose data we don't have */ diff --git a/src/mpi/coll/allgather/allgather_intra_ring.c b/src/mpi/coll/allgather/allgather_intra_ring.c index 12f19b0b427..d763d70942e 100644 --- a/src/mpi/coll/allgather/allgather_intra_ring.c +++ b/src/mpi/coll/allgather/allgather_intra_ring.c @@ -25,7 +25,8 @@ int MPIR_Allgather_intra_ring(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; @@ -33,8 +34,7 @@ int MPIR_Allgather_intra_ring(const void *sendbuf, int j, i; int left, right, jnext; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); @@ -63,7 +63,8 @@ int MPIR_Allgather_intra_ring(const void *sendbuf, ((char *) recvbuf + jnext * recvcount * recvtype_extent), recvcount, recvtype, left, - MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLGATHER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); j = jnext; jnext = (comm_size + jnext - 1) % comm_size; diff --git a/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c b/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c index 1a7fc1430b3..c2ecc46466b 100644 --- a/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c +++ b/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c @@ -7,7 +7,8 @@ int MPIR_Allgatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +16,7 @@ int MPIR_Allgatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Data /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm_ptr, &req_ptr); + comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c b/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c index a5c05fae9a2..4f4f1b3a15f 100644 --- a/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c +++ b/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c @@ -19,7 +19,8 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int remote_size, mpi_errno, root, rank; MPIR_Comm *newcomm_ptr = NULL; @@ -34,23 +35,23 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain /* gatherv from right group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* gatherv to right group */ root = 0; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* gatherv to left group */ root = 0; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* gatherv from left group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -71,7 +72,7 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain mpi_errno = MPIR_Type_commit_impl(&newtype); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Bcast_allcomm_auto(recvbuf, 1, newtype, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Bcast_allcomm_auto(recvbuf, 1, newtype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&newtype); diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c index 99f867f732d..b838fb4eb2f 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c @@ -23,7 +23,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, rank, j, i; int mpi_errno = MPI_SUCCESS; @@ -34,7 +34,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); total_count = 0; for (i = 0; i < comm_size; i++) @@ -76,7 +76,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, MPIR_ALLGATHERV_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), (total_count - curr_cnt) * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHERV_TAG, comm_ptr, &status, errflag); + src, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { recv_cnt = 0; @@ -103,7 +103,8 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, dst, MPIR_ALLGATHERV_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), (total_count - curr_cnt) * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHERV_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c index d083b43e411..0603924a960 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c @@ -25,7 +25,8 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank, j, i; int mpi_errno = MPI_SUCCESS; @@ -38,8 +39,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, MPI_Aint position, send_offset, recv_offset, offset; MPIR_CHKLMEM_DECL(1); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. @@ -112,7 +112,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, MPIR_ALLGATHERV_TAG, ((char *) tmp_buf + recv_offset * recvtype_sz), (total_count - recv_offset) * recvtype_sz, MPI_BYTE, dst, - MPIR_ALLGATHERV_TAG, comm_ptr, &status, errflag); + MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { last_recv_cnt = 0; @@ -175,7 +175,8 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Send(((char *) tmp_buf + offset * recvtype_sz), last_recv_cnt * recvtype_sz, - MPI_BYTE, dst, MPIR_ALLGATHERV_TAG, comm_ptr, errflag); + MPI_BYTE, dst, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* last_recv_cnt was set in the previous * receive. that's the amount of data to be @@ -193,7 +194,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Recv(((char *) tmp_buf + offset * recvtype_sz), (total_count - offset) * recvtype_sz, MPI_BYTE, - dst, MPIR_ALLGATHERV_TAG, comm_ptr, &status); + dst, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { last_recv_cnt = 0; diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c index 016c35b77c7..18c55b5172d 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c @@ -28,7 +28,8 @@ int MPIR_Allgatherv_intra_ring(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank, i, left, right; int mpi_errno = MPI_SUCCESS; @@ -36,8 +37,7 @@ int MPIR_Allgatherv_intra_ring(const void *sendbuf, MPI_Aint recvtype_extent; MPI_Aint total_count; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); total_count = 0; for (i = 0; i < comm_size; i++) @@ -108,19 +108,19 @@ int MPIR_Allgatherv_intra_ring(const void *sendbuf, /* Don't do anything. This case is possible if two * consecutive processes contribute 0 bytes each. */ } else if (!sendnow) { /* If there's no data to send, just do a recv call */ - mpi_errno = - MPIC_Recv(rbuf, recvnow, recvtype, left, MPIR_ALLGATHERV_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(rbuf, recvnow, recvtype, left, MPIR_ALLGATHERV_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); torecv -= recvnow; } else if (!recvnow) { /* If there's no data to receive, just do a send call */ - mpi_errno = - MPIC_Send(sbuf, sendnow, recvtype, right, MPIR_ALLGATHERV_TAG, comm_ptr, errflag); + mpi_errno = MPIC_Send(sbuf, sendnow, recvtype, right, MPIR_ALLGATHERV_TAG, + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); tosend -= sendnow; } else { /* There's data to be sent and received */ mpi_errno = MPIC_Sendrecv(sbuf, sendnow, recvtype, right, MPIR_ALLGATHERV_TAG, rbuf, recvnow, recvtype, left, MPIR_ALLGATHERV_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); tosend -= sendnow; torecv -= recvnow; diff --git a/src/mpi/coll/allreduce/allreduce_allcomm_nb.c b/src/mpi/coll/allreduce/allreduce_allcomm_nb.c index c076b2bcd8e..bc7decc7290 100644 --- a/src/mpi/coll/allreduce/allreduce_allcomm_nb.c +++ b/src/mpi/coll/allreduce/allreduce_allcomm_nb.c @@ -7,13 +7,14 @@ int MPIR_Allreduce_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Iallreduce(sendbuf, recvbuf, count, datatype, op, comm_ptr, &req_ptr); + mpi_errno = + MPIR_Iallreduce(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c b/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c index ab199653c20..cb594f415e1 100644 --- a/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c +++ b/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c @@ -15,7 +15,8 @@ int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPI_Aint true_extent, true_lb, extent; @@ -39,19 +40,20 @@ int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbu newcomm_ptr = comm_ptr->local_comm; /* Do a local reduce on this intracommunicator */ - mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* Do a exchange between local and remote rank 0 on this intercommunicator */ if (comm_ptr->rank == 0) { mpi_errno = MPIC_Sendrecv(tmp_buf, count, datatype, 0, MPIR_REDUCE_TAG, recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, - comm_ptr, MPI_STATUS_IGNORE, errflag); + comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); } /* Do a local broadcast on this intracommunicator */ - mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c index cb46472865c..3428178baa8 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c @@ -14,8 +14,9 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int k, - int single_phase_recv, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm, int coll_group, + int k, int single_phase_recv, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int rank, nranks, nbr; @@ -35,9 +36,11 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, MPIR_CHKLMEM_DECL(2); MPIR_Assert(k > 1); + /* This algorithm uses cached data in comm, thus it won't work with coll_group */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); - rank = comm->rank; - nranks = comm->local_size; MPIR_Assert(MPIR_Op_is_commutative(op)); /* need to allocate temporary buffer to store incoming data */ @@ -106,13 +109,14 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, if (!in_step2) { /* even */ /* non-participating rank sends the data to a participating rank */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, errflag); + datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* odd */ for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets data from non-partcipating ranks */ mpi_errno = MPIC_Recv(tmp_recvbuf, count, datatype, step1_recvfrom[i], MPIR_ALLREDUCE_TAG, comm, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* Do reduction of reduced data */ mpi_errno = MPIR_Reduce_local(tmp_recvbuf, recvbuf, count, datatype, op); @@ -162,8 +166,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, send_cnt += cnts[offset + x]; mpi_errno = MPIC_Isend((char *) recvbuf + send_offset, send_cnt, - datatype, dst, MPIR_ALLREDUCE_TAG, comm, &recv_reqs[num_rreq++], - errflag); + datatype, dst, MPIR_ALLREDUCE_TAG, comm, coll_group, + &recv_reqs[num_rreq++], errflag); MPIR_ERR_CHECK(mpi_errno); rank_for_offset = MPII_Recexchalgo_reverse_digits_step2(rank, nranks, k); @@ -176,7 +180,7 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, recv_cnt += cnts[offset + x]; mpi_errno = MPIC_Irecv((char *) tmp_recvbuf + recv_offset, recv_cnt, datatype, - dst, MPIR_ALLREDUCE_TAG, comm, &recv_reqs[num_rreq++]); + dst, MPIR_ALLREDUCE_TAG, comm, coll_group, &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Waitall(num_rreq, recv_reqs, MPI_STATUSES_IGNORE); @@ -209,7 +213,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, for (x = 0; x < current_cnt; x++) recv_count += cnts[offset + x]; mpi_errno = MPIC_Irecv(((char *) recvbuf + recv_offset), recv_count, datatype, - nbr, MPIR_ALLREDUCE_TAG, comm, &recv_reqs[num_rreq++]); + nbr, MPIR_ALLREDUCE_TAG, comm, coll_group, + &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); } recv_phase--; @@ -225,8 +230,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, for (x = 0; x < current_cnt; x++) send_count += cnts[offset + x]; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), send_count, datatype, - nbr, MPIR_ALLREDUCE_TAG, comm, &send_reqs[num_sreq++], - errflag); + nbr, MPIR_ALLREDUCE_TAG, comm, coll_group, + &send_reqs[num_sreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } /* wait on prev recvs */ @@ -249,7 +254,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, send_count += cnts[offset + x]; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), send_count, datatype, nbr, - MPIR_ALLREDUCE_TAG, comm, &send_reqs[num_sreq++], errflag); + MPIR_ALLREDUCE_TAG, comm, coll_group, &send_reqs[num_sreq++], + errflag); MPIR_ERR_CHECK(mpi_errno); } /* wait on prev recvs */ @@ -267,15 +273,15 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, /* Step 3: This is reverse of Step 1. Rans that participated in Step 2 * send the data to non-partcipating rans */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ - mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, - MPI_STATUS_IGNORE); + mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, + comm, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } else { if (step1_nrecvs > 0) { for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIC_Isend(recvbuf, count, datatype, step1_recvfrom[i], MPIR_ALLREDUCE_TAG, - comm, &send_reqs[i], errflag); + comm, coll_group, &send_reqs[i], errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allreduce/allreduce_intra_recexch.c b/src/mpi/coll/allreduce/allreduce_intra_recexch.c index 5503a54ec91..f304d8e349d 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recexch.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recexch.c @@ -17,8 +17,8 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int k, int single_phase_recv, - MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm, int coll_group, int k, + int single_phase_recv, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int is_commutative, rank, nranks, nbr, myidx; @@ -34,8 +34,11 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, MPIR_Request **send_reqs = NULL, **recv_reqs = NULL; int send_nreq = 0, recv_nreq = 0, total_phases = 0; - rank = comm->rank; - nranks = comm->local_size; + /* uses cached data in comm */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); + is_commutative = MPIR_Op_is_commutative(op); bool is_float; @@ -154,14 +157,16 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, if (!in_step2) { /* even */ /* non-participating rank sends the data to a participating rank */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, errflag); + datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* odd */ if (step1_nrecvs) { for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets data from non-partcipating ranks */ mpi_errno = MPIC_Irecv(nbr_buffer[i], count, datatype, step1_recvfrom[i], - MPIR_ALLREDUCE_TAG, comm, &recv_reqs[recv_nreq++]); + MPIR_ALLREDUCE_TAG, comm, coll_group, + &recv_reqs[recv_nreq++]); MPIR_ERR_CHECK(mpi_errno); } mpi_errno = MPIC_Waitall(recv_nreq, recv_reqs, MPI_STATUSES_IGNORE); @@ -187,7 +192,7 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, nbr = step2_nbrs[phase + j][i]; mpi_errno = MPIC_Irecv(nbr_buffer[buf++], count, datatype, nbr, MPIR_ALLREDUCE_TAG, - comm, &recv_reqs[recv_nreq++]); + comm, coll_group, &recv_reqs[recv_nreq++]); MPIR_ERR_CHECK(mpi_errno); } } @@ -196,11 +201,9 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, /* send data to all the neighbors */ for (i = 0; i < k - 1; i++) { nbr = step2_nbrs[phase][i]; - mpi_errno = MPIC_Isend(recvbuf, count, datatype, nbr, MPIR_ALLREDUCE_TAG, comm, - &send_reqs[send_nreq++], errflag); + mpi_errno = MPIC_Isend(recvbuf, count, datatype, nbr, MPIR_ALLREDUCE_TAG, + comm, coll_group, &send_reqs[send_nreq++], errflag); MPIR_ERR_CHECK(mpi_errno); - if (rank > nbr) { - } } mpi_errno = MPIC_Waitall(send_nreq, send_reqs, MPI_STATUSES_IGNORE); @@ -227,7 +230,7 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, mpi_errno = MPIC_Isend(recvbuf, count, datatype, nbr, MPIR_ALLREDUCE_TAG, comm, - &send_reqs[send_nreq++], errflag); + coll_group, &send_reqs[send_nreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -251,14 +254,14 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, /* Step 3: This is reverse of Step 1. Rans that participated in Step 2 * send the data to non-partcipating rans */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ - mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, - MPI_STATUS_IGNORE); + mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, + comm, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIC_Isend(recvbuf, count, datatype, step1_recvfrom[i], MPIR_ALLREDUCE_TAG, - comm, &send_reqs[i], errflag); + comm, coll_group, &send_reqs[i], errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c index 896a8d5359a..85f555466a9 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c @@ -22,7 +22,8 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPIR_CHKLMEM_DECL(1); int comm_size, rank; @@ -31,7 +32,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, MPI_Aint true_extent, true_lb, extent; void *tmp_buf; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); @@ -65,7 +66,8 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -75,7 +77,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, } else { /* odd */ mpi_errno = MPIC_Recv(tmp_buf, count, datatype, rank - 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -111,7 +113,8 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype, dst, MPIR_ALLREDUCE_TAG, tmp_buf, count, datatype, dst, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -139,11 +142,12 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2) /* odd */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); else /* even */ mpi_errno = MPIC_Recv(recvbuf, count, datatype, rank + 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c index 327148196fd..c072067b870 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c @@ -43,7 +43,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPIR_CHKLMEM_DECL(3); int comm_size, rank; @@ -52,8 +53,7 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, MPI_Aint true_extent, true_lb, extent; void *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* need to allocate temporary buffer to store incoming data */ MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -72,7 +72,7 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; @@ -85,7 +85,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -95,7 +96,7 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, } else { /* odd */ mpi_errno = MPIC_Recv(tmp_buf, count, datatype, rank - 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -175,7 +176,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -234,7 +236,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, (char *) recvbuf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (newrank > newdst) @@ -249,11 +252,12 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2) /* odd */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); else /* even */ mpi_errno = MPIC_Recv(recvbuf, count, datatype, rank + 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/allreduce/allreduce_intra_ring.c b/src/mpi/coll/allreduce/allreduce_intra_ring.c index ca87f50b9ce..88e2c2e695d 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_ring.c +++ b/src/mpi/coll/allreduce/allreduce_intra_ring.c @@ -11,7 +11,7 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int i, src, dst; @@ -25,8 +25,7 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count MPIR_Request *reqs[2]; /* one send and one recv per transfer */ is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); @@ -75,14 +74,15 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count send_rank = (nranks + rank - 1 - i) % nranks; /* get a new tag to prevent out of order messages */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIC_Irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, &reqs[0]); + mpi_errno = MPIC_Irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, + comm, coll_group, &reqs[0]); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Isend((char *) recvbuf + displs[send_rank] * extent, cnts[send_rank], - datatype, dst, tag, comm, &reqs[1], errflag); + datatype, dst, tag, comm, coll_group, &reqs[1], errflag); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Waitall(2, reqs, MPI_STATUSES_IGNORE); @@ -96,7 +96,7 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count /* Phase 3: Allgatherv ring, so everyone has the reduced data */ mpi_errno = MPIR_Allgatherv_intra_ring(MPI_IN_PLACE, -1, MPI_DATATYPE_NULL, recvbuf, cnts, - displs, datatype, comm, errflag); + displs, datatype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPL_free(cnts); diff --git a/src/mpi/coll/allreduce/allreduce_intra_smp.c b/src/mpi/coll/allreduce/allreduce_intra_smp.c index 24d1ef57a47..08818cf1900 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_smp.c +++ b/src/mpi/coll/allreduce/allreduce_intra_smp.c @@ -6,27 +6,31 @@ #include "mpiimpl.h" int MPIR_Allreduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + /* on each node, do a reduce to the local root */ - if (comm_ptr->node_comm != NULL) { + if (local_size > 1) { /* take care of the MPI_IN_PLACE case. For reduce, * MPI_IN_PLACE is specified only on the root; * for allreduce it is specified on all processes. */ - if ((sendbuf == MPI_IN_PLACE) && (comm_ptr->node_comm->rank != 0)) { + if ((sendbuf == MPI_IN_PLACE) && (local_rank != 0)) { /* IN_PLACE and not root of reduce. Data supplied to this * allreduce is in recvbuf. Pass that as the sendbuf to reduce. */ - mpi_errno = - MPIR_Reduce(recvbuf, NULL, count, datatype, op, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Reduce(recvbuf, NULL, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } else { - mpi_errno = - MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { @@ -38,16 +42,15 @@ int MPIR_Allreduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, } /* now do an IN_PLACE allreduce among the local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = - MPIR_Allreduce(MPI_IN_PLACE, recvbuf, count, datatype, op, comm_ptr->node_roots_comm, - errflag); + if (local_rank == 0) { + mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } /* now broadcast the result among local processes */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } goto fn_exit; diff --git a/src/mpi/coll/allreduce/allreduce_intra_tree.c b/src/mpi/coll/allreduce/allreduce_intra_tree.c index 7a6bb4f9709..73915a72484 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_tree.c +++ b/src/mpi/coll/allreduce/allreduce_intra_tree.c @@ -14,7 +14,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, int tree_type, int k, int chunk_size, int buffer_per_child, MPIR_Errflag_t errflag) { @@ -37,8 +37,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, MPIR_Request **reqs; int num_reqs = 0; - comm_size = MPIR_Comm_size(comm_ptr); - rank = MPIR_Comm_rank(comm_ptr); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -67,12 +66,13 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, /* initialize the tree */ if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, k, root, + MPIR_Treealgo_tree_create_topo_aware(comm_ptr, coll_group, tree_type, k, root, MPIR_CVAR_ALLREDUCE_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLREDUCE, .comm_ptr = comm_ptr, + .coll_group = coll_group, .u.allreduce.sendbuf = sendbuf, .u.allreduce.recvbuf = recvbuf, .u.allreduce.count = count, @@ -96,7 +96,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, } mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm_ptr, k, root, + MPIR_Treealgo_tree_create_topo_wave(comm_ptr, coll_group, k, root, MPIR_CVAR_ALLREDUCE_TOPO_REORDER_ENABLE, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, &my_tree); @@ -139,7 +139,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, void *reduce_address = (char *) reduce_buffer + offset * extent; MPIR_ERR_CHKANDJUMP(!reduce_address, mpi_errno, MPI_ERR_OTHER, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < num_children; i++) { @@ -150,7 +150,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, mpi_errno = MPIC_Recv(recv_address, msgsize, datatype, child, MPIR_ALLREDUCE_TAG, comm_ptr, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); /* for communication errors, just record the error but continue */ MPIR_ERR_CHECK(mpi_errno); @@ -172,14 +172,14 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, if (rank != root) { /* send data to the parent */ mpi_errno = MPIC_Isend(reduce_address, msgsize, datatype, my_tree.parent, MPIR_ALLREDUCE_TAG, - comm_ptr, &reqs[num_reqs++], errflag); + comm_ptr, coll_group, &reqs[num_reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); } if (my_tree.parent != -1) { mpi_errno = MPIC_Recv(reduce_address, msgsize, datatype, my_tree.parent, MPIR_ALLREDUCE_TAG, comm_ptr, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } if (num_children) { @@ -189,7 +189,8 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, MPIR_Assert(child != 0); mpi_errno = MPIC_Isend(reduce_address, msgsize, datatype, child, - MPIR_ALLREDUCE_TAG, comm_ptr, &reqs[num_reqs++], errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, &reqs[num_reqs++], + errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/allreduce_group/allreduce_group.c b/src/mpi/coll/allreduce_group/allreduce_group.c index 68a959f17f7..0fa313ededa 100644 --- a/src/mpi/coll/allreduce_group/allreduce_group.c +++ b/src/mpi/coll/allreduce_group/allreduce_group.c @@ -67,7 +67,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, if (group_rank < 2 * rem) { if (group_rank % 2 == 0) { /* even */ to_comm_rank(cdst, group_rank + 1); - mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag); + mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, + MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -76,7 +77,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, newrank = -1; } else { /* odd */ to_comm_rank(csrc, group_rank - 1); - mpi_errno = MPIC_Recv(tmp_buf, count, datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE); + mpi_errno = MPIC_Recv(tmp_buf, count, datatype, csrc, tag, comm_ptr, + MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -116,7 +118,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype, cdst, tag, tmp_buf, count, datatype, cdst, - tag, comm_ptr, MPI_STATUS_IGNORE, errflag); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (!mpi_errno) { /* tmp_buf contains data received in this step. @@ -197,7 +200,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, cdst, - tag, comm_ptr, MPI_STATUS_IGNORE, errflag); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -257,7 +261,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, (char *) recvbuf + disps[recv_idx] * extent, recv_cnt, datatype, cdst, - tag, comm_ptr, MPI_STATUS_IGNORE, errflag); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (newrank > newdst) @@ -274,10 +279,14 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, if (group_rank < 2 * rem) { if (group_rank % 2) { /* odd */ to_comm_rank(cdst, group_rank - 1); - mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag); + mpi_errno = + MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); } else { /* even */ to_comm_rank(csrc, group_rank + 1); - mpi_errno = MPIC_Recv(recvbuf, count, datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE); + mpi_errno = + MPIC_Recv(recvbuf, count, datatype, csrc, tag, comm_ptr, MPIR_SUBGROUP_NONE, + MPI_STATUS_IGNORE); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_allcomm_nb.c b/src/mpi/coll/alltoall/alltoall_allcomm_nb.c index ecb74cd135f..ec3a545e412 100644 --- a/src/mpi/coll/alltoall/alltoall_allcomm_nb.c +++ b/src/mpi/coll/alltoall/alltoall_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Alltoall_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Alltoall_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Dataty /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Ialltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm_ptr, - &req_ptr); + coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c b/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c index 932d7965c50..0af0ce90fb1 100644 --- a/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c @@ -19,7 +19,7 @@ int MPIR_Alltoall_inter_pairwise_exchange(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int local_size, remote_size, max_size, i; MPI_Aint sendtype_extent, recvtype_extent; @@ -57,7 +57,7 @@ int MPIR_Alltoall_inter_pairwise_exchange(const void *sendbuf, MPI_Aint sendcoun mpi_errno = MPIC_Sendrecv(sendaddr, sendcount, sendtype, dst, MPIR_ALLTOALL_TAG, recvaddr, recvcount, recvtype, src, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_intra_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_brucks.c index 2aea8b1860d..ecbaef9990c 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_brucks.c @@ -23,7 +23,8 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, pof2; MPI_Aint sendtype_extent, recvtype_extent; @@ -36,7 +37,7 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(6); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf != MPI_IN_PLACE); @@ -106,7 +107,8 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf, mpi_errno = MPIC_Sendrecv(tmp_buf, newtype_sz, MPI_BYTE, dst, MPIR_ALLTOALL_TAG, recvbuf, 1, newtype, - src, MPIR_ALLTOALL_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLTOALL_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&newtype); diff --git a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c index 9286b2a5ee4..cbe766b9008 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c @@ -108,7 +108,7 @@ int MPIR_Alltoall_intra_k_brucks(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcnt, - MPI_Datatype recvtype, MPIR_Comm * comm, int k, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int k, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -134,8 +134,7 @@ int MPIR_Alltoall_intra_k_brucks(const void *sendbuf, is_inplace = (sendbuf == MPI_IN_PLACE); - rank = MPIR_Comm_rank(comm); - size = MPIR_Comm_size(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); nphases = 0; max = size - 1; @@ -251,12 +250,12 @@ int MPIR_Alltoall_intra_k_brucks(const void *sendbuf, mpi_errno = MPIC_Irecv(tmp_rbuf[j - 1], packsize, MPI_BYTE, src, MPIR_ALLTOALL_TAG, comm, - &reqs[num_reqs++]); + coll_group, &reqs[num_reqs++]); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Isend(tmp_sbuf[j - 1], packsize, MPI_BYTE, dst, MPIR_ALLTOALL_TAG, comm, - &reqs[num_reqs++], errflag); + coll_group, &reqs[num_reqs++], errflag); if (mpi_errno) { MPIR_ERR_POP(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c index 28dcd7ed7d6..736f71c6750 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c @@ -28,15 +28,14 @@ int MPIR_Alltoall_intra_pairwise(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; MPI_Aint sendtype_extent, recvtype_extent; int mpi_errno = MPI_SUCCESS, src, dst, rank; MPI_Status status; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf != MPI_IN_PLACE); @@ -75,7 +74,7 @@ int MPIR_Alltoall_intra_pairwise(const void *sendbuf, ((char *) recvbuf + src * recvcount * recvtype_extent), recvcount, recvtype, src, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c index 22604189e30..9c83d5af935 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c @@ -25,15 +25,15 @@ int MPIR_Alltoall_intra_pairwise_sendrecv_replace(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, j; MPI_Aint recvtype_extent; int mpi_errno = MPI_SUCCESS, rank; MPI_Status status; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent of send and recv types */ MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); @@ -55,14 +55,16 @@ int MPIR_Alltoall_intra_pairwise_sendrecv_replace(const void *sendbuf, mpi_errno = MPIC_Sendrecv_replace(((char *) recvbuf + j * recvcount * recvtype_extent), recvcount, recvtype, j, MPIR_ALLTOALL_TAG, j, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == j) { /* same as above with i/j args reversed */ mpi_errno = MPIC_Sendrecv_replace(((char *) recvbuf + i * recvcount * recvtype_extent), recvcount, recvtype, i, MPIR_ALLTOALL_TAG, i, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/alltoall/alltoall_intra_scattered.c b/src/mpi/coll/alltoall/alltoall_intra_scattered.c index f1986f99ce6..ee31a35400a 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_scattered.c +++ b/src/mpi/coll/alltoall/alltoall_intra_scattered.c @@ -33,7 +33,7 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; MPI_Aint sendtype_extent, recvtype_extent; @@ -42,8 +42,7 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, MPI_Status *starray; MPIR_CHKLMEM_DECL(6); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf != MPI_IN_PLACE); @@ -72,7 +71,7 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, mpi_errno = MPIC_Irecv((char *) recvbuf + dst * recvcount * recvtype_extent, recvcount, recvtype, dst, - MPIR_ALLTOALL_TAG, comm_ptr, &reqarray[i]); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &reqarray[i]); MPIR_ERR_CHECK(mpi_errno); } @@ -81,7 +80,8 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, mpi_errno = MPIC_Isend((char *) sendbuf + dst * sendcount * sendtype_extent, sendcount, sendtype, dst, - MPIR_ALLTOALL_TAG, comm_ptr, &reqarray[i + ss], errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &reqarray[i + ss], + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c b/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c index 40854e91c20..288427a910a 100644 --- a/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c +++ b/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c @@ -8,7 +8,8 @@ int MPIR_Alltoallv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -16,7 +17,7 @@ int MPIR_Alltoallv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Ialltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, - recvtype, comm_ptr, &req_ptr); + recvtype, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c b/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c index ea4cb5d1962..341ff3153f0 100644 --- a/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c @@ -23,7 +23,8 @@ int MPIR_Alltoallv_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint * const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int local_size, remote_size, max_size, i; MPI_Aint send_extent, recv_extent; @@ -65,7 +66,8 @@ int MPIR_Alltoallv_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint * mpi_errno = MPIC_Sendrecv(sendaddr, sendcount, sendtype, dst, MPIR_ALLTOALLV_TAG, recvaddr, recvcount, - recvtype, src, MPIR_ALLTOALLV_TAG, comm_ptr, &status, errflag); + recvtype, src, MPIR_ALLTOALLV_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c index 7f6cc8d4814..bbe2ba5a1cc 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c @@ -22,7 +22,8 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, j; MPI_Aint recv_extent; @@ -30,8 +31,7 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP MPI_Status status; int rank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent of recv type, but send type is only valid if (sendbuf!=MPI_IN_PLACE) */ MPIR_Datatype_get_extent_macro(recvtype, recv_extent); @@ -58,7 +58,7 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[j], recvtype, j, MPIR_ALLTOALLV_TAG, j, MPIR_ALLTOALLV_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == j) { @@ -67,7 +67,7 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[i], recvtype, i, MPIR_ALLTOALLV_TAG, i, MPIR_ALLTOALLV_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c index 4adaeb83681..2c0a34728fe 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c @@ -24,7 +24,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; @@ -37,7 +37,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent of recv type, but send type is only valid if (sendbuf!=MPI_IN_PLACE) */ MPIR_Datatype_get_extent_macro(recvtype, recv_extent); @@ -71,7 +71,8 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou if (type_size) { mpi_errno = MPIC_Irecv((char *) recvbuf + rdispls[dst] * recv_extent, recvcounts[dst], recvtype, dst, - MPIR_ALLTOALLV_TAG, comm_ptr, &reqarray[req_cnt]); + MPIR_ALLTOALLV_TAG, comm_ptr, coll_group, + &reqarray[req_cnt]); MPIR_ERR_CHECK(mpi_errno); req_cnt++; } @@ -86,7 +87,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou if (type_size) { mpi_errno = MPIC_Isend((char *) sendbuf + sdispls[dst] * send_extent, sendcounts[dst], sendtype, dst, - MPIR_ALLTOALLV_TAG, comm_ptr, + MPIR_ALLTOALLV_TAG, comm_ptr, coll_group, &reqarray[req_cnt], errflag); MPIR_ERR_CHECK(mpi_errno); req_cnt++; diff --git a/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c b/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c index e3e55da8a89..ca12d33031a 100644 --- a/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c +++ b/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c @@ -8,7 +8,7 @@ int MPIR_Alltoallw_allcomm_nb(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, + const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -17,7 +17,7 @@ int MPIR_Alltoallw_allcomm_nb(const void *sendbuf, const MPI_Aint sendcounts[], /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, - recvtypes, comm_ptr, &req_ptr); + recvtypes, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c b/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c index c1918c7be0a..f7d6d6ba967 100644 --- a/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c @@ -23,7 +23,8 @@ int MPIR_Alltoallw_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint s const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int local_size, remote_size, max_size, i; int mpi_errno = MPI_SUCCESS; @@ -66,7 +67,7 @@ int MPIR_Alltoallw_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint s mpi_errno = MPIC_Sendrecv(sendaddr, sendcount, sendtype, dst, MPIR_ALLTOALLW_TAG, recvaddr, recvcount, recvtype, src, - MPIR_ALLTOALLW_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALLW_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c index d0ce2ccc10c..2d3423f2aa4 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c @@ -23,15 +23,15 @@ int MPIR_Alltoallw_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, j; int mpi_errno = MPI_SUCCESS; MPI_Status status; int rank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf == MPI_IN_PLACE); @@ -55,7 +55,7 @@ int MPIR_Alltoallw_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[j], recvtypes[j], j, MPIR_ALLTOALLW_TAG, j, MPIR_ALLTOALLW_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == j) { /* same as above with i/j args reversed */ @@ -63,7 +63,7 @@ int MPIR_Alltoallw_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[i], recvtypes[i], i, MPIR_ALLTOALLW_TAG, i, MPIR_ALLTOALLW_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c index f0063d4ad91..936fcf9e040 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c @@ -23,7 +23,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; int mpi_errno = MPI_SUCCESS; @@ -35,7 +35,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount MPI_Aint type_size; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* When MPI_IN_PLACE, we use pair-wise sendrecv_replace in order to conserve memory usage, @@ -68,7 +68,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount if (type_size) { mpi_errno = MPIC_Irecv((char *) recvbuf + rdispls[dst], recvcounts[dst], recvtypes[dst], dst, - MPIR_ALLTOALLW_TAG, comm_ptr, + MPIR_ALLTOALLW_TAG, comm_ptr, coll_group, &reqarray[outstanding_requests]); MPIR_ERR_CHECK(mpi_errno); @@ -84,7 +84,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount if (type_size) { mpi_errno = MPIC_Isend((char *) sendbuf + sdispls[dst], sendcounts[dst], sendtypes[dst], dst, - MPIR_ALLTOALLW_TAG, comm_ptr, + MPIR_ALLTOALLW_TAG, comm_ptr, coll_group, &reqarray[outstanding_requests], errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/barrier/barrier_allcomm_nb.c b/src/mpi/coll/barrier/barrier_allcomm_nb.c index 72a579949fd..7c1f26956a6 100644 --- a/src/mpi/coll/barrier/barrier_allcomm_nb.c +++ b/src/mpi/coll/barrier/barrier_allcomm_nb.c @@ -5,13 +5,13 @@ #include "mpiimpl.h" -int MPIR_Barrier_allcomm_nb(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_allcomm_nb(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Ibarrier(comm_ptr, &req_ptr); + mpi_errno = MPIR_Ibarrier(comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/barrier/barrier_inter_bcast.c b/src/mpi/coll/barrier/barrier_inter_bcast.c index e1d81c23443..3d775eb1272 100644 --- a/src/mpi/coll/barrier/barrier_inter_bcast.c +++ b/src/mpi/coll/barrier/barrier_inter_bcast.c @@ -17,7 +17,7 @@ * group. */ -int MPIR_Barrier_inter_bcast(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_inter_bcast(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, mpi_errno = MPI_SUCCESS, root; int i = 0; @@ -34,28 +34,28 @@ int MPIR_Barrier_inter_bcast(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) newcomm_ptr = comm_ptr->local_comm; /* do a barrier on the local intracommunicator */ - mpi_errno = MPIR_Barrier(newcomm_ptr, errflag); + mpi_errno = MPIR_Barrier(newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (comm_ptr->is_low_group) { /* bcast to right */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* receive bcast from right */ root = 0; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* receive bcast from left */ root = 0; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* bcast to left */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c index 927c62843c4..13c54dccc3d 100644 --- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c +++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c @@ -16,11 +16,11 @@ * process i sends to process (i + 2^k) % p and receives from process * (i - 2^k + p) % p. */ -int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int size, rank, src, dst, mask, mpi_errno = MPI_SUCCESS; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, size); mask = 0x1; while (mask < size) { @@ -28,7 +28,8 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errfla src = (rank - mask + size) % size; mpi_errno = MPIC_Sendrecv(NULL, 0, MPI_BYTE, dst, MPIR_BARRIER_TAG, NULL, 0, MPI_BYTE, - src, MPIR_BARRIER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_BARRIER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); mask <<= 1; } @@ -42,18 +43,18 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errfla /* Algorithm: high radix dissemination * Similar to dissemination algorithm, but generalized with high radix k */ -int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_Errflag_t errflag) +int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int i, j, nranks, rank; int p_of_k; /* minimum power of k that is greater than or equal to number of ranks */ int shift, to, from; int nphases = 0; - MPIR_Request *sreqs[MAX_RADIX], *rreqs[MAX_RADIX * 2]; + MPIR_Request *static_sreqs[MAX_RADIX], *static_rreqs[MAX_RADIX * 2]; MPIR_Request **send_reqs = NULL, **recv_reqs = NULL; - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); if (nranks == 1) goto fn_exit; @@ -62,7 +63,7 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_Errflag_t e k = nranks; if (k == 2) { - return MPIR_Barrier_intra_dissemination(comm, errflag); + return MPIR_Barrier_intra_dissemination(comm, coll_group, errflag); } /* If k value is greater than the maximum radix defined by MAX_RADIX macro, @@ -75,8 +76,8 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_Errflag_t e send_reqs = (MPIR_Request **) MPL_malloc((k - 1) * sizeof(MPIR_Request *), MPL_MEM_BUFFER); MPIR_ERR_CHKANDJUMP(!send_reqs, mpi_errno, MPI_ERR_OTHER, "**nomem"); } else { - send_reqs = sreqs; - recv_reqs = rreqs; + send_reqs = static_sreqs; + recv_reqs = static_rreqs; } p_of_k = 1; @@ -85,6 +86,8 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_Errflag_t e nphases++; } + MPIR_Request **rreqs = recv_reqs; + MPIR_Request **prev_rreqs = recv_reqs + (k - 1); shift = 1; for (i = 0; i < nphases; i++) { for (j = 1; j < k; j++) { @@ -96,29 +99,29 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_Errflag_t e MPIR_Assert(to >= 0 && to < nranks); /* recv from (k-1) nbrs */ - mpi_errno = - MPIC_Irecv(NULL, 0, MPI_BYTE, from, MPIR_BARRIER_TAG, comm, - &recv_reqs[(j - 1) + ((k - 1) * (i & 1))]); + mpi_errno = MPIC_Irecv(NULL, 0, MPI_BYTE, from, MPIR_BARRIER_TAG, comm, coll_group, + &rreqs[j - 1]); MPIR_ERR_CHECK(mpi_errno); /* wait on recvs from prev phase */ if (i > 0 && j == 1) { - mpi_errno = - MPIC_Waitall(k - 1, &recv_reqs[((k - 1) * ((i - 1) & 1))], MPI_STATUSES_IGNORE); + mpi_errno = MPIC_Waitall(k - 1, prev_rreqs, MPI_STATUSES_IGNORE); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = - MPIC_Isend(NULL, 0, MPI_BYTE, to, MPIR_BARRIER_TAG, comm, &send_reqs[j - 1], - errflag); + mpi_errno = MPIC_Isend(NULL, 0, MPI_BYTE, to, MPIR_BARRIER_TAG, comm, coll_group, + &send_reqs[j - 1], errflag); MPIR_ERR_CHECK(mpi_errno); } mpi_errno = MPIC_Waitall(k - 1, send_reqs, MPI_STATUSES_IGNORE); MPIR_ERR_CHECK(mpi_errno); shift *= k; + + MPIR_Request **tmp = rreqs; + rreqs = prev_rreqs; + prev_rreqs = tmp; } - mpi_errno = - MPIC_Waitall(k - 1, recv_reqs + ((k - 1) * ((nphases - 1) & 1)), MPI_STATUSES_IGNORE); + mpi_errno = MPIC_Waitall(k - 1, prev_rreqs, MPI_STATUSES_IGNORE); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/barrier/barrier_intra_recexch.c b/src/mpi/coll/barrier/barrier_intra_recexch.c index a46a6e25d8e..72e0ef918e3 100644 --- a/src/mpi/coll/barrier/barrier_intra_recexch.c +++ b/src/mpi/coll/barrier/barrier_intra_recexch.c @@ -8,13 +8,13 @@ /* Algorithm: call Allreduce's recursive exchange algorithm */ -int MPIR_Barrier_intra_recexch(MPIR_Comm * comm, int k, int single_phase_recv, +int MPIR_Barrier_intra_recexch(MPIR_Comm * comm, int coll_group, int k, int single_phase_recv, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allreduce_intra_recexch(MPI_IN_PLACE, NULL, 0, - MPI_BYTE, MPI_SUM, comm, + MPI_BYTE, MPI_SUM, comm, coll_group, k, single_phase_recv, errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/barrier/barrier_intra_smp.c b/src/mpi/coll/barrier/barrier_intra_smp.c index f723be96165..cef13c91059 100644 --- a/src/mpi/coll/barrier/barrier_intra_smp.c +++ b/src/mpi/coll/barrier/barrier_intra_smp.c @@ -5,30 +5,32 @@ #include "mpiimpl.h" -int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; /* do the intranode barrier on all nodes */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Barrier(comm_ptr->node_comm, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Barrier(comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } /* do the barrier across roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Barrier(comm_ptr->node_roots_comm, errflag); + if (local_rank == 0) { + mpi_errno = MPIR_Barrier(comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } /* release the local processes on each node with a 1-byte * broadcast (0-byte broadcast just returns without doing * anything) */ - if (comm_ptr->node_comm != NULL) { + if (local_size > 1) { int i = 0; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, 0, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast.h b/src/mpi/coll/bcast/bcast.h index 23d15fc1325..5e2d5b3194e 100644 --- a/src/mpi/coll/bcast/bcast.h +++ b/src/mpi/coll/bcast/bcast.h @@ -9,7 +9,7 @@ #include "mpiimpl.h" int MPII_Scatter_for_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPI_Aint nbytes, void *tmp_buf, - int is_contig, MPIR_Errflag_t errflag); + int root, MPIR_Comm * comm_ptr, int coll_group, MPI_Aint nbytes, + void *tmp_buf, int is_contig, MPIR_Errflag_t errflag); #endif /* BCAST_H_INCLUDED */ diff --git a/src/mpi/coll/bcast/bcast_allcomm_nb.c b/src/mpi/coll/bcast/bcast_allcomm_nb.c index 99c615ca658..5f2f19ff1f4 100644 --- a/src/mpi/coll/bcast/bcast_allcomm_nb.c +++ b/src/mpi/coll/bcast/bcast_allcomm_nb.c @@ -6,13 +6,13 @@ #include "mpiimpl.h" int MPIR_Bcast_allcomm_nb(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Ibcast(buffer, count, datatype, root, comm_ptr, &req_ptr); + mpi_errno = MPIR_Ibcast(buffer, count, datatype, root, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c index b22a916ff32..266c48eb02a 100644 --- a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c +++ b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c @@ -14,21 +14,23 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, mpi_errno; MPI_Status status; MPIR_Comm *newcomm_ptr = NULL; MPIR_FUNC_ENTER; - + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); if (root == MPI_PROC_NULL) { /* local processes other than root do nothing */ mpi_errno = MPI_SUCCESS; } else if (root == MPI_ROOT) { /* root sends to rank 0 on remote group and returns */ - mpi_errno = MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, comm_ptr, errflag); + mpi_errno = + MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. rank 0 on remote group receives from root */ @@ -36,7 +38,8 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, rank = comm_ptr->rank; if (rank == 0) { - mpi_errno = MPIC_Recv(buffer, count, datatype, root, MPIR_BCAST_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(buffer, count, datatype, root, MPIR_BCAST_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } @@ -50,7 +53,8 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, /* now do the usual broadcast on this intracommunicator * with rank 0 as root. */ - mpi_errno = MPIR_Bcast_allcomm_auto(buffer, count, datatype, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Bcast_allcomm_auto(buffer, count, datatype, 0, newcomm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_intra_binomial.c b/src/mpi/coll/bcast/bcast_intra_binomial.c index 77d3bd73d24..e154ce259c6 100644 --- a/src/mpi/coll/bcast/bcast_intra_binomial.c +++ b/src/mpi/coll/bcast/bcast_intra_binomial.c @@ -13,7 +13,8 @@ int MPIR_Bcast_intra_binomial(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, src, dst; int relative_rank, mask; @@ -32,7 +33,7 @@ int MPIR_Bcast_intra_binomial(void *buffer, void *tmp_buf = NULL; MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (HANDLE_IS_BUILTIN(datatype)) is_contig = 1; @@ -91,10 +92,10 @@ int MPIR_Bcast_intra_binomial(void *buffer, src += comm_size; if (!is_contig) mpi_errno = MPIC_Recv(tmp_buf, nbytes, MPI_BYTE, src, - MPIR_BCAST_TAG, comm_ptr, status_p); + MPIR_BCAST_TAG, comm_ptr, coll_group, status_p); else mpi_errno = MPIC_Recv(buffer, count, datatype, src, - MPIR_BCAST_TAG, comm_ptr, status_p); + MPIR_BCAST_TAG, comm_ptr, coll_group, status_p); MPIR_ERR_CHECK(mpi_errno); #ifdef HAVE_ERROR_CHECKING /* check that we received as much as we expected */ @@ -128,10 +129,10 @@ int MPIR_Bcast_intra_binomial(void *buffer, dst -= comm_size; if (!is_contig) mpi_errno = MPIC_Send(tmp_buf, nbytes, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); else mpi_errno = MPIC_Send(buffer, count, datatype, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } mask >>= 1; diff --git a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c index f42d0515053..f551de11ded 100644 --- a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c @@ -14,7 +14,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, int tree_type, + int root, MPIR_Comm * comm_ptr, int coll_group, int tree_type, int branching_factor, int is_nb, int chunk_size, int recv_pre_posted, MPIR_Errflag_t errflag) { @@ -31,8 +31,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, MPIR_Treealgo_tree_t my_tree; MPIR_CHKLMEM_DECL(3); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* If there is only one process, return */ if (comm_size == 1) @@ -75,11 +74,12 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, branching_factor, root, + MPIR_Treealgo_tree_create_topo_aware(comm_ptr, coll_group, tree_type, + branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm_ptr, branching_factor, root, + MPIR_Treealgo_tree_create_topo_wave(comm_ptr, coll_group, branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, MPIR_CVAR_BCAST_TOPO_OVERHEAD, MPIR_CVAR_BCAST_TOPO_DIFF_GROUPS, @@ -118,7 +118,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (src != -1) { /* post receive from parent */ mpi_errno = MPIC_Irecv((char *) sendbuf + offset, msgsize, MPI_BYTE, - src, MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++]); + src, MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++]); MPIR_ERR_CHECK(mpi_errno); } offset += msgsize; @@ -131,7 +131,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (src != -1) { mpi_errno = MPIC_Irecv((char *) sendbuf + offset, msgsize, MPI_BYTE, - src, MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++]); + src, MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++]); MPIR_ERR_CHECK(mpi_errno); } offset += msgsize; @@ -170,7 +170,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (src != -1) { mpi_errno = MPIC_Recv((char *) sendbuf + offset, msgsize, MPI_BYTE, - src, MPIR_BCAST_TAG, comm_ptr, &status); + src, MPIR_BCAST_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Get_count_impl(&status, MPI_BYTE, &recvd_size); MPIR_ERR_CHKANDJUMP2(recvd_size != nbytes, mpi_errno, MPI_ERR_OTHER, @@ -191,11 +191,11 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (!is_nb) { mpi_errno = MPIC_Send((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); } else { mpi_errno = MPIC_Isend((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], errflag); } MPIR_ERR_CHECK(mpi_errno); @@ -207,11 +207,11 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, dst = *p; if (!is_nb) { mpi_errno = MPIC_Send((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); } else { mpi_errno = MPIC_Isend((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], errflag); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c index 95f23dd51e5..4867712378d 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c @@ -29,7 +29,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { MPI_Status status; @@ -45,8 +45,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint true_extent, true_lb; void *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; if (HANDLE_IS_BUILTIN(datatype)) @@ -78,7 +77,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint scatter_size; scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ - mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, + mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, coll_group, nbytes, tmp_buf, is_contig, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -119,12 +118,10 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, curr_size, MPI_BYTE, dst, MPIR_BCAST_TAG, ((char *) tmp_buf + recv_offset), (nbytes - recv_offset < 0 ? 0 : nbytes - recv_offset), - MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, &status, errflag); + MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); - if (mpi_errno) { - recv_size = 0; - } else - MPIR_Get_count_impl(&status, MPI_BYTE, &recv_size); + MPIR_Get_count_impl(&status, MPI_BYTE, &recv_size); curr_size += recv_size; } @@ -184,7 +181,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, * fflush(stdout); */ mpi_errno = MPIC_Send(((char *) tmp_buf + offset), recv_size, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); /* recv_size was set in the previous * receive. that's the amount of data to be * sent now. */ @@ -199,7 +196,8 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, * relative_rank, dst); */ mpi_errno = MPIC_Recv(((char *) tmp_buf + offset), nbytes - offset < 0 ? 0 : nbytes - offset, - MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, &status); + MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, + &status); /* nprocs_completed is also equal to the no. of processes * whose data we don't have */ MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c index 41e17b99e3c..4c4641c347a 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c @@ -24,7 +24,8 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size; int mpi_errno = MPI_SUCCESS; @@ -38,8 +39,7 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, MPI_Aint true_extent, true_lb; MPIR_CHKLMEM_DECL(1); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (HANDLE_IS_BUILTIN(datatype)) is_contig = 1; @@ -69,7 +69,7 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ - mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, + mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, coll_group, nbytes, tmp_buf, is_contig, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -103,7 +103,8 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, mpi_errno = MPIC_Sendrecv((char *) tmp_buf + right_disp, right_count, MPI_BYTE, right, MPIR_BCAST_TAG, (char *) tmp_buf + left_disp, left_count, - MPI_BYTE, left, MPIR_BCAST_TAG, comm_ptr, &status, errflag); + MPI_BYTE, left, MPIR_BCAST_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Get_count_impl(&status, MPI_BYTE, &recvd_size); curr_size += recvd_size; diff --git a/src/mpi/coll/bcast/bcast_intra_smp.c b/src/mpi/coll/bcast/bcast_intra_smp.c index 9ff9e684e54..d875772bf56 100644 --- a/src/mpi/coll/bcast/bcast_intra_smp.c +++ b/src/mpi/coll/bcast/bcast_intra_smp.c @@ -5,28 +5,67 @@ #include "mpiimpl.h" +/* TODO: move this to commutil.c */ +static void MPIR_Comm_construct_internode_roots_group(MPIR_Comm * comm, int root, + int *group_p, int *root_rank_p) +{ + int inter_size = comm->num_external; + int inter_rank = comm->internode_table[comm->rank]; + + int inter_group, *proc_table; + MPIR_COMM_PUSH_SUBGROUP(comm, inter_size, inter_rank, inter_group, proc_table); + + for (int i = 0; i < inter_size; i++) { + proc_table[i] = -1; + } + for (int i = 0; i < comm->local_size; i++) { + int r = comm->internode_table[i]; + if (proc_table[r] == -1) { + proc_table[r] = i; + } + } + int inter_root_rank = comm->internode_table[root]; + proc_table[inter_root_rank] = root; + + comm->subgroups[inter_group].proc_table = proc_table; + + *group_p = inter_group; + *root_rank_p = inter_root_rank; +} + /* FIXME This function uses some heuristsics based off of some testing on a * cluster at Argonne. We need a better system for detrmining and controlling * the cutoff points for these algorithms. If I've done this right, you should * be able to make changes along these lines almost exclusively in this function * and some new functions. [goodell@ 2008/01/07] */ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Aint type_size, nbytes = 0; - MPI_Status *status_p; -#ifdef HAVE_ERROR_CHECKING - MPI_Status status; - status_p = &status; - MPI_Aint recvd_size; -#else - status_p = MPI_STATUS_IGNORE; -#endif #ifdef HAVE_ERROR_CHECKING - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); #endif + int comm_size = comm_ptr->local_size; + + int node_group = 0, node_roots_group = 0; + int local_rank, local_size, local_root, inter_root = -1; + + node_group = MPIR_SUBGROUP_NODE; +#define NODEGROUP(field) comm_ptr->subgroups[node_group].field + + local_rank = NODEGROUP(rank); + local_size = NODEGROUP(size); + local_root = MPIR_Get_intranode_rank(comm_ptr, root); + if (local_root < 0) { + /* non-root node use local rank 0 as local root */ + local_root = 0; + } + if (local_rank == local_root) { + MPIR_Comm_construct_internode_roots_group(comm_ptr, root, &node_roots_group, &inter_root); + MPIR_Assert(node_roots_group > 0); + } MPIR_Datatype_get_size_macro(datatype, type_size); @@ -34,92 +73,33 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in if (nbytes == 0) goto fn_exit; /* nothing to do */ - if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || - (comm_ptr->local_size < MPIR_CVAR_BCAST_MIN_PROCS)) { - /* send to intranode-rank 0 on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) and is on our node (!-1) */ - if (root == comm_ptr->rank) { - mpi_errno = MPIC_Send(buffer, count, datatype, 0, - MPIR_BCAST_TAG, comm_ptr->node_comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - } else if (0 == comm_ptr->node_comm->rank) { - mpi_errno = - MPIC_Recv(buffer, count, datatype, MPIR_Get_intranode_rank(comm_ptr, root), - MPIR_BCAST_TAG, comm_ptr->node_comm, status_p); - MPIR_ERR_CHECK(mpi_errno); -#ifdef HAVE_ERROR_CHECKING - /* check that we received as much as we expected */ - MPIR_Get_count_impl(status_p, MPI_BYTE, &recvd_size); - MPIR_ERR_CHKANDJUMP2(recvd_size != nbytes, mpi_errno, MPI_ERR_OTHER, - "**collective_size_mismatch", - "**collective_size_mismatch %d %d", - (int) recvd_size, (int) nbytes); -#endif - } - - } - - /* perform the internode broadcast */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, - MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS) || + (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(comm_size))) { + /* local roots perform the internode broadcast */ + if (local_rank == local_root) { + mpi_errno = MPIR_Bcast(buffer, count, datatype, inter_root, + comm_ptr, node_roots_group, errflag); MPIR_ERR_CHECK(mpi_errno); } - /* perform the intranode broadcast on all except for the root's node */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - } - } else { /* (nbytes > MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_ptr->size >= MPIR_CVAR_BCAST_MIN_PROCS) */ - - /* supposedly... - * smp+doubling good for pof2 - * reg+ring better for non-pof2 */ - if (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(comm_ptr->local_size)) { - /* medium-sized msg and pof2 np */ - - /* perform the intranode broadcast on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) and is on our node (!-1) */ - /* FIXME binomial may not be the best algorithm for on-node - * bcast. We need a more comprehensive system for selecting the - * right algorithms here. */ - mpi_errno = MPIR_Bcast(buffer, count, datatype, - MPIR_Get_intranode_rank(comm_ptr, root), - comm_ptr->node_comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - } - - /* perform the internode broadcast */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, - MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - } - - /* perform the intranode broadcast on all except for the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) <= 0) { /* 0 if root was local root too, -1 if different node than root */ - /* FIXME binomial may not be the best algorithm for on-node - * bcast. We need a more comprehensive system for selecting the - * right algorithms here. */ - mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - } - } else { /* large msg or non-pof2 */ - - /* FIXME It would be good to have an SMP-aware version of this - * algorithm that (at least approximately) minimized internode - * communication. */ - mpi_errno = - MPIR_Bcast_intra_scatter_ring_allgather(buffer, count, datatype, root, comm_ptr, - errflag); + /* perform the intranode broadcast */ + if (local_size > 1) { + mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr, node_group, errflag); MPIR_ERR_CHECK(mpi_errno); } + } else { + /* FIXME It would be good to have an SMP-aware version of this + * algorithm that (at least approximately) minimized internode + * communication. */ + mpi_errno = MPIR_Bcast_intra_scatter_ring_allgather(buffer, count, datatype, root, comm_ptr, + coll_group, errflag); + MPIR_ERR_CHECK(mpi_errno); } fn_exit: + if (node_roots_group) { + MPIR_COMM_POP_SUBGROUP(comm_ptr); + } return mpi_errno; fn_fail: goto fn_exit; diff --git a/src/mpi/coll/bcast/bcast_intra_tree.c b/src/mpi/coll/bcast/bcast_intra_tree.c index 6cd2f02ba7e..5cc2261b33a 100644 --- a/src/mpi/coll/bcast/bcast_intra_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_tree.c @@ -12,10 +12,10 @@ int MPIR_Bcast_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, int tree_type, + int root, MPIR_Comm * comm_ptr, int coll_group, int tree_type, int branching_factor, int is_nb, MPIR_Errflag_t errflag) { - int rank, comm_size, src, dst, *p, j, k, lrank = -1, is_contig; + int rank, comm_size, src, dst, *p, j, k, is_contig; int parent = -1, num_children = 0, num_req = 0, is_root = 0; int mpi_errno = MPI_SUCCESS; MPI_Aint nbytes = 0, type_size, actual_packed_unpacked_bytes, recvd_size; @@ -29,8 +29,7 @@ int MPIR_Bcast_intra_tree(void *buffer, MPIR_Treealgo_tree_t my_tree; MPIR_CHKLMEM_DECL(3); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* If there is only one process, return */ if (comm_size == 1) @@ -64,6 +63,7 @@ int MPIR_Bcast_intra_tree(void *buffer, dtype = MPI_BYTE; } + int lrank = 0; if (tree_type == MPIR_TREE_TYPE_KARY) { if (rank == root) is_root = 1; @@ -76,12 +76,14 @@ int MPIR_Bcast_intra_tree(void *buffer, if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, branching_factor, root, + MPIR_Treealgo_tree_create_topo_aware(comm_ptr, coll_group, tree_type, + branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BCAST, .comm_ptr = comm_ptr, + .coll_group = coll_group, .u.bcast.buffer = buffer, .u.bcast.count = count, .u.bcast.datatype = datatype, @@ -104,7 +106,7 @@ int MPIR_Bcast_intra_tree(void *buffer, } mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm_ptr, branching_factor, root, + MPIR_Treealgo_tree_create_topo_wave(comm_ptr, coll_group, branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, &my_tree); @@ -129,7 +131,8 @@ int MPIR_Bcast_intra_tree(void *buffer, if ((parent != -1 && tree_type != MPIR_TREE_TYPE_KARY) || (!is_root && tree_type == MPIR_TREE_TYPE_KARY)) { src = parent; - mpi_errno = MPIC_Recv(send_buf, count, dtype, src, MPIR_BCAST_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(send_buf, count, dtype, src, MPIR_BCAST_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); /* check that we received as much as we expected */ MPIR_Get_count_impl(&status, MPI_BYTE, &recvd_size); @@ -147,10 +150,12 @@ int MPIR_Bcast_intra_tree(void *buffer, if (!is_nb) { mpi_errno = - MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, errflag); + MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, + errflag); } else { mpi_errno = MPIC_Isend(send_buf, count, dtype, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], + errflag); } MPIR_ERR_CHECK(mpi_errno); } @@ -161,10 +166,12 @@ int MPIR_Bcast_intra_tree(void *buffer, if (!is_nb) { mpi_errno = - MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, errflag); + MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, + errflag); } else { mpi_errno = MPIC_Isend(send_buf, count, dtype, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], + errflag); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_utils.c b/src/mpi/coll/bcast/bcast_utils.c index 4d5900385da..8b614ae6df3 100644 --- a/src/mpi/coll/bcast/bcast_utils.c +++ b/src/mpi/coll/bcast/bcast_utils.c @@ -20,7 +20,7 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), MPI_Aint count ATTRIBUTE((unused)), MPI_Datatype datatype ATTRIBUTE((unused)), int root, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPI_Aint nbytes, void *tmp_buf, int is_contig, MPIR_Errflag_t errflag) { MPI_Status status; @@ -30,8 +30,8 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), MPI_Aint scatter_size, recv_size = 0; MPI_Aint curr_size, send_size; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; /* use long message algorithm: binomial tree scatter followed by an allgather */ @@ -66,7 +66,8 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), } else { mpi_errno = MPIC_Recv(((char *) tmp_buf + relative_rank * scatter_size), - recv_size, MPI_BYTE, src, MPIR_BCAST_TAG, comm_ptr, &status); + recv_size, MPI_BYTE, src, MPIR_BCAST_TAG, comm_ptr, + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); /* query actual size of data received */ MPIR_Get_count_impl(&status, MPI_BYTE, &curr_size); @@ -93,7 +94,8 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), dst -= comm_size; mpi_errno = MPIC_Send(((char *) tmp_buf + scatter_size * (relative_rank + mask)), - send_size, MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, errflag); + send_size, MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); curr_size -= send_size; diff --git a/src/mpi/coll/coll_algorithms.txt b/src/mpi/coll/coll_algorithms.txt index 4251999ff7f..23e2a1d032b 100644 --- a/src/mpi/coll/coll_algorithms.txt +++ b/src/mpi/coll/coll_algorithms.txt @@ -174,10 +174,12 @@ allgather-intra: func_name: recexch extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_DOUBLING, k, single_phase_recv cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup recexch_halving func_name: recexch extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_HALVING, k, single_phase_recv cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup allgather-inter: local_gather_remote_bcast iallgather-intra: @@ -346,10 +348,11 @@ allreduce-intra: recexch extra_params: k, single_phase_recv cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup ring restrictions: commutative k_reduce_scatter_allgather - restrictions: commutative + restrictions: commutative, nogroup extra_params: k, single_phase_recv cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV allreduce-inter: diff --git a/src/mpi/coll/exscan/exscan_allcomm_nb.c b/src/mpi/coll/exscan/exscan_allcomm_nb.c index a1050eb428f..317745d2ad9 100644 --- a/src/mpi/coll/exscan/exscan_allcomm_nb.c +++ b/src/mpi/coll/exscan/exscan_allcomm_nb.c @@ -6,14 +6,14 @@ #include "mpiimpl.h" int MPIR_Exscan_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Iexscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, &req_ptr); + mpi_errno = MPIR_Iexscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c index b21c91636cd..5f57f2a91b6 100644 --- a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c +++ b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c @@ -48,7 +48,8 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPI_Status status; int rank, comm_size; @@ -58,7 +59,7 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf, void *partial_scan, *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); @@ -92,7 +93,7 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(partial_scan, count, datatype, dst, MPIR_EXSCAN_TAG, tmp_buf, count, datatype, dst, - MPIR_EXSCAN_TAG, comm_ptr, &status, errflag); + MPIR_EXSCAN_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank > dst) { diff --git a/src/mpi/coll/gather/gather_allcomm_nb.c b/src/mpi/coll/gather/gather_allcomm_nb.c index 91f4f71b68a..6a81234d6cb 100644 --- a/src/mpi/coll/gather/gather_allcomm_nb.c +++ b/src/mpi/coll/gather/gather_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Gather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Gather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Igather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, - &req_ptr); + coll_group, &req_ptr); mpi_errno = MPIC_Wait(req_ptr); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/gather/gather_inter_linear.c b/src/mpi/coll/gather/gather_inter_linear.c index fbf29f904e5..2b6b76b6dbb 100644 --- a/src/mpi/coll/gather/gather_inter_linear.c +++ b/src/mpi/coll/gather/gather_inter_linear.c @@ -15,7 +15,7 @@ int MPIR_Gather_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int remote_size, mpi_errno = MPI_SUCCESS; int i; @@ -33,14 +33,13 @@ int MPIR_Gather_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Dataty MPIR_Datatype_get_extent_macro(recvtype, extent); for (i = 0; i < remote_size; i++) { - mpi_errno = - MPIC_Recv(((char *) recvbuf + recvcount * i * extent), recvcount, recvtype, i, - MPIR_GATHER_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(((char *) recvbuf + recvcount * i * extent), recvcount, recvtype, + i, MPIR_GATHER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = - MPIC_Send(sendbuf, sendcount, sendtype, root, MPIR_GATHER_TAG, comm_ptr, errflag); + mpi_errno = MPIC_Send(sendbuf, sendcount, sendtype, root, MPIR_GATHER_TAG, + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c b/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c index 934aca48fa3..6f35bb27dde 100644 --- a/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c +++ b/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c @@ -16,7 +16,8 @@ int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS; MPI_Status status; @@ -35,7 +36,7 @@ int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, MPI_Aint sen /* root receives data from rank 0 on remote group */ mpi_errno = MPIC_Recv(recvbuf, recvcount * remote_size, recvtype, 0, MPIR_GATHER_TAG, comm_ptr, - &status); + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. Rank 0 allocates temporary buffer, does @@ -67,12 +68,12 @@ int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, MPI_Aint sen /* now do the a local gather on this intracommunicator */ mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, MPI_BYTE, 0, newcomm_ptr, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { mpi_errno = MPIC_Send(tmp_buf, sendcount * local_size * sendtype_sz, MPI_BYTE, - root, MPIR_GATHER_TAG, comm_ptr, errflag); + root, MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/gather/gather_intra_binomial.c b/src/mpi/coll/gather/gather_intra_binomial.c index d8915452cda..876b4529e64 100644 --- a/src/mpi/coll/gather/gather_intra_binomial.c +++ b/src/mpi/coll/gather/gather_intra_binomial.c @@ -39,7 +39,7 @@ */ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; @@ -57,7 +57,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Use binomial tree algorithm. */ @@ -137,13 +137,15 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data (((rank + mask) % comm_size) * (MPI_Aint) recvcount * extent)), (MPI_Aint) recvblks * recvcount, - recvtype, src, MPIR_GATHER_TAG, comm_ptr, &status); + recvtype, src, MPIR_GATHER_TAG, comm_ptr, coll_group, + &status); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { /* small transfer size case. cast ok */ MPIR_Assert(recvblks * nbytes == (int) (recvblks * nbytes)); mpi_errno = MPIC_Recv(tmp_buf, (int) (recvblks * nbytes), - MPI_BYTE, src, MPIR_GATHER_TAG, comm_ptr, &status); + MPI_BYTE, src, MPIR_GATHER_TAG, comm_ptr, coll_group, + &status); MPIR_ERR_CHECK(mpi_errno); copy_offset = rank + mask; copy_blks = recvblks; @@ -163,7 +165,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Recv(recvbuf, 1, tmp_type, src, - MPIR_GATHER_TAG, comm_ptr, &status); + MPIR_GATHER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&tmp_type); @@ -184,7 +186,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data offset = (mask - 1) * nbytes; mpi_errno = MPIC_Recv(((char *) tmp_buf + offset), recvblks * nbytes, MPI_BYTE, src, - MPIR_GATHER_TAG, comm_ptr, &status); + MPIR_GATHER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); curr_cnt += (recvblks * nbytes); } @@ -196,11 +198,11 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data if (!tmp_buf_size) { /* leaf nodes send directly from sendbuf */ mpi_errno = MPIC_Send(sendbuf, sendcount, sendtype, dst, - MPIR_GATHER_TAG, comm_ptr, errflag); + MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { mpi_errno = MPIC_Send(tmp_buf, curr_cnt, MPI_BYTE, dst, - MPIR_GATHER_TAG, comm_ptr, errflag); + MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { MPI_Aint blocks[2]; @@ -225,7 +227,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Send(MPI_BOTTOM, 1, tmp_type, dst, - MPIR_GATHER_TAG, comm_ptr, errflag); + MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&tmp_type); if (types[1] != MPI_BYTE) diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c index cabf1ef9bb8..62a3fda6ef2 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c @@ -22,7 +22,8 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; @@ -32,14 +33,17 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) - comm_size = comm_ptr->remote_size; - MPIR_Datatype_get_extent_macro(recvtype, extent); MPIR_CHKLMEM_MALLOC(reqarray, MPIR_Request **, comm_size * sizeof(MPIR_Request *), @@ -60,7 +64,8 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, } else { mpi_errno = MPIC_Irecv(((char *) recvbuf + displs[i] * extent), recvcounts[i], recvtype, i, - MPIR_GATHERV_TAG, comm_ptr, &reqarray[reqs++]); + MPIR_GATHERV_TAG, comm_ptr, coll_group, + &reqarray[reqs++]); MPIR_ERR_CHECK(mpi_errno); } } @@ -73,7 +78,7 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (sendcount) { mpi_errno = MPIC_Send(sendbuf, sendcount, sendtype, root, - MPIR_GATHERV_TAG, comm_ptr, errflag); + MPIR_GATHERV_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_nb.c b/src/mpi/coll/gatherv/gatherv_allcomm_nb.c index 3a49ce11a4e..433035a0850 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_nb.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Gatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, + MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -16,7 +16,7 @@ int MPIR_Gatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatyp /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, - comm_ptr, &req_ptr); + comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/helper_fns.c b/src/mpi/coll/helper_fns.c index 86a4b7b6e62..225a54d96c6 100644 --- a/src/mpi/coll/helper_fns.c +++ b/src/mpi/coll/helper_fns.c @@ -14,39 +14,53 @@ sends/receives by setting the context offset MPIR_CONTEXT_COLL_OFFSET. */ +static int get_coll_group_rank(MPIR_Comm * comm, int coll_group, int group_rank) +{ + if (coll_group > 0) { + return comm->subgroups[coll_group].proc_table[group_rank]; + } else { + return group_rank; + } +} + #ifdef ENABLE_THREADCOMM -#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, req) \ +#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, req) \ do { \ + int rank = get_coll_group_rank(comm_ptr, coll_group, dest); \ if (comm_ptr->threadcomm) { \ - mpi_errno = MPIR_Threadcomm_isend_attr(buf, count, datatype, dest, tag, \ + mpi_errno = MPIR_Threadcomm_isend_attr(buf, count, datatype, rank, tag, \ comm_ptr->threadcomm, attr, req); \ } else { \ - mpi_errno = MPID_Isend(buf, count, datatype, dest, tag, comm_ptr, attr, req); \ + mpi_errno = MPID_Isend(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } \ } while (0) -#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, req) \ +#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, req) \ do { \ + int rank = get_coll_group_rank(comm_ptr, coll_group, source); \ if (comm_ptr->threadcomm) { \ - mpi_errno = MPIR_Threadcomm_irecv_attr(buf, count, datatype, source, tag, \ + mpi_errno = MPIR_Threadcomm_irecv_attr(buf, count, datatype, rank, tag, \ comm_ptr->threadcomm, attr, req, true); \ } else { \ - mpi_errno = MPID_Irecv(buf, count, datatype, source, tag, comm_ptr, attr, req); \ + mpi_errno = MPID_Irecv(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } \ } while (0) #else -#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, req) \ +#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, req) \ do { \ - mpi_errno = MPID_Isend(buf, count, datatype, dest, tag, comm_ptr, attr, req); \ + int rank = get_coll_group_rank(comm_ptr, coll_group, dest); \ + mpi_errno = MPID_Isend(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } while (0) -#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, req) \ +#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, req) \ do { \ - mpi_errno = MPID_Irecv(buf, count, datatype, source, tag, comm_ptr, attr, req); \ + int rank = get_coll_group_rank(comm_ptr, coll_group, source); \ + mpi_errno = MPID_Irecv(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } while (0) #endif +/* NOTE: MPIC_Probe is never used group collectives */ int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status) { int mpi_errno = MPI_SUCCESS; @@ -127,7 +141,7 @@ int MPIC_Wait(MPIR_Request * request_ptr) this is OK since there is no data that can be received corrupted. */ int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -146,7 +160,7 @@ int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, MPIR_CONTEXT_COLL_OFFSET); MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, &request_ptr); + DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, &request_ptr); MPIR_ERR_CHECK(mpi_errno); if (request_ptr) { mpi_errno = MPIC_Wait(request_ptr); @@ -168,7 +182,7 @@ int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, } int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status) + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -191,7 +205,7 @@ int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int if (status == MPI_STATUS_IGNORE) status = &mystatus; - DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, &request_ptr); + DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, &request_ptr); MPIR_ERR_CHECK(mpi_errno); if (request_ptr) { mpi_errno = MPIC_Wait(request_ptr); @@ -218,7 +232,7 @@ int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -244,7 +258,8 @@ int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype "**nomemreq"); MPIR_Status_set_procnull(&recv_req_ptr->status); } else { - DO_MPID_IRECV(recvbuf, recvcount, recvtype, source, recvtag, comm_ptr, attr, &recv_req_ptr); + DO_MPID_IRECV(recvbuf, recvcount, recvtype, source, recvtag, comm_ptr, coll_group, attr, + &recv_req_ptr); MPIR_ERR_CHECK(mpi_errno); } @@ -255,7 +270,8 @@ int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype "**nomemreq"); } else { MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(sendbuf, sendcount, sendtype, dest, sendtag, comm_ptr, attr, &send_req_ptr); + DO_MPID_ISEND(sendbuf, sendcount, sendtype, dest, sendtag, comm_ptr, coll_group, attr, + &send_req_ptr); MPIR_ERR_CHECK(mpi_errno); } @@ -297,7 +313,8 @@ int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int sendtag, int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Status mystatus; @@ -345,7 +362,7 @@ int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, MPIR_ERR_CHKANDSTMT(rreq == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq"); MPIR_Status_set_procnull(&rreq->status); } else { - DO_MPID_IRECV(buf, count, datatype, source, recvtag, comm_ptr, attr, &rreq); + DO_MPID_IRECV(buf, count, datatype, source, recvtag, comm_ptr, coll_group, attr, &rreq); MPIR_ERR_CHECK(mpi_errno); } @@ -355,7 +372,8 @@ int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, MPIR_ERR_CHKANDSTMT(sreq == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq"); } else { MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(tmpbuf, actual_pack_bytes, MPI_PACKED, dest, sendtag, comm_ptr, attr, &sreq); + DO_MPID_ISEND(tmpbuf, actual_pack_bytes, MPI_PACKED, dest, sendtag, comm_ptr, coll_group, + attr, &sreq); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno != MPI_SUCCESS) { /* --BEGIN ERROR HANDLING-- */ @@ -400,7 +418,8 @@ int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, } int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Request ** request_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request_ptr, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -421,7 +440,7 @@ int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, MPIR_CONTEXT_COLL_OFFSET); MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, request_ptr); + DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, request_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -434,7 +453,7 @@ int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, } int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, - int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request_ptr) + int tag, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request_ptr) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -455,7 +474,7 @@ int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, MPIR_CONTEXT_COLL_OFFSET); - DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, request_ptr); + DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, request_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c b/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c index 62e702e4679..2920955fd69 100644 --- a/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c +++ b/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c @@ -14,7 +14,8 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, local_size, remote_size, root; @@ -46,7 +47,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (sendcount != 0) { mpi_errno = MPIR_Igather_intra_sched_auto(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, MPI_BYTE, 0, - newcomm_ptr, s); + newcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -58,7 +59,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ibcast_inter_sched_auto(tmp_buf, sendcount * local_size * sendtype_sz, - MPI_BYTE, root, comm_ptr, s); + MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -68,7 +69,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (recvcount != 0) { root = 0; mpi_errno = MPIR_Ibcast_inter_sched_auto(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -77,7 +78,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (recvcount != 0) { root = 0; mpi_errno = MPIR_Ibcast_inter_sched_auto(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -87,7 +88,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ibcast_inter_sched_auto(tmp_buf, sendcount * local_size * sendtype_sz, - MPI_BYTE, root, comm_ptr, s); + MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c index 955a1447ce5..44a5019b115 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c @@ -16,7 +16,8 @@ */ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2, rem, src, dst; @@ -24,8 +25,7 @@ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Aint recvtype_extent, recvtype_sz; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); /* allocate a temporary buffer of the same size as recvbuf. */ @@ -56,11 +56,13 @@ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, src = (rank + pof2) % comm_size; dst = (rank - pof2 + comm_size) % comm_size; - mpi_errno = MPIR_Sched_send(tmp_buf, curr_cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, curr_cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); /* logically sendrecv, so no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + curr_cnt * recvtype_sz), - curr_cnt * recvtype_sz, MPI_BYTE, src, comm_ptr, s); + curr_cnt * recvtype_sz, MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -76,11 +78,13 @@ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, dst = (rank - pof2 + comm_size) % comm_size; mpi_errno = - MPIR_Sched_send(tmp_buf, rem * recvcount * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + MPIR_Sched_send(tmp_buf, rem * recvcount * recvtype_sz, MPI_BYTE, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* logically sendrecv, so no barrier here */ mpi_errno = MPIR_Sched_recv((char *) tmp_buf + curr_cnt * recvtype_sz, - rem * recvcount * recvtype_sz, MPI_BYTE, src, comm_ptr, s); + rem * recvcount * recvtype_sz, MPI_BYTE, src, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c index dd0ba4c217a..42bcb719edd 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c @@ -46,7 +46,8 @@ static int reset_shared_state(MPIR_Comm * comm, int tag, void *state) int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct shared_state *ss = NULL; @@ -56,8 +57,7 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint int dst_tree_root, my_tree_root, tree_root; MPI_Aint recvtype_extent; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. @@ -104,12 +104,13 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint if (dst < comm_size) { mpi_errno = MPIR_Sched_send_defer(((char *) recvbuf + send_offset), - &ss->curr_count, recvtype, dst, comm_ptr, s); + &ss->curr_count, recvtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); /* send-recv, no sched barrier here */ mpi_errno = MPIR_Sched_recv_status(((char *) recvbuf + recv_offset), ((comm_size - dst_tree_root) * recvcount), - recvtype, dst, comm_ptr, &ss->status, s); + recvtype, dst, comm_ptr, coll_group, &ss->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -168,7 +169,7 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint * sent now. */ mpi_errno = MPIR_Sched_send_defer(((char *) recvbuf + offset), &ss->last_recv_count, - recvtype, dst, comm_ptr, s); + recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -182,7 +183,8 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint mpi_errno = MPIR_Sched_recv_status(((char *) recvbuf + offset), ((comm_size - (my_tree_root + mask)) * recvcount), - recvtype, dst, comm_ptr, &ss->status, s); + recvtype, dst, comm_ptr, coll_group, + &ss->status, s); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&get_count, ss, s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c index 78bbbacce62..da715eb1c1a 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c @@ -22,15 +22,15 @@ */ int MPIR_Iallgather_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; int i, j, jnext, left, right; MPI_Aint recvtype_extent; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); @@ -52,11 +52,11 @@ int MPIR_Iallgather_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, MP jnext = left; for (i = 1; i < comm_size; i++) { mpi_errno = MPIR_Sched_send(((char *) recvbuf + j * recvcount * recvtype_extent), - recvcount, recvtype, right, comm_ptr, s); + recvcount, recvtype, right, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* concurrent, no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) recvbuf + jnext * recvcount * recvtype_extent), - recvcount, recvtype, left, comm_ptr, s); + recvcount, recvtype, left, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c index 8e526ac45b3..72ac971069f 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c @@ -10,7 +10,8 @@ int MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, int k, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, j; @@ -20,8 +21,9 @@ MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, int src, dst, p_of_k = 0; /* Largest power of k that is (strictly) smaller than 'size' */ MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; - int rank = MPIR_Comm_rank(comm); - int size = MPIR_Comm_size(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); int max = size - 1; int vtx_id; @@ -38,7 +40,7 @@ MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_FUNC_ENTER; @@ -116,18 +118,20 @@ MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* Receive at the exact location. */ mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_recvbuf + j * recvcount * delta * recvtype_extent, - count, recvtype, src, tag, comm, sched, 0, NULL, &vtx_id); + count, recvtype, src, tag, comm, coll_group, sched, 0, NULL, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); recv_id[i_recv++] = vtx_id; /* Send from the start of recv till `count` amount of data. */ if (i == 0) mpi_errno = - MPIR_TSP_sched_isend(tmp_recvbuf, count, recvtype, dst, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(tmp_recvbuf, count, recvtype, dst, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); else mpi_errno = MPIR_TSP_sched_isend(tmp_recvbuf, count, recvtype, dst, tag, - comm, sched, n_invtcs, recv_id, &vtx_id); + comm, coll_group, sched, n_invtcs, recv_id, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); } n_invtcs += (k - 1); diff --git a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c index 6f554bacecc..bb3dfcacdc1 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c @@ -12,7 +12,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(int rank, int n MPI_Datatype recvtype, size_t recv_extent, MPI_Aint recvcount, int tag, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -33,7 +33,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(int rank, int n /* send my data to partner */ mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), count * recvcount, recvtype, - partner, tag, comm, sched, 0, NULL, &vtx_id); + partner, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); @@ -46,7 +46,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(int rank, int n /* recv data from my partner */ mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), count * recvcount, recvtype, - partner, tag, comm, sched, 0, NULL, &vtx_id); + partner, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -65,7 +65,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step1(int step1_sendto, int * void *recvbuf, size_t recv_extent, MPI_Aint recvcount, MPI_Datatype recvtype, int n_invtcs, int *invtx, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i; @@ -83,15 +83,15 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step1(int step1_sendto, int * else buf_to_send = (void *) sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, recvcount, recvtype, step1_sendto, tag, comm, sched, - 0, NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, recvcount, recvtype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets the data from non-participating rank */ MPI_Aint recv_offset = step1_recvfrom[i] * recv_extent * recvcount; mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recvcount, recvtype, - step1_recvfrom[i], tag, comm, sched, n_invtcs, invtx, - &vtx_id); + step1_recvfrom[i], tag, comm, coll_group, sched, + n_invtcs, invtx, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -111,7 +111,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step2(int step1_sendto, int s void *recvbuf, size_t recv_extent, MPI_Aint recvcount, MPI_Datatype recvtype, int is_dist_halving, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int phase, i, j, count, nbr, offset, rank_for_offset; @@ -140,7 +140,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step2(int step1_sendto, int s MPI_Aint send_offset = offset * recv_extent * recvcount; mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), count * recvcount, recvtype, - nbr, tag, comm, sched, nrecvs, recv_id, &vtx_id); + nbr, tag, comm, coll_group, sched, nrecvs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, @@ -158,7 +158,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step2(int step1_sendto, int s MPI_Aint recv_offset = offset * recv_extent * recvcount; mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), count * recvcount, recvtype, - nbr, tag, comm, sched, 0, NULL, &vtx_id); + nbr, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); @@ -191,7 +191,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step3(int step1_sendto, int * int nranks, int k, int nrecvs, int *recv_id, int tag, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, vtx_id; @@ -202,13 +202,13 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step3(int step1_sendto, int * if (step1_sendto != -1) { mpi_errno = MPIR_TSP_sched_irecv(recvbuf, recvcount * nranks, recvtype, step1_sendto, tag, comm, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIR_TSP_sched_isend(recvbuf, recvcount * nranks, recvtype, step1_recvfrom[i], - tag, comm, sched, nrecvs, recv_id, &vtx_id); + tag, comm, coll_group, sched, nrecvs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -232,8 +232,8 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step3(int step1_sendto, int * int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, int is_dist_halving, int k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int is_dist_halving, + int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace, i; @@ -255,12 +255,11 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); @@ -292,7 +291,7 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco MPIR_TSP_Iallgather_sched_intra_recexch_step1(step1_sendto, step1_recvfrom, step1_nrecvs, is_inplace, rank, tag, sendbuf, recvbuf, recv_extent, recvcount, recvtype, n_invtcs, - &invtx, comm, sched); + &invtx, comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); @@ -302,7 +301,8 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco if (step1_sendto == -1) { MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(rank, nranks, k, p_of_k, log_pofk, T, recvbuf, recvtype, recv_extent, - recvcount, tag, comm, sched); + recvcount, tag, comm, coll_group, + sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); } @@ -312,13 +312,14 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco MPIR_TSP_Iallgather_sched_intra_recexch_step2(step1_sendto, step2_nphases, step2_nbrs, rank, nranks, k, p_of_k, log_pofk, T, &nrecvs, &recv_id, tag, recvbuf, recv_extent, recvcount, recvtype, - is_dist_halving, comm, sched); + is_dist_halving, comm, coll_group, sched); /* Step 3: This is reverse of Step 1. Ranks that participated in Step 2 * send the data to non-partcipating ranks */ MPIR_TSP_Iallgather_sched_intra_recexch_step3(step1_sendto, step1_recvfrom, step1_nrecvs, step2_nphases, recvbuf, recvcount, nranks, k, - nrecvs, recv_id, tag, recvtype, comm, sched); + nrecvs, recv_id, tag, recvtype, comm, coll_group, + sched); /* free the memory */ for (i = 0; i < step2_nphases; i++) diff --git a/src/mpi/coll/iallgather/iallgather_tsp_ring.c b/src/mpi/coll/iallgather/iallgather_tsp_ring.c index f9b725ba9fc..7b5e34ce881 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_ring.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_ring.c @@ -9,15 +9,16 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, src, dst, copy_dst; /* Temporary buffers to execute the ring algorithm */ void *buf1, *buf2, *data_buf, *rbuf, *sbuf; - int size = MPIR_Comm_size(comm); - int rank = MPIR_Comm_rank(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); int tag; int vtx_id; @@ -82,7 +83,7 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount int recv_id[3] = { 0 }; /* warning fix: icc: maybe used before set */ for (i = 0; i < size - 1; i++) { /* Get new tag for each cycle so that the send-recv pairs are matched correctly */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); int vtcs[3], nvtcs; @@ -90,14 +91,16 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount nvtcs = 1; vtcs[0] = dtcopy_id[0]; mpi_errno = MPIR_TSP_sched_isend((char *) sbuf, recvcount, recvtype, - dst, tag, comm, sched, nvtcs, vtcs, &send_id[0]); + dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id[0]); nvtcs = 0; } else { nvtcs = 2; vtcs[0] = recv_id[(i - 1) % 3]; vtcs[1] = send_id[(i - 1) % 3]; mpi_errno = MPIR_TSP_sched_isend((char *) sbuf, recvcount, recvtype, - dst, tag, comm, sched, nvtcs, vtcs, &send_id[i % 3]); + dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id[i % 3]); if (i == 1) { nvtcs = 2; vtcs[0] = send_id[0]; @@ -112,7 +115,8 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_TSP_sched_irecv((char *) rbuf, recvcount, recvtype, - src, tag, comm, sched, nvtcs, vtcs, &recv_id[i % 3]); + src, tag, comm, coll_group, sched, nvtcs, vtcs, + &recv_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c b/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c index 575dbf7bedb..70d273e1074 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c @@ -21,7 +21,8 @@ int MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int remote_size, root, rank; @@ -37,23 +38,27 @@ int MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(const void *sendbuf, /* gatherv from right group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* gatherv to right group */ root = 0; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* gatherv to left group */ root = 0; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* gatherv from left group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -76,7 +81,7 @@ int MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(const void *sendbuf, mpi_errno = MPIR_Type_commit_impl(&newtype); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, 1, newtype, 0, newcomm_ptr, s); + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, 1, newtype, 0, newcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&newtype); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c index a849dd3522e..e978058c7d7 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c @@ -8,7 +8,8 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, j, i; @@ -16,8 +17,7 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, int dst, pof2, src, rem; void *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); MPIR_Datatype_get_size_macro(recvtype, recvtype_sz); @@ -69,11 +69,14 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, incoming_count += recvcounts[(src + i) % comm_size]; } - mpi_errno = MPIR_Sched_send(tmp_buf, curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + curr_count * recvtype_sz), - incoming_count * recvtype_sz, MPI_BYTE, src, comm_ptr, s); + incoming_count * recvtype_sz, MPI_BYTE, src, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -92,12 +95,13 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, for (i = 0; i < rem; i++) cnt += recvcounts[(rank + i) % comm_size]; - mpi_errno = MPIR_Sched_send(tmp_buf, cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + curr_count * recvtype_sz), (total_count - curr_count) * recvtype_sz, MPI_BYTE, - src, comm_ptr, s); + src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c index a178c94799d..da0f616f048 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c @@ -9,7 +9,8 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, i, j, k; @@ -17,8 +18,7 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain MPI_Aint recvtype_extent, recvtype_sz, position, offset; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. @@ -108,11 +108,13 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain incoming_count += recvcounts[j]; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + send_offset * recvtype_sz), - curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + recv_offset * recvtype_sz), - incoming_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + incoming_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -177,7 +179,7 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain * sent now. */ mpi_errno = MPIR_Sched_send(((char *) tmp_buf + offset), incoming_count * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -199,7 +201,7 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + offset * recvtype_sz), incoming_count * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); curr_count += incoming_count; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c index 76b390f94b3..5eefd3de298 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c @@ -8,7 +8,8 @@ int MPIR_Iallgatherv_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; @@ -18,8 +19,8 @@ int MPIR_Iallgatherv_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, char *sbuf = NULL; char *rbuf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); total_count = 0; @@ -77,12 +78,12 @@ int MPIR_Iallgatherv_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, /* Communicate */ if (recvnow) { /* If there's no data to send, just do a recv call */ - mpi_errno = MPIR_Sched_recv(rbuf, recvnow, recvtype, left, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(rbuf, recvnow, recvtype, left, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); torecv -= recvnow; } if (sendnow) { /* If there's no data to receive, just do a send call */ - mpi_errno = MPIR_Sched_send(sbuf, sendnow, recvtype, right, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sbuf, sendnow, recvtype, right, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); tosend -= sendnow; } diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c index 896cc02847e..0a72891cf7d 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c @@ -29,7 +29,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int k, MPIR_TSP_sched_t sched) { int i, j, l; @@ -64,12 +64,11 @@ MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); is_inplace = (sendbuf == MPI_IN_PLACE); - rank = MPIR_Comm_rank(comm); - size = MPIR_Comm_size(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); max = size - 1; if (is_inplace) { @@ -218,8 +217,8 @@ MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* Recv at the exact location */ mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_recvbuf + recv_index[idx] * recvtype_extent, - r_counts[i][j - 1], recvtype, src, tag, comm, sched, 0, NULL, - &vtx_id); + r_counts[i][j - 1], recvtype, src, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); recv_id[idx] = vtx_id; @@ -228,7 +227,7 @@ MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* Send from the start of recv till the count amount of data */ mpi_errno = MPIR_TSP_sched_isend(tmp_recvbuf, s_counts[i][j - 1], recvtype, dst, tag, comm, - sched, n_invtcs, recv_id, &vtx_id); + coll_group, sched, n_invtcs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } n_invtcs += (k - 1); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c index 48505646110..1751ac4c091 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c @@ -13,7 +13,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(int rank, int size_t recv_extent, const MPI_Aint * recvcounts, const MPI_Aint * displs, int tag, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -40,7 +40,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(int rank, int /* send my data to partner */ mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), send_count, recvtype, partner, - tag, comm, sched, 0, NULL, &vtx_id); + tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); /* calculate offset and count of the data to be received from the partner */ @@ -54,7 +54,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(int rank, int recv_offset, recv_count)); /* recv data from my partner */ mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recv_count, recvtype, - partner, tag, comm, sched, 0, NULL, &vtx_id); + partner, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -74,7 +74,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(int step1_sendto, int const MPI_Aint * displs, MPI_Datatype recvtype, int n_invtcs, int *invtx, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, vtx_id; @@ -92,7 +92,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(int step1_sendto, int buf_to_send = (void *) sendbuf; mpi_errno = MPIR_TSP_sched_isend(buf_to_send, recvcounts[rank], recvtype, step1_sendto, tag, comm, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets the data from non-participating rank */ @@ -100,7 +100,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(int step1_sendto, int mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recvcounts[step1_recvfrom[i]], recvtype, step1_recvfrom[i], - tag, comm, sched, n_invtcs, invtx, &vtx_id); + tag, comm, coll_group, sched, n_invtcs, invtx, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -120,7 +120,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n size_t recv_extent, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int is_dist_halving, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int phase, i, j, count, nbr, offset, rank_for_offset; @@ -152,7 +152,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n for (x = 0; x < count; x++) send_count += recvcounts[offset + x]; mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), send_count, recvtype, - nbr, tag, comm, sched, nrecvs, recv_id, &vtx_id); + nbr, tag, comm, coll_group, sched, nrecvs, recv_id, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, @@ -173,7 +174,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n recv_count += recvcounts[offset + x]; mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recv_count, recvtype, - nbr, tag, comm, sched, 0, NULL, &vtx_id); + nbr, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); recv_id[j * (k - 1) + i] = vtx_id; @@ -206,7 +207,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step3(int step1_sendto, int const MPI_Aint * recvcounts, int nranks, int k, int nrecvs, int *recv_id, int tag, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPI_Aint total_count = 0; @@ -221,14 +222,14 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step3(int step1_sendto, int if (step1_sendto != -1) { mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, total_count, recvtype, step1_sendto, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, total_count, recvtype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIR_TSP_sched_isend(recvbuf, total_count, recvtype, step1_recvfrom[i], - tag, comm, sched, nrecvs, recv_id, &vtx_id); + tag, comm, coll_group, sched, nrecvs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -244,7 +245,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - int is_dist_halving, int k, MPIR_TSP_sched_t sched) + int coll_group, int is_dist_halving, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace, i; @@ -263,8 +265,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); @@ -272,7 +273,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* get the neighbors, the function allocates the required memory */ @@ -300,7 +301,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(step1_sendto, step1_recvfrom, step1_nrecvs, is_inplace, rank, tag, sendbuf, recvbuf, recv_extent, recvcounts, displs, recvtype, - n_invtcs, &invtx, comm, sched); + n_invtcs, &invtx, comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); @@ -311,7 +312,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(rank, nranks, k, p_of_k, log_pofk, T, recvbuf, recvtype, recv_extent, recvcounts, displs, - tag, comm, sched); + tag, comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); } @@ -321,13 +322,15 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(step1_sendto, step2_nphases, step2_nbrs, rank, nranks, k, p_of_k, log_pofk, T, &nrecvs, &recv_id, tag, recvbuf, recv_extent, recvcounts, - displs, recvtype, is_dist_halving, comm, sched); + displs, recvtype, is_dist_halving, comm, + coll_group, sched); /* Step 3: This is reverse of Step 1. Ranks that participated in Step 2 * send the data to non-partcipating ranks */ MPIR_TSP_Iallgatherv_sched_intra_recexch_step3(step1_sendto, step1_recvfrom, step1_nrecvs, step2_nphases, recvbuf, recvcounts, nranks, k, - nrecvs, recv_id, tag, recvtype, comm, sched); + nrecvs, recv_id, tag, recvtype, comm, coll_group, + sched); fn_exit: /* free the memory */ diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c index dfcdf0e81c1..d014fa777a2 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c @@ -10,7 +10,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { size_t extent; @@ -26,8 +26,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); /* find out the buffer which has the send data and point data_buf to it */ if (is_inplace) { @@ -85,7 +84,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun send_rank = (rank - i + nranks) % nranks; /* Rank whose data you're sending */ /* New tag for each send-recv pair. */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); int nvtcs, vtcs[3]; @@ -94,8 +93,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun vtcs[0] = dtcopy_id[0]; mpi_errno = - MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, sched, - nvtcs, vtcs, &send_id[i % 3]); + MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); nvtcs = 0; } else { @@ -104,8 +103,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun vtcs[1] = send_id[(i - 1) % 3]; mpi_errno = - MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, sched, - nvtcs, vtcs, &send_id[i % 3]); + MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); if (i == 1) { @@ -121,8 +120,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun } mpi_errno = - MPIR_TSP_sched_irecv(rbuf, recvcounts[recv_rank], recvtype, src, tag, comm, sched, - nvtcs, vtcs, &recv_id[i % 3]); + MPIR_TSP_sched_irecv(rbuf, recvcounts[recv_rank], recvtype, src, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); /* Copy to correct position in recvbuf */ mpi_errno = diff --git a/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c b/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c index 1e4e312c83f..6ce03a2255c 100644 --- a/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c +++ b/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c @@ -19,7 +19,7 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, root; @@ -35,7 +35,8 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* no barrier, these reductions can be concurrent */ @@ -43,13 +44,15 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v /* reduce to rank 0 of right group */ root = 0; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* no barrier, these reductions can be concurrent */ @@ -57,7 +60,8 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -72,7 +76,7 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v } lcomm_ptr = comm_ptr->local_comm; - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, lcomm_ptr, s); + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, lcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c index 36a26b6e42c..eb5544c70bf 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c @@ -8,26 +8,30 @@ /* implements the naive intracomm allreduce, that is, reduce followed by bcast */ int MPIR_Iallreduce_intra_sched_naive(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int rank; + int rank, comm_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + if (comm_size == 1) + goto fn_exit; if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) { mpi_errno = - MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, comm_ptr, s); + MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr, s); + MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c index b0a08613efd..bdb79e3fd0a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c @@ -7,7 +7,8 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2, rem, comm_size, is_commutative, rank; @@ -15,8 +16,7 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re MPI_Aint true_lb, true_extent, extent; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); @@ -38,7 +38,7 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; @@ -50,7 +50,8 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -59,7 +60,8 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -85,10 +87,10 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re /* Send the most current data, which is in recvbuf. Recv * into tmp_buf */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -120,10 +122,12 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re * (rank-1), the ranks who didn't participate above. */ if (rank < 2 * rem) { if (rank % 2) { /* odd */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* even */ - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c index 75bab1d9a84..803961e4bb4 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c @@ -9,7 +9,7 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, newrank, pof2, rem; @@ -24,8 +24,7 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo MPIR_Assert(HANDLE_IS_BUILTIN(op)); #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* need to allocate temporary buffer to store incoming data */ MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -45,7 +44,7 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; @@ -57,7 +56,8 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -66,7 +66,8 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -133,11 +134,11 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -192,11 +193,11 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo } mpi_errno = MPIR_Sched_recv(((char *) recvbuf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -212,10 +213,12 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo * (rank-1), the ranks who didn't participate above. */ if (rank < 2 * rem) { if (rank % 2) { /* odd */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* even */ - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c index 20fb4fa88c9..31d4767108a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c @@ -8,17 +8,14 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int is_commutative; - MPIR_Comm *nc; - MPIR_Comm *nrc; - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); - - nc = comm_ptr->node_comm; - nrc = comm_ptr->node_roots_comm; + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; is_commutative = MPIR_Op_is_commutative(op); @@ -26,24 +23,26 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint if (!is_commutative) { /* use flat fallback */ mpi_errno = - MPIR_Iallreduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, comm_ptr, s); + MPIR_Iallreduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } /* on each node, do a reduce to the local root */ - if (nc != NULL) { + if (local_size > 1) { /* take care of the MPI_IN_PLACE case. For reduce, * MPI_IN_PLACE is specified only on the root; * for allreduce it is specified on all processes. */ - if ((sendbuf == MPI_IN_PLACE) && (comm_ptr->node_comm->rank != 0)) { + if ((sendbuf == MPI_IN_PLACE) && (local_rank != 0)) { /* IN_PLACE and not root of reduce. Data supplied to this * allreduce is in recvbuf. Pass that as the sendbuf to reduce. */ - mpi_errno = MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, nc, s); + mpi_errno = MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); } else { - mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, nc, s); + mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -57,16 +56,17 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint } /* now do an IN_PLACE allreduce among the local roots of all nodes */ - if (nrc != NULL) { - mpi_errno = - MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, recvbuf, count, datatype, op, nrc, s); + if (local_rank == 0) { + mpi_errno = MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } /* now broadcast the result among local processes */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, nc, s); + if (local_size > 1) { + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c b/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c index be266211063..47de5324fe6 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c @@ -10,7 +10,8 @@ /* Routine to schedule a pipelined tree based allreduce */ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_commutative = MPIR_Op_is_commutative(op); @@ -21,6 +22,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IALLREDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.iallreduce.sendbuf = sendbuf, .u.iallreduce.recvbuf = recvbuf, @@ -35,7 +37,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPIR_CVAR_IALLREDUCE_INTRA_ALGORITHM_tsp_recexch_single_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); break; @@ -43,7 +45,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPIR_CVAR_IALLREDUCE_INTRA_ALGORITHM_tsp_recexch_multiple_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); break; @@ -56,7 +58,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, "Iallreduce gentran_tree cannot be applied.\n"); mpi_errno = MPIR_TSP_Iallreduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_Iallreduce_tree_type, + comm, coll_group, MPIR_Iallreduce_tree_type, MPIR_CVAR_IALLREDUCE_TREE_KVAL, MPIR_CVAR_IALLREDUCE_TREE_PIPELINE_CHUNK_SIZE, MPIR_CVAR_IALLREDUCE_TREE_BUFFER_PER_CHILD, @@ -68,7 +70,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, "Iallreduce gentran_ring cannot be applied.\n"); mpi_errno = MPIR_TSP_Iallreduce_sched_intra_ring(sendbuf, recvbuf, count, datatype, - op, comm, sched); + op, comm, coll_group, sched); break; case MPIR_CVAR_IALLREDUCE_INTRA_ALGORITHM_tsp_recexch_reduce_scatter_recexch_allgatherv: /* This algorithm will work for commutative @@ -86,6 +88,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, datatype, op, comm, + coll_group, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); break; @@ -97,7 +100,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_recexch_single_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER, cnt->u. iallreduce.intra_tsp_recexch_single_buffer. @@ -107,7 +110,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_recexch_multiple_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, cnt->u. iallreduce.intra_tsp_recexch_single_buffer. @@ -117,7 +120,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_tree: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - comm, + comm, coll_group, cnt->u.iallreduce. intra_tsp_tree.tree_type, cnt->u.iallreduce.intra_tsp_tree.k, @@ -131,13 +134,13 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_ring: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_ring(sendbuf, recvbuf, count, datatype, op, - comm, sched); + comm, coll_group, sched); break; case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_recexch_reduce_scatter_recexch_allgatherv: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv - (sendbuf, recvbuf, count, datatype, op, comm, + (sendbuf, recvbuf, count, datatype, op, comm, coll_group, cnt->u.iallreduce.intra_tsp_recexch_reduce_scatter_recexch_allgatherv.k, sched); break; @@ -155,7 +158,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, fallback: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); fn_exit: diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c index 91aa69af236..38679c4907a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c @@ -11,8 +11,8 @@ /* Routine to schedule a recursive exchange based allreduce */ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, int per_nbr_buffer, int k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int per_nbr_buffer, + int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace, i, j; @@ -39,8 +39,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); @@ -51,7 +50,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); /* get the neighbors, the function allocates the required memory */ MPII_Recexchalgo_get_neighbors(rank, nranks, &k, &step1_sendto, @@ -76,7 +75,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, tag, extent, dtcopy_id, recv_id, reduce_id, vtcs, is_inplace, step1_sendto, in_step2, step1_nrecvs, step1_recvfrom, per_nbr_buffer, &step1_recvbuf, - comm, sched); + comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_sink(sched, &step1_id); /* sink for all the tasks up to end of Step 1 */ if (mpi_errno) @@ -152,8 +151,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, nbr = step2_nbrs[phase][i]; mpi_errno = - MPIR_TSP_sched_isend(tmp_buf, count, datatype, nbr, tag, comm, sched, nvtcs, vtcs, - &send_id[i]); + MPIR_TSP_sched_isend(tmp_buf, count, datatype, nbr, tag, comm, coll_group, sched, + nvtcs, vtcs, &send_id[i]); MPIR_ERR_CHECK(mpi_errno); if (rank > nbr) { myidx = i + 1; @@ -168,8 +167,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, vtcs[nvtcs++] = (counter == 0) ? reduce_id[k - 2] : reduce_id[counter - 1]; } mpi_errno = - MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, sched, nvtcs, - vtcs, &recv_id[buf]); + MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[buf]); MPIR_ERR_CHECK(mpi_errno); if (count != 0) { @@ -196,8 +195,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, vtcs[nvtcs++] = (counter == 0) ? reduce_id[k - 2] : reduce_id[counter - 1]; } mpi_errno = - MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, sched, nvtcs, - vtcs, &recv_id[buf]); + MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[buf]); MPIR_ERR_CHECK(mpi_errno); @@ -233,8 +232,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, * send the data to non-partcipating ranks */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { @@ -253,8 +252,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, vtcs[0] = reduce_id[k - 2]; } mpi_errno = - MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, sched, - nvtcs, vtcs, &vtx_id); + MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, + coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c index b741e42df12..2ecde2af9d9 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c @@ -15,7 +15,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - int k, + int coll_group, int k, MPIR_TSP_sched_t sched) { @@ -45,8 +45,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); @@ -55,7 +54,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* get the neighbors, the function allocates the required memory */ @@ -86,7 +85,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co tag, extent, dtcopy_id, recv_id, reduce_id, vtcs, is_inplace, step1_sendto, in_step2, step1_nrecvs, step1_recvfrom, per_nbr_buffer, &step1_recvbuf, - comm, sched); + comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_sink(sched, &sink_id); /* sink for all the tasks up to end of Step 1 */ if (mpi_errno) @@ -119,7 +118,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(recvbuf, tmp_recvbuf, cnts, displs, datatype, op, extent, tag, - comm, k, redscat_algo_type, + comm, coll_group, k, redscat_algo_type, step2_nphases, step2_nbrs, rank, nranks, sink_id, 0, NULL, sched); @@ -128,7 +127,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(step1_sendto, step2_nphases, step2_nbrs, rank, nranks, k, p_of_k, log_pofk, T, &nvtcs, &recv_id, tag, recvbuf, extent, cnts, displs, - datatype, allgather_algo_type, comm, sched); + datatype, allgather_algo_type, comm, + coll_group, sched); } @@ -138,14 +138,14 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co * send the data to non-partcipating ranks */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, sched, 1, - &sink_id, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, coll_group, + sched, 1, &sink_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { mpi_errno = - MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, sched, - nvtcs, recv_id, &vtx_id); + MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, + coll_group, sched, nvtcs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c index cb6a631cb6d..9e6632673fe 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c @@ -46,7 +46,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, int step1_sendto, bool in_step2, int step1_nrecvs, int *step1_recvfrom, int per_nbr_buffer, void ***step1_recvbuf_, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, nvtcs, vtx_id; @@ -62,8 +62,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, else buf_to_send = sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, count, datatype, step1_sendto, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, count, datatype, step1_sendto, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { /* Step 2 participating rank */ step1_recvbuf = *step1_recvbuf_ = @@ -89,8 +89,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, reduce_id[i - 1])); } mpi_errno = MPIR_TSP_sched_irecv(step1_recvbuf[i], count, datatype, - step1_recvfrom[i], tag, comm, sched, nvtcs, vtcs, - &recv_id[i]); + step1_recvfrom[i], tag, comm, coll_group, sched, nvtcs, + vtcs, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); if (count != 0) { /* Reduce only if data is present */ /* setup reduce dependencies */ diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h index 47964cb2729..1e362a4c8c6 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h @@ -16,5 +16,5 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, int step1_sendto, bool in_step2, int step1_nrecvs, int *step1_recvfrom, int per_nbr_buffer, void ***step1_recvbuf_, MPIR_Comm * comm, - MPIR_TSP_sched_t sched); + int coll_group, MPIR_TSP_sched_t sched); #endif /* IALLREDUCE_TSP_RECURSIVE_EXCHANGE_COMMON_H_INCLUDED */ diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c index fb136764b87..5906a710024 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c @@ -12,7 +12,7 @@ * explained here: http://andrew.gibiansky.com/ */ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, src, dst; @@ -30,8 +30,7 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI MPIR_CHKLMEM_DECL(4); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); @@ -82,14 +81,14 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI send_rank = (nranks + rank - 1 - i) % nranks; /* get a new tag to prevent out of order messages */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); nvtcs = (i == 0) ? 0 : 1; vtcs = (i == 0) ? 0 : reduce_id[(i - 1) % 2]; mpi_errno = - MPIR_TSP_sched_irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, sched, nvtcs, - &vtcs, &recv_id); + MPIR_TSP_sched_irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, coll_group, + sched, nvtcs, &vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); mpi_errno = @@ -101,7 +100,8 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI mpi_errno = MPIR_TSP_sched_isend((char *) recvbuf + displs[send_rank] * extent, cnts[send_rank], - datatype, dst, tag, comm, sched, nvtcs, &vtcs, &vtx_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, &vtcs, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); } MPIR_CHKLMEM_MALLOC(reduce_id, int *, 2 * sizeof(int), mpi_errno, "reduce_id", MPL_MEM_COLL); @@ -111,7 +111,7 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI /* Phase 3: Allgatherv ring, so everyone has the reduced data */ MPIR_TSP_Iallgatherv_sched_intra_ring(MPI_IN_PLACE, -1, MPI_DATATYPE_NULL, recvbuf, cnts, - displs, datatype, comm, sched); + displs, datatype, comm, coll_group, sched); MPIR_CHKLMEM_FREEALL(); diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c index 64c84cc6af1..a0ea3ddd7dd 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c @@ -10,8 +10,9 @@ /* Routine to schedule a pipelined tree based allreduce */ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, int tree_type, int k, int chunk_size, - int buffer_per_child, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int tree_type, int k, + int chunk_size, int buffer_per_child, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, j, t; @@ -37,8 +38,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -117,7 +117,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < num_children; i++) { @@ -138,8 +138,9 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI nvtcs = 1; } - mpi_errno = MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, - sched, nvtcs, vtcs, &recv_id[i]); + mpi_errno = + MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); /* Setup dependencies for reduction. Reduction depends on the corresponding recv to complete */ @@ -186,7 +187,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI if (rank != root) { mpi_errno = MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, my_tree.parent, tag, comm, - sched, nvtcs, vtcs, &vtx_id); + coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -200,7 +201,8 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI if (my_tree.parent != -1) { mpi_errno = MPIR_TSP_sched_irecv(reduce_address, msgsize, datatype, - my_tree.parent, tag, comm, sched, 1, &sink_id, &bcast_recv_id); + my_tree.parent, tag, comm, coll_group, sched, 1, &sink_id, + &bcast_recv_id); MPIR_ERR_CHECK(mpi_errno); } @@ -210,7 +212,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI vtcs[0] = bcast_recv_id; mpi_errno = MPIR_TSP_sched_imcast(reduce_address, msgsize, datatype, ut_int_array(my_tree.children), num_children, tag, - comm, sched, nvtcs, vtcs, &vtx_id); + comm, coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c index b6c46c5b54a..271e0721b11 100644 --- a/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c @@ -19,7 +19,8 @@ int MPIR_Ialltoall_inter_sched_pairwise_exchange(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int local_size, remote_size, max_size, i; @@ -53,9 +54,9 @@ int MPIR_Ialltoall_inter_sched_pairwise_exchange(const void *sendbuf, MPI_Aint s sendaddr = (char *) sendbuf + dst * sendcount * sendtype_extent; } - mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c index 154176f55ff..25584afd8aa 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c @@ -20,7 +20,8 @@ */ int MPIR_Ialltoall_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; @@ -37,8 +38,7 @@ int MPIR_Ialltoall_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPIR_Assert(sendbuf != MPI_IN_PLACE); /* we do not handle in-place */ #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent); MPIR_Datatype_get_size_macro(recvtype, recvtype_sz); @@ -107,9 +107,9 @@ int MPIR_Ialltoall_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPIR_SCHED_BARRIER(s); /* now send and recv in parallel */ - mpi_errno = MPIR_Sched_send(tmp_buf, newtype_size, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_buf, newtype_size, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvbuf, 1, newtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, 1, newtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c index a634d337005..dbeab81ec5d 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c @@ -19,7 +19,8 @@ * scenario. */ int MPIR_Ialltoall_intra_sched_inplace(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; void *tmp_buf = NULL; @@ -33,8 +34,8 @@ int MPIR_Ialltoall_intra_sched_inplace(const void *sendbuf, MPI_Aint sendcount, MPIR_Assert(sendbuf == MPI_IN_PLACE); #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Datatype_get_size_macro(recvtype, recvtype_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); nbytes = recvtype_size * recvcount; @@ -60,10 +61,11 @@ int MPIR_Ialltoall_intra_sched_inplace(const void *sendbuf, MPI_Aint sendcount, MPIR_SCHED_BARRIER(s); /* now simultaneously send from tmp_buf and recv to recvbuf */ - mpi_errno = MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, peer, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) recvbuf + peer * recvcount * recvtype_extent), - recvcount, recvtype, peer, comm_ptr, s); + recvcount, recvtype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c index 1f953411065..54580b0588a 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c @@ -24,7 +24,8 @@ */ int MPIR_Ialltoall_intra_sched_pairwise(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; @@ -36,8 +37,7 @@ int MPIR_Ialltoall_intra_sched_pairwise(const void *sendbuf, MPI_Aint sendcount, MPIR_Assert(sendbuf != MPI_IN_PLACE); /* we do not handle in-place */ #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); @@ -62,10 +62,10 @@ int MPIR_Ialltoall_intra_sched_pairwise(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = MPIR_Sched_send(((char *) sendbuf + dst * sendcount * sendtype_extent), - sendcount, sendtype, dst, comm_ptr, s); + sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) recvbuf + src * recvcount * recvtype_extent), - recvcount, recvtype, src, comm_ptr, s); + recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c index f8e03a093bd..23c909a4bcc 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c @@ -16,7 +16,8 @@ int MPIR_Ialltoall_intra_sched_permuted_sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; @@ -28,8 +29,7 @@ int MPIR_Ialltoall_intra_sched_permuted_sendrecv(const void *sendbuf, MPI_Aint s MPIR_Assert(sendbuf != MPI_IN_PLACE); /* we do not handle in-place */ #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); @@ -44,14 +44,14 @@ int MPIR_Ialltoall_intra_sched_permuted_sendrecv(const void *sendbuf, MPI_Aint s for (i = 0; i < ss; i++) { dst = (rank + i + ii) % comm_size; mpi_errno = MPIR_Sched_recv(((char *) recvbuf + dst * recvcount * recvtype_extent), - recvcount, recvtype, dst, comm_ptr, s); + recvcount, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < ss; i++) { dst = (rank - i - ii + comm_size) % comm_size; mpi_errno = MPIR_Sched_send(((char *) sendbuf + dst * sendcount * sendtype_extent), - sendcount, sendtype, dst, comm_ptr, s); + sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c index 4d19c6fdf08..aabd475f413 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c @@ -117,7 +117,7 @@ brucks_sched_pup(int pack, void *rbuf, void *pupbuf, MPI_Datatype rtype, MPI_Ain int MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int k, int buffer_per_phase, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -146,7 +146,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_CHKLMEM_MALLOC(pack_invtcs, int *, sizeof(int) * k, mpi_errno, "pack_invtcs", @@ -159,8 +159,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, is_inplace = (sendbuf == MPI_IN_PLACE); - rank = MPIR_Comm_rank(comm); - size = MPIR_Comm_size(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); max = size - 1; @@ -287,7 +286,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_TSP_sched_isend(tmp_sbuf[i][j - 1], packsize, MPI_BYTE, dst, tag, - comm, sched, 1, &packids[j - 1], &sendids[j - 1]); + comm, coll_group, sched, 1, &packids[j - 1], &sendids[j - 1]); MPIR_ERR_CHECK(mpi_errno); if (i != 0 && buffer_per_phase == 0) { /* this dependency holds only when we don't have dedicated recv buffer per phase */ @@ -296,7 +295,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = MPIR_TSP_sched_irecv(tmp_rbuf[i][j - 1], packsize, MPI_BYTE, - src, tag, comm, sched, recv_ninvtcs, recv_invtcs, + src, tag, comm, coll_group, sched, recv_ninvtcs, recv_invtcs, &recvids[j - 1]); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c index 7d79dab6394..29e34d080d6 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c @@ -35,7 +35,7 @@ copy (buf1)<--recv (buf1) send (buf2) / /* Routine to schedule a ring based allgather */ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -46,8 +46,9 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, void *buf1, *buf2, *data_buf, *sbuf, *rbuf; int tag, vtx_id; - int size = MPIR_Comm_size(comm); - int rank = MPIR_Comm_rank(comm); + int size, rank; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); MPI_Aint recvtype_lb, recvtype_extent; @@ -116,7 +117,7 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, for (i = 0; i < size - 1; i++) { /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); int vtcs[3], nvtcs; @@ -131,8 +132,8 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = - MPIR_TSP_sched_isend((char *) sbuf, size * recvcount, recvtype, dst, tag, comm, sched, - nvtcs, vtcs, &send_id[i % 3]); + MPIR_TSP_sched_isend((char *) sbuf, size * recvcount, recvtype, dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); /* schedule recv */ if (i == 0) @@ -149,8 +150,8 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = - MPIR_TSP_sched_irecv((char *) rbuf, size * recvcount, recvtype, src, tag, comm, sched, - nvtcs, vtcs, &recv_id[i % 3]); + MPIR_TSP_sched_irecv((char *) rbuf, size * recvcount, recvtype, src, tag, comm, + coll_group, sched, nvtcs, vtcs, &recv_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); /* destination offset of the copy */ diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c index bf6f6406eb4..506a5546d5c 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c @@ -37,8 +37,8 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, int batch_size, int bblock, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int batch_size, + int bblock, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int src, dst; @@ -58,11 +58,10 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendc /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); is_inplace = (sendbuf == MPI_IN_PLACE); /* vtcs is twice the batch size to store both send and recv ids */ @@ -110,13 +109,15 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendc src = (rank + i) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + src * recvcount * recvtype_extent, - recvcount, recvtype, src, tag, comm, sched, 0, NULL, &recv_id[i]); + recvcount, recvtype, src, tag, comm, coll_group, sched, 0, NULL, + &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) data_buf + dst * sendcount * sendtype_extent, - sendcount, sendtype, dst, tag, comm, sched, 0, NULL, &send_id[i]); + sendcount, sendtype, dst, tag, comm, coll_group, sched, 0, NULL, + &send_id[i]); MPIR_ERR_CHECK(mpi_errno); } @@ -137,15 +138,15 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendc src = (rank + i + j) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + src * recvcount * recvtype_extent, - recvcount, recvtype, src, tag, comm, sched, 1, &invtcs, - &recv_id[(i + j) % bblock]); + recvcount, recvtype, src, tag, comm, coll_group, sched, 1, + &invtcs, &recv_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i - j + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) data_buf + dst * sendcount * sendtype_extent, - sendcount, sendtype, dst, tag, comm, sched, 1, &invtcs, - &send_id[(i + j) % bblock]); + sendcount, sendtype, dst, tag, comm, coll_group, sched, 1, + &invtcs, &send_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c index b3436310334..6c933b56841 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c @@ -9,7 +9,8 @@ int MPIR_Ialltoallv_inter_sched_pairwise_exchange(const void *sendbuf, const MPI const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { /* Intercommunicator alltoallv. We use a pairwise exchange algorithm * similar to the one used in intracommunicator alltoallv. Since the @@ -66,9 +67,9 @@ int MPIR_Ialltoallv_inter_sched_pairwise_exchange(const void *sendbuf, const MPI if (recvcount * recvtype_size == 0) src = MPI_PROC_NULL; - mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c index 0c5b926111c..5cdbf428f09 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c @@ -9,7 +9,7 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size; @@ -22,8 +22,7 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Assert(sendbuf != MPI_IN_PLACE); #endif /* HAVE_ERROR_CHECKING */ - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent and size of recvtype, don't look at sendtype for MPI_IN_PLACE */ MPIR_Datatype_get_extent_macro(recvtype, recv_extent); @@ -46,7 +45,8 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send dst = (rank + i + ii) % comm_size; if (recvcounts[dst] && recvtype_size) { mpi_errno = MPIR_Sched_recv((char *) recvbuf + rdispls[dst] * recv_extent, - recvcounts[dst], recvtype, dst, comm_ptr, s); + recvcounts[dst], recvtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } @@ -55,7 +55,8 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send dst = (rank - i - ii + comm_size) % comm_size; if (sendcounts[dst] && sendtype_size) { mpi_errno = MPIR_Sched_send((char *) sendbuf + sdispls[dst] * send_extent, - sendcounts[dst], sendtype, dst, comm_ptr, s); + sendcounts[dst], sendtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c index 8f039334481..3b8444bf5be 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c @@ -9,7 +9,7 @@ int MPIR_Ialltoallv_intra_sched_inplace(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { void *tmp_buf = NULL; int mpi_errno = MPI_SUCCESS; @@ -18,8 +18,7 @@ int MPIR_Ialltoallv_intra_sched_inplace(const void *sendbuf, const MPI_Aint send MPI_Aint recvtype_extent, recvtype_sz; int dst, rank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent and size of recvtype, don't look at sendtype for MPI_IN_PLACE */ MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); @@ -58,10 +57,11 @@ int MPIR_Ialltoallv_intra_sched_inplace(const void *sendbuf, const MPI_Aint send dst = i; mpi_errno = MPIR_Sched_send(((char *) recvbuf + rdispls[dst] * recvtype_extent), - recvcounts[dst], recvtype, dst, comm_ptr, s); + recvcounts[dst], recvtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(tmp_buf, recvcounts[dst] * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c index 67ab8288b01..1bb6902f2ba 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c @@ -11,7 +11,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, int bblock, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int bblock, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; size_t recv_extent, send_extent, sendtype_size, recvtype_size; @@ -27,11 +28,10 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); @@ -55,8 +55,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint dst = (rank + j + i) % nranks; if (recvcounts[dst] && recvtype_size) { mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[dst] * recv_extent, - recvcounts[dst], recvtype, dst, tag, comm, sched, - 0, NULL, &vtx_id); + recvcounts[dst], recvtype, dst, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -65,8 +65,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint dst = (rank - j - i + nranks) % nranks; if (sendcounts[dst] && sendtype_size) { mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * send_extent, - sendcounts[dst], sendtype, dst, tag, comm, sched, - 0, NULL, &vtx_id); + sendcounts[dst], sendtype, dst, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c index 25566a0041b..039e1e4c693 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c @@ -11,7 +11,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_inplace(const void *sendbuf, const MPI_Aint const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; size_t recv_extent; @@ -26,11 +27,10 @@ int MPIR_TSP_Ialltoallv_sched_intra_inplace(const void *sendbuf, const MPI_Aint /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); @@ -51,12 +51,12 @@ int MPIR_TSP_Ialltoallv_sched_intra_inplace(const void *sendbuf, const MPI_Aint vtcs[0] = dtcopy_id; mpi_errno = MPIR_TSP_sched_isend((char *) recvbuf + rdispls[dst] * recv_extent, - recvcounts[dst], recvtype, dst, tag, comm, + recvcounts[dst], recvtype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, &send_id); MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_TSP_sched_irecv(tmp_buf, recvcounts[dst], recvtype, dst, tag, comm, + MPIR_TSP_sched_irecv(tmp_buf, recvcounts[dst], recvtype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c index 4266aec8f32..8bf0b6cedfa 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c @@ -11,8 +11,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, int batch_size, int bblock, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int batch_size, + int bblock, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int src, dst; @@ -25,8 +25,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain MPIR_Assert(!(sendbuf == MPI_IN_PLACE)); - int size = MPIR_Comm_size(comm); - int rank = MPIR_Comm_rank(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); MPI_Aint recvtype_lb, recvtype_extent; MPI_Aint sendtype_lb, sendtype_extent; @@ -55,7 +55,7 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain MPIR_Type_get_true_extent_impl(sendtype, &sendtype_lb, &sendtype_true_extent); sendtype_extent = MPL_MAX(sendtype_extent, sendtype_true_extent); - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* First, post bblock number of sends/recvs */ @@ -63,15 +63,15 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain src = (rank + i) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent, - recvcounts[src], recvtype, src, tag, comm, sched, 0, NULL, - &recv_id[i]); + recvcounts[src], recvtype, src, tag, comm, coll_group, sched, 0, + NULL, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent, - sendcounts[dst], sendtype, dst, tag, comm, sched, 0, NULL, - &send_id[i]); + sendcounts[dst], sendtype, dst, tag, comm, coll_group, sched, 0, + NULL, &send_id[i]); MPIR_ERR_CHECK(mpi_errno); } @@ -93,15 +93,15 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain src = (rank + i + j) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent, - recvcounts[src], recvtype, src, tag, comm, sched, 1, &invtcs, - &recv_id[(i + j) % bblock]); + recvcounts[src], recvtype, src, tag, comm, coll_group, sched, + 1, &invtcs, &recv_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i - j + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent, - sendcounts[dst], sendtype, dst, tag, comm, sched, 1, &invtcs, - &send_id[(i + j) % bblock]); + sendcounts[dst], sendtype, dst, tag, comm, coll_group, sched, + 1, &invtcs, &send_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c index 163aaf79f34..9ff9cfb9868 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c @@ -11,7 +11,8 @@ int MPIR_Ialltoallw_inter_sched_pairwise_exchange(const void *sendbuf, const MPI const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { /* Intercommunicator alltoallw. We use a pairwise exchange algorithm similar to the one used in intracommunicator alltoallw. Since the local and @@ -59,10 +60,10 @@ int MPIR_Ialltoallw_inter_sched_pairwise_exchange(const void *sendbuf, const MPI sendtype = sendtypes[dst]; } - mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c index 1523dd78f0d..fe99d505cc2 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c @@ -23,7 +23,7 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, i; @@ -34,8 +34,7 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Assert(sendbuf != MPI_IN_PLACE); #endif /* HAVE_ERROR_CHECKING */ - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); bblock = MPIR_CVAR_ALLTOALL_THROTTLE; if (bblock == 0) @@ -53,7 +52,8 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Datatype_get_size_macro(recvtypes[dst], type_size); if (type_size) { mpi_errno = MPIR_Sched_recv((char *) recvbuf + rdispls[dst], - recvcounts[dst], recvtypes[dst], dst, comm_ptr, s); + recvcounts[dst], recvtypes[dst], dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -66,7 +66,8 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Datatype_get_size_macro(sendtypes[dst], type_size); if (type_size) { mpi_errno = MPIR_Sched_send((char *) sendbuf + sdispls[dst], - sendcounts[dst], sendtypes[dst], dst, comm_ptr, s); + sendcounts[dst], sendtypes[dst], dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c index 23bdca07055..7b48e76247e 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c @@ -21,7 +21,7 @@ int MPIR_Ialltoallw_intra_sched_inplace(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, i, j; @@ -29,8 +29,7 @@ int MPIR_Ialltoallw_intra_sched_inplace(const void *sendbuf, const MPI_Aint send MPI_Aint recvtype_sz; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* The regular MPI_Alltoallw handles MPI_IN_PLACE using pairwise * sendrecv_replace calls. We don't have a sendrecv_replace, so just @@ -67,10 +66,11 @@ int MPIR_Ialltoallw_intra_sched_inplace(const void *sendbuf, const MPI_Aint send MPIR_Datatype_get_size_macro(recvtypes[dst], recvtype_sz); mpi_errno = MPIR_Sched_send(((char *) recvbuf + rdispls[dst]), - recvcounts[dst], recvtypes[dst], dst, comm_ptr, s); + recvcounts[dst], recvtypes[dst], dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(tmp_buf, recvcounts[dst] * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c index 21c1667b1d1..922a7cd4a00 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c @@ -12,7 +12,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], MPIR_Comm * comm, - int bblock, MPIR_TSP_sched_t sched) + int coll_group, int bblock, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int tag, vtx_id; @@ -25,15 +25,14 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint MPIR_Assert(sendbuf != MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); if (bblock == 0) bblock = nranks; /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* post only bblock isends/irecvs at a time as suggested by Tony Ladd */ @@ -48,7 +47,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint if (recvtype_size) { mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[dst], recvcounts[dst], recvtypes[dst], dst, tag, - comm, sched, 0, NULL, &vtx_id); + comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -61,7 +60,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint if (sendtype_size) { mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst], sendcounts[dst], sendtypes[dst], dst, tag, - comm, sched, 0, NULL, &vtx_id); + comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c index fdac5344e56..8bebe075c03 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c @@ -12,7 +12,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_inplace(const void *sendbuf, const MPI_Aint const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int tag; @@ -27,12 +27,11 @@ int MPIR_TSP_Ialltoallw_sched_intra_inplace(const void *sendbuf, const MPI_Aint MPIR_Assert(sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* FIXME: Here we allocate tmp_buf using extent and send/recv with datatype directly, @@ -62,12 +61,12 @@ int MPIR_TSP_Ialltoallw_sched_intra_inplace(const void *sendbuf, const MPI_Aint adj_tmp_buf = (void *) ((char *) tmp_buf - true_lb); mpi_errno = MPIR_TSP_sched_isend((char *) recvbuf + rdispls[dst], - recvcounts[dst], recvtypes[dst], dst, tag, comm, sched, - nvtcs, vtcs, &send_id); + recvcounts[dst], recvtypes[dst], dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_TSP_sched_irecv(adj_tmp_buf, recvcounts[dst], recvtypes[dst], dst, tag, comm, - sched, nvtcs, vtcs, &recv_id); + coll_group, sched, nvtcs, vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs = 2; diff --git a/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c b/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c index e98bff426f0..c74d09e7b0c 100644 --- a/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c +++ b/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c @@ -5,7 +5,7 @@ #include "mpiimpl.h" -int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, root; @@ -23,7 +23,7 @@ int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, MPIR_Sched_t s) /* do a barrier on the local intracommunicator */ if (comm_ptr->local_size != 1) { - mpi_errno = MPIR_Ibarrier_intra_sched_auto(comm_ptr->local_comm, s); + mpi_errno = MPIR_Ibarrier_intra_sched_auto(comm_ptr->local_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -40,26 +40,26 @@ int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, MPIR_Sched_t s) * left group */ if (comm_ptr->is_low_group) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); /* receive bcast from right */ root = 0; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* receive bcast from left */ root = 0; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); /* bcast to left */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c index 243b54fc9f1..8fba9c192d5 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c @@ -18,25 +18,25 @@ * process i sends to process (i + 2^k) % p and receives from process * (i - 2^k + p) % p. */ -int MPIR_Ibarrier_intra_sched_recursive_doubling(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_intra_sched_recursive_doubling(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int size, rank, src, dst, mask; MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, size); mask = 0x1; while (mask < size) { dst = (rank + mask) % size; src = (rank - mask + size) % size; - mpi_errno = MPIR_Sched_send(NULL, 0, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(NULL, 0, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(NULL, 0, MPI_BYTE, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(NULL, 0, MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c index bb59afc82be..323c41eb3cb 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c @@ -6,7 +6,8 @@ #include "mpiimpl.h" /* Routine to schedule a disdem based barrier with radix k */ -int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_TSP_sched_t sched) +int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; @@ -19,10 +20,9 @@ int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_ MPIR_FUNC_ENTER; - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -50,15 +50,15 @@ int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_ MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "dissem barrier - scheduling recv from %d\n", from)); mpi_errno = - MPIR_TSP_sched_irecv(NULL, 0, MPI_BYTE, from, tag, comm, sched, 0, NULL, + MPIR_TSP_sched_irecv(NULL, 0, MPI_BYTE, from, tag, comm, coll_group, sched, 0, NULL, &recv_ids[i * (k - 1) + j - 1]); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "dissem barrier - scheduling send to %d\n", to)); mpi_errno = - MPIR_TSP_sched_isend(NULL, 0, MPI_BYTE, to, tag, comm, sched, i * (k - 1), recv_ids, - &vtx_id); + MPIR_TSP_sched_isend(NULL, 0, MPI_BYTE, to, tag, comm, coll_group, sched, + i * (k - 1), recv_ids, &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c index 31e2f1de567..29b814e9e56 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c @@ -6,7 +6,8 @@ #include "mpiimpl.h" /* Routine to schedule a disdem based barrier with radix k */ -int MPIR_TSP_Ibarrier_sched_intra_recexch(MPIR_Comm * comm, int k, MPIR_TSP_sched_t sched) +int MPIR_TSP_Ibarrier_sched_intra_recexch(MPIR_Comm * comm, int coll_group, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; void *recvbuf = NULL; @@ -14,7 +15,7 @@ int MPIR_TSP_Ibarrier_sched_intra_recexch(MPIR_Comm * comm, int k, MPIR_TSP_sche mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, recvbuf, 0, MPI_BYTE, MPI_SUM, - comm, + comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, k, sched); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c b/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c index d98497b9ff0..d9a685a2c35 100644 --- a/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c +++ b/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c @@ -6,13 +6,14 @@ #include "mpiimpl.h" /* sched version of CVAR and json based collective selection. Meant only for gentran scheduler */ -int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sched) +int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IBARRIER, .comm_ptr = comm, + .coll_group = coll_group, }; MPII_Csel_container_s *cnt; void *recvbuf = NULL; @@ -23,14 +24,15 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc case MPIR_CVAR_IBARRIER_INTRA_ALGORITHM_tsp_recexch: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, recvbuf, 0, MPI_BYTE, MPI_SUM, - comm, + comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, MPIR_CVAR_IBARRIER_RECEXCH_KVAL, sched); break; case MPIR_CVAR_IBARRIER_INTRA_ALGORITHM_tsp_k_dissemination: mpi_errno = - MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, MPIR_CVAR_IBARRIER_DISSEM_KVAL, + MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, coll_group, + MPIR_CVAR_IBARRIER_DISSEM_KVAL, sched); break; @@ -42,7 +44,7 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibarrier_intra_tsp_recexch: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, recvbuf, 0, MPI_BYTE, - MPI_SUM, comm, + MPI_SUM, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, cnt->u.ibarrier.intra_tsp_recexch.k, sched); @@ -50,7 +52,7 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibarrier_intra_tsp_k_dissemination: mpi_errno = - MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, + MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, coll_group, cnt->u. ibarrier.intra_tsp_k_dissemination. k, sched); @@ -68,7 +70,8 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc fallback: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, NULL, 0, - MPI_BYTE, MPI_SUM, comm, 0, 2, sched); + MPI_BYTE, MPI_SUM, comm, coll_group, 0, 2, + sched); fn_exit: return mpi_errno; diff --git a/src/mpi/coll/ibcast/ibcast.h b/src/mpi/coll/ibcast/ibcast.h index cc21d5645ff..e13e3f3f1c2 100644 --- a/src/mpi/coll/ibcast/ibcast.h +++ b/src/mpi/coll/ibcast/ibcast.h @@ -20,7 +20,7 @@ int MPII_Ibcast_sched_test_length(MPIR_Comm * comm, int tag, void *state); int MPII_Ibcast_sched_test_curr_length(MPIR_Comm * comm, int tag, void *state); int MPII_Ibcast_sched_init_length(MPIR_Comm * comm, int tag, void *state); int MPII_Ibcast_sched_add_length(MPIR_Comm * comm, int tag, void *state); -int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, MPI_Aint nbytes, - MPIR_Sched_t s); +int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, int coll_group, + MPI_Aint nbytes, MPIR_Sched_t s); #endif /* IBCAST_H_INCLUDED */ diff --git a/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c b/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c index d7f7108c970..2af50df7ded 100644 --- a/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c +++ b/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c @@ -7,7 +7,7 @@ #include "ibcast.h" int MPIR_Ibcast_inter_sched_flat(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; @@ -21,12 +21,12 @@ int MPIR_Ibcast_inter_sched_flat(void *buffer, MPI_Aint count, MPI_Datatype data mpi_errno = MPI_SUCCESS; } else if (root == MPI_ROOT) { /* root sends to rank 0 on remote group and returns */ - mpi_errno = MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. rank 0 on remote group receives from root */ if (comm_ptr->rank == 0) { - mpi_errno = MPIR_Sched_recv(buffer, count, datatype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(buffer, count, datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -39,7 +39,8 @@ int MPIR_Ibcast_inter_sched_flat(void *buffer, MPI_Aint count, MPI_Datatype data /* now do the usual broadcast on this intracommunicator * with rank 0 as root. */ mpi_errno = - MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, root, comm_ptr->local_comm, s); + MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, root, comm_ptr->local_comm, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c index 0a6a6d3d038..0b4d309bf27 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c @@ -13,7 +13,7 @@ * to build up a larger hierarchical broadcast from multiple invocations of this * function. */ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int mask; @@ -25,8 +25,7 @@ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype struct MPII_Ibcast_state *ibcast_state; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_is_contig(datatype, &is_contig); MPIR_Datatype_get_size_macro(datatype, type_size); @@ -92,10 +91,10 @@ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype src += comm_size; if (!is_contig) mpi_errno = MPIR_Sched_recv_status(tmp_buf, nbytes, MPI_BYTE, src, - comm_ptr, &ibcast_state->status, s); + comm_ptr, coll_group, &ibcast_state->status, s); else mpi_errno = MPIR_Sched_recv_status(buffer, count, datatype, src, - comm_ptr, &ibcast_state->status, s); + comm_ptr, coll_group, &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -125,9 +124,10 @@ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype if (dst >= comm_size) dst -= comm_size; if (!is_contig) - mpi_errno = MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, dst, comm_ptr, coll_group, s); else - mpi_errno = MPIR_Sched_send(buffer, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(buffer, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* NOTE: This is departure from MPIR_Bcast_intra_binomial. A true analog diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c index 0307947da03..3ece3d2a562 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c @@ -49,7 +49,7 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, dst; @@ -61,8 +61,8 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M void *tmp_buf; struct MPII_Ibcast_state *ibcast_state; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; #ifdef HAVE_ERROR_CHECKING @@ -110,7 +110,7 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M } - mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, nbytes, s); + mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, coll_group, nbytes, s); MPIR_ERR_CHECK(mpi_errno); MPI_Aint scatter_size, curr_size, incoming_count; @@ -162,12 +162,13 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M incoming_count = 0; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + send_offset), - curr_size, MPI_BYTE, dst, comm_ptr, s); + curr_size, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier */ mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + recv_offset), incoming_count, - MPI_BYTE, dst, comm_ptr, &ibcast_state->status, s); + MPI_BYTE, dst, comm_ptr, coll_group, + &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_add_length, ibcast_state, s); @@ -228,7 +229,8 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M * receive. that's the amount of data to be * sent now. */ mpi_errno = MPIR_Sched_send(((char *) tmp_buf + offset), - incoming_count, MPI_BYTE, dst, comm_ptr, s); + incoming_count, MPI_BYTE, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -248,7 +250,7 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M * whose data we don't have */ mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + offset), incoming_count, MPI_BYTE, dst, comm_ptr, - &ibcast_state->status, s); + coll_group, &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_add_length, ibcast_state, s); diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c index 0418f6e28b8..3dc2d9bb6bf 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c @@ -26,7 +26,8 @@ */ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank; @@ -37,8 +38,7 @@ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, struct MPII_Ibcast_state *ibcast_state; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (HANDLE_IS_BUILTIN(datatype)) is_contig = 1; @@ -78,7 +78,7 @@ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, } } - mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, nbytes, s); + mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, coll_group, nbytes, s); MPIR_ERR_CHECK(mpi_errno); MPI_Aint scatter_size, curr_size; @@ -119,11 +119,11 @@ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, right_disp = rel_j * scatter_size; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + right_disp), - right_count, MPI_BYTE, right, comm_ptr, s); + right_count, MPI_BYTE, right, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + left_disp), - left_count, MPI_BYTE, left, comm_ptr, + left_count, MPI_BYTE, left, comm_ptr, coll_group, &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c index d734e7d8c2d..70ae42fb5ab 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c @@ -28,15 +28,19 @@ static int sched_test_length(MPIR_Comm * comm, int tag, void *state) * currently make any decision about which particular algorithm to use for any * subcommunicator. */ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint type_size; struct MPII_Ibcast_state *ibcast_state; #ifdef HAVE_ERROR_CHECKING - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); #endif + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int local_root = MPIR_Get_intranode_rank(comm_ptr, root); + ibcast_state = MPIR_Sched_alloc_state(s, sizeof(struct MPII_Ibcast_state)); MPIR_ERR_CHKANDJUMP(!ibcast_state, mpi_errno, MPI_ERR_OTHER, "**nomem"); @@ -46,15 +50,15 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat /* TODO insert packing here */ /* send to intranode-rank 0 on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) *//* and is on our node (!-1) */ - if (root == comm_ptr->rank) { - mpi_errno = MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr->node_comm, s); - MPIR_ERR_CHECK(mpi_errno); - } else if (0 == comm_ptr->node_comm->rank) { + if (local_size > 1 && local_root > 0) { /* is not the node root (0) *//* and is on our node (!-1) */ + if (local_rank == local_root) { mpi_errno = - MPIR_Sched_recv_status(buffer, count, datatype, - MPIR_Get_intranode_rank(comm_ptr, root), comm_ptr->node_comm, - &ibcast_state->status, s); + MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr, MPIR_SUBGROUP_NODE, s); + MPIR_ERR_CHECK(mpi_errno); + } else if (local_rank == 0) { + mpi_errno = MPIR_Sched_recv_status(buffer, count, datatype, local_root, + comm_ptr, MPIR_SUBGROUP_NODE, &ibcast_state->status, + s); MPIR_ERR_CHECK(mpi_errno); #ifdef HAVE_ERROR_CHECKING MPIR_SCHED_BARRIER(s); @@ -66,19 +70,19 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat } /* perform the internode broadcast */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, - MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, s); + if (local_rank == 0) { + int inter_root = MPIR_Get_internode_rank(comm_ptr, root); + mpi_errno = MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); /* don't allow the local ops for the intranode phase to start until this has completed */ MPIR_SCHED_BARRIER(s); } /* perform the intranode broadcast on all except for the root's node */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = - MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, 0, comm_ptr->node_comm, s); + if (local_size > 1) { + mpi_errno = MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ibcast/ibcast_tsp_auto.c b/src/mpi/coll/ibcast/ibcast_tsp_auto.c index 52c9016c083..edc9f4aad85 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_auto.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_auto.c @@ -13,30 +13,34 @@ /* Remove this function when gentran algos are in json file */ static int MPIR_Ibcast_sched_intra_tsp_flat_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; - int comm_size; + int comm_size, rank; MPI_Aint type_size, nbytes; int tree_type = MPIR_TREE_TYPE_KNOMIAL_1; int radix = 2, scatterv_k = 2, allgatherv_k = 2, block_size = 0; MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - comm_size = comm_ptr->local_size; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warning */ + MPIR_Datatype_get_size_macro(datatype, type_size); nbytes = type_size * count; /* simplistic implementation for now */ if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS)) { /* gentran tree with knomial tree type, radix 2 and pipeline block size 0 */ - mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - tree_type, radix, block_size, sched); + mpi_errno = + MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, coll_group, + tree_type, radix, block_size, sched); } else { /* gentran scatterv recexch allgather with radix 2 */ mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, root, - comm_ptr, + comm_ptr, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_recexch_doubling, scatterv_k, allgatherv_k, sched); } @@ -51,7 +55,8 @@ static int MPIR_Ibcast_sched_intra_tsp_flat_auto(void *buffer, MPI_Aint count, /* sched version of CVAR and json based collective selection. Meant only for gentran scheduler */ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -72,7 +77,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_tree: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - MPIR_Ibcast_tree_type, + coll_group, MPIR_Ibcast_tree_type, MPIR_CVAR_IBCAST_TREE_KVAL, MPIR_CVAR_IBCAST_TREE_PIPELINE_CHUNK_SIZE, sched); break; @@ -80,7 +85,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_scatterv_recexch_allgatherv: mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, - root, comm_ptr, + root, comm_ptr, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_recexch_doubling, MPIR_CVAR_IBCAST_SCATTERV_KVAL, MPIR_CVAR_IBCAST_ALLGATHERV_RECEXCH_KVAL, @@ -90,13 +95,14 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_scatterv_ring_allgatherv: mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_ring_allgatherv(buffer, count, datatype, - root, comm_ptr, 1, sched); + root, comm_ptr, coll_group, 1, + sched); break; case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_ring: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - MPIR_TREE_TYPE_KARY, 1, + coll_group, MPIR_TREE_TYPE_KARY, 1, MPIR_CVAR_IBCAST_RING_CHUNK_SIZE, sched); break; @@ -108,6 +114,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibcast_intra_tsp_tree: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, + coll_group, cnt->u.ibcast.intra_tsp_tree.tree_type, cnt->u.ibcast.intra_tsp_tree.k, cnt->u.ibcast.intra_tsp_tree.chunk_size, @@ -116,7 +123,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibcast_intra_tsp_scatterv_recexch_allgatherv: mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, - root, comm_ptr, + root, comm_ptr, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_recexch_doubling, cnt->u. ibcast.intra_tsp_scatterv_recexch_allgatherv.scatterv_k, @@ -129,13 +136,14 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_ring_allgatherv(buffer, count, datatype, root, - comm_ptr, 1, sched); + comm_ptr, coll_group, + 1, sched); break; case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibcast_intra_tsp_ring: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - MPIR_TREE_TYPE_KARY, 1, + coll_group, MPIR_TREE_TYPE_KARY, 1, cnt->u.ibcast.intra_tsp_tree.chunk_size, sched); break; @@ -150,7 +158,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat fallback: mpi_errno = MPIR_Ibcast_sched_intra_tsp_flat_auto(buffer, count, datatype, root, - comm_ptr, sched); + comm_ptr, coll_group, sched); fn_exit: return mpi_errno; diff --git a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c index f7601720547..30cff266f12 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c @@ -10,9 +10,9 @@ /* Routine to schedule a scatter followed by recursive exchange based broadcast */ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, int allgatherv_algo, - int scatterv_k, int allgatherv_k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + int allgatherv_algo, int scatterv_k, + int allgatherv_k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; size_t extent, type_size; @@ -32,19 +32,18 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_FUNC_ENTER; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ + MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "Scheduling scatter followed by recursive exchange allgather based broadcast on %d ranks, root=%d\n", - MPIR_Comm_size(comm), root)); - - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); - lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ + size, root)); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -150,14 +149,14 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count ibcast_state->n_bytes = recv_size; mpi_errno = MPIR_TSP_sched_irecv_status((char *) tmp_buf + displs[rank], recv_size, MPI_BYTE, - my_tree.parent, tag, comm, &ibcast_state->status, sched, 0, - NULL, &recv_id); + my_tree.parent, tag, comm, coll_group, + &ibcast_state->status, sched, 0, NULL, &recv_id); MPIR_TSP_sched_cb(&MPII_Ibcast_sched_test_length, ibcast_state, sched, 1, &recv_id, &vtx_id); #else mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_buf + displs[rank], recv_size, MPI_BYTE, - my_tree.parent, tag, comm, sched, 0, NULL, &recv_id); + my_tree.parent, tag, comm, coll_group, sched, 0, NULL, &recv_id); #endif MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "rank:%d posts recv", rank)); @@ -174,8 +173,8 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count mpi_errno = MPIR_TSP_sched_isend((char *) tmp_buf + displs[child], child_subtree_size[i], MPI_BYTE, - child, tag, comm, sched, num_send_dependencies, &recv_id, - &vtx_id); + child, tag, comm, coll_group, sched, num_send_dependencies, + &recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -188,13 +187,13 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count /* Schedule Allgatherv ring */ mpi_errno = MPIR_TSP_Iallgatherv_sched_intra_ring(MPI_IN_PLACE, cnts[rank], MPI_BYTE, tmp_buf, - cnts, displs, MPI_BYTE, comm, sched); + cnts, displs, MPI_BYTE, comm, coll_group, sched); else /* Schedule Allgatherv recexch */ mpi_errno = MPIR_TSP_Iallgatherv_sched_intra_recexch(MPI_IN_PLACE, cnts[rank], MPI_BYTE, tmp_buf, - cnts, displs, MPI_BYTE, comm, 0, allgatherv_k, - sched); + cnts, displs, MPI_BYTE, comm, coll_group, 0, + allgatherv_k, sched); MPIR_ERR_CHECK(mpi_errno); if (!is_contig) { diff --git a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c index 25016705afa..1ea3e9f0f50 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c @@ -9,15 +9,15 @@ /* Routine to schedule a scatter followed by ring based broadcast */ int MPIR_TSP_Ibcast_sched_intra_scatterv_ring_allgatherv(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, int scatterv_k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + int scatterv_k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, root, - comm, + comm, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_ring, scatterv_k, 0, sched); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ibcast/ibcast_tsp_tree.c b/src/mpi/coll/ibcast/ibcast_tsp_tree.c index f308714e294..4ec7336c063 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_tree.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_tree.c @@ -10,8 +10,8 @@ /* Routine to schedule a pipelined tree based broadcast */ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, int tree_type, int k, int chunk_size, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int tree_type, int k, + int chunk_size, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i; @@ -29,8 +29,7 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -62,7 +61,7 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* Receive message from parent */ @@ -70,7 +69,7 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype #ifdef HAVE_ERROR_CHECKING mpi_errno = MPIR_TSP_sched_irecv_status((char *) buffer + offset * extent, msgsize, - datatype, my_tree.parent, tag, comm, + datatype, my_tree.parent, tag, comm, coll_group, &ibcast_state->status, sched, 0, NULL, &recv_id); MPIR_ERR_CHECK(mpi_errno); MPIR_TSP_sched_cb(&MPII_Ibcast_sched_test_length, ibcast_state, sched, 1, &recv_id, @@ -78,7 +77,8 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype #else mpi_errno = MPIR_TSP_sched_irecv((char *) buffer + offset * extent, msgsize, datatype, - my_tree.parent, tag, comm, sched, 0, NULL, &recv_id); + my_tree.parent, tag, comm, coll_group, sched, 0, NULL, + &recv_id); MPIR_ERR_CHECK(mpi_errno); #endif } @@ -87,8 +87,8 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype /* Multicast data to the children */ mpi_errno = MPIR_TSP_sched_imcast((char *) buffer + offset * extent, msgsize, datatype, ut_int_array(my_tree.children), num_children, tag, - comm, sched, (my_tree.parent != -1) ? 1 : 0, &recv_id, - &vtx_id); + comm, coll_group, sched, + (my_tree.parent != -1) ? 1 : 0, &recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } offset += msgsize; diff --git a/src/mpi/coll/ibcast/ibcast_utils.c b/src/mpi/coll/ibcast/ibcast_utils.c index 9bfaf925b3a..09d035ac87a 100644 --- a/src/mpi/coll/ibcast/ibcast_utils.c +++ b/src/mpi/coll/ibcast/ibcast_utils.c @@ -68,16 +68,16 @@ int MPII_Ibcast_sched_add_length(MPIR_Comm * comm, int tag, void *state) /* This is a binomial scatter operation, but it does *not* take * typical scatter arguments. At the moment this function always * scatters a buffer of nbytes starting at tmp_buf address. */ -int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, MPI_Aint nbytes, - MPIR_Sched_t s) +int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, int coll_group, + MPI_Aint nbytes, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, src, dst; int relative_rank, mask; MPI_Aint scatter_size, curr_size, recv_size, send_size; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; /* The scatter algorithm divides the buffer into nprocs pieces and @@ -110,7 +110,7 @@ int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, if (recv_size > 0) { mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + relative_rank * scatter_size), - recv_size, MPI_BYTE, src, comm_ptr, s); + recv_size, MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -135,7 +135,7 @@ int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, dst -= comm_size; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + scatter_size * (relative_rank + mask)), - send_size, MPI_BYTE, dst, comm_ptr, s); + send_size, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); curr_size -= send_size; diff --git a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c index 62148d159e4..799ad718131 100644 --- a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c @@ -50,7 +50,8 @@ */ int MPIR_Iexscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; @@ -58,8 +59,7 @@ int MPIR_Iexscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvb MPI_Aint true_extent, true_lb, extent; void *partial_scan, *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); @@ -89,10 +89,11 @@ int MPIR_Iexscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvb dst = rank ^ mask; if (dst < comm_size) { /* Send partial_scan to dst. Recv into tmp_buf */ - mpi_errno = MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/igather/igather_inter_sched_long.c b/src/mpi/coll/igather/igather_inter_sched_long.c index fade46c3594..6da72eb5823 100644 --- a/src/mpi/coll/igather/igather_inter_sched_long.c +++ b/src/mpi/coll/igather/igather_inter_sched_long.c @@ -14,7 +14,7 @@ */ int MPIR_Igather_inter_sched_long(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint remote_size; @@ -29,11 +29,11 @@ int MPIR_Igather_inter_sched_long(const void *sendbuf, MPI_Aint sendcount, MPI_D for (i = 0; i < remote_size; i++) { mpi_errno = MPIR_Sched_recv(((char *) recvbuf + recvcount * i * extent), - recvcount, recvtype, i, comm_ptr, s); + recvcount, recvtype, i, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/igather/igather_inter_sched_short.c b/src/mpi/coll/igather/igather_inter_sched_short.c index 81dc2bc2ddf..f6d9b55f412 100644 --- a/src/mpi/coll/igather/igather_inter_sched_short.c +++ b/src/mpi/coll/igather/igather_inter_sched_short.c @@ -15,7 +15,7 @@ */ int MPIR_Igather_inter_sched_short(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank; @@ -30,7 +30,8 @@ int MPIR_Igather_inter_sched_short(const void *sendbuf, MPI_Aint sendcount, MPI_ mpi_errno = MPI_SUCCESS; } else if (root == MPI_ROOT) { /* root receives data from rank 0 on remote group */ - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount * remote_size, recvtype, 0, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount * remote_size, recvtype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. Rank 0 allocates temporary buffer, does @@ -61,12 +62,12 @@ int MPIR_Igather_inter_sched_short(const void *sendbuf, MPI_Aint sendcount, MPI_ /* now do the a local gather on this intracommunicator */ mpi_errno = MPIR_Igather_intra_sched_auto(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, MPI_BYTE, 0, - newcomm_ptr, s); + newcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { mpi_errno = MPIR_Sched_send(tmp_buf, sendcount * local_size * sendtype_sz, MPI_BYTE, - root, comm_ptr, s); + root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/igather/igather_intra_sched_binomial.c b/src/mpi/coll/igather/igather_intra_sched_binomial.c index 0d15129bba7..b0a0e15c957 100644 --- a/src/mpi/coll/igather/igather_intra_sched_binomial.c +++ b/src/mpi/coll/igather/igather_intra_sched_binomial.c @@ -28,7 +28,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank; @@ -42,8 +42,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, int copy_offset = 0, copy_blks = 0; MPI_Datatype types[2], tmp_type; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); @@ -93,15 +92,14 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, if (rank == root) { if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype, - ((char *) recvbuf + extent * recvcount * rank), + ((char *) recvbuf + extent * recvcount * rank), recvcount, recvtype, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } } else if (tmp_buf_size && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)) { /* copy from sendbuf into tmp_buf */ - mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype, - tmp_buf, nbytes, MPI_BYTE, s); + mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype, tmp_buf, nbytes, MPI_BYTE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -128,14 +126,15 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, char *rp = (char *) recvbuf + (((rank + mask) % comm_size) * recvcount * extent); mpi_errno = - MPIR_Sched_recv(rp, (recvblks * recvcount), recvtype, src, comm_ptr, s); + MPIR_Sched_recv(rp, (recvblks * recvcount), recvtype, src, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { mpi_errno = MPIR_Sched_recv(tmp_buf, (recvblks * nbytes), MPI_BYTE, src, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -154,7 +153,8 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_Type_commit_impl(&tmp_type); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvbuf, 1, tmp_type, src, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, 1, tmp_type, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -177,7 +177,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, offset = (mask - 1) * nbytes; mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + offset), (recvblks * nbytes), - MPI_BYTE, src, comm_ptr, s); + MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -190,12 +190,14 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, if (!tmp_buf_size) { /* leaf nodes send directly from sendbuf */ - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { - mpi_errno = MPIR_Sched_send(tmp_buf, curr_cnt, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, curr_cnt, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -220,7 +222,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_Type_commit_impl(&tmp_type); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_send(MPI_BOTTOM, 1, tmp_type, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(MPI_BOTTOM, 1, tmp_type, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/igather/igather_tsp_tree.c b/src/mpi/coll/igather/igather_tsp_tree.c index 45b4e6517ef..bbb5dee7059 100644 --- a/src/mpi/coll/igather/igather_tsp_tree.c +++ b/src/mpi/coll/igather/igather_tsp_tree.c @@ -11,7 +11,7 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - int k, MPIR_TSP_sched_t sched) + int coll_group, int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int size, rank, lrank; @@ -32,8 +32,7 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ if (rank == root) @@ -46,7 +45,7 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); if (rank == root && is_inplace) { @@ -135,8 +134,8 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* Leaf nodes send to parent */ if (num_children == 0) { mpi_errno = - MPIR_TSP_sched_isend(tmp_buf, sendcount, sendtype, my_tree.parent, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(tmp_buf, sendcount, sendtype, my_tree.parent, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "rank:%d posts recv\n", rank)); } else { @@ -160,13 +159,14 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_buf + child_data_offset[i] * recvtype_extent, child_subtree_size[i] * recvcount, recvtype, child, tag, comm, - sched, num_dependencies, &dtcopy_id, &recv_id[i]); + coll_group, sched, num_dependencies, &dtcopy_id, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); } if (my_tree.parent != -1) { mpi_errno = MPIR_TSP_sched_isend(tmp_buf, recv_size, recvtype, my_tree.parent, - tag, comm, sched, num_children, recv_id, &vtx_id); + tag, comm, coll_group, sched, num_children, recv_id, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c index e2f6cf21dc5..9660596ca43 100644 --- a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c +++ b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c @@ -15,22 +15,24 @@ int MPIR_Igatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; int comm_size, rank; MPI_Aint extent; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; MPIR_Datatype_get_extent_macro(recvtype, extent); @@ -45,7 +47,8 @@ int MPIR_Igatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint sendcount, } } else { mpi_errno = MPIR_Sched_recv(((char *) recvbuf + displs[i] * extent), - recvcounts[i], recvtype, i, comm_ptr, s); + recvcounts[i], recvtype, i, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } @@ -53,7 +56,8 @@ int MPIR_Igatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint sendcount, } else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (sendcount) { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/igatherv/igatherv_tsp_linear.c b/src/mpi/coll/igatherv/igatherv_tsp_linear.c index 932efd3414d..f3a941a22da 100644 --- a/src/mpi/coll/igatherv/igatherv_tsp_linear.c +++ b/src/mpi/coll/igatherv/igatherv_tsp_linear.c @@ -21,7 +21,7 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, vtx_id; @@ -30,19 +30,19 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou int tag; MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; - MPIR_Datatype_get_extent_macro(recvtype, extent); for (i = 0; i < comm_size; i++) { @@ -58,7 +58,7 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou } else { mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + displs[i] * extent), recvcounts[i], recvtype, i, tag, comm_ptr, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); } MPIR_ERR_CHECK(mpi_errno); } @@ -67,7 +67,7 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (sendcount) { mpi_errno = MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, root, tag, comm_ptr, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/include/coll_impl.h b/src/mpi/coll/include/coll_impl.h index b5b576b1b33..cef96e59aeb 100644 --- a/src/mpi/coll/include/coll_impl.h +++ b/src/mpi/coll/include/coll_impl.h @@ -75,7 +75,7 @@ int MPII_Coll_finalize(void); mpi_errno = MPIR_Sched_create(&s, sched_kind); \ MPIR_ERR_CHECK(mpi_errno); \ int tag = -1; \ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); \ + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); \ MPIR_ERR_CHECK(mpi_errno); \ MPIR_Sched_set_tag(s, tag); \ *sched_type_p = MPIR_SCHED_NORMAL; \ diff --git a/src/mpi/coll/include/coll_types.h b/src/mpi/coll/include/coll_types.h index a32ce6c551d..22fbad4716b 100644 --- a/src/mpi/coll/include/coll_types.h +++ b/src/mpi/coll/include/coll_types.h @@ -13,16 +13,6 @@ #define MPIR_COLL_FLAG_REDUCE_L 1 #define MPIR_COLL_FLAG_REDUCE_R 0 -/* enumerator for different tree types */ -typedef enum MPIR_Tree_type_t { - MPIR_TREE_TYPE_KARY = 0, - MPIR_TREE_TYPE_KNOMIAL_1, - MPIR_TREE_TYPE_KNOMIAL_2, - MPIR_TREE_TYPE_TOPOLOGY_AWARE, - MPIR_TREE_TYPE_TOPOLOGY_AWARE_K, - MPIR_TREE_TYPE_TOPOLOGY_WAVE, -} MPIR_Tree_type_t; - /* enumerator for different recexch types */ enum { MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER = 0, diff --git a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c index c4037e5b4e2..6d4a0dd84b4 100644 --- a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c @@ -36,13 +36,15 @@ int MPIR_Ineighbor_allgather_allcomm_sched_linear(const void *sendbuf, MPI_Aint MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c index 237275ce321..c0717d70ddd 100644 --- a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c +++ b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c @@ -38,21 +38,21 @@ int MPIR_TSP_Ineighbor_allgather_sched_allcomm_linear(const void *sendbuf, MPI_A /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { mpi_errno = - MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c index b4c51feb741..d61720abedb 100644 --- a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c @@ -37,13 +37,15 @@ int MPIR_Ineighbor_allgatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + displs[l] * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c index f33f1cf7693..863b86e4973 100644 --- a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c +++ b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c @@ -39,21 +39,21 @@ int MPIR_TSP_Ineighbor_allgatherv_sched_allcomm_linear(const void *sendbuf, MPI_ /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { mpi_errno = - MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + displs[l] * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c index c593a5fe200..0c1a6333baf 100644 --- a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c @@ -38,7 +38,8 @@ int MPIR_Ineighbor_alltoall_allcomm_sched_linear(const void *sendbuf, MPI_Aint s for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + k * sendcount * sendtype_extent; - mpi_errno = MPIR_Sched_send(sb, sendcount, sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sb, sendcount, sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } @@ -57,7 +58,8 @@ int MPIR_Ineighbor_alltoall_allcomm_sched_linear(const void *sendbuf, MPI_Aint s */ for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c index 3131e163f0d..4b427bfbe73 100644 --- a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c @@ -39,15 +39,15 @@ int MPIR_TSP_Ineighbor_alltoall_sched_allcomm_linear(const void *sendbuf, MPI_Ai /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + k * sendcount * sendtype_extent; mpi_errno = - MPIR_TSP_sched_isend(sb, sendcount, sendtype, dsts[k], tag, comm_ptr, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_isend(sb, sendcount, sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -57,8 +57,8 @@ int MPIR_TSP_Ineighbor_alltoall_sched_allcomm_linear(const void *sendbuf, MPI_Ai for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c index 95713065499..27f200a3933 100644 --- a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c @@ -39,7 +39,8 @@ int MPIR_Ineighbor_alltoallv_allcomm_sched_linear(const void *sendbuf, const MPI for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + sdispls[k] * sendtype_extent; - mpi_errno = MPIR_Sched_send(sb, sendcounts[k], sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sb, sendcounts[k], sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } @@ -48,7 +49,8 @@ int MPIR_Ineighbor_alltoallv_allcomm_sched_linear(const void *sendbuf, const MPI * ref. ineighbor_alltoall_allcomm_sched_linear.c */ for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + rdispls[l] * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c index cdf14c9a6d7..87eead1a457 100644 --- a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c @@ -43,14 +43,14 @@ int MPIR_TSP_Ineighbor_alltoallv_sched_allcomm_linear(const void *sendbuf, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + sdispls[k] * sendtype_extent; mpi_errno = - MPIR_TSP_sched_isend(sb, sendcounts[k], sendtype, dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sb, sendcounts[k], sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -60,8 +60,8 @@ int MPIR_TSP_Ineighbor_alltoallv_sched_allcomm_linear(const void *sendbuf, for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + rdispls[l] * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c index dd9f143bcf2..5ccde6b4a7f 100644 --- a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c @@ -39,7 +39,9 @@ int MPIR_Ineighbor_alltoallw_allcomm_sched_linear(const void *sendbuf, const MPI char *sb; sb = ((char *) sendbuf) + sdispls[k]; - mpi_errno = MPIR_Sched_send(sb, sendcounts[k], sendtypes[k], dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sb, sendcounts[k], sendtypes[k], dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, + s); MPIR_ERR_CHECK(mpi_errno); } @@ -50,7 +52,9 @@ int MPIR_Ineighbor_alltoallw_allcomm_sched_linear(const void *sendbuf, const MPI char *rb; rb = ((char *) recvbuf) + rdispls[l]; - mpi_errno = MPIR_Sched_recv(rb, recvcounts[l], recvtypes[l], srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcounts[l], recvtypes[l], srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, + s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c index ee5b5b3872e..711f95a220d 100644 --- a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c @@ -39,7 +39,7 @@ int MPIR_TSP_Ineighbor_alltoallw_sched_allcomm_linear(const void *sendbuf, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { @@ -47,8 +47,8 @@ int MPIR_TSP_Ineighbor_alltoallw_sched_allcomm_linear(const void *sendbuf, sb = ((char *) sendbuf) + sdispls[k]; mpi_errno = - MPIR_TSP_sched_isend(sb, sendcounts[k], sendtypes[k], dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sb, sendcounts[k], sendtypes[k], dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -60,8 +60,8 @@ int MPIR_TSP_Ineighbor_alltoallw_sched_allcomm_linear(const void *sendbuf, rb = ((char *) recvbuf) + rdispls[l]; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtypes[l], srcs[l], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtypes[l], srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c b/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c index 683e18f3e6d..46af4fbb22f 100644 --- a/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c +++ b/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c @@ -14,7 +14,7 @@ int MPIR_Ireduce_inter_sched_local_reduce_remote_send(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank; @@ -30,7 +30,7 @@ int MPIR_Ireduce_inter_sched_local_reduce_remote_send(const void *sendbuf, void if (root == MPI_ROOT) { /* root receives data from rank 0 on remote group */ - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -57,13 +57,13 @@ int MPIR_Ireduce_inter_sched_local_reduce_remote_send(const void *sendbuf, void mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, - comm_ptr->local_comm, s); + comm_ptr->local_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { - mpi_errno = MPIR_Sched_send(tmp_buf, count, datatype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_buf, count, datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c index 5ed414d2c09..d803296b1be 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c @@ -7,7 +7,7 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, is_commutative; @@ -17,8 +17,7 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Create a temporary buffer */ @@ -93,7 +92,8 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai source = (relrank | mask); if (source < comm_size) { source = (source + lroot) % comm_size; - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, source, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, source, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -119,7 +119,7 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai /* I've received all that I'm going to. Send my result to * my parent */ source = ((relrank & (~mask)) + lroot) % comm_size; - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, source, comm_ptr, s); + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, source, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -131,12 +131,12 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai if (!is_commutative && (root != 0)) { if (rank == 0) { - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); } else if (rank == root) { - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c index 2f35e5d6f93..501a45bd077 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c @@ -34,7 +34,8 @@ */ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i, j, comm_size, rank, pof2, is_commutative ATTRIBUTE((unused)); @@ -44,8 +45,7 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re MPI_Aint true_lb, true_extent, extent; MPIR_CHKLMEM_DECL(2); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* NOTE: this algorithm is currently only correct for commutative operations */ is_commutative = MPIR_Op_is_commutative(op); @@ -64,7 +64,7 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re tmp_buf = (void *) ((char *) tmp_buf - true_lb); /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(HANDLE_IS_BUILTIN(op)); @@ -104,7 +104,8 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re if (rank < 2 * rem) { if (rank % 2 != 0) { /* odd */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -113,7 +114,8 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re * doubling */ newrank = -1; } else { /* even */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -182,11 +184,11 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -231,7 +233,7 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re for (i = 1; i < pof2; i++) disps[i] = disps[i - 1] + cnts[i - 1]; - mpi_errno = MPIR_Sched_recv(recvbuf, cnts[0], datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, cnts[0], datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -239,7 +241,8 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re send_idx = 0; last_idx = 2; } else if (newrank == 0) { /* send */ - mpi_errno = MPIR_Sched_send(recvbuf, cnts[0], datatype, root, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, cnts[0], datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); newrank = -1; @@ -304,14 +307,14 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re /* send and exit */ /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); break; } else { /* recv and continue */ mpi_errno = MPIR_Sched_recv(((char *) recvbuf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c index a1a8a63c72f..1e00f6ea3f8 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c @@ -7,32 +7,31 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int is_commutative; MPI_Aint true_lb, true_extent, extent; void *tmp_buf = NULL; - MPIR_Comm *nc; - MPIR_Comm *nrc; - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - - nc = comm_ptr->node_comm; - nrc = comm_ptr->node_roots_comm; + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int local_root = MPIR_Get_intranode_rank(comm_ptr, root); /* is the op commutative? We do SMP optimizations only if it is. */ is_commutative = MPIR_Op_is_commutative(op); if (!is_commutative) { mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } /* Create a temporary buffer on local roots of all nodes */ - if (nrc != NULL) { + if (local_rank == 0) { MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -43,32 +42,32 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co } /* do the intranode reduce on all nodes other than the root's node */ - if (nc != NULL && MPIR_Get_intranode_rank(comm_ptr, root) == -1) { - mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, nc, s); + if (local_size > 1 && local_root == -1) { + mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } /* do the internode reduce to the root's node */ - if (nrc != NULL) { - if (nrc->rank != MPIR_Get_internode_rank(comm_ptr, root)) { + if (local_rank == 0) { + int inter_root = MPIR_Get_internode_rank(comm_ptr, root); + if (local_root < 0) { /* I am not on root's node. Use tmp_buf if we * participated in the first reduce, otherwise use sendbuf */ - const void *buf = (nc == NULL ? sendbuf : tmp_buf); - mpi_errno = MPIR_Ireduce_intra_sched_auto(buf, NULL, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - nrc, s); + const void *buf = (local_size > 1 ? tmp_buf : sendbuf); + mpi_errno = MPIR_Ireduce_intra_sched_auto(buf, NULL, count, datatype, op, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { /* I am on root's node. I have not participated in the earlier reduce. */ - if (comm_ptr->rank != root) { + if (local_rank != local_root) { /* I am not the root though. I don't have a valid recvbuf. * Use tmp_buf as recvbuf. */ mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, - root), nrc, - s); + op, inter_root, comm_ptr, + MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -78,9 +77,8 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co /* I am the root. in_place is automatically handled. */ mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, - root), nrc, - s); + op, inter_root, comm_ptr, + MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -91,10 +89,9 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co } /* do the intranode reduce on the root's node */ - if (nc != NULL && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { + if (local_size > 1 && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm_ptr, root), nc, - s); + op, local_root, comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce/ireduce_tsp_auto.c b/src/mpi/coll/ireduce/ireduce_tsp_auto.c index 1666e2dd029..f86966ccaf6 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_auto.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_auto.c @@ -11,7 +11,7 @@ /* Remove this function when gentran algos are in json file */ static int MPIR_Ireduce_sched_intra_tsp_flat_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -25,7 +25,7 @@ static int MPIR_Ireduce_sched_intra_tsp_flat_auto(const void *sendbuf, void *rec * gentran_tree algo */ /* gentran tree with knomial tree type, radix 2 and pipeline block size 0 */ mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, + datatype, op, root, comm_ptr, coll_group, tree_type, radix, block_size, buffer_per_child, sched); if (mpi_errno) @@ -40,13 +40,15 @@ static int MPIR_Ireduce_sched_intra_tsp_flat_auto(const void *sendbuf, void *rec /* sched version of CVAR and json based collective selection. Meant only for gentran scheduler */ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IREDUCE, .comm_ptr = comm_ptr, + .coll_group = coll_group, .u.ireduce.sendbuf = sendbuf, .u.ireduce.recvbuf = recvbuf, @@ -59,15 +61,21 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); + int rank, comm_size; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + if (comm_size == 1) { + goto fn_exit; + } + switch (MPIR_CVAR_IREDUCE_INTRA_ALGORITHM) { case MPIR_CVAR_IREDUCE_INTRA_ALGORITHM_tsp_tree: /*Only knomial_1 tree supports non-commutative operations */ - MPII_COLLECTIVE_FALLBACK_CHECK(comm_ptr->rank, MPIR_Op_is_commutative(op) || + MPII_COLLECTIVE_FALLBACK_CHECK(rank, MPIR_Op_is_commutative(op) || MPIR_Ireduce_tree_type == MPIR_TREE_TYPE_KNOMIAL_1, mpi_errno, "Ireduce gentran_tree cannot be applied.\n"); mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, root, - comm_ptr, MPIR_Ireduce_tree_type, + comm_ptr, coll_group, MPIR_Ireduce_tree_type, MPIR_CVAR_IREDUCE_TREE_KVAL, MPIR_CVAR_IREDUCE_TREE_PIPELINE_CHUNK_SIZE, MPIR_CVAR_IREDUCE_TREE_BUFFER_PER_CHILD, sched); @@ -76,7 +84,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP case MPIR_CVAR_IREDUCE_INTRA_ALGORITHM_tsp_ring: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, root, - comm_ptr, MPIR_TREE_TYPE_KARY, 1, + comm_ptr, coll_group, MPIR_TREE_TYPE_KARY, 1, MPIR_CVAR_IREDUCE_RING_CHUNK_SIZE, MPIR_CVAR_IREDUCE_TREE_BUFFER_PER_CHILD, sched); break; @@ -89,7 +97,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ireduce_intra_tsp_tree: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - root, comm_ptr, + root, comm_ptr, coll_group, cnt->u.ireduce.intra_tsp_tree.tree_type, cnt->u.ireduce.intra_tsp_tree.k, cnt->u.ireduce.intra_tsp_tree.chunk_size, @@ -100,7 +108,8 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ireduce_intra_tsp_ring: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - root, comm_ptr, MPIR_TREE_TYPE_KARY, 1, + root, comm_ptr, coll_group, + MPIR_TREE_TYPE_KARY, 1, cnt->u.ireduce.intra_tsp_ring.chunk_size, cnt->u.ireduce. intra_tsp_ring.buffer_per_child, sched); @@ -110,7 +119,8 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP /* Replace this call with MPIR_Assert(0) when json files have gentran algos */ mpi_errno = MPIR_Ireduce_sched_intra_tsp_flat_auto(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, sched); + datatype, op, root, comm_ptr, + coll_group, sched); break; } } @@ -120,7 +130,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP fallback: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, root, - comm_ptr, MPIR_TREE_TYPE_KARY, 1, + comm_ptr, coll_group, MPIR_TREE_TYPE_KARY, 1, MPIR_CVAR_IREDUCE_RING_CHUNK_SIZE, MPIR_CVAR_IREDUCE_TREE_BUFFER_PER_CHILD, sched); diff --git a/src/mpi/coll/ireduce/ireduce_tsp_tree.c b/src/mpi/coll/ireduce/ireduce_tsp_tree.c index 0ce1f25852a..4fbd0fe2ae5 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_tree.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_tree.c @@ -10,8 +10,8 @@ /* Routine to schedule a pipelined tree based reduce */ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, int tree_type, int k, int chunk_size, - int buffer_per_child, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int tree_type, int k, + int chunk_size, int buffer_per_child, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, j, t; @@ -42,8 +42,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); is_root = (rank == root); /* main algorithm */ @@ -70,12 +69,13 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai my_tree.children = NULL; if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm, tree_type, k, tree_root, + MPIR_Treealgo_tree_create_topo_aware(comm, coll_group, tree_type, k, tree_root, MPIR_CVAR_IREDUCE_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IREDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.ireduce.sendbuf = sendbuf, .u.ireduce.recvbuf = recvbuf, .u.ireduce.count = count, @@ -100,7 +100,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai } mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm, k, tree_root, + MPIR_Treealgo_tree_create_topo_wave(comm, coll_group, k, tree_root, MPIR_CVAR_IREDUCE_TOPO_REORDER_ENABLE, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, &my_tree); @@ -193,7 +193,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < num_children; i++) { @@ -213,8 +213,9 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai nvtcs = 1; } - mpi_errno = MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, - sched, nvtcs, vtcs, &recv_id[i]); + mpi_errno = + MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); /* Setup dependencies for reduction. Reduction depends on the corresponding recv to complete */ @@ -260,7 +261,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai if (!is_tree_root) { mpi_errno = MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, my_tree.parent, tag, comm, - sched, nvtcs, vtcs, &vtx_id); + coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -268,12 +269,12 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai if (tree_root != root) { if (is_tree_root) { /* tree_root sends data to root */ mpi_errno = - MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, root, tag, comm, sched, - nvtcs, vtcs, &vtx_id); + MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, root, tag, comm, + coll_group, sched, nvtcs, vtcs, &vtx_id); } else if (is_root) { /* root receives data from tree_root */ mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + offset * extent, msgsize, datatype, - tree_root, tag, comm, sched, 0, NULL, &vtx_id); + tree_root, tag, comm, coll_group, sched, 0, NULL, &vtx_id); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c index f015ae31d5c..41c4d159d67 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c @@ -17,7 +17,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, root, local_size, total_count, i; @@ -62,7 +62,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -71,13 +71,13 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -86,7 +86,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -101,7 +101,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se mpi_errno = MPIR_Iscatterv_intra_sched_auto(tmp_buf, recvcounts, disps, datatype, recvbuf, recvcounts[rank], datatype, 0, newcomm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c index 863467231c9..5d23863a6e6 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c @@ -23,11 +23,11 @@ int MPIR_Ireduce_scatter_intra_sched_noncommutative(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size = comm_ptr->local_size; - int rank = comm_ptr->rank; + int comm_size, rank; int log2_comm_size; int i, k; MPI_Aint true_extent, true_lb; @@ -36,6 +36,8 @@ int MPIR_Ireduce_scatter_intra_sched_noncommutative(const void *sendbuf, void *r void *tmp_buf1; void *result_ptr; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING @@ -98,10 +100,10 @@ int MPIR_Ireduce_scatter_intra_sched_noncommutative(const void *sendbuf, void *r } mpi_errno = MPIR_Sched_send((outgoing_data + send_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv((incoming_data + recv_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c index 7b720f4951f..71a2756c623 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c @@ -14,7 +14,8 @@ */ int MPIR_Ireduce_scatter_intra_sched_pairwise(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; @@ -23,8 +24,7 @@ int MPIR_Ireduce_scatter_intra_sched_pairwise(const void *sendbuf, void *recvbuf void *tmp_recvbuf; int src, dst; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -72,14 +72,15 @@ int MPIR_Ireduce_scatter_intra_sched_pairwise(const void *sendbuf, void *recvbuf * needs from src into tmp_recvbuf */ if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Sched_send(((char *) sendbuf + disps[dst] * extent), - recvcounts[dst], datatype, dst, comm_ptr, s); + recvcounts[dst], datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[dst] * extent), - recvcounts[dst], datatype, dst, comm_ptr, s); + recvcounts[dst], datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, recvcounts[rank], datatype, src, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, recvcounts[rank], datatype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c index 6b05fd0195a..2f1dc4de1db 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c @@ -19,7 +19,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; @@ -31,8 +32,7 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi MPI_Datatype sendtype, recvtype; int nprocs_completed, tmp_mask, tree_root, is_commutative; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -140,9 +140,9 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi * received in tmp_recvbuf and then accumulated into * tmp_results. accumulation is done later below. */ - mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; @@ -183,7 +183,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi if ((dst > rank) && (rank < tree_root + nprocs_completed) && (dst >= tree_root + nprocs_completed)) { /* send the current result */ - mpi_errno = MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -192,7 +193,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi else if ((dst < rank) && (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) { - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c index 69814169823..aa5b1c2ad09 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c @@ -36,7 +36,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; @@ -47,8 +48,7 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void int rem, newdst, send_idx, recv_idx, last_idx; int pof2, old_i, newrank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -95,7 +95,7 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; @@ -107,7 +107,9 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -116,7 +118,9 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -190,10 +194,10 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void int recv_dst = (recv_cnt ? dst : MPI_PROC_NULL); mpi_errno = MPIR_Sched_send(((char *) tmp_results + newdisps[send_idx] * extent), - send_cnt, datatype, send_dst, comm_ptr, s); + send_cnt, datatype, send_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) tmp_recvbuf + newdisps[recv_idx] * extent), - recv_cnt, datatype, recv_dst, comm_ptr, s); + recv_cnt, datatype, recv_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -231,14 +235,16 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void if (rank % 2) { /* odd */ if (recvcounts[rank - 1]) { mpi_errno = MPIR_Sched_send(((char *) tmp_results + disps[rank - 1] * extent), - recvcounts[rank - 1], datatype, rank - 1, comm_ptr, s); + recvcounts[rank - 1], datatype, rank - 1, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } } else { /* even */ if (recvcounts[rank]) { mpi_errno = - MPIR_Sched_recv(recvbuf, recvcounts[rank], datatype, rank + 1, comm_ptr, s); + MPIR_Sched_recv(recvbuf, recvcounts[rank], datatype, rank + 1, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c index c1e3bbc582b..1d0e23bf409 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c @@ -41,11 +41,11 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * const MPI_Aint * recvcounts, MPI_Aint * displs, MPI_Datatype datatype, MPI_Op op, size_t extent, int tag, - MPIR_Comm * comm, int k, int is_dist_halving, - int step2_nphases, int **step2_nbrs, - int rank, int nranks, int sink_id, - int is_out_vtcs, int *reduce_id_, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int k, + int is_dist_halving, int step2_nphases, + int **step2_nbrs, int rank, int nranks, + int sink_id, int is_out_vtcs, + int *reduce_id_, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int x, i, j, phase; @@ -86,7 +86,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * send_offset, send_cnt)); mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + send_offset, send_cnt, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &send_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id); MPIR_ERR_CHECK(mpi_errno); rank_for_offset = @@ -103,7 +104,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * recv_offset, recv_cnt)); mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_recvbuf + recv_offset, recv_cnt, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &recv_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &recv_id); MPIR_ERR_CHECK(mpi_errno); @@ -132,8 +134,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * /* Routine to schedule a recursive exchange based reduce_scatter with distance halving in each phase */ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int is_dist_halving, - int k, MPIR_TSP_sched_t sched) + MPI_Op op, MPIR_Comm * comm, int coll_group, + int is_dist_halving, int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace; @@ -157,11 +159,10 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); @@ -216,8 +217,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv else buf_to_send = (void *) sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, sched, - 0, NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { /* Step 2 participating rank */ @@ -226,8 +227,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv nvtcs = 1; vtcs[0] = (i == 0) ? dtcopy_id : reduce_id; mpi_errno = MPIR_TSP_sched_irecv(tmp_recvbuf, total_count, datatype, - step1_recvfrom[i], tag, comm, sched, nvtcs, vtcs, - &recv_id); + step1_recvfrom[i], tag, comm, coll_group, sched, nvtcs, + vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs++; vtcs[1] = recv_id; @@ -246,9 +247,10 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv if (in_step2) { MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(tmp_results, tmp_recvbuf, recvcounts, displs, datatype, op, extent, - tag, comm, k, is_dist_halving, - step2_nphases, step2_nbrs, rank, nranks, - sink_id, 1, &reduce_id, sched); + tag, comm, coll_group, k, + is_dist_halving, step2_nphases, + step2_nbrs, rank, nranks, sink_id, 1, + &reduce_id, sched); /* copy data from tmp_results buffer correct position into recvbuf for all participating ranks */ nvtcs = 1; vtcs[0] = reduce_id; /* This assignment will also be used in step3 sends */ @@ -265,7 +267,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = MPIR_TSP_sched_irecv(recvbuf, recvcounts[rank], datatype, step1_sendto, tag, comm, - sched, 1, &sink_id, &vtx_id); + coll_group, sched, 1, &sink_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { @@ -273,7 +275,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv /* vtcs will be assigned to last reduce_id in step2 function */ mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + displs[step1_recvfrom[i]] * extent, recvcounts[step1_recvfrom[i]], datatype, step1_recvfrom[i], - tag, comm, sched, nvtcs, vtcs, &vtx_id); + tag, comm, coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c index 793db9a8bd8..8c9629fe642 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c @@ -18,6 +18,7 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; @@ -51,7 +52,7 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -60,13 +61,13 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -75,7 +76,7 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -89,7 +90,8 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Iscatter_intra_sched_auto(tmp_buf, recvcount, datatype, - recvbuf, recvcount, datatype, 0, newcomm_ptr, s); + recvbuf, recvcount, datatype, 0, newcomm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c index 910c04314fe..080e957245d 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c @@ -12,11 +12,10 @@ int MPIR_Ireduce_scatter_block_intra_sched_noncommutative(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size = comm_ptr->local_size; - int rank = comm_ptr->rank; + int comm_size, rank; int log2_comm_size; int i, k; MPI_Aint true_extent, true_lb; @@ -25,6 +24,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_noncommutative(const void *sendbuf, v void *tmp_buf1; void *result_ptr; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING @@ -84,10 +85,10 @@ int MPIR_Ireduce_scatter_block_intra_sched_noncommutative(const void *sendbuf, v } mpi_errno = MPIR_Sched_send((outgoing_data + send_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv((incoming_data + recv_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c index b5091c80817..0f03670c6a5 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c @@ -9,7 +9,8 @@ * commutative op and is intended for use with large messages. */ int MPIR_Ireduce_scatter_block_intra_sched_pairwise(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; @@ -19,8 +20,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_pairwise(const void *sendbuf, void *r int src, dst; MPI_Aint total_count; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -62,14 +62,14 @@ int MPIR_Ireduce_scatter_block_intra_sched_pairwise(const void *sendbuf, void *r * needs from src into tmp_recvbuf */ if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Sched_send(((char *) sendbuf + disps[dst] * extent), - recvcount, datatype, dst, comm_ptr, s); + recvcount, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[dst] * extent), - recvcount, datatype, dst, comm_ptr, s); + recvcount, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, recvcount, datatype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_recvbuf, recvcount, datatype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c index 071682730f4..3554a64d6a5 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c @@ -10,7 +10,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; @@ -22,8 +23,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu MPI_Datatype sendtype, recvtype; int nprocs_completed, tmp_mask, tree_root, is_commutative; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -123,9 +123,9 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu /* tmp_results contains data to be sent in each step. Data is * received in tmp_recvbuf and then accumulated into * tmp_results. accumulation is done later below. */ - mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; @@ -166,7 +166,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu if ((dst > rank) && (rank < tree_root + nprocs_completed) && (dst >= tree_root + nprocs_completed)) { /* send the current result */ - mpi_errno = MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -175,7 +176,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu else if ((dst < rank) && (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) { - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c index 2a092a14721..34c3b68db34 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c @@ -10,7 +10,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; @@ -21,8 +22,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf int rem, newdst, send_idx, recv_idx, last_idx, send_cnt, recv_cnt; int pof2, old_i, newrank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -76,7 +76,9 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -85,7 +87,9 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -158,10 +162,10 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf int recv_dst = (recv_cnt ? dst : MPI_PROC_NULL); mpi_errno = MPIR_Sched_send(((char *) tmp_results + newdisps[send_idx] * extent), - send_cnt, datatype, send_dst, comm_ptr, s); + send_cnt, datatype, send_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) tmp_recvbuf + newdisps[recv_idx] * extent), - recv_cnt, datatype, recv_dst, comm_ptr, s); + recv_cnt, datatype, recv_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -196,11 +200,12 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf if (rank < 2 * rem) { if (rank % 2) { /* odd */ mpi_errno = MPIR_Sched_send(((char *) tmp_results + disps[rank - 1] * extent), - recvcount, datatype, rank - 1, comm_ptr, s); + recvcount, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { /* even */ - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c index d0c2839ec81..70aa2681896 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c @@ -10,8 +10,8 @@ /* Routine to schedule a recursive exchange based reduce_scatter with distance halving in each phase */ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int k, - MPIR_TSP_sched_t sched) + MPI_Op op, MPIR_Comm * comm, int coll_group, + int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace; @@ -32,11 +32,10 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); @@ -75,8 +74,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void else buf_to_send = (void *) sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, sched, - 0, NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { /* Step 2 participating rank */ for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets data from non-partcipating ranks */ @@ -84,8 +83,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void nvtcs = 1; vtcs[0] = (i == 0) ? dtcopy_id : reduce_id; mpi_errno = MPIR_TSP_sched_irecv(tmp_recvbuf, total_count, datatype, - step1_recvfrom[i], tag, comm, sched, nvtcs, vtcs, - &recv_id); + step1_recvfrom[i], tag, comm, coll_group, sched, nvtcs, + vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs++; vtcs[1] = recv_id; @@ -118,7 +117,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + send_offset, send_cnt * recvcount, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &send_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id); MPIR_ERR_CHECK(mpi_errno); MPII_Recexchalgo_get_count_and_offset(rank, phase, k, nranks, &recv_cnt, &offset); @@ -126,7 +126,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void mpi_errno = MPIR_TSP_sched_irecv(tmp_recvbuf, recv_cnt * recvcount, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &recv_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs = 2; @@ -155,8 +156,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void * send the data to non-partcipating ranks */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, recvcount, datatype, step1_sendto, tag, comm, sched, 1, - &step1_id, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, recvcount, datatype, step1_sendto, tag, comm, coll_group, + sched, 1, &step1_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { @@ -164,8 +165,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void vtcs[0] = reduce_id; mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + recvcount * step1_recvfrom[i] * extent, - recvcount, datatype, step1_recvfrom[i], tag, comm, sched, nvtcs, - vtcs, &vtx_id); + recvcount, datatype, step1_recvfrom[i], tag, comm, coll_group, + sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c index 4a914802380..be5eb268dae 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c @@ -7,7 +7,7 @@ int MPIR_Iscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint true_extent, true_lb, extent; @@ -16,8 +16,7 @@ int MPIR_Iscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf void *partial_scan = NULL; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); @@ -56,10 +55,11 @@ int MPIR_Iscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf dst = rank ^ mask; if (dst < comm_size) { /* Send partial_scan to dst. Recv into tmp_buf */ - mpi_errno = MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/iscan/iscan_intra_sched_smp.c b/src/mpi/coll/iscan/iscan_intra_sched_smp.c index 3cd74b72ed6..20e5bbdc445 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_smp.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_smp.c @@ -8,17 +8,19 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int rank = comm_ptr->rank; - MPIR_Comm *node_comm; - MPIR_Comm *roots_comm; MPI_Aint true_extent, true_lb, extent; void *tempbuf = NULL; void *prefulldata = NULL; void *localfulldata = NULL; + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int inter_rank = MPIR_Get_internode_rank(comm_ptr, comm_ptr->rank); + /* In order to use the SMP-aware algorithm, the "op" can be * either commutative or non-commutative, but we require a * communicator in which all the nodes contain processes with @@ -28,12 +30,9 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun /* We can't use the SMP-aware algorithm, use the non-SMP-aware * one */ return MPIR_Iscan_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); } - node_comm = comm_ptr->node_comm; - roots_comm = comm_ptr->node_roots_comm; - MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -42,12 +41,12 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun tempbuf = (void *) ((char *) tempbuf - true_lb); /* Create prefulldata and localfulldata on local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { prefulldata = MPIR_Sched_alloc_state(s, count * (MPL_MAX(extent, true_extent))); MPIR_ERR_CHKANDJUMP(!prefulldata, mpi_errno, MPI_ERR_OTHER, "**nomem"); prefulldata = (void *) ((char *) prefulldata - true_lb); - if (node_comm != NULL) { + if (local_size > 1) { localfulldata = MPIR_Sched_alloc_state(s, count * (MPL_MAX(extent, true_extent))); MPIR_ERR_CHKANDJUMP(!localfulldata, mpi_errno, MPI_ERR_OTHER, "**nomem"); localfulldata = (void *) ((char *) localfulldata - true_lb); @@ -56,9 +55,9 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun /* perform intranode scan to get temporary result in recvbuf. if there is only * one process, just copy the raw data. */ - if (node_comm != NULL) { - mpi_errno = - MPIR_Iscan_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, node_comm, s); + if (local_size > 1) { + mpi_errno = MPIR_Iscan_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else if (sendbuf != MPI_IN_PLACE) { @@ -71,17 +70,16 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * contains the reduce result of the whole node. Name it as * localfulldata. For example, localfulldata from node 1 contains * reduced data of rank 1,2,3. */ - if (roots_comm != NULL && node_comm != NULL) { - mpi_errno = MPIR_Sched_recv(localfulldata, count, datatype, - (node_comm->local_size - 1), node_comm, s); + if (local_rank == 0 && local_size > 1) { + mpi_errno = MPIR_Sched_recv(localfulldata, count, datatype, (local_size - 1), + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - } else if (roots_comm == NULL && node_comm != NULL && - node_comm->rank == node_comm->local_size - 1) { - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, 0, node_comm, s); + } else if (local_rank != 0 && local_size > 1 && local_rank == local_size - 1) { + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, 0, comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - } else if (roots_comm != NULL) { + } else if (local_rank == 0) { localfulldata = recvbuf; } @@ -89,25 +87,23 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * prefulldata on rank 4 contains reduce result of ranks * 1,2,3,4,5,6. it will be sent to rank 7 which is the * main process of node 3. */ - if (roots_comm != NULL) { - /* FIXME just use roots_comm->rank instead */ - int roots_rank = MPIR_Get_internode_rank(comm_ptr, rank); - MPIR_Assert(roots_rank == roots_comm->rank); - - mpi_errno = - MPIR_Iscan_intra_sched_auto(localfulldata, prefulldata, count, datatype, op, roots_comm, - s); + if (local_rank == 0) { + int inter_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE_CROSS].size; + + mpi_errno = MPIR_Iscan_intra_sched_auto(localfulldata, prefulldata, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - if (roots_rank != roots_comm->local_size - 1) { - mpi_errno = - MPIR_Sched_send(prefulldata, count, datatype, (roots_rank + 1), roots_comm, s); + if (inter_rank != inter_size - 1) { + mpi_errno = MPIR_Sched_send(prefulldata, count, datatype, (inter_rank + 1), + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } - if (roots_rank != 0) { - mpi_errno = MPIR_Sched_recv(tempbuf, count, datatype, (roots_rank - 1), roots_comm, s); + if (inter_rank != 0) { + mpi_errno = MPIR_Sched_recv(tempbuf, count, datatype, (inter_rank - 1), + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -119,12 +115,13 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * then we should broadcast this result in the local node, and * reduce it with recvbuf to get final result if necessary. */ - if (MPIR_Get_internode_rank(comm_ptr, rank) != 0) { + if (inter_rank != 0) { /* we aren't on "node 0", so our node leader (possibly us) received * "prefulldata" from another leader into "tempbuf" */ - if (node_comm != NULL) { - mpi_errno = MPIR_Ibcast_intra_sched_auto(tempbuf, count, datatype, 0, node_comm, s); + if (local_size > 1) { + mpi_errno = MPIR_Ibcast_intra_sched_auto(tempbuf, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c index b6010400be0..34a24b05a5c 100644 --- a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c @@ -9,7 +9,8 @@ /* Routine to schedule a recursive exchange based scan */ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPI_Aint extent, true_extent; @@ -27,11 +28,10 @@ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *rec /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); is_commutative = MPIR_Op_is_commutative(op); @@ -74,8 +74,8 @@ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *rec nvtcs = 1; vtcs[0] = (loop_count == 0) ? dtcopy_id : reduce_id; mpi_errno = - MPIR_TSP_sched_isend(partial_scan, count, datatype, dst, tag, comm, sched, nvtcs, - vtcs, &send_id); + MPIR_TSP_sched_isend(partial_scan, count, datatype, dst, tag, comm, coll_group, + sched, nvtcs, vtcs, &send_id); MPIR_ERR_CHECK(mpi_errno); @@ -84,8 +84,8 @@ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *rec vtcs[1] = recv_reduce; } mpi_errno = - MPIR_TSP_sched_irecv(tmp_buf, count, datatype, dst, tag, comm, sched, nvtcs, vtcs, - &recv_id); + MPIR_TSP_sched_irecv(tmp_buf, count, datatype, dst, tag, comm, coll_group, sched, + nvtcs, vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c b/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c index 93a564d543d..0c9a97581b0 100644 --- a/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c +++ b/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c @@ -15,7 +15,7 @@ int MPIR_Iscatter_inter_sched_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int remote_size; @@ -34,12 +34,12 @@ int MPIR_Iscatter_inter_sched_linear(const void *sendbuf, MPI_Aint sendcount, MP for (i = 0; i < remote_size; i++) { mpi_errno = MPIR_Sched_send(((char *) sendbuf + sendcount * i * extent), sendcount, sendtype, i, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); } else { - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c b/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c index 1d0f7a3439d..7b381fc5fee 100644 --- a/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c +++ b/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c @@ -18,7 +18,7 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, local_size, remote_size; @@ -34,7 +34,8 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI if (root == MPI_ROOT) { /* root sends all data to rank 0 on remote group and returns */ - mpi_errno = MPIR_Sched_send(sendbuf, sendcount * remote_size, sendtype, 0, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount * remote_size, sendtype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); goto fn_exit; @@ -53,7 +54,7 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI mpi_errno = MPIR_Sched_recv(tmp_buf, recvcount * local_size * recvtype_sz, MPI_BYTE, - root, comm_ptr, s); + root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { @@ -70,7 +71,8 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI /* now do the usual scatter on this intracommunicator */ mpi_errno = MPIR_Iscatter_intra_sched_auto(tmp_buf, recvcount * recvtype_sz, MPI_BYTE, - recvbuf, recvcount, recvtype, 0, newcomm_ptr, s); + recvbuf, recvcount, recvtype, 0, newcomm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c index 49cc7fc0159..5f17deed25c 100644 --- a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c +++ b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c @@ -66,7 +66,7 @@ static int calc_curr_count(MPIR_Comm * comm, int tag, void *state) int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint extent = 0; @@ -77,8 +77,7 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, void *tmp_buf = NULL; struct shared_state *ss = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); ss = MPIR_Sched_alloc_state(s, sizeof(struct shared_state)); MPIR_ERR_CHKANDJUMP(!ss, mpi_errno, MPI_ERR_OTHER, "**nomem"); @@ -158,7 +157,8 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, * they don't have to forward data to anyone. Others * receive data into a temporary buffer. */ if (relative_rank % 2) { - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { @@ -167,7 +167,7 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, * some cases. query amount of data actually received */ mpi_errno = MPIR_Sched_recv_status(tmp_buf, tmp_buf_size, MPI_BYTE, src, comm_ptr, - &ss->status, s); + coll_group, &ss->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&get_count, ss, s); @@ -205,7 +205,8 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, /* mask is also the size of this process's subtree */ mpi_errno = MPIR_Sched_send_defer(((char *) sendbuf + extent * sendcount * mask), - &ss->send_subtree_count, sendtype, dst, comm_ptr, s); + &ss->send_subtree_count, sendtype, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { @@ -218,7 +219,7 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, /* mask is also the size of this process's subtree */ mpi_errno = MPIR_Sched_send_defer(((char *) tmp_buf + ss->nbytes * mask), &ss->send_subtree_count, MPI_BYTE, dst, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscatter/iscatter_tsp_tree.c b/src/mpi/coll/iscatter/iscatter_tsp_tree.c index b92ed46017d..1c8aa9121e0 100644 --- a/src/mpi/coll/iscatter/iscatter_tsp_tree.c +++ b/src/mpi/coll/iscatter/iscatter_tsp_tree.c @@ -10,7 +10,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - int k, MPIR_TSP_sched_t sched) + int coll_group, int k, MPIR_TSP_sched_t sched) { MPIR_FUNC_ENTER; @@ -34,8 +34,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; MPIR_CHKLMEM_DECL(2); - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ if (rank == root) @@ -48,7 +47,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); if (rank == root && is_inplace) { @@ -148,7 +147,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* receive data from the parent */ if (my_tree.parent != -1) { mpi_errno = MPIR_TSP_sched_irecv(tmp_buf, recv_size, recvtype, my_tree.parent, - tag, comm, sched, 0, NULL, &recv_id); + tag, comm, coll_group, sched, 0, NULL, &recv_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "rank:%d posts recv", rank)); } @@ -158,7 +157,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, int child = *(int *) utarray_eltptr(my_tree.children, i); mpi_errno = MPIR_TSP_sched_isend((char *) tmp_buf + child_data_offset[i] * sendtype_extent, child_subtree_size[i] * sendcount, sendtype, - child, tag, comm, sched, num_send_dependencies, + child, tag, comm, coll_group, sched, num_send_dependencies, (lrank == 0) ? dtcopy_id : &recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c index 0257a6df1ec..4f19b399152 100644 --- a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c @@ -18,22 +18,25 @@ int MPIR_Iscatterv_allcomm_sched_linear(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; MPI_Aint extent; int i; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; MPIR_Datatype_get_extent_macro(sendtype, extent); @@ -48,7 +51,8 @@ int MPIR_Iscatterv_allcomm_sched_linear(const void *sendbuf, const MPI_Aint send } } else { mpi_errno = MPIR_Sched_send(((char *) sendbuf + displs[i] * extent), - sendcounts[i], sendtype, i, comm_ptr, s); + sendcounts[i], sendtype, i, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } @@ -58,7 +62,8 @@ int MPIR_Iscatterv_allcomm_sched_linear(const void *sendbuf, const MPI_Aint send else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (recvcount) { - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c index dbc0669299a..39fece505dc 100644 --- a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c @@ -11,7 +11,7 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; @@ -22,20 +22,22 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint MPIR_FUNC_ENTER; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; MPIR_Datatype_get_extent_macro(sendtype, extent); /* We need a check to ensure extent will fit in a @@ -57,7 +59,7 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint } else { mpi_errno = MPIR_TSP_sched_isend(((char *) sendbuf + displs[i] * extent), sendcounts[i], sendtype, i, tag, comm_ptr, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); } } MPIR_ERR_CHECK(mpi_errno); @@ -68,8 +70,8 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (recvcount) { mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, recvcount, recvtype, root, tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, recvcount, recvtype, root, tag, comm_ptr, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/mpir_coll_sched_auto.c b/src/mpi/coll/mpir_coll_sched_auto.c index 8b118bda349..8ad4e037144 100644 --- a/src/mpi/coll/mpir_coll_sched_auto.c +++ b/src/mpi/coll/mpir_coll_sched_auto.c @@ -12,50 +12,55 @@ * defining them here. */ -int MPIR_Ibarrier_intra_sched_auto(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_intra_sched_auto(MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibarrier_intra_sched_recursive_doubling(comm_ptr, s); + mpi_errno = MPIR_Ibarrier_intra_sched_recursive_doubling(comm_ptr, coll_group, s); return mpi_errno; } /* It will choose between several different algorithms based on the given * parameters. */ -int MPIR_Ibarrier_inter_sched_auto(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_inter_sched_auto(MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno; - mpi_errno = MPIR_Ibarrier_inter_sched_bcast(comm_ptr, s); + mpi_errno = MPIR_Ibarrier_inter_sched_bcast(comm_ptr, coll_group, s); return mpi_errno; } int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size; MPI_Aint type_size, nbytes; MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT) { - mpi_errno = MPIR_Ibcast_intra_sched_smp(buffer, count, datatype, root, comm_ptr, s); + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group)) { + mpi_errno = + MPIR_Ibcast_intra_sched_smp(buffer, count, datatype, root, comm_ptr, coll_group, s); if (mpi_errno) MPIR_ERR_POP(mpi_errno); goto fn_exit; } - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + MPIR_Datatype_get_size_macro(datatype, type_size); nbytes = type_size * count; /* simplistic implementation for now */ if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS)) { - mpi_errno = MPIR_Ibcast_intra_sched_binomial(buffer, count, datatype, root, comm_ptr, s); + mpi_errno = + MPIR_Ibcast_intra_sched_binomial(buffer, count, datatype, root, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else { /* (nbytes >= MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_size >= MPIR_CVAR_BCAST_MIN_PROCS) */ @@ -63,12 +68,13 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data mpi_errno = MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(buffer, count, datatype, root, - comm_ptr, s); + comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Ibcast_intra_sched_scatter_ring_allgather(buffer, count, datatype, root, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -83,24 +89,25 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data * know anything about hierarchy. It will choose between several * different algorithms based on the given parameters. */ int MPIR_Ibcast_inter_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibcast_inter_sched_flat(buffer, count, datatype, root, comm_ptr, s); + mpi_errno = + MPIR_Ibcast_inter_sched_flat(buffer, count, datatype, root, comm_ptr, coll_group, s); return mpi_errno; } int MPIR_Igather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igather_intra_sched_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -111,7 +118,7 @@ int MPIR_Igather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_D int MPIR_Igather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint local_size, remote_size; @@ -137,11 +144,11 @@ int MPIR_Igather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_D if (nbytes < MPIR_CVAR_GATHER_INTER_SHORT_MSG_SIZE) { mpi_errno = MPIR_Igather_inter_sched_short(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Igather_inter_sched_long(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); } fn_exit: @@ -151,13 +158,13 @@ int MPIR_Igather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_D int MPIR_Igatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igatherv_allcomm_sched_linear(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm_ptr, s); + displs, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -170,13 +177,13 @@ int MPIR_Igatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Igatherv_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igatherv_allcomm_sched_linear(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm_ptr, s); + displs, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -188,13 +195,13 @@ int MPIR_Igatherv_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Iscatter_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatter_intra_sched_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); @@ -206,7 +213,7 @@ int MPIR_Iscatter_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Iscatter_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int local_size, remote_size; MPI_Aint sendtype_size, recvtype_size, nbytes; @@ -227,11 +234,11 @@ int MPIR_Iscatter_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ mpi_errno = MPIR_Iscatter_inter_sched_remote_send_local_scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, - comm_ptr, s); + comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Iscatter_inter_sched_linear(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, - s); + coll_group, s); } MPIR_ERR_CHECK(mpi_errno); @@ -245,13 +252,13 @@ int MPIR_Iscatter_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Iscatterv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatterv_allcomm_sched_linear(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, s); + recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -264,13 +271,13 @@ int MPIR_Iscatterv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcoun int MPIR_Iscatterv_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatterv_allcomm_sched_linear(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, s); + recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -282,13 +289,14 @@ int MPIR_Iscatterv_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcoun int MPIR_Iallgather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size; MPI_Aint recvtype_size, tot_bytes; - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ MPIR_Datatype_get_size_macro(recvtype, recvtype_size); tot_bytes = (MPI_Aint) recvcount *comm_size * recvtype_size; @@ -296,15 +304,16 @@ int MPIR_Iallgather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MP if ((tot_bytes < MPIR_CVAR_ALLGATHER_LONG_MSG_SIZE) && !(comm_size & (comm_size - 1))) { mpi_errno = MPIR_Iallgather_intra_sched_recursive_doubling(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, s); + recvcount, recvtype, comm_ptr, + coll_group, s); } else if (tot_bytes < MPIR_CVAR_ALLGATHER_SHORT_MSG_SIZE) { mpi_errno = MPIR_Iallgather_intra_sched_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Iallgather_intra_sched_ring(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } MPIR_ERR_CHECK(mpi_errno); @@ -316,13 +325,15 @@ int MPIR_Iallgather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MP int MPIR_Iallgather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgather_inter_sched_local_gather_remote_bcast(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, + coll_group, s); return mpi_errno; } @@ -330,13 +341,17 @@ int MPIR_Iallgather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, int MPIR_Iallgatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int i, comm_size; + int i; MPI_Aint total_count, recvtype_size; - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + MPIR_Datatype_get_size_macro(recvtype, recvtype_size); total_count = 0; @@ -353,21 +368,21 @@ int MPIR_Iallgatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_Iallgatherv_intra_sched_recursive_doubling(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else if (total_count * recvtype_size < MPIR_CVAR_ALLGATHER_SHORT_MSG_SIZE) { /* Short message and non-power-of-two no. of processes. Use * Bruck algorithm (see description above). */ mpi_errno = MPIR_Iallgatherv_intra_sched_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, s); + displs, recvtype, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* long message or medium-size message and non-power-of-two * no. of processes. Use ring algorithm. */ mpi_errno = MPIR_Iallgatherv_intra_sched_ring(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, s); + displs, recvtype, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -380,27 +395,29 @@ int MPIR_Iallgatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, int MPIR_Iallgatherv_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm_ptr, s); + comm_ptr, coll_group, s); return mpi_errno; } int MPIR_Ialltoall_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size; MPI_Aint nbytes, sendtype_size; - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ MPIR_Datatype_get_size_macro(sendtype, sendtype_size); nbytes = sendtype_size * sendcount; @@ -408,19 +425,20 @@ int MPIR_Ialltoall_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI if (sendbuf == MPI_IN_PLACE) { mpi_errno = MPIR_Ialltoall_intra_sched_inplace(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } else if ((nbytes <= MPIR_CVAR_ALLTOALL_SHORT_MSG_SIZE) && (comm_size >= 8)) { mpi_errno = MPIR_Ialltoall_intra_sched_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } else if (nbytes <= MPIR_CVAR_ALLTOALL_MEDIUM_MSG_SIZE) { mpi_errno = MPIR_Ialltoall_intra_sched_permuted_sendrecv(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, s); + recvcount, recvtype, comm_ptr, coll_group, + s); } else { mpi_errno = MPIR_Ialltoall_intra_sched_pairwise(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } MPIR_ERR_CHECK(mpi_errno); @@ -433,13 +451,14 @@ int MPIR_Ialltoall_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI int MPIR_Ialltoall_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoall_inter_sched_pairwise_exchange(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, s); + comm_ptr, coll_group, s); return mpi_errno; } @@ -447,7 +466,8 @@ int MPIR_Ialltoall_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI int MPIR_Ialltoallv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; @@ -456,11 +476,11 @@ int MPIR_Ialltoallv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcou if (sendbuf == MPI_IN_PLACE) { mpi_errno = MPIR_Ialltoallv_intra_sched_inplace(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm_ptr, s); + rdispls, recvtype, comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Ialltoallv_intra_sched_blocked(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm_ptr, s); + rdispls, recvtype, comm_ptr, coll_group, s); } return mpi_errno; @@ -469,13 +489,15 @@ int MPIR_Ialltoallv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcou int MPIR_Ialltoallv_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallv_inter_sched_pairwise_exchange(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm_ptr, s); + rdispls, recvtype, comm_ptr, + coll_group, s); return mpi_errno; } @@ -484,18 +506,20 @@ int MPIR_Ialltoallw_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcou const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; if (sendbuf == MPI_IN_PLACE) { mpi_errno = MPIR_Ialltoallw_intra_sched_inplace(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, s); + rdispls, recvtypes, comm_ptr, coll_group, + s); } else { mpi_errno = MPIR_Ialltoallw_intra_sched_blocked(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, s); + rdispls, recvtypes, comm_ptr, coll_group, + s); } return mpi_errno; @@ -505,20 +529,21 @@ int MPIR_Ialltoallw_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcou const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallw_inter_sched_pairwise_exchange(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, s); + rdispls, recvtypes, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2; @@ -526,9 +551,9 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT && MPIR_Op_is_commutative(op)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group) && MPIR_Op_is_commutative(op)) { mpi_errno = MPIR_Ireduce_intra_sched_smp(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -538,20 +563,20 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c MPIR_Datatype_get_size_macro(datatype, type_size); /* get nearest power-of-two less than or equal to number of ranks in the communicator */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(MPIR_Coll_size(comm_ptr, coll_group)); if ((count * type_size > MPIR_CVAR_REDUCE_SHORT_MSG_SIZE) && (HANDLE_IS_BUILTIN(op)) && (count >= pof2)) { /* do a reduce-scatter followed by gather to root. */ mpi_errno = MPIR_Ireduce_intra_sched_reduce_scatter_gather(sendbuf, recvbuf, count, datatype, op, - root, comm_ptr, s); + root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* use a binomial tree algorithm */ mpi_errno = MPIR_Ireduce_intra_sched_binomial(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -563,28 +588,30 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c int MPIR_Ireduce_inter_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_inter_sched_local_reduce_remote_send(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2; MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT && MPIR_Op_is_commutative(op)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group) && MPIR_Op_is_commutative(op)) { mpi_errno = - MPIR_Iallreduce_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, s); + MPIR_Iallreduce_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, + coll_group, s); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -595,7 +622,7 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain MPIR_Datatype_get_size_macro(datatype, type_size); /* get nearest power-of-two less than or equal to number of ranks in the communicator */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(MPIR_Coll_size(comm_ptr, coll_group)); /* If op is user-defined or count is less than pof2, use * recursive doubling algorithm. Otherwise do a reduce-scatter @@ -611,13 +638,13 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain /* use recursive doubling */ mpi_errno = MPIR_Iallreduce_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* do a reduce-scatter followed by allgather */ mpi_errno = MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(sendbuf, recvbuf, count, datatype, - op, comm_ptr, s); + op, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -629,29 +656,33 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain int MPIR_Iallreduce_inter_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(sendbuf, recvbuf, count, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; int is_commutative; MPI_Aint total_count, type_size, nbytes; - int comm_size; is_commutative = MPIR_Op_is_commutative(op); - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + total_count = 0; for (i = 0; i < comm_size; i++) { total_count += recvcounts[i]; @@ -666,12 +697,13 @@ int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, if (is_commutative && (nbytes < MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_intra_sched_recursive_halving(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else if (is_commutative && (nbytes >= MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_intra_sched_pairwise(sendbuf, recvbuf, recvcounts, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* (!is_commutative) */ @@ -687,13 +719,15 @@ int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, /* noncommutative, pof2 size, and block regular */ mpi_errno = MPIR_Ireduce_scatter_intra_sched_noncommutative(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else { /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */ mpi_errno = MPIR_Ireduce_scatter_intra_sched_recursive_doubling(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -706,29 +740,34 @@ int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, int MPIR_Ireduce_scatter_inter_sched_auto(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int is_commutative; MPI_Aint total_count, type_size, nbytes; - int comm_size; is_commutative = MPIR_Op_is_commutative(op); - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + total_count = recvcount * comm_size; if (total_count == 0) { goto fn_exit; @@ -740,12 +779,13 @@ int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvb if (is_commutative && (nbytes < MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else if (is_commutative && (nbytes >= MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_pairwise(sendbuf, recvbuf, recvcount, datatype, - op, comm_ptr, s); + op, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* (!is_commutative) */ @@ -753,14 +793,15 @@ int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvb /* noncommutative, pof2 size */ mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_noncommutative(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* noncommutative and non-pof2, use recursive doubling. */ mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(sendbuf, recvbuf, recvcount, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -774,30 +815,34 @@ int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvb int MPIR_Ireduce_scatter_block_inter_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(sendbuf, recvbuf, recvcount, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, + s); return mpi_errno; } int MPIR_Iscan_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT) { - mpi_errno = MPIR_Iscan_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, s); + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group)) { + mpi_errno = + MPIR_Iscan_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + s); } else { mpi_errno = MPIR_Iscan_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); } return mpi_errno; @@ -805,13 +850,13 @@ int MPIR_Iscan_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint cou int MPIR_Iexscan_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iexscan_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, comm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/op/opequal.c b/src/mpi/coll/op/opequal.c index cf6c7cb846f..fd9d57da163 100644 --- a/src/mpi/coll/op/opequal.c +++ b/src/mpi/coll/op/opequal.c @@ -55,7 +55,7 @@ int MPIR_EQUAL_check_dtype(MPI_Datatype type) MPIR_Assert(actual_pack_bytes == count * type_sz) int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, int root, MPIR_Comm * comm_ptr) + int *is_equal, int root, MPIR_Comm * comm_ptr, int coll_group) { int mpi_errno = MPI_SUCCESS; @@ -64,10 +64,12 @@ int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype /* Not all algorithm will work. In particular, we can't split the message */ if (comm_ptr->rank == root) { mpi_errno = MPIR_Reduce_intra_binomial(MPI_IN_PLACE, local_buf, byte_count, MPI_BYTE, - MPIX_EQUAL, root, comm_ptr, MPIR_ERR_NONE); + MPIX_EQUAL, root, comm_ptr, coll_group, + MPIR_ERR_NONE); } else { mpi_errno = MPIR_Reduce_intra_binomial(local_buf, NULL, byte_count, MPI_BYTE, - MPIX_EQUAL, root, comm_ptr, MPIR_ERR_NONE); + MPIX_EQUAL, root, comm_ptr, coll_group, + MPIR_ERR_NONE); } MPIR_ERR_CHECK(mpi_errno); @@ -84,7 +86,7 @@ int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype int MPIR_Allreduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, MPIR_Comm * comm_ptr) + int *is_equal, MPIR_Comm * comm_ptr, int coll_group) { int mpi_errno = MPI_SUCCESS; @@ -93,7 +95,8 @@ int MPIR_Allreduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datat /* Not all algorithm will work. In particular, we can't split the message */ mpi_errno = MPIR_Allreduce_intra_recursive_doubling(MPI_IN_PLACE, local_buf, byte_count, MPI_BYTE, - MPIX_EQUAL, comm_ptr, MPIR_ERR_NONE); + MPIX_EQUAL, comm_ptr, coll_group, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); *is_equal = local_buf->is_equal; diff --git a/src/mpi/coll/reduce/reduce_allcomm_nb.c b/src/mpi/coll/reduce/reduce_allcomm_nb.c index 2ab95bc494e..9c31c443cfe 100644 --- a/src/mpi/coll/reduce/reduce_allcomm_nb.c +++ b/src/mpi/coll/reduce/reduce_allcomm_nb.c @@ -7,13 +7,14 @@ int MPIR_Reduce_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, &req_ptr); + mpi_errno = + MPIR_Ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c b/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c index ad54cce357b..7b111d611c2 100644 --- a/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c +++ b/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c @@ -19,7 +19,8 @@ int MPIR_Reduce_inter_local_reduce_remote_send(const void *sendbuf, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, mpi_errno; MPI_Status status; @@ -35,7 +36,8 @@ int MPIR_Reduce_inter_local_reduce_remote_send(const void *sendbuf, if (root == MPI_ROOT) { /* root receives data from rank 0 on remote group */ - mpi_errno = MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. Rank 0 allocates temporary buffer, does @@ -63,12 +65,13 @@ int MPIR_Reduce_inter_local_reduce_remote_send(const void *sendbuf, newcomm_ptr = comm_ptr->local_comm; /* now do a local reduce on this intracommunicator */ - mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { mpi_errno = MPIC_Send(tmp_buf, count, datatype, root, - MPIR_REDUCE_TAG, comm_ptr, errflag); + MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/reduce/reduce_intra_binomial.c b/src/mpi/coll/reduce/reduce_intra_binomial.c index 4363b5a7031..93f6e233199 100644 --- a/src/mpi/coll/reduce/reduce_intra_binomial.c +++ b/src/mpi/coll/reduce/reduce_intra_binomial.c @@ -13,7 +13,8 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Status status; @@ -23,7 +24,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Create a temporary buffer */ @@ -96,7 +97,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, if (source < comm_size) { source = (source + lroot) % comm_size; mpi_errno = MPIC_Recv(tmp_buf, count, datatype, source, - MPIR_REDUCE_TAG, comm_ptr, &status); + MPIR_REDUCE_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); /* The sender is above us, so the received buffer must be @@ -117,7 +118,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, * my parent */ source = ((relrank & (~mask)) + lroot) % comm_size; mpi_errno = MPIC_Send(recvbuf, count, datatype, - source, MPIR_REDUCE_TAG, comm_ptr, errflag); + source, MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); break; } @@ -127,9 +128,11 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, if (!is_commutative && (root != 0)) { if (rank == 0) { mpi_errno = MPIC_Send(recvbuf, count, datatype, root, - MPIR_REDUCE_TAG, comm_ptr, errflag); + MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); } else if (rank == root) { - mpi_errno = MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, &status); + mpi_errno = + MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, coll_group, + &status); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c index a8113ce6658..a5a2b1c4f98 100644 --- a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c +++ b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c @@ -37,7 +37,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, pof2, rem, newrank; @@ -49,8 +50,7 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, MPIR_CHKLMEM_DECL(4); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Create a temporary buffer */ @@ -77,7 +77,7 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(HANDLE_IS_BUILTIN(op)); @@ -105,7 +105,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2 != 0) { /* odd */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank - 1, MPIR_REDUCE_TAG, comm_ptr, errflag); + datatype, rank - 1, MPIR_REDUCE_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -114,7 +115,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, newrank = -1; } else { /* even */ mpi_errno = MPIC_Recv(tmp_buf, count, - datatype, rank + 1, MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + datatype, rank + 1, MPIR_REDUCE_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. */ @@ -189,7 +191,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_REDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -236,14 +239,14 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, disps[i] = disps[i - 1] + cnts[i - 1]; mpi_errno = MPIC_Recv(recvbuf, cnts[0], datatype, - 0, MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + 0, MPIR_REDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); newrank = 0; send_idx = 0; last_idx = 2; } else if (newrank == 0) { /* send */ mpi_errno = MPIC_Send(recvbuf, cnts[0], datatype, - root, MPIR_REDUCE_TAG, comm_ptr, errflag); + root, MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); newrank = -1; } @@ -309,7 +312,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIC_Send((char *) recvbuf + disps[send_idx] * extent, - send_cnt, datatype, dst, MPIR_REDUCE_TAG, comm_ptr, errflag); + send_cnt, datatype, dst, MPIR_REDUCE_TAG, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); break; } else { @@ -319,7 +323,7 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, mpi_errno = MPIC_Recv((char *) recvbuf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce/reduce_intra_smp.c b/src/mpi/coll/reduce/reduce_intra_smp.c index 6ff46543185..a22fc8ed956 100644 --- a/src/mpi/coll/reduce/reduce_intra_smp.c +++ b/src/mpi/coll/reduce/reduce_intra_smp.c @@ -7,7 +7,7 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; void *tmp_buf = NULL; @@ -19,11 +19,17 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, int is_commutative; is_commutative = MPIR_Op_is_commutative(op); MPIR_Assertp(is_commutative); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); } #endif /* HAVE_ERROR_CHECKING */ + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int local_root = MPIR_Get_intranode_rank(comm_ptr, root); + int inter_root = MPIR_Get_internode_rank(comm_ptr, root); + /* Create a temporary buffer on local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -35,30 +41,29 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, } /* do the intranode reduce on all nodes other than the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) == -1) { - mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, - op, 0, comm_ptr->node_comm, errflag); + if (local_size > 1 && local_root == -1) { + mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } /* do the internode reduce to the root's node */ - if (comm_ptr->node_roots_comm != NULL) { - if (comm_ptr->node_roots_comm->rank != MPIR_Get_internode_rank(comm_ptr, root)) { + if (local_rank == 0) { + if (local_root == -1) { /* I am not on root's node. Use tmp_buf if we * participated in the first reduce, otherwise use sendbuf */ - const void *buf = (comm_ptr->node_comm == NULL ? sendbuf : tmp_buf); + const void *buf = (local_size > 1 ? tmp_buf : sendbuf); mpi_errno = MPIR_Reduce(buf, NULL, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + op, inter_root, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* I am on root's node. I have not participated in the earlier reduce. */ - if (comm_ptr->rank != root) { + if (local_root != 0) { /* I am not the root though. I don't have a valid recvbuf. * Use tmp_buf as recvbuf. */ mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + op, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); /* point sendbuf at tmp_buf to make final intranode reduce easy */ @@ -67,8 +72,8 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, /* I am the root. in_place is automatically handled. */ mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + op, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); /* set sendbuf to MPI_IN_PLACE to make final intranode reduce easy. */ @@ -79,10 +84,9 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, } /* do the intranode reduce on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { + if (local_size > 1 && local_root != -1) { mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm_ptr, root), - comm_ptr->node_comm, errflag); + op, local_root, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c b/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c index a29d82db025..51ae8ca3120 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c @@ -7,14 +7,15 @@ int MPIR_Reduce_scatter_allcomm_nb(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ mpi_errno = - MPIR_Ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, &req_ptr); + MPIR_Ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, coll_group, + &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c b/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c index eba1c515df0..9b7770fccfd 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c @@ -15,7 +15,7 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, mpi_errno, root, local_size, total_count, i; @@ -61,25 +61,25 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, v /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -92,7 +92,7 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, v newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Scatterv(tmp_buf, recvcounts, disps, datatype, recvbuf, - recvcounts[rank], datatype, 0, newcomm_ptr, errflag); + recvcounts[rank], datatype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c index d3085b27355..f2f169d5c0e 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c @@ -22,7 +22,7 @@ */ int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -37,6 +37,8 @@ int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf, void *result_ptr; MPIR_CHKLMEM_DECL(3); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING @@ -101,7 +103,7 @@ int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf, size, datatype, peer, MPIR_REDUCE_SCATTER_TAG, incoming_data + recv_offset * true_extent, size, datatype, peer, MPIR_REDUCE_SCATTER_TAG, - comm_ptr, MPI_STATUS_IGNORE, errflag); + comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); /* always perform the reduction at recv_offset, the data at send_offset * is now our peer's responsibility */ diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c index 0ae8b879377..a3367dd6705 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c @@ -15,7 +15,8 @@ */ int MPIR_Reduce_scatter_intra_pairwise(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; @@ -25,8 +26,7 @@ int MPIR_Reduce_scatter_intra_pairwise(const void *sendbuf, void *recvbuf, int src, dst; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -81,14 +81,14 @@ int MPIR_Reduce_scatter_intra_pairwise(const void *sendbuf, void *recvbuf, recvcounts[dst], datatype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, recvcounts[rank], datatype, src, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else mpi_errno = MPIC_Sendrecv(((char *) recvbuf + disps[dst] * extent), recvcounts[dst], datatype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, recvcounts[rank], datatype, src, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c index 6ab1d986302..972d1a13fbe 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c @@ -19,7 +19,7 @@ */ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, comm_size, i; @@ -34,8 +34,7 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv int nprocs_completed, tmp_mask, tree_root, is_commutative; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -148,7 +147,7 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv mpi_errno = MPIC_Sendrecv(tmp_results, 1, sendtype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, 1, recvtype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); received = 1; MPIR_ERR_CHECK(mpi_errno); @@ -190,7 +189,8 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv && (dst >= tree_root + nprocs_completed)) { /* send the current result */ mpi_errno = MPIC_Send(tmp_recvbuf, 1, recvtype, - dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } /* recv only if this proc. doesn't have data and sender @@ -199,7 +199,8 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Recv(tmp_recvbuf, 1, recvtype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); received = 1; MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c index 03f58c3d7c2..4af12cc5135 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c @@ -36,7 +36,7 @@ */ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, comm_size, i; @@ -50,7 +50,7 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb int pof2, old_i, newrank; MPIR_CHKLMEM_DECL(5); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING { @@ -113,7 +113,8 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(tmp_results, total_count, - datatype, rank + 1, MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + datatype, rank + 1, MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -123,7 +124,7 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb } else { /* odd */ mpi_errno = MPIC_Recv(tmp_recvbuf, total_count, datatype, rank - 1, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -199,18 +200,19 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb (char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else if ((send_cnt == 0) && (recv_cnt != 0)) mpi_errno = MPIC_Recv((char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); else if ((recv_cnt == 0) && (send_cnt != 0)) mpi_errno = MPIC_Send((char *) tmp_results + newdisps[send_idx] * extent, send_cnt, datatype, - dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -250,14 +252,15 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb mpi_errno = MPIC_Send((char *) tmp_results + disps[rank - 1] * extent, recvcounts[rank - 1], datatype, rank - 1, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { /* even */ if (recvcounts[rank]) { mpi_errno = MPIC_Recv(recvbuf, recvcounts[rank], datatype, rank + 1, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c index 42a37790218..adf50cf765d 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c @@ -7,14 +7,15 @@ int MPIR_Reduce_scatter_block_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ mpi_errno = - MPIR_Ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, &req_ptr); + MPIR_Ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, coll_group, + &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c index ed15eafb02f..84d1b45758e 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c @@ -18,6 +18,7 @@ int MPIR_Reduce_scatter_block_inter_remote_reduce_local_scatter(const void *send MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int rank, mpi_errno, root, local_size; @@ -51,25 +52,25 @@ int MPIR_Reduce_scatter_block_inter_remote_reduce_local_scatter(const void *send /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -80,7 +81,7 @@ int MPIR_Reduce_scatter_block_inter_remote_reduce_local_scatter(const void *send newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Scatter(tmp_buf, recvcount, datatype, recvbuf, - recvcount, datatype, 0, newcomm_ptr, errflag); + recvcount, datatype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c index b4edac03945..6451bcb4cee 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c @@ -25,7 +25,8 @@ int MPIR_Reduce_scatter_block_intra_noncommutative(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int comm_size = comm_ptr->local_size; @@ -39,6 +40,8 @@ int MPIR_Reduce_scatter_block_intra_noncommutative(const void *sendbuf, void *result_ptr; MPIR_CHKLMEM_DECL(3); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING @@ -99,7 +102,7 @@ int MPIR_Reduce_scatter_block_intra_noncommutative(const void *sendbuf, size, datatype, peer, MPIR_REDUCE_SCATTER_BLOCK_TAG, incoming_data + recv_offset * true_extent, size, datatype, peer, MPIR_REDUCE_SCATTER_BLOCK_TAG, - comm_ptr, MPI_STATUS_IGNORE, errflag); + comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); /* always perform the reduction at recv_offset, the data at send_offset * is now our peer's responsibility */ diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c index f7e5e636906..132d16218bc 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c @@ -25,7 +25,8 @@ int MPIR_Reduce_scatter_block_intra_pairwise(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; @@ -34,8 +35,7 @@ int MPIR_Reduce_scatter_block_intra_pairwise(const void *sendbuf, int src, dst; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -83,14 +83,14 @@ int MPIR_Reduce_scatter_block_intra_pairwise(const void *sendbuf, recvcount, datatype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf, recvcount, datatype, src, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else mpi_errno = MPIC_Sendrecv(((char *) recvbuf + disps[dst] * extent), recvcount, datatype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf, recvcount, datatype, src, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c index d666ee884a1..82d2d20c692 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c @@ -25,7 +25,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; @@ -38,8 +39,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, int nprocs_completed, tmp_mask, tree_root, is_commutative; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -145,7 +145,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(tmp_results, 1, sendtype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf, 1, recvtype, dst, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); received = 1; MPIR_ERR_CHECK(mpi_errno); @@ -187,7 +187,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, && (dst >= tree_root + nprocs_completed)) { /* send the current result */ mpi_errno = MPIC_Send(tmp_recvbuf, 1, recvtype, - dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } /* recv only if this proc. doesn't have data and sender @@ -197,7 +198,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, (rank >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Recv(tmp_recvbuf, 1, recvtype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, - comm_ptr, MPI_STATUS_IGNORE); + comm_ptr, coll_group, MPI_STATUS_IGNORE); received = 1; MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c index 05fede37670..85b10866140 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c @@ -40,7 +40,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; @@ -52,7 +53,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, int pof2, old_i, newrank; MPIR_CHKLMEM_DECL(5); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING { @@ -114,7 +115,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(tmp_results, total_count, datatype, rank + 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -124,7 +125,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, } else { /* odd */ mpi_errno = MPIC_Recv(tmp_recvbuf, total_count, datatype, rank - 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -202,18 +204,20 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, (char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else if ((send_cnt == 0) && (recv_cnt != 0)) mpi_errno = MPIC_Recv((char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); else if ((recv_cnt == 0) && (send_cnt != 0)) mpi_errno = MPIC_Send((char *) tmp_results + newdisps[send_idx] * extent, send_cnt, datatype, - dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -248,11 +252,12 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, mpi_errno = MPIC_Send((char *) tmp_results + disps[rank - 1] * extent, recvcount, datatype, rank - 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, errflag); } else { /* even */ mpi_errno = MPIC_Recv(recvbuf, recvcount, datatype, rank + 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scan/scan_allcomm_nb.c b/src/mpi/coll/scan/scan_allcomm_nb.c index 0528d599fa7..e957aa40664 100644 --- a/src/mpi/coll/scan/scan_allcomm_nb.c +++ b/src/mpi/coll/scan/scan_allcomm_nb.c @@ -6,13 +6,13 @@ #include "mpiimpl.h" int MPIR_Scan_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Iscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, &req_ptr); + mpi_errno = MPIR_Iscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/scan/scan_intra_recursive_doubling.c b/src/mpi/coll/scan/scan_intra_recursive_doubling.c index 6afdb2dfd34..31109b14d89 100644 --- a/src/mpi/coll/scan/scan_intra_recursive_doubling.c +++ b/src/mpi/coll/scan/scan_intra_recursive_doubling.c @@ -44,7 +44,8 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPI_Status status; int rank, comm_size; @@ -54,7 +55,7 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf, void *partial_scan, *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); @@ -96,7 +97,7 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(partial_scan, count, datatype, dst, MPIR_SCAN_TAG, tmp_buf, count, datatype, dst, - MPIR_SCAN_TAG, comm_ptr, &status, errflag); + MPIR_SCAN_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank > dst) { diff --git a/src/mpi/coll/scan/scan_intra_smp.c b/src/mpi/coll/scan/scan_intra_smp.c index 9ea89a81786..1b798979181 100644 --- a/src/mpi/coll/scan/scan_intra_smp.c +++ b/src/mpi/coll/scan/scan_intra_smp.c @@ -7,18 +7,21 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_CHKLMEM_DECL(3); - int rank = comm_ptr->rank; MPI_Status status; void *tempbuf = NULL, *localfulldata = NULL, *prefulldata = NULL; MPI_Aint true_lb, true_extent, extent; int noneed = 1; /* noneed=1 means no need to bcast tempbuf and * reduce tempbuf & recvbuf */ + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -28,12 +31,12 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, tempbuf = (void *) ((char *) tempbuf - true_lb); /* Create prefulldata and localfulldata on local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { MPIR_CHKLMEM_MALLOC(prefulldata, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno, "prefulldata for scan", MPL_MEM_BUFFER); prefulldata = (void *) ((char *) prefulldata - true_lb); - if (comm_ptr->node_comm != NULL) { + if (local_size > 1) { MPIR_CHKLMEM_MALLOC(localfulldata, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno, "localfulldata for scan", MPL_MEM_BUFFER); localfulldata = (void *) ((char *) localfulldata - true_lb); @@ -42,8 +45,9 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, /* perform intranode scan to get temporary result in recvbuf. if there is only * one process, just copy the raw data. */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Scan(sendbuf, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf, count, datatype); @@ -54,18 +58,15 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * contains the reduce result of the whole node. Name it as * localfulldata. For example, localfulldata from node 1 contains * reduced data of rank 1,2,3. */ - if (comm_ptr->node_roots_comm != NULL && comm_ptr->node_comm != NULL) { + if (local_rank == 0 && local_size > 1) { mpi_errno = MPIC_Recv(localfulldata, count, datatype, - comm_ptr->node_comm->local_size - 1, MPIR_SCAN_TAG, - comm_ptr->node_comm, &status); + local_size - 1, MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE, &status); MPIR_ERR_CHECK(mpi_errno); - } else if (comm_ptr->node_roots_comm == NULL && - comm_ptr->node_comm != NULL && - MPIR_Get_intranode_rank(comm_ptr, rank) == comm_ptr->node_comm->local_size - 1) { + } else if (local_rank > 0 && local_size > 1 && local_rank == local_size - 1) { mpi_errno = MPIC_Send(recvbuf, count, datatype, - 0, MPIR_SCAN_TAG, comm_ptr->node_comm, errflag); + 0, MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); - } else if (comm_ptr->node_roots_comm != NULL) { + } else if (local_rank == 0) { localfulldata = recvbuf; } @@ -73,21 +74,23 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * prefulldata on rank 4 contains reduce result of ranks * 1,2,3,4,5,6. it will be sent to rank 7 which is the * main process of node 3. */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { mpi_errno = MPIR_Scan(localfulldata, prefulldata, count, datatype, - op, comm_ptr->node_roots_comm, errflag); + op, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); - if (MPIR_Get_internode_rank(comm_ptr, rank) != comm_ptr->node_roots_comm->local_size - 1) { + int inter_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE_CROSS].rank; + int inter_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE_CROSS].size; + if (inter_rank != inter_size - 1) { mpi_errno = MPIC_Send(prefulldata, count, datatype, - MPIR_Get_internode_rank(comm_ptr, rank) + 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, errflag); + inter_rank + 1, + MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } - if (MPIR_Get_internode_rank(comm_ptr, rank) != 0) { + if (inter_rank != 0) { mpi_errno = MPIC_Recv(tempbuf, count, datatype, - MPIR_Get_internode_rank(comm_ptr, rank) - 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, &status); + inter_rank - 1, + MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, &status); noneed = 0; MPIR_ERR_CHECK(mpi_errno); } @@ -99,14 +102,15 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * then we should broadcast this result in the local node, and * reduce it with recvbuf to get final result if necessary. */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Bcast(&noneed, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } if (noneed == 0) { - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Bcast(tempbuf, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scatter/scatter_allcomm_nb.c b/src/mpi/coll/scatter/scatter_allcomm_nb.c index e344a0ed25c..16a53460c1a 100644 --- a/src/mpi/coll/scatter/scatter_allcomm_nb.c +++ b/src/mpi/coll/scatter/scatter_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Scatter_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Scatter_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatyp /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, - &req_ptr); + coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/scatter/scatter_inter_linear.c b/src/mpi/coll/scatter/scatter_inter_linear.c index 96e616b104e..0fcd47a06a5 100644 --- a/src/mpi/coll/scatter/scatter_inter_linear.c +++ b/src/mpi/coll/scatter/scatter_inter_linear.c @@ -14,7 +14,7 @@ int MPIR_Scatter_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int remote_size, mpi_errno = MPI_SUCCESS; int i; @@ -33,12 +33,12 @@ int MPIR_Scatter_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datat for (i = 0; i < remote_size; i++) { mpi_errno = MPIC_Send(((char *) sendbuf + sendcount * i * extent), sendcount, sendtype, i, - MPIR_SCATTER_TAG, comm_ptr, errflag); + MPIR_SCATTER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = - MPIC_Recv(recvbuf, recvcount, recvtype, root, MPIR_SCATTER_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(recvbuf, recvcount, recvtype, root, MPIR_SCATTER_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c b/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c index 81f5ce30245..def8d1ade0f 100644 --- a/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c +++ b/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c @@ -16,7 +16,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS; @@ -36,7 +36,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint s /* root sends all data to rank 0 on remote group and returns */ mpi_errno = MPIC_Send(sendbuf, sendcount * remote_size, sendtype, 0, MPIR_SCATTER_TAG, comm_ptr, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } else { @@ -54,7 +54,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint s "tmp_buf", MPL_MEM_BUFFER); mpi_errno = MPIC_Recv(tmp_buf, recvcount * local_size * recvtype_sz, MPI_BYTE, - root, MPIR_SCATTER_TAG, comm_ptr, &status); + root, MPIR_SCATTER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { /* silience -Wmaybe-uninitialized due to MPIR_Scatter by non-zero ranks */ @@ -69,7 +69,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint s /* now do the usual scatter on this intracommunicator */ mpi_errno = MPIR_Scatter(tmp_buf, recvcount * recvtype_sz, MPI_BYTE, - recvbuf, recvcount, recvtype, 0, newcomm_ptr, errflag); + recvbuf, recvcount, recvtype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scatter/scatter_intra_binomial.c b/src/mpi/coll/scatter/scatter_intra_binomial.c index dd211126582..1db7390e44a 100644 --- a/src/mpi/coll/scatter/scatter_intra_binomial.c +++ b/src/mpi/coll/scatter/scatter_intra_binomial.c @@ -28,7 +28,7 @@ /* not declared static because a machine-specific function may call this one in some cases */ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { MPI_Status status; MPI_Aint extent = 0; @@ -41,7 +41,7 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat int mpi_errno = MPI_SUCCESS; MPIR_CHKLMEM_DECL(4); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (rank == root) MPIR_Datatype_get_extent_macro(sendtype, extent); @@ -116,11 +116,11 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat * receive data into a temporary buffer. */ if (relative_rank % 2) { mpi_errno = MPIC_Recv(recvbuf, recvcount, recvtype, - src, MPIR_SCATTER_TAG, comm_ptr, &status); + src, MPIR_SCATTER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIC_Recv(tmp_buf, tmp_buf_size, MPI_BYTE, src, - MPIR_SCATTER_TAG, comm_ptr, &status); + MPIR_SCATTER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { curr_cnt = 0; @@ -152,14 +152,16 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat mpi_errno = MPIC_Send(((char *) sendbuf + extent * sendcount * mask), send_subtree_cnt, - sendtype, dst, MPIR_SCATTER_TAG, comm_ptr, errflag); + sendtype, dst, MPIR_SCATTER_TAG, comm_ptr, coll_group, + errflag); } else { /* non-zero root and others */ send_subtree_cnt = curr_cnt - nbytes * mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send(((char *) tmp_buf + nbytes * mask), send_subtree_cnt, - MPI_BYTE, dst, MPIR_SCATTER_TAG, comm_ptr, errflag); + MPI_BYTE, dst, MPIR_SCATTER_TAG, comm_ptr, coll_group, + errflag); } MPIR_ERR_CHECK(mpi_errno); curr_cnt -= send_subtree_cnt; diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c index 8612a34b09f..223027e6128 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c @@ -20,7 +20,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, comm_size, mpi_errno = MPI_SUCCESS; MPI_Aint extent; @@ -29,14 +29,17 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) - comm_size = comm_ptr->remote_size; - MPIR_Datatype_get_extent_macro(sendtype, extent); MPIR_CHKLMEM_MALLOC(reqarray, MPIR_Request **, comm_size * sizeof(MPIR_Request *), @@ -57,7 +60,8 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount } else { mpi_errno = MPIC_Isend(((char *) sendbuf + displs[i] * extent), sendcounts[i], sendtype, i, - MPIR_SCATTERV_TAG, comm_ptr, &reqarray[reqs++], errflag); + MPIR_SCATTERV_TAG, comm_ptr, coll_group, + &reqarray[reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); } } @@ -70,7 +74,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (recvcount) { mpi_errno = MPIC_Recv(recvbuf, recvcount, recvtype, root, - MPIR_SCATTERV_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_SCATTERV_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_nb.c b/src/mpi/coll/scatterv/scatterv_allcomm_nb.c index 953bb60b819..53bc8904361 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_nb.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_nb.c @@ -8,7 +8,7 @@ int MPIR_Scatterv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -16,7 +16,7 @@ int MPIR_Scatterv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, - comm_ptr, &req_ptr); + comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/src/coll_impl.c b/src/mpi/coll/src/coll_impl.c index b4b70a10b7a..46cb38c4e5a 100644 --- a/src/mpi/coll/src/coll_impl.c +++ b/src/mpi/coll/src/coll_impl.c @@ -224,8 +224,6 @@ int MPIR_Coll_comm_init(MPIR_Comm * comm) { int mpi_errno = MPI_SUCCESS; - comm->coll.pof2 = MPL_pof2(comm->local_size); - /* initialize any stub algo related data structures */ mpi_errno = MPII_Stubalgo_comm_init(comm); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/src/csel.c b/src/mpi/coll/src/csel.c index 1253f847f66..016c8e18a35 100644 --- a/src/mpi/coll/src/csel.c +++ b/src/mpi/coll/src/csel.c @@ -630,8 +630,8 @@ static csel_node_s *prune_tree(csel_node_s * root, MPIR_Comm * comm_ptr) break; case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_NODE_COMM_SIZE: - if (comm_ptr->node_comm != NULL && - MPIR_Comm_size(comm_ptr) == MPIR_Comm_size(comm_ptr->node_comm)) + /* comm_size equal to node_comm_size just mean the size inter-node is 1 */ + if (comm_ptr->num_external == 1) node = node->success; else node = node->failure; @@ -643,14 +643,14 @@ static csel_node_s *prune_tree(csel_node_s * root, MPIR_Comm * comm_ptr) else node = node->success; break; - +/* case CSEL_NODE_TYPE__OPERATOR__COMM_HIERARCHY: if (comm_ptr->hierarchy_kind == node->u.comm_hierarchy.val) node = node->success; else node = node->failure; break; - +*/ case CSEL_NODE_TYPE__OPERATOR__IS_NODE_CONSECUTIVE: if (MPII_Comm_is_node_consecutive(comm_ptr) == node->u.is_node_consecutive.val) node = node->success; @@ -659,14 +659,14 @@ static csel_node_s *prune_tree(csel_node_s * root, MPIR_Comm * comm_ptr) break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LE: - if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LT: - if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; @@ -924,7 +924,10 @@ static inline MPI_Aint get_count(MPIR_Csel_coll_sig_s coll_info) { MPI_Aint count = 0; int i = 0; - int comm_size = coll_info.comm_ptr->local_size; + + int comm_size, rank; + MPIR_COLL_RANK_SIZE(coll_info.comm_ptr, coll_info.coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ switch (coll_info.coll_type) { case MPIR_CSEL_COLL_TYPE__BCAST: @@ -981,7 +984,10 @@ static inline MPI_Aint get_count(MPIR_Csel_coll_sig_s coll_info) static inline MPI_Aint get_total_msgsize(MPIR_Csel_coll_sig_s coll_info) { MPI_Aint total_bytes = 0, i = 0, count = 0, typesize = 0; - int comm_size = coll_info.comm_ptr->local_size; + + int comm_size, rank; + MPIR_COLL_RANK_SIZE(coll_info.comm_ptr, coll_info.coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ switch (coll_info.coll_type) { case MPIR_CSEL_COLL_TYPE__ALLREDUCE: @@ -1182,6 +1188,7 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) csel_s *csel = (csel_s *) csel_; csel_node_s *node = NULL; MPIR_Comm *comm_ptr = coll_info.comm_ptr; + int coll_group = coll_info.coll_group; MPIR_Assert(csel_); @@ -1229,8 +1236,7 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_NODE_COMM_SIZE: - if (comm_ptr->node_comm != NULL && - MPIR_Comm_size(comm_ptr) == MPIR_Comm_size(comm_ptr->node_comm)) + if (comm_ptr->num_external == 1) node = node->success; else node = node->failure; @@ -1286,7 +1292,8 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COUNT_LT_POW2: - if (get_count(coll_info) < coll_info.comm_ptr->coll.pof2) + if (get_count(coll_info) < + MPL_pof2(MPIR_Coll_size(coll_info.comm_ptr, coll_info.coll_group))) node = node->success; else node = node->failure; @@ -1329,21 +1336,22 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LE: - if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LT: - if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; break; case CSEL_NODE_TYPE__OPERATOR__COMM_HIERARCHY: - if (coll_info.comm_ptr->hierarchy_kind == node->u.comm_hierarchy.val) + if (node->u.comm_hierarchy.val == MPIR_COMM_HIERARCHY_KIND__PARENT && + MPIR_Comm_is_parent_comm(comm_ptr, coll_group)) node = node->success; else node = node->failure; diff --git a/src/mpi/coll/transports/gentran/gentran_types.h b/src/mpi/coll/transports/gentran/gentran_types.h index 52063db1b5f..24a8891b7c3 100644 --- a/src/mpi/coll/transports/gentran/gentran_types.h +++ b/src/mpi/coll/transports/gentran/gentran_types.h @@ -49,6 +49,7 @@ typedef struct MPII_Genutil_vtx_t { int dest; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request *req; } isend; struct { @@ -58,6 +59,7 @@ typedef struct MPII_Genutil_vtx_t { int src; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request *req; } irecv; struct { @@ -67,6 +69,7 @@ typedef struct MPII_Genutil_vtx_t { int src; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request *req; MPI_Status *status; } irecv_status; @@ -78,6 +81,7 @@ typedef struct MPII_Genutil_vtx_t { int num_dests; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request **req; int last_complete; } imcast; diff --git a/src/mpi/coll/transports/gentran/gentran_utils.c b/src/mpi/coll/transports/gentran/gentran_utils.c index 94c75c56ecf..18f27da9b56 100644 --- a/src/mpi/coll/transports/gentran/gentran_utils.c +++ b/src/mpi/coll/transports/gentran/gentran_utils.c @@ -43,8 +43,8 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.isend.count, vtxp->u.isend.dt, vtxp->u.isend.dest, - vtxp->u.isend.tag, vtxp->u.isend.comm, &vtxp->u.isend.req, - r->u.nbc.errflag); + vtxp->u.isend.tag, vtxp->u.isend.comm, vtxp->u.isend.coll_group, + &vtxp->u.isend.req, r->u.nbc.errflag); if (MPIR_Request_is_complete(vtxp->u.isend.req)) { MPIR_Request_free(vtxp->u.isend.req); @@ -75,7 +75,7 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.irecv.count, vtxp->u.irecv.dt, vtxp->u.irecv.src, vtxp->u.irecv.tag, vtxp->u.irecv.comm, - &vtxp->u.irecv.req); + vtxp->u.irecv.coll_group, &vtxp->u.irecv.req); if (MPIR_Request_is_complete(vtxp->u.irecv.req)) { MPIR_Request_free(vtxp->u.irecv.req); @@ -104,7 +104,8 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.irecv_status.count, vtxp->u.irecv_status.dt, vtxp->u.irecv_status.src, vtxp->u.irecv_status.tag, - vtxp->u.irecv_status.comm, &vtxp->u.irecv_status.req); + vtxp->u.irecv_status.comm, vtxp->u.irecv_status.coll_group, + &vtxp->u.irecv_status.req); if (MPIR_Request_is_complete(vtxp->u.irecv_status.req)) { if (vtxp->u.irecv_status.status != MPI_STATUS_IGNORE) { @@ -143,7 +144,8 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.imcast.count, vtxp->u.imcast.dt, dests[i], - vtxp->u.imcast.tag, vtxp->u.imcast.comm, &vtxp->u.imcast.req[i], + vtxp->u.imcast.tag, vtxp->u.imcast.comm, + vtxp->u.imcast.coll_group, &vtxp->u.imcast.req[i], r->u.nbc.errflag); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, diff --git a/src/mpi/coll/transports/gentran/tsp_gentran.c b/src/mpi/coll/transports/gentran/tsp_gentran.c index 76c20855847..6aa253d5239 100644 --- a/src/mpi/coll/transports/gentran/tsp_gentran.c +++ b/src/mpi/coll/transports/gentran/tsp_gentran.c @@ -187,8 +187,8 @@ int MPIR_TSP_sched_isend(const void *buf, MPI_Datatype dt, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t s, int n_in_vtcs, int *in_vtcs, - int *vtx_id) + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t s, int n_in_vtcs, + int *in_vtcs, int *vtx_id) { MPII_Genutil_sched_t *sched = s; vtx_t *vtxp; @@ -205,6 +205,7 @@ int MPIR_TSP_sched_isend(const void *buf, vtxp->u.isend.dest = dest; vtxp->u.isend.tag = tag; vtxp->u.isend.comm = comm_ptr; + vtxp->u.isend.coll_group = coll_group; /* the user may free the comm & type after initiating but before the * underlying send is actually posted, so we must add a reference here and @@ -224,8 +225,8 @@ int MPIR_TSP_sched_irecv(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t s, int n_in_vtcs, int *in_vtcs, - int *vtx_id) + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t s, int n_in_vtcs, + int *in_vtcs, int *vtx_id) { MPII_Genutil_sched_t *sched = s; vtx_t *vtxp; @@ -243,6 +244,7 @@ int MPIR_TSP_sched_irecv(void *buf, vtxp->u.irecv.src = source; vtxp->u.irecv.tag = tag; vtxp->u.irecv.comm = comm_ptr; + vtxp->u.irecv.coll_group = coll_group; MPIR_Comm_add_ref(comm_ptr); MPIR_Datatype_add_ref_if_not_builtin(dt); @@ -258,7 +260,7 @@ int MPIR_TSP_sched_irecv_status(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status, + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, int *vtx_id) { vtx_t *vtxp; @@ -277,6 +279,7 @@ int MPIR_TSP_sched_irecv_status(void *buf, vtxp->u.irecv_status.src = source; vtxp->u.irecv_status.tag = tag; vtxp->u.irecv_status.comm = comm_ptr; + vtxp->u.irecv_status.coll_group = coll_group; vtxp->u.irecv_status.status = status; MPIR_Comm_add_ref(comm_ptr); @@ -296,8 +299,8 @@ int MPIR_TSP_sched_imcast(const void *buf, int *dests, int num_dests, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t s, int n_in_vtcs, int *in_vtcs, - int *vtx_id) + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t s, int n_in_vtcs, + int *in_vtcs, int *vtx_id) { MPII_Genutil_sched_t *sched = s; vtx_t *vtxp; @@ -317,6 +320,7 @@ int MPIR_TSP_sched_imcast(const void *buf, memcpy(ut_int_array(&vtxp->u.imcast.dests), dests, num_dests * sizeof(int)); vtxp->u.imcast.tag = tag; vtxp->u.imcast.comm = comm_ptr; + vtxp->u.imcast.coll_group = coll_group; vtxp->u.imcast.req = (struct MPIR_Request **) MPL_malloc(sizeof(struct MPIR_Request *) * num_dests, MPL_MEM_COLL); diff --git a/src/mpi/coll/transports/tsp_impl.h b/src/mpi/coll/transports/tsp_impl.h index fb83f6d6a74..d18828dfc6c 100644 --- a/src/mpi/coll/transports/tsp_impl.h +++ b/src/mpi/coll/transports/tsp_impl.h @@ -49,8 +49,8 @@ int MPIR_TSP_sched_isend(const void *buf, MPI_Datatype dt, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, - int *vtx_id); + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched, + int n_in_vtcs, int *in_vtcs, int *vtx_id); /* Transport function to schedule an irecv vertex */ int MPIR_TSP_sched_irecv(void *buf, @@ -58,8 +58,8 @@ int MPIR_TSP_sched_irecv(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, - int *vtx_id); + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched, + int n_in_vtcs, int *in_vtcs, int *vtx_id); /* Transport function to schedule a irecv with status vertex */ int MPIR_TSP_sched_irecv_status(void *buf, @@ -67,7 +67,7 @@ int MPIR_TSP_sched_irecv_status(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status, + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, int *vtx_id); /* Transport function to schedule an imcast vertex */ @@ -77,7 +77,7 @@ int MPIR_TSP_sched_imcast(const void *buf, int *dests, int num_dests, int tag, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, int *vtx_id); diff --git a/src/mpi/comm/comm_impl.c b/src/mpi/comm/comm_impl.c index 7fcb9828f74..6094f82b4d2 100644 --- a/src/mpi/comm/comm_impl.c +++ b/src/mpi/comm/comm_impl.c @@ -471,8 +471,8 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co info[1] = group_ptr->size; mpi_errno = MPIC_Sendrecv(info, 2, MPI_INT, 0, 0, - rinfo, 2, MPI_INT, 0, 0, comm_ptr, MPI_STATUS_IGNORE, - MPIR_ERR_NONE); + rinfo, 2, MPI_INT, 0, 0, comm_ptr, MPIR_SUBGROUP_NONE, + MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*newcomm_ptr != NULL) { (*newcomm_ptr)->context_id = rinfo[0]; @@ -486,19 +486,23 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co /* Populate and exchange the ranks */ mpi_errno = MPIC_Sendrecv(mapping, group_ptr->size, MPI_INT, 0, 0, remote_mapping, remote_size, MPI_INT, 0, 0, - comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Broadcast to the other members of the local group */ - mpi_errno = MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast(remote_mapping, remote_size, MPI_INT, 0, - comm_ptr->local_comm, MPIR_ERR_NONE); + comm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else { /* The other processes */ /* Broadcast to the other members of the local group */ - mpi_errno = MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*newcomm_ptr != NULL) { (*newcomm_ptr)->context_id = rinfo[0]; @@ -508,7 +512,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co remote_size * sizeof(int), mpi_errno, "remote_mapping", MPL_MEM_ADDRESS); mpi_errno = MPIR_Bcast(remote_mapping, remote_size, MPI_INT, 0, - comm_ptr->local_comm, MPIR_ERR_NONE); + comm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -1029,7 +1033,7 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader, mpi_errno = MPIC_Sendrecv(&recvcontext_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, remote_leader, tag, &remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, remote_leader, tag, - peer_comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); final_context_id = remote_context_id; @@ -1038,14 +1042,18 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader, * along with the final context id */ comm_info[0] = final_context_id; MPL_DBG_MSG(MPIR_DBG_COMM, VERBOSE, "About to bcast on local_comm"); - mpi_errno = MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_D(MPIR_DBG_COMM, VERBOSE, "end of bcast on local_comm of size %d", local_comm_ptr->local_size); } else { /* we're the other processes */ MPL_DBG_MSG(MPIR_DBG_COMM, VERBOSE, "About to receive bcast on local_comm"); - mpi_errno = MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Extract the context and group sign information */ @@ -1198,7 +1206,7 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i /* This routine allows use to use the collective communication * context rather than the point-to-point context. */ mpi_errno = MPIC_Sendrecv(&local_high, 1, MPI_INT, 0, 0, - &remote_high, 1, MPI_INT, 0, 0, comm_ptr, + &remote_high, 1, MPI_INT, 0, 0, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -1217,7 +1225,9 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i * value of local_high, which may have changed if both groups * of processes had the same value for high */ - mpi_errno = MPIR_Bcast(&local_high, 1, MPI_INT, 0, comm_ptr->local_comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(&local_high, 1, MPI_INT, 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* diff --git a/src/mpi/comm/comm_split.c b/src/mpi/comm/comm_split.c index 94f722f4ef9..f8685836054 100644 --- a/src/mpi/comm/comm_split.c +++ b/src/mpi/comm/comm_split.c @@ -114,7 +114,8 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** } /* Gather information on the local group of processes */ mpi_errno = - MPIR_Allgather(MPI_IN_PLACE, 2, MPI_INT, table, 2, MPI_INT, local_comm_ptr, MPIR_ERR_NONE); + MPIR_Allgather(MPI_IN_PLACE, 2, MPI_INT, table, 2, MPI_INT, local_comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Step 2: How many processes have our same color? */ @@ -161,7 +162,7 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** mypair.color = color; mypair.key = key; mpi_errno = MPIR_Allgather(&mypair, 2, MPI_INT, remotetable, 2, MPI_INT, - comm_ptr, MPIR_ERR_NONE); + comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Each process can now match its color with the entries in the table */ @@ -216,11 +217,12 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** if (comm_ptr->rank == 0) { mpi_errno = MPIC_Sendrecv(&new_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, 0, &remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, - 0, 0, comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + 0, 0, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast(&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, local_comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (!in_newcomm) { @@ -230,7 +232,7 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** /* Broadcast to the other members of the local group */ mpi_errno = MPIR_Bcast(&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, local_comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/comm/comm_split_type_nbhd.c b/src/mpi/comm/comm_split_type_nbhd.c index 9e76911aafb..9fbaa730ca0 100644 --- a/src/mpi/comm/comm_split_type_nbhd.c +++ b/src/mpi/comm/comm_split_type_nbhd.c @@ -277,7 +277,7 @@ static int network_split_by_minsize(MPIR_Comm * comm_ptr, int key, int subcomm_m /* Send the count to processes */ mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, num_processes_at_node, num_nodes, MPI_INT, - MPI_SUM, comm_ptr, MPIR_ERR_NONE); + MPI_SUM, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (topo_type == MPIR_NETTOPO_TYPE__FAT_TREE || topo_type == MPIR_NETTOPO_TYPE__CLOS_NETWORK) { @@ -377,7 +377,7 @@ static int network_split_by_minsize(MPIR_Comm * comm_ptr, int key, int subcomm_m /* get min tree depth to all processes */ MPIR_Allreduce(&tree_depth, &min_tree_depth, 1, MPI_INT, MPI_MIN, node_comm, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (min_tree_depth) { int num_hwloc_objs_at_depth; @@ -391,7 +391,7 @@ static int network_split_by_minsize(MPIR_Comm * comm_ptr, int key, int subcomm_m /* get parent_idx to all processes */ MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, parent_idx, 1, MPI_INT, - node_comm, MPIR_ERR_NONE); + node_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); /* reorder parent indices */ for (i = 0; i < num_procs - 1; i++) { @@ -474,12 +474,7 @@ static int network_split_by_min_memsize(MPIR_Comm * comm_ptr, int key, long min_ if (min_mem_size == 0 || topo_type == MPIR_NETTOPO_TYPE__INVALID) { *newcomm_ptr = NULL; } else { - int num_ranks_node; - if (MPIR_Process.comm_world->node_comm != NULL) { - num_ranks_node = MPIR_Comm_size(MPIR_Process.comm_world->node_comm); - } else { - num_ranks_node = 1; - } + int num_ranks_node = MPIR_Process.local_size; memory_per_process = total_memory_size / num_ranks_node; mpi_errno = network_split_by_minsize(comm_ptr, key, min_mem_size / memory_per_process, newcomm_ptr); diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index e930b943e33..3c9a2a69193 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -309,6 +309,7 @@ int MPII_Comm_init(MPIR_Comm * comm_p) comm_p->node_roots_comm = NULL; comm_p->intranode_table = NULL; comm_p->internode_table = NULL; + comm_p->num_subgroups = 0; /* abstractions bleed a bit here... :(*/ comm_p->next_sched_tag = MPIR_FIRST_NBC_TAG; @@ -523,58 +524,6 @@ int MPIR_Comm_map_free(MPIR_Comm * comm) return mpi_errno; } -static int get_node_count(MPIR_Comm * comm, int *node_count) -{ - int mpi_errno = MPI_SUCCESS; - struct uniq_nodes { - int id; - UT_hash_handle hh; - } *node_list = NULL; - struct uniq_nodes *s, *tmp; - - if (comm->comm_kind != MPIR_COMM_KIND__INTRACOMM) { - *node_count = comm->local_size; - goto fn_exit; - } else if (comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE) { - *node_count = 1; - goto fn_exit; - } else if (comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE_ROOTS) { - *node_count = comm->local_size; - goto fn_exit; - } - - /* go through the list of ranks and add the unique ones to the - * node_list array */ - for (int i = 0; i < comm->local_size; i++) { - int node; - - mpi_errno = MPID_Get_node_id(comm, i, &node); - MPIR_ERR_CHECK(mpi_errno); - - HASH_FIND_INT(node_list, &node, s); - if (s == NULL) { - s = (struct uniq_nodes *) MPL_malloc(sizeof(struct uniq_nodes), MPL_MEM_COLL); - MPIR_Assert(s); - s->id = node; - HASH_ADD_INT(node_list, id, s, MPL_MEM_COLL); - } - } - - /* the final size of our hash table is our node count */ - *node_count = HASH_COUNT(node_list); - - /* free up everything */ - HASH_ITER(hh, node_list, s, tmp) { - HASH_DEL(node_list, s); - MPL_free(s); - } - - fn_exit: - return mpi_errno; - fn_fail: - goto fn_exit; -} - static int MPIR_Comm_commit_internal(MPIR_Comm * comm) { int mpi_errno = MPI_SUCCESS; @@ -584,9 +533,6 @@ static int MPIR_Comm_commit_internal(MPIR_Comm * comm) mpi_errno = MPID_Comm_commit_pre_hook(comm); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = get_node_count(comm, &comm->node_count); - MPIR_ERR_CHECK(mpi_errno); - MPIR_Comm_map_free(comm); fn_exit: @@ -665,6 +611,25 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm) MPIR_Assert(num_local > 1 || external_rank >= 0); MPIR_Assert(external_rank < 0 || external_procs != NULL); + comm->num_local = num_local; + comm->num_external = num_external; + + /* node */ +#define NODE_GROUP(field) comm->subgroups[MPIR_SUBGROUP_NODE].field + NODE_GROUP(rank) = local_rank; + NODE_GROUP(size) = num_local; + NODE_GROUP(proc_table) = MPL_malloc(num_local * sizeof(int), MPL_MEM_OTHER); + for (int i = 0; i < num_local; i++) { + NODE_GROUP(proc_table)[i] = local_procs[i]; + } +#define NODE_CROSS_GROUP(field) comm->subgroups[MPIR_SUBGROUP_NODE_CROSS].field + NODE_CROSS_GROUP(rank) = external_rank; + NODE_CROSS_GROUP(size) = num_external; + NODE_CROSS_GROUP(proc_table) = MPL_malloc(num_external * sizeof(int), MPL_MEM_OTHER); + for (int i = 0; i < num_external; i++) { + NODE_CROSS_GROUP(proc_table)[i] = external_procs[i]; + } + /* if the node_roots_comm and comm would be the same size, then creating * the second communicator is useless and wasteful. */ if (num_external == comm->remote_size) { @@ -758,7 +723,8 @@ static int init_comm_seq(MPIR_Comm * comm) /* Every rank need share the same seq from root. NOTE: it is possible for * different communicators to have the same seq. It is only used as an * opportunistic optimization */ - mpi_errno = MPIR_Bcast_allcomm_auto(&tmp, 1, MPI_INT, 0, comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_allcomm_auto(&tmp, 1, MPI_INT, 0, comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); comm->seq = tmp; @@ -789,6 +755,14 @@ int MPIR_Comm_commit(MPIR_Comm * comm) MPIR_FUNC_ENTER; + /* preset reserved subgroups */ + comm->num_subgroups = MPIR_SUBGROUP_NUM_RESERVED; + for (int i = 0; i < comm->num_subgroups; i++) { + comm->subgroups[i].rank = -1; + comm->subgroups[i].size = 0; + comm->subgroups[i].proc_table = NULL; + } + /* It's OK to relax these assertions, but we should do so very * intentionally. For now this function is the only place that we create * our hierarchy of communicators */ @@ -847,9 +821,11 @@ int MPIR_Comm_commit(MPIR_Comm * comm) /* Returns true if the given communicator is aware of node topology information, false otherwise. Such information could be used to implement more efficient collective communication, for example. */ -int MPIR_Comm_is_parent_comm(MPIR_Comm * comm) +int MPIR_Comm_is_parent_comm(MPIR_Comm * comm, int coll_group) { - return (comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT); + return (coll_group == MPIR_SUBGROUP_NONE && + comm->num_external > 1 && comm->num_external != comm->remote_size && + comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT); } /* Returns true if the communicator is node-aware and processes in all the nodes @@ -860,7 +836,7 @@ int MPII_Comm_is_node_consecutive(MPIR_Comm * comm) int i = 0, curr_nodeidx = 0; int *internode_table = comm->internode_table; - if (!MPIR_Comm_is_parent_comm(comm)) + if (!MPIR_Comm_is_parent_comm(comm, MPIR_SUBGROUP_NONE)) return 0; for (; i < comm->local_size; i++) { @@ -1220,6 +1196,12 @@ int MPIR_Comm_delete_internal(MPIR_Comm * comm_ptr) MPL_free(comm_ptr->intranode_table); MPL_free(comm_ptr->internode_table); + /* free subgroups */ + for (int i = 0; i < comm_ptr->num_subgroups; i++) { + MPL_free(comm_ptr->subgroups[i].proc_table); + } + comm_ptr->num_subgroups = 0; + MPIR_stream_comm_free(comm_ptr); /* Free the context value. This should come after freeing the @@ -1303,11 +1285,14 @@ int MPII_collect_info_key(MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, const char } int is_equal; - mpi_errno = MPIR_Allreduce_equal(&hint_str_size, 1, MPI_INT, &is_equal, comm_ptr); + mpi_errno = + MPIR_Allreduce_equal(&hint_str_size, 1, MPI_INT, &is_equal, comm_ptr, MPIR_SUBGROUP_NONE); MPIR_ERR_CHECK(mpi_errno); if (is_equal && hint_str_size > 0) { - mpi_errno = MPIR_Allreduce_equal(hint_str, hint_str_size, MPI_CHAR, &is_equal, comm_ptr); + mpi_errno = + MPIR_Allreduce_equal(hint_str, hint_str_size, MPI_CHAR, &is_equal, comm_ptr, + MPIR_SUBGROUP_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -1334,7 +1319,7 @@ int MPII_Comm_is_node_balanced(MPIR_Comm * comm, int *num_nodes, bool * node_bal MPIR_CHKPMEM_DECL(1); - if (!MPIR_Comm_is_parent_comm(comm)) { + if (!MPIR_Comm_is_parent_comm(comm, MPIR_SUBGROUP_NONE)) { *node_balanced = false; goto fn_exit; } diff --git a/src/mpi/comm/contextid.c b/src/mpi/comm/contextid.c index 8bee3890caf..d10b9833d6f 100644 --- a/src/mpi/comm/contextid.c +++ b/src/mpi/comm/contextid.c @@ -462,7 +462,8 @@ int MPIR_Get_contextid_sparse_group(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr MPIR_ERR_NONE); } else { mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, st.local_mask, MPIR_MAX_CONTEXT_MASK + 1, - MPI_INT, MPI_BAND, comm_ptr, MPIR_ERR_NONE); + MPI_INT, MPI_BAND, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } MPIR_ERR_CHECK(mpi_errno); @@ -562,7 +563,8 @@ int MPIR_Get_contextid_sparse_group(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr comm_ptr, group_ptr, coll_tag, MPIR_ERR_NONE); } else { mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, &minfree, 1, MPI_INT, - MPI_MIN, comm_ptr, MPIR_ERR_NONE); + MPI_MIN, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } if (minfree > 0) { @@ -644,18 +646,18 @@ static int sched_cb_gcn_bcast(MPIR_Comm * comm, int tag, void *state) if (st->comm_ptr_inter->rank == 0) { mpi_errno = MPIR_Sched_recv(st->ctx1, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, st->comm_ptr_inter, - st->s); + MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_send(st->ctx0, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, st->comm_ptr_inter, - st->s); + MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(st->s); } mpi_errno = MPIR_Ibcast_intra_sched_auto(st->ctx1, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, st->comm_ptr, - st->s); + MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(st->s); } @@ -733,7 +735,7 @@ static int sched_cb_gcn_allocate_cid(MPIR_Comm * comm, int tag, void *state) */ /* FIXME: study and resolve */ /* - * mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &minfree, 1, MPI_INT, MPI_MIN, st->comm_ptr, MPIR_ERR_NONE); + * mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &minfree, 1, MPI_INT, MPI_MIN, st->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); * MPIR_ERR_CHECK(mpi_errno); */ if (minfree > 0) { @@ -760,7 +762,7 @@ static int sched_cb_gcn_allocate_cid(MPIR_Comm * comm, int tag, void *state) * are not necessarily completed in the same order as they are issued, also on the * same communicator. To avoid deadlocks, we cannot add the elements to the * list bevfore the first iallreduce is completed. The "tag" is created for the - * scheduling - by calling MPIR_Sched_next_tag(comm_ptr, &tag) - and the same + * scheduling - by calling MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag) - and the same * for a idup operation on all processes. So we use it here. */ /* FIXME I'm not sure if there can be an overflows for this tag */ st->tag = (uint64_t) tag + MPIR_Process.attrs.tag_ub; @@ -837,7 +839,7 @@ static int sched_cb_gcn_copy_mask(MPIR_Comm * comm, int tag, void *state) mpi_errno = MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, st->local_mask, MPIR_MAX_CONTEXT_MASK + 1, MPI_UINT32_T, MPI_BAND, - st->comm_ptr, st->s); + st->comm_ptr, MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(st->s); @@ -943,7 +945,7 @@ int MPIR_Get_contextid_nonblock(MPIR_Comm * comm_ptr, MPIR_Comm * newcommp, MPIR MPIR_FUNC_ENTER; /* now create a schedule */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_create(&s, MPIR_SCHED_KIND_GENERALIZED); MPIR_ERR_CHECK(mpi_errno); @@ -984,7 +986,7 @@ int MPIR_Get_intercomm_contextid_nonblock(MPIR_Comm * comm_ptr, MPIR_Comm * newc } /* now create a schedule */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_create(&s, MPIR_SCHED_KIND_GENERALIZED); MPIR_ERR_CHECK(mpi_errno); @@ -1056,14 +1058,14 @@ int MPIR_Get_intercomm_contextid(MPIR_Comm * comm_ptr, MPIR_Context_id_t * conte if (comm_ptr->rank == 0) { mpi_errno = MPIC_Sendrecv(&mycontext_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, tag, &remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, tag, - comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } /* Make sure that all of the local processes now have this * id */ mpi_errno = MPIR_Bcast_impl(&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, - 0, comm_ptr->local_comm, MPIR_ERR_NONE); + 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* The recvcontext_id must be the one that was allocated out of the local * group, not the remote group. Otherwise we could end up posting two diff --git a/src/mpi/init/init_async.c b/src/mpi/init/init_async.c index a7f4e1f4807..2e8c82de94a 100644 --- a/src/mpi/init/init_async.c +++ b/src/mpi/init/init_async.c @@ -179,17 +179,14 @@ static int get_thread_affinity(bool * apply_affinity, int **p_thread_affinity, i } global_rank = MPIR_Process.rank; - local_rank = - (MPIR_Process.comm_world->node_comm) ? MPIR_Process.comm_world->node_comm->rank : 0; + local_rank = MPIR_Process.local_rank; if (have_cliques) { - /* If local cliques > 1, using local_size from node_comm will have conflict on thread idx. + /* If local cliques > 1, using local_size will have conflict on thread idx. * In multiple nodes case, this would cost extra memory for allocating thread affinity on every * node, but it is okay to solve progress thread oversubscription. */ local_size = MPIR_Process.comm_world->local_size; } else { - local_size = - (MPIR_Process.comm_world->node_comm) ? MPIR_Process.comm_world-> - node_comm->local_size : 1; + local_size = MPIR_Process.local_size; } async_threads_per_node = local_size; diff --git a/src/mpi/stream/stream_enqueue.c b/src/mpi/stream/stream_enqueue.c index fe50b4b693f..b8367ed5599 100644 --- a/src/mpi/stream/stream_enqueue.c +++ b/src/mpi/stream/stream_enqueue.c @@ -611,8 +611,9 @@ static void allreduce_enqueue_cb(void *data) } } - mpi_errno = MPIR_Allreduce(sendbuf, recvbuf, p->count, p->datatype, p->op, p->comm_ptr, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Allreduce(sendbuf, recvbuf, p->count, p->datatype, p->op, p->comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_Assertp(mpi_errno == MPI_SUCCESS); if (p->host_recvbuf) { diff --git a/src/mpi/stream/stream_impl.c b/src/mpi/stream/stream_impl.c index 1223ab8a5d5..069288eb8cf 100644 --- a/src/mpi/stream/stream_impl.c +++ b/src/mpi/stream/stream_impl.c @@ -269,7 +269,8 @@ int MPIR_Stream_comm_create_impl(MPIR_Comm * comm_ptr, MPIR_Stream * stream_ptr, MPIR_ERR_CHKANDJUMP(!vci_table, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = - MPIR_Allgather_impl(&vci, 1, MPI_INT, vci_table, 1, MPI_INT, comm_ptr, MPIR_ERR_NONE); + MPIR_Allgather_impl(&vci, 1, MPI_INT, vci_table, 1, MPI_INT, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); (*newcomm_ptr)->stream_comm_type = MPIR_STREAM_COMM_SINGLE; @@ -313,7 +314,8 @@ int MPIR_Stream_comm_create_multiplex_impl(MPIR_Comm * comm_ptr, MPI_Aint num_tmp = num_streams; mpi_errno = MPIR_Allgather_impl(&num_tmp, 1, MPI_AINT, - num_table, 1, MPI_AINT, comm_ptr, MPIR_ERR_NONE); + num_table, 1, MPI_AINT, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPI_Aint num_total = 0; @@ -347,7 +349,7 @@ int MPIR_Stream_comm_create_multiplex_impl(MPIR_Comm * comm_ptr, mpi_errno = MPIR_Allgatherv_impl(local_vcis, num_streams, MPI_INT, vci_table, num_table, displs, MPI_INT, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); (*newcomm_ptr)->stream_comm_type = MPIR_STREAM_COMM_MULTIPLEX; diff --git a/src/mpi/threadcomm/threadcomm_coll_impl.c b/src/mpi/threadcomm/threadcomm_coll_impl.c index 23951e5ccbf..529b415d17d 100644 --- a/src/mpi/threadcomm/threadcomm_coll_impl.c +++ b/src/mpi/threadcomm/threadcomm_coll_impl.c @@ -34,7 +34,7 @@ int MPIR_Threadcomm_barrier_impl(MPIR_Comm * comm) if (comm->local_size == 1) { thread_barrier(comm->threadcomm); } else { - mpi_errno = MPIR_Barrier_intra_dissemination(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_intra_dissemination(comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); } return mpi_errno; @@ -45,7 +45,9 @@ int MPIR_Threadcomm_bcast_impl(void *buffer, MPI_Aint count, MPI_Datatype dataty { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm, MPIR_SUBGROUP_THREADCOMM, + MPIR_ERR_NONE); return mpi_errno; } @@ -57,7 +59,8 @@ int MPIR_Threadcomm_gather_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Dat int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Gather_intra_binomial(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, root, comm, MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, root, comm, + MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -71,7 +74,7 @@ int MPIR_Threadcomm_gatherv_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Da mpi_errno = MPIR_Gatherv_allcomm_linear(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -84,7 +87,7 @@ int MPIR_Threadcomm_scatter_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Da mpi_errno = MPIR_Scatter_intra_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -98,7 +101,7 @@ int MPIR_Threadcomm_scatterv_impl(const void *sendbuf, const MPI_Aint * sendcoun mpi_errno = MPIR_Scatterv_allcomm_linear(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -110,7 +113,8 @@ int MPIR_Threadcomm_allgather_impl(const void *sendbuf, MPI_Aint sendcount, MPI_ int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allgather_intra_brucks(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, comm, + MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -124,7 +128,7 @@ int MPIR_Threadcomm_allgatherv_impl(const void *sendbuf, MPI_Aint sendcount, MPI mpi_errno = MPIR_Allgatherv_intra_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -137,7 +141,8 @@ int MPIR_Threadcomm_alltoall_impl(const void *sendbuf, MPI_Aint sendcount, MPI_D MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoall_intra_brucks(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, comm, + MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -153,7 +158,7 @@ int MPIR_Threadcomm_alltoallv_impl(const void *sendbuf, const MPI_Aint * sendcou MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoallv_intra_scattered(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -169,7 +174,7 @@ int MPIR_Threadcomm_alltoallw_impl(const void *sendbuf, const MPI_Aint * sendcou MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoallw_intra_scattered(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -181,7 +186,8 @@ int MPIR_Threadcomm_allreduce_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allreduce_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, + MPIR_ERR_NONE); return mpi_errno; } @@ -193,7 +199,7 @@ int MPIR_Threadcomm_reduce_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_intra_binomial(sendbuf, recvbuf, count, datatype, op, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -206,7 +212,9 @@ int MPIR_Threadcomm_reduce_scatter_impl(const void *sendbuf, void *recvbuf, MPIR_Assert(MPIR_Op_is_commutative(op)); mpi_errno = MPIR_Reduce_scatter_intra_recursive_halving(sendbuf, recvbuf, recvcounts, - datatype, op, comm, MPIR_ERR_NONE); + datatype, op, comm, + MPIR_SUBGROUP_THREADCOMM, + MPIR_ERR_NONE); return mpi_errno; } @@ -220,7 +228,8 @@ int MPIR_Threadcomm_reduce_scatter_block_impl(const void *sendbuf, void *recvbuf MPIR_Assert(MPIR_Op_is_commutative(op)); mpi_errno = MPIR_Reduce_scatter_block_intra_recursive_halving(sendbuf, recvbuf, recvcount, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, + MPIR_ERR_NONE); return mpi_errno; } @@ -231,7 +240,7 @@ int MPIR_Threadcomm_scan_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Scan_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -242,7 +251,7 @@ int MPIR_Threadcomm_exscan_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Exscan_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } diff --git a/src/mpi/threadcomm/threadcomm_impl.c b/src/mpi/threadcomm/threadcomm_impl.c index 8161ecad5fa..8b96ddcf95c 100644 --- a/src/mpi/threadcomm/threadcomm_impl.c +++ b/src/mpi/threadcomm/threadcomm_impl.c @@ -34,8 +34,9 @@ int MPIR_Threadcomm_init_impl(MPIR_Comm * comm, int num_threads, MPIR_Comm ** co threads_table = MPL_malloc(comm_size * sizeof(int), MPL_MEM_OTHER); MPIR_ERR_CHKANDJUMP(!threads_table, mpi_errno, MPI_ERR_OTHER, "**nomem"); - mpi_errno = MPIR_Allgather_impl(&num_threads, 1, MPI_INT, threads_table, 1, MPI_INT, comm, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Allgather_impl(&num_threads, 1, MPI_INT, threads_table, 1, MPI_INT, comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); int *rank_offset_table;; diff --git a/src/mpi/topo/dist_graph_create.c b/src/mpi/topo/dist_graph_create.c index ffe9c27c850..07b7270fdee 100644 --- a/src/mpi/topo/dist_graph_create.c +++ b/src/mpi/topo/dist_graph_create.c @@ -133,7 +133,8 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, /* compute the number of peers I will recv from */ int in_out_peers[2] = { -1, 1 }; mpi_errno = - MPIR_Reduce_scatter_block(rs, in_out_peers, 2, MPI_INT, MPI_SUM, comm_ptr, MPIR_ERR_NONE); + MPIR_Reduce_scatter_block(rs, in_out_peers, 2, MPI_INT, MPI_SUM, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPIR_Assert(in_out_peers[0] <= comm_size && in_out_peers[0] >= 0); @@ -150,14 +151,14 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, /* send edges where i is a destination to process i */ mpi_errno = MPIC_Isend(&rin[i][0], rin_sizes[i], MPI_INT, i, MPIR_TOPO_A_TAG, comm_ptr, - &reqs[idx++], MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, &reqs[idx++], MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } if (rout_sizes[i]) { /* send edges where i is a source to process i */ mpi_errno = MPIC_Isend(&rout[i][0], rout_sizes[i], MPI_INT, i, MPIR_TOPO_B_TAG, comm_ptr, - &reqs[idx++], MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, &reqs[idx++], MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } } @@ -203,7 +204,7 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, MPIR_ERR_CHKANDJUMP(!buf, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = MPIC_Recv(buf, count, MPI_INT, MPI_ANY_SOURCE, MPIR_TOPO_A_TAG, - comm_ptr, MPI_STATUS_IGNORE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); for (int j = 0; j < count / 2; ++j) { @@ -236,7 +237,7 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, MPIR_ERR_CHKANDJUMP(!buf, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = MPIC_Recv(buf, count, MPI_INT, MPI_ANY_SOURCE, MPIR_TOPO_B_TAG, - comm_ptr, MPI_STATUS_IGNORE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); for (int j = 0; j < count / 2; ++j) { diff --git a/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c b/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c index ca4039a9d9d..1760f00fce8 100644 --- a/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c +++ b/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c @@ -201,7 +201,7 @@ static int MPIDI_CH3I_SHM_Wins_match(MPIR_Win ** win_ptr, MPIR_Win ** matched_wi base_shm_offs[node_rank] = (MPI_Aint) ((*win_ptr)->base) - (MPI_Aint) (shm_win->shm_base_addr); mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - base_shm_offs, 1, MPI_AINT, node_comm_ptr, MPIR_ERR_NONE); + base_shm_offs, 1, MPI_AINT, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); base_diff = 0; @@ -345,12 +345,12 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* wait for other processes to attach to win */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -362,7 +362,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, /* get serialized handle from rank 0 and deserialize it */ mpi_errno = - MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -376,7 +376,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, &(*win_ptr)->info_shm_base_addr, 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -391,7 +391,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, tmp_buf[4 * comm_rank + 3] = (MPI_Aint) (*win_ptr)->handle; mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, tmp_buf, 4, MPI_AINT, - (*win_ptr)->comm_ptr, MPIR_ERR_NONE); + (*win_ptr)->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (node_rank == 0) { @@ -406,7 +406,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, } /* Make sure that all local processes see the results written by node_rank == 0 */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -479,7 +479,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, node_sizes, sizeof(MPI_Aint), MPI_BYTE, - node_comm_ptr, MPIR_ERR_NONE); + node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); MPIR_ERR_CHECK(mpi_errno); @@ -518,12 +518,12 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* wait for other processes to attach to win */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -535,7 +535,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * /* get serialized handle from rank 0 and deserialize it */ mpi_errno = - MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -550,7 +550,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * &(*win_ptr)->shm_base_addr, 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -577,12 +577,12 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* wait for other processes to attach to win */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -594,7 +594,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * /* get serialized handle from rank 0 and deserialize it */ mpi_errno = - MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -609,7 +609,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * (void **) &(*win_ptr)->shm_mutex, 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch3/include/mpid_coll.h b/src/mpid/ch3/include/mpid_coll.h index 7de2103a3f9..7a2c47cb23a 100644 --- a/src/mpid/ch3/include/mpid_coll.h +++ b/src/mpid/ch3/include/mpid_coll.h @@ -11,39 +11,39 @@ #include "../../common/hcoll/hcoll.h" #endif -static inline int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +static inline int MPID_Barrier(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Barrier(comm, errflag)) return MPI_SUCCESS; #endif - return MPIR_Barrier_impl(comm, errflag); + return MPIR_Barrier_impl(comm, coll_group, errflag); } static inline int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Bcast(buffer, count, datatype, root, comm, errflag)) return MPI_SUCCESS; #endif - return MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); } static inline int MPID_Allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag)) return MPI_SUCCESS; #endif - return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); } static inline int MPID_Allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Allgather(sendbuf, sendcount, sendtype, recvbuf, @@ -51,17 +51,17 @@ static inline int MPID_Allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Da return MPI_SUCCESS; #endif return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, errflag); } static inline int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm, + recvcounts, displs, recvtype, comm, coll_group, errflag); return mpi_errno; @@ -69,25 +69,25 @@ static inline int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_D static inline int MPID_Scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Scatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, + recvbuf, recvcount, recvtype, root, comm, coll_group, errflag); return mpi_errno; @@ -95,25 +95,25 @@ static inline int MPID_Scatterv(const void *sendbuf, const MPI_Aint * sendcounts static inline int MPID_Gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, + recvcounts, displs, recvtype, root, comm, coll_group, errflag); return mpi_errno; @@ -121,26 +121,26 @@ static inline int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Data static inline int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Alltoallv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, - const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, + const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, - comm, errflag); + comm, coll_group, errflag); return mpi_errno; } @@ -148,75 +148,75 @@ static inline int MPID_Alltoallv(const void *sendbuf, const MPI_Aint * sendcount static inline int MPID_Alltoallw(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, + const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); return mpi_errno; } static inline int MPID_Reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); return mpi_errno; } static inline int MPID_Reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, + datatype, op, comm_ptr, coll_group, errflag); return mpi_errno; } static inline int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); return mpi_errno; @@ -366,70 +366,70 @@ static inline int MPID_Ineighbor_alltoallw(const void *sendbuf, const MPI_Aint s return mpi_errno; } -static inline int MPID_Ibarrier(MPIR_Comm * comm, MPIR_Request **request) +static inline int MPID_Ibarrier(MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibarrier_impl(comm, request); + mpi_errno = MPIR_Ibarrier_impl(comm, coll_group, request); return mpi_errno; } static inline int MPID_Ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, request); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, request); return mpi_errno; } static inline int MPID_Iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, request); + recvcount, recvtype, comm, coll_group, request); return mpi_errno; } static inline int MPID_Iallgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm, MPIR_Request **request) + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm, + recvcounts, displs, recvtype, comm, coll_group, request); return mpi_errno; } static inline int MPID_Iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, - comm, request); + comm, coll_group, request); return mpi_errno; } static inline int MPID_Ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, request); + recvcount, recvtype, comm, coll_group, request); return mpi_errno; } @@ -438,13 +438,13 @@ static inline int MPID_Ialltoallv(const void *sendbuf, const MPI_Aint sendcounts const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, - comm, request); + comm, coll_group, request); return mpi_errno; } @@ -453,24 +453,24 @@ static inline int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint sendcounts const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, - comm, request); + comm, coll_group, request); return mpi_errno; } static inline int MPID_Iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); return mpi_errno; @@ -478,71 +478,71 @@ static inline int MPID_Iexscan(const void *sendbuf, void *recvbuf, MPI_Aint coun static inline int MPID_Igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request **request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); return mpi_errno; } static inline int MPID_Igatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, + recvcounts, displs, recvtype, root, comm, coll_group, request); return mpi_errno; } static inline int MPID_Ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm, request); + datatype, op, comm, coll_group, request); return mpi_errno; } static inline int MPID_Ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm, request); + datatype, op, comm, coll_group, request); return mpi_errno; } static inline int MPID_Ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm * comm, MPIR_Request **request) + MPI_Op op, int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, - comm, request); + comm, coll_group, request); return mpi_errno; } static inline int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Request **request) + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); return mpi_errno; @@ -550,12 +550,12 @@ static inline int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, static inline int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request **request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); return mpi_errno; } @@ -563,12 +563,12 @@ static inline int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Dat static inline int MPID_Iscatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request **request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, + recvbuf, recvcount, recvtype, root, comm, coll_group, request); return mpi_errno; diff --git a/src/mpid/ch3/include/mpidpre.h b/src/mpid/ch3/include/mpidpre.h index 9a98093d328..63835586fdf 100644 --- a/src/mpid/ch3/include/mpidpre.h +++ b/src/mpid/ch3/include/mpidpre.h @@ -614,71 +614,71 @@ int MPID_Recv_init( void *buf, MPI_Aint count, MPI_Datatype datatype, int MPID_Startall(int count, MPIR_Request *requests[]); int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request **request); + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request); int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Op op, int root, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request); int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request); + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request); int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request); + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request); int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Datatype recvtype, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request); int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); + int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Neighbor_allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, diff --git a/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c b/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c index 7c1feb92dae..c1c9d94c940 100644 --- a/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c +++ b/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c @@ -75,13 +75,13 @@ int MPIDI_Comm_spawn_multiple(int count, char **commands, } if (errcodes != MPI_ERRCODES_IGNORE) { - mpi_errno = MPIR_Bcast(&should_accept, 1, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast(&should_accept, 1, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Bcast(&total_num_processes, 1, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast(&total_num_processes, 1, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Bcast(errcodes, total_num_processes, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast(errcodes, total_num_processes, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch3/src/ch3u_port.c b/src/mpid/ch3/src/ch3u_port.c index 3ef43d48aab..390eedfd0f3 100644 --- a/src/mpid/ch3/src/ch3u_port.c +++ b/src/mpid/ch3/src/ch3u_port.c @@ -646,7 +646,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, send_ints[0], send_ints[1], send_ints[2])); mpi_errno = MPIC_Sendrecv(send_ints, 3, MPI_INT, 0, sendtag++, recv_ints, 3, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) { /* this is a no_port error because we may fail to connect @@ -657,7 +657,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, /* broadcast the received info to local processes */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"broadcasting the received 3 ints"); - mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* check if root was unable to connect to the port */ @@ -689,7 +689,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, mpi_errno = MPIC_Sendrecv(local_translation, local_comm_size * 2, MPI_INT, 0, sendtag++, remote_translation, remote_comm_size * 2, - MPI_INT, 0, recvtag++, tmp_comm, + MPI_INT, 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -711,7 +711,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, /* Broadcast out the remote rank translation array */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Broadcasting remote translation"); mpi_errno = MPIR_Bcast_allcomm_auto(remote_translation, remote_comm_size * 2, MPI_INT, - root, comm_ptr, MPIR_ERR_NONE); + root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); #ifdef MPICH_DBG_OUTPUT @@ -740,7 +740,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"sync with peer"); mpi_errno = MPIC_Sendrecv(&i, 0, MPI_INT, 0, sendtag++, &j, 0, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -749,7 +749,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, } /*printf("connect:barrier\n");fflush(stdout);*/ - mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Free new_vc. It was explicitly allocated in MPIDI_CH3_Connect_to_root.*/ @@ -795,7 +795,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, /* notify other processes to return an error */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"broadcasting 3 ints: error case"); - mpi_errno2 = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno2 = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno2) MPIR_ERR_ADD(mpi_errno, mpi_errno2); goto fn_fail; } @@ -928,7 +928,7 @@ static int ReceivePGAndDistribute( MPIR_Comm *tmp_comm, MPIR_Comm *comm_ptr, if (rank == root) { /* First, receive the pg description from the partner */ mpi_errno = MPIC_Recv(&j, 1, MPI_INT, 0, recvtag++, - tmp_comm, MPI_STATUS_IGNORE); + tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); *recvtag_p = recvtag; MPIR_ERR_CHECK(mpi_errno); pg_str = (char*)MPL_malloc(j, MPL_MEM_DYNAMIC); @@ -936,14 +936,14 @@ static int ReceivePGAndDistribute( MPIR_Comm *tmp_comm, MPIR_Comm *comm_ptr, MPIR_ERR_POP(mpi_errno); } mpi_errno = MPIC_Recv(pg_str, j, MPI_CHAR, 0, recvtag++, - tmp_comm, MPI_STATUS_IGNORE); + tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); *recvtag_p = recvtag; MPIR_ERR_CHECK(mpi_errno); } /* Broadcast the size and data to the local communicator */ /*printf("accept:broadcasting 1 int\n");fflush(stdout);*/ - mpi_errno = MPIR_Bcast_allcomm_auto(&j, 1, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(&j, 1, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (rank != root) { @@ -954,7 +954,7 @@ static int ReceivePGAndDistribute( MPIR_Comm *tmp_comm, MPIR_Comm *comm_ptr, } } /*printf("accept:broadcasting string of length %d\n", j);fflush(stdout);*/ - mpi_errno = MPIR_Bcast_allcomm_auto(pg_str, j, MPI_CHAR, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(pg_str, j, MPI_CHAR, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Then reconstruct the received process group. This step also initializes the created process group */ @@ -998,7 +998,7 @@ int MPID_PG_BCast( MPIR_Comm *peercomm_p, MPIR_Comm *comm_p, int root ) } /* Now, broadcast the number of local pgs */ - mpi_errno = MPIR_Bcast( &n_local_pgs, 1, MPI_INT, root, comm_p, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast( &n_local_pgs, 1, MPI_INT, root, comm_p, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); pg_list = pg_head; @@ -1018,7 +1018,7 @@ int MPID_PG_BCast( MPIR_Comm *peercomm_p, MPIR_Comm *comm_p, int root ) len = pg_list->lenStr; pg_list = pg_list->next; } - mpi_errno = MPIR_Bcast( &len, 1, MPI_INT, root, comm_p, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast( &len, 1, MPI_INT, root, comm_p, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (rank != root) { pg_str = (char *)MPL_malloc(len, MPL_MEM_DYNAMIC); @@ -1027,7 +1027,7 @@ int MPID_PG_BCast( MPIR_Comm *peercomm_p, MPIR_Comm *comm_p, int root ) goto fn_exit; } } - mpi_errno = MPIR_Bcast( pg_str, len, MPI_CHAR, root, comm_p, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast( pg_str, len, MPI_CHAR, root, comm_p, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno) { if (rank != root) MPL_free( pg_str ); @@ -1083,13 +1083,13 @@ static int SendPGtoPeerAndFree( MPIR_Comm *tmp_comm, int *sendtag_p, pg_iter = pg_list; i = pg_iter->lenStr; /*printf("connect:sending 1 int: %d\n", i);fflush(stdout);*/ - mpi_errno = MPIC_Send(&i, 1, MPI_INT, 0, sendtag++, tmp_comm, MPIR_ERR_NONE); + mpi_errno = MPIC_Send(&i, 1, MPI_INT, 0, sendtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); *sendtag_p = sendtag; MPIR_ERR_CHECK(mpi_errno); /* printf("connect:sending string length %d\n", i);fflush(stdout); */ mpi_errno = MPIC_Send(pg_iter->str, i, MPI_CHAR, 0, sendtag++, - tmp_comm, MPIR_ERR_NONE); + tmp_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); *sendtag_p = sendtag; MPIR_ERR_CHECK(mpi_errno); @@ -1182,14 +1182,14 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, /*printf("accept:sending 3 ints, %d, %d, %d, and receiving 2 ints\n", send_ints[0], send_ints[1], send_ints[2]);fflush(stdout);*/ mpi_errno = MPIC_Sendrecv(send_ints, 3, MPI_INT, 0, sendtag++, recv_ints, 3, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } /* broadcast the received info to local processes */ /*printf("accept:broadcasting 2 ints - %d and %d\n", recv_ints[0], recv_ints[1]);fflush(stdout);*/ - mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -1221,7 +1221,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, mpi_errno = MPIC_Sendrecv(local_translation, local_comm_size * 2, MPI_INT, 0, sendtag++, remote_translation, remote_comm_size * 2, - MPI_INT, 0, recvtag++, tmp_comm, + MPI_INT, 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -1244,7 +1244,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, /* Broadcast out the remote rank translation array */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Broadcast remote_translation"); mpi_errno = MPIR_Bcast_allcomm_auto(remote_translation, remote_comm_size * 2, MPI_INT, - root, comm_ptr, MPIR_ERR_NONE); + root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); #ifdef MPICH_DBG_OUTPUT MPL_DBG_MSG_D(MPIDI_CH3_DBG_OTHER,TERSE,"[%d]accept:Received remote_translation after broadcast:\n", rank); @@ -1271,7 +1271,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"sync with peer"); mpi_errno = MPIC_Sendrecv(&i, 0, MPI_INT, 0, sendtag++, &j, 0, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -1280,7 +1280,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, } MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Barrier"); - mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Free new_vc once the connection is completed. It was explicitly @@ -1360,7 +1360,7 @@ static int SetupNewIntercomm( MPIR_Comm *comm_ptr, int remote_comm_size, MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Barrier"); - mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpid/ch3/src/ch3u_recvq.c b/src/mpid/ch3/src/ch3u_recvq.c index d30938ae36a..10742ae4e98 100644 --- a/src/mpid/ch3/src/ch3u_recvq.c +++ b/src/mpid/ch3/src/ch3u_recvq.c @@ -921,7 +921,7 @@ int MPIDI_CH3U_Clean_recvq(MPIR_Comm *comm_ptr) } } - if (MPIR_Comm_is_parent_comm(comm_ptr)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, MPIR_SUBGROUP_NONE)) { /* node_comm pt2pt */ match.parts.context_id = comm_ptr->recvcontext_id + MPIR_CONTEXT_INTRANODE_OFFSET; @@ -1014,7 +1014,7 @@ int MPIDI_CH3U_Clean_recvq(MPIR_Comm *comm_ptr) } } - if (MPIR_Comm_is_parent_comm(comm_ptr)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, MPIR_SUBGROUP_NONE)) { /* node_comm coll */ match.parts.context_id = comm_ptr->recvcontext_id + MPIR_CONTEXT_INTRANODE_OFFSET; diff --git a/src/mpid/ch3/src/ch3u_rma_sync.c b/src/mpid/ch3/src/ch3u_rma_sync.c index 6463c17ab4a..8510fb6a10d 100644 --- a/src/mpid/ch3/src/ch3u_rma_sync.c +++ b/src/mpid/ch3/src/ch3u_rma_sync.c @@ -489,11 +489,11 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) if (win_ptr->shm_allocated == TRUE) { MPIR_Comm *node_comm_ptr = win_ptr->comm_ptr->node_comm; - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, &fence_sync_req_ptr); + mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, &fence_sync_req_ptr); MPIR_ERR_CHECK(mpi_errno); if (fence_sync_req_ptr == NULL) { @@ -539,7 +539,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) win_ptr->at_completion_counter += comm_size; mpi_errno = MPIR_Reduce_scatter_block(MPI_IN_PLACE, rma_target_marks, 1, - MPI_INT, MPI_SUM, win_ptr->comm_ptr, MPIR_ERR_NONE); + MPI_INT, MPI_SUM, win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); win_ptr->at_completion_counter -= comm_size; @@ -579,7 +579,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) MPIR_ERR_CHECK(mpi_errno); if (scalable_fence_enabled) { - mpi_errno = MPIR_Barrier(win_ptr->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Set window access state properly. */ @@ -604,7 +604,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) MPIR_Request* fence_sync_req_ptr; /* Prepare for the next possible epoch */ - mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, &fence_sync_req_ptr); + mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, &fence_sync_req_ptr); MPIR_ERR_CHECK(mpi_errno); if (fence_sync_req_ptr == NULL) { @@ -629,7 +629,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) if (win_ptr->shm_allocated == TRUE) { MPIR_Comm *node_comm_ptr = win_ptr->comm_ptr->node_comm; - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpid/ch3/src/ch3u_win_fns.c b/src/mpid/ch3/src/ch3u_win_fns.c index 0891b21d723..5e1b3083016 100644 --- a/src/mpid/ch3/src/ch3u_win_fns.c +++ b/src/mpid/ch3/src/ch3u_win_fns.c @@ -62,7 +62,7 @@ int MPIDI_CH3U_Win_gather_info(void *base, MPI_Aint size, int disp_unit, tmp_buf[4 * rank + 3] = (MPI_Aint) (*win_ptr)->handle; mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - tmp_buf, 4, MPI_AINT, (*win_ptr)->comm_ptr, MPIR_ERR_NONE); + tmp_buf, 4, MPI_AINT, (*win_ptr)->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c b/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c index 0c1ed9377c2..588512a9d23 100644 --- a/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c +++ b/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c @@ -107,7 +107,7 @@ int MPID_Comm_get_all_failed_procs(MPIR_Comm *comm_ptr, MPIR_Group **failed_grou for (i = 1; i < comm_ptr->local_size; i++) { /* Get everyone's list of failed processes to aggregate */ ret_errno = MPIC_Recv(remote_bitarray, bitarray_size, MPI_INT, - i, tag, comm_ptr, MPI_STATUS_IGNORE); + i, tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); if (ret_errno) continue; /* Combine the received bitarray with my own */ @@ -121,7 +121,7 @@ int MPID_Comm_get_all_failed_procs(MPIR_Comm *comm_ptr, MPIR_Group **failed_grou for (i = 1; i < comm_ptr->local_size; i++) { /* Send the list to each rank to be processed locally */ ret_errno = MPIC_Send(bitarray, bitarray_size, MPI_INT, i, - tag, comm_ptr, MPIR_ERR_NONE); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (ret_errno) continue; } @@ -130,11 +130,11 @@ int MPID_Comm_get_all_failed_procs(MPIR_Comm *comm_ptr, MPIR_Group **failed_grou } else { /* Send my bitarray to rank 0 */ mpi_errno = MPIC_Send(bitarray, bitarray_size, MPI_INT, 0, - tag, comm_ptr, MPIR_ERR_NONE); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); /* Get the resulting bitarray back from rank 0 */ mpi_errno = MPIC_Recv(remote_bitarray, bitarray_size, MPI_INT, 0, - tag, comm_ptr, MPI_STATUS_IGNORE); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); /* Convert the bitarray into a group */ *failed_group = bitarray_to_group(comm_ptr, remote_bitarray); diff --git a/src/mpid/ch3/src/mpid_startall.c b/src/mpid/ch3/src/mpid_startall.c index cba93847a43..034d0c95bb1 100644 --- a/src/mpid/ch3/src/mpid_startall.c +++ b/src/mpid/ch3/src/mpid_startall.c @@ -317,12 +317,12 @@ int MPID_Recv_init(void * buf, MPI_Aint count, MPI_Datatype datatype, int rank, } int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request **request) + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, info_ptr, request); + mpi_errno = MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -334,13 +334,13 @@ int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int roo } int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, info_ptr, + mpi_errno = MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -353,13 +353,13 @@ int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_ } int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Op op, int root, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + mpi_errno = MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -373,13 +373,13 @@ int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Dat int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request) + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoall_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -392,14 +392,14 @@ int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sen int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallv_init_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm_ptr, info_ptr, + recvcounts, rdispls, recvtype, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -414,13 +414,13 @@ int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallw_init_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm_ptr, info_ptr, + recvcounts, rdispls, recvtypes, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -434,13 +434,13 @@ int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request) + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -453,14 +453,14 @@ int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype se int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Datatype recvtype, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, info_ptr, request); + displs, recvtype, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -472,13 +472,13 @@ int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype s } int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, + mpi_errno = MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -491,13 +491,13 @@ int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint } int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, + mpi_errno = MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -510,12 +510,12 @@ int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint } int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -527,14 +527,14 @@ int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datat } int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -547,13 +547,13 @@ int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendt int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -565,14 +565,14 @@ int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype send } int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -585,14 +585,14 @@ int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype send int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_init_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -603,12 +603,12 @@ int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const M goto fn_exit; } -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Barrier_init_impl(comm, info, request); + mpi_errno = MPIR_Barrier_init_impl(comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -620,12 +620,12 @@ int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** reques } int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); diff --git a/src/mpid/ch3/src/mpid_vc.c b/src/mpid/ch3/src/mpid_vc.c index 81cb71c91e6..3fd867ab269 100644 --- a/src/mpid/ch3/src/mpid_vc.c +++ b/src/mpid/ch3/src/mpid_vc.c @@ -492,7 +492,7 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, remote_leader, cts_tag, remote_size, 1, MPI_INT, remote_leader, cts_tag, - peer_comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE ); + peer_comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIDI_CH3_DBG_OTHER,VERBOSE,(MPL_DBG_FDEST, "local size = %d, remote size = %d", local_size, @@ -511,7 +511,7 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, mpi_errno = MPIC_Sendrecv( local_gpids, local_size*sizeof(MPIDI_Gpid), MPI_BYTE, remote_leader, cts_tag, remote_gpids, (*remote_size)*sizeof(MPIDI_Gpid), MPI_BYTE, - remote_leader, cts_tag, peer_comm_ptr, + remote_leader, cts_tag, peer_comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); @@ -554,10 +554,10 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, comm_info[0] = *remote_size; comm_info[1] = *is_low_group; MPL_DBG_MSG(MPIDI_CH3_DBG_OTHER,VERBOSE,"About to bcast on local_comm"); - mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE ); + mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast( remote_gpids, (*remote_size)*sizeof(MPIDI_Gpid), MPI_BYTE, local_leader, - local_comm_ptr, MPIR_ERR_NONE ); + local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_D(MPIDI_CH3_DBG_OTHER,VERBOSE,"end of bcast on local_comm of size %d", local_comm_ptr->local_size ); @@ -566,13 +566,13 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, { /* we're the other processes */ MPL_DBG_MSG(MPIDI_CH3_DBG_OTHER,VERBOSE,"About to receive bcast on local_comm"); - mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE ); + mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); *remote_size = comm_info[0]; MPIR_CHKLMEM_MALLOC(remote_gpids,MPIDI_Gpid*,(*remote_size)*sizeof(MPIDI_Gpid), mpi_errno,"remote_gpids", MPL_MEM_DYNAMIC); *remote_lpids = (uint64_t*) MPL_malloc((*remote_size)*sizeof(uint64_t), MPL_MEM_ADDRESS); mpi_errno = MPIR_Bcast( remote_gpids, (*remote_size)*sizeof(MPIDI_Gpid), MPI_BYTE, local_leader, - local_comm_ptr, MPIR_ERR_NONE ); + local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); /* Extract the context and group sign information */ @@ -736,7 +736,7 @@ int MPIDI_PG_ForwardPGInfo( MPIR_Comm *peer_ptr, MPIR_Comm *comm_ptr, } /* See if everyone is happy */ - mpi_errno = MPIR_Allreduce( MPI_IN_PLACE, &allfound, 1, MPI_INT, MPI_LAND, comm_ptr, MPIR_ERR_NONE ); + mpi_errno = MPIR_Allreduce( MPI_IN_PLACE, &allfound, 1, MPI_INT, MPI_LAND, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); if (allfound) return MPI_SUCCESS; diff --git a/src/mpid/ch3/src/mpidi_rma.c b/src/mpid/ch3/src/mpidi_rma.c index 8d86d47ff91..2715e87d8b7 100644 --- a/src/mpid/ch3/src/mpidi_rma.c +++ b/src/mpid/ch3/src/mpidi_rma.c @@ -164,7 +164,7 @@ int MPID_Win_free(MPIR_Win ** win_ptr) MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Barrier((*win_ptr)->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier((*win_ptr)->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Free window resources in lower layer. */ diff --git a/src/mpid/ch4/ch4_api.txt b/src/mpid/ch4/ch4_api.txt index 96ace1800b8..dbff9925425 100644 --- a/src/mpid/ch4/ch4_api.txt +++ b/src/mpid/ch4/ch4_api.txt @@ -275,56 +275,56 @@ Native API: rank_is_local : int NM*: target, comm mpi_barrier : int - NM*: comm, errflag - SHM*: comm, errflag + NM*: comm, coll_group, errflag + SHM*: comm, coll_group, errflag mpi_bcast : int - NM*: buffer, count, datatype, root, comm, errflag - SHM*: buffer, count, datatype, root, comm, errflag + NM*: buffer, count, datatype, root, comm, coll_group, errflag + SHM*: buffer, count, datatype, root, comm, coll_group, errflag mpi_allreduce : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, comm, errflag + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag mpi_allgather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag mpi_allgatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, errflag mpi_scatter : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag mpi_scatterv : int - NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, errflag - SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, errflag + NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag mpi_gather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag mpi_gatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, errflag mpi_alltoall : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag mpi_alltoallv : int - NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, errflag - SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, errflag + NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, errflag mpi_alltoallw : int - NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, errflag - SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, errflag + NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, errflag + SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, errflag mpi_reduce : int - NM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag + NM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag mpi_reduce_scatter : int - NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag - SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag + NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, errflag mpi_reduce_scatter_block : int - NM*: sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, errflag - SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, errflag + NM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, errflag mpi_scan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, comm, errflag + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag mpi_exscan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, comm, errflag + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag mpi_neighbor_allgather : int NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm @@ -356,56 +356,56 @@ Native API: NM*: sendbuf, sendcounts, sdispls-2, sendtypes, recvbuf, recvcounts, rdispls-2, recvtypes, comm, req_p SHM*: sendbuf, sendcounts, sdispls-2, sendtypes, recvbuf, recvcounts, rdispls-2, recvtypes, comm, req_p mpi_ibarrier : int - NM*: comm, req_p - SHM*: comm, req_p + NM*: comm, coll_group, req_p + SHM*: comm, coll_group, req_p mpi_ibcast : int - NM*: buffer, count, datatype, root, comm, req_p - SHM*: buffer, count, datatype, root, comm, req_p + NM*: buffer, count, datatype, root, comm, coll_group, req_p + SHM*: buffer, count, datatype, root, comm, coll_group, req_p mpi_iallgather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p mpi_iallgatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, req_p mpi_iallreduce : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p mpi_ialltoall : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p mpi_ialltoallv : int - NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req_p - SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req_p + NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req_p mpi_ialltoallw : int - NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req_p - SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req_p + NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req_p + SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req_p mpi_iexscan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p mpi_igather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p mpi_igatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, req_p mpi_ireduce_scatter_block : int - NM*: sendbuf, recvbuf, recvcount, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, req_p mpi_ireduce_scatter : int - NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, req_p mpi_ireduce : int - NM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req_p + NM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req_p mpi_iscan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p mpi_iscatter : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p mpi_iscatterv : int - NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, req_p - SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, req_p + NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p mpi_type_commit_hook : int NM : datatype_p SHM : type @@ -440,6 +440,7 @@ PARAM: buf: const void * buf-2: void * buffer: void * + coll_group: int comm: MPIR_Comm * comm_ptr: MPIR_Comm * compare_addr: const void * diff --git a/src/mpid/ch4/include/mpidch4.h b/src/mpid/ch4/include/mpidch4.h index 3dd3528efbc..4fc7ad8e999 100644 --- a/src/mpid/ch4/include/mpidch4.h +++ b/src/mpid/ch4/include/mpidch4.h @@ -175,52 +175,62 @@ int MPID_Comm_set_hints(MPIR_Comm *, MPIR_Info *); int MPID_Comm_commit_post_hook(MPIR_Comm *); int MPID_Stream_create_hook(MPIR_Stream * stream); int MPID_Stream_free_hook(MPIR_Stream * stream); -MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; +MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *, MPI_Aint, MPI_Datatype, int, MPIR_Comm *, - MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, void *, MPI_Aint, MPI_Datatype, int, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, int, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, MPIR_Comm *, + int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], MPIR_Comm *, + int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, int, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *, void *, const MPI_Aint[], - MPI_Datatype, MPI_Op, MPIR_Comm *, + MPI_Datatype, MPI_Op, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *, void *, MPI_Aint, MPI_Datatype, - MPI_Op, MPIR_Comm *, + MPI_Op, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Neighbor_allgather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, MPI_Datatype, MPIR_Comm *) MPL_STATIC_INLINE_SUFFIX; @@ -261,118 +271,119 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ineighbor_alltoallw(const void *, const MPI_Ai void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; -MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; +MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ibcast(void *, MPI_Aint, MPI_Datatype, int, MPIR_Comm *, - MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iallgather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iallgatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, - MPIR_Comm *, + MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iallreduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ialltoall(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallv(const void *, const MPI_Aint[], const MPI_Aint[], MPI_Datatype, void *, const MPI_Aint[], const MPI_Aint[], MPI_Datatype, MPIR_Comm *, + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], MPIR_Comm *, + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iexscan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Igather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Igatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, int, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter_block(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, MPIR_Comm *, + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter(const void *, void *, const MPI_Aint[], - MPI_Datatype, MPI_Op, MPIR_Comm *, + MPI_Datatype, MPI_Op, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ireduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, int, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iscan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iscatter(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iscatterv(const void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, void *, MPI_Aint, MPI_Datatype, int, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; -int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], - const MPI_Aint sdispls[], MPI_Datatype sendtype, - void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], - const MPI_Aint sdispls[], - const MPI_Datatype sendtypes[], - void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], - const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - const MPI_Aint * recvcounts, - const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, - MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, - const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Info * info, MPIR_Request ** request); -int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Info * info, MPIR_Request ** request); -int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], - const MPI_Aint displs[], MPI_Datatype sendtype, - void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; +int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, + MPI_Op op, int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Info * info_ptr, MPIR_Request ** request); +int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, + void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], + MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], + const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); +int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], + const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], + const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, + void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, + void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Info * info_ptr, MPIR_Request ** request); +int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Info * info, MPIR_Request ** request); +int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, + MPIR_Info * info, MPIR_Request ** request); +int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request); +int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Info * info, MPIR_Request ** request); +int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, + const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, + int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request); +int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Info * info, MPIR_Request ** request); +int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], + MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Neighbor_allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, diff --git a/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h b/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h index 9ea0e1f6048..7515d06e583 100644 --- a/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h +++ b/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h @@ -6,93 +6,99 @@ #ifndef NETMOD_AM_FALLBACK_COLL_H_INCLUDED #define NETMOD_AM_FALLBACK_COLL_H_INCLUDED -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Barrier_impl(comm_ptr, errflag); + return MPIR_Barrier_impl(comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm_ptr, errflag); + recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, @@ -100,10 +106,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, @@ -113,50 +121,57 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag); + return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_neighbor_allgather(const void *sendbuf, @@ -288,25 +303,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ineighbor_alltoallw(const void *sendbu rdispls, recvtypes, comm_ptr, req); } -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibarrier_impl(comm_ptr, req); + return MPIR_Ibarrier_impl(comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, req); + return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_Aint sendcount, @@ -314,27 +332,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, req); + recvcounts, displs, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { - return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, @@ -344,10 +363,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, @@ -357,81 +376,87 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, req); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, req); + return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { - return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { return MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, @@ -439,10 +464,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { return MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); } #endif /* NETMOD_AM_FALLBACK_COLL_H_INCLUDED */ diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h index cd32092ccd7..acf822b005e 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h @@ -29,7 +29,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_rma(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - int tree_type, + int coll_group, int tree_type, int branching_factor) { int mpi_errno = MPI_SUCCESS; @@ -53,13 +53,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_rma(void *buffer, i /* Invoke the helper function to perform one-sided knomial tree-based Ibcast */ mpi_errno = MPIDI_OFI_Ibcast_knomial_triggered_rma(buffer, count, datatype, root, comm_ptr, - tree_type, branching_factor, &num_children, - &snd_cntr, &rcv_cntr, &r_mr, &works, &my_tree, - myrank, nranks, &num_works); + coll_group, tree_type, branching_factor, + &num_children, &snd_cntr, &rcv_cntr, &r_mr, + &works, &my_tree, myrank, nranks, &num_works); } else { /* Invoke the helper function to perform one-sided kary tree-based Ibcast */ mpi_errno = - MPIDI_OFI_Ibcast_kary_triggered_rma(buffer, count, datatype, root, comm_ptr, + MPIDI_OFI_Ibcast_kary_triggered_rma(buffer, count, datatype, root, comm_ptr, coll_group, branching_factor, &leaf, &num_children, &snd_cntr, &rcv_cntr, &r_mr, &works, myrank, nranks, &num_works); diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h index 46fdc79a163..a1077eb93d7 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h @@ -28,7 +28,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_tagged(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - int tree_type, + int coll_group, int tree_type, int branching_factor) { int mpi_errno = MPI_SUCCESS; @@ -52,16 +52,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_tagged(void *buffer if (tree_type == MPIR_TREE_TYPE_KNOMIAL_1 || tree_type == MPIR_TREE_TYPE_KNOMIAL_2) { mpi_errno = MPIDI_OFI_Ibcast_knomial_triggered_tagged(buffer, count, datatype, root, comm_ptr, - tree_type, branching_factor, &num_children, - &snd_cntr, &rcv_cntr, &works, &my_tree, - myrank, nranks, &num_works); + coll_group, tree_type, branching_factor, + &num_children, &snd_cntr, &rcv_cntr, &works, + &my_tree, myrank, nranks, &num_works); } else { /* Invoke the helper function to perform kary tree-based Ibcast */ mpi_errno = MPIDI_OFI_Ibcast_kary_triggered_tagged(buffer, count, datatype, root, comm_ptr, - branching_factor, &leaf, &num_children, - &snd_cntr, &rcv_cntr, &works, myrank, nranks, - &num_works); + coll_group, branching_factor, &leaf, + &num_children, &snd_cntr, &rcv_cntr, &works, + myrank, nranks, &num_works); } /* Wait for the completion counters to reach their desired values */ diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h index 43edc7fe6cf..65ad84d638b 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h @@ -36,6 +36,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, int tree_type, int branching_factor, int *num_children, @@ -89,7 +90,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf sizeof(struct fi_deferred_work), MPL_MEM_BUFFER); MPIR_ERR_CHKANDSTMT(*works == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -109,7 +110,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf } i = i + j; - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -194,6 +195,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, int branching_factor, int *is_leaf, int *children, struct fid_cntr **snd_cntr, @@ -238,7 +240,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer MPIR_ERR_CHKANDSTMT(*works == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -259,7 +261,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer } i = i + j; - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -349,7 +351,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - int tree_type, + int coll_group, int tree_type, int branching_factor, int *num_children, struct fid_cntr **snd_cntr, @@ -419,7 +421,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer MPIR_ERR_CHKANDJUMP1(*works == NULL, mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", "Triggered bcast deferred work alloc"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -505,6 +507,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_rma(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, int branching_factor, int *is_leaf, int *children, struct fid_cntr **snd_cntr, @@ -564,7 +567,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_rma(void *buffer, i sizeof(struct fi_deferred_work), MPL_MEM_BUFFER); MPIR_ERR_CHKANDSTMT(*works == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); diff --git a/src/mpid/ch4/netmod/ofi/init_addrxchg.c b/src/mpid/ch4/netmod/ofi/init_addrxchg.c index f323565b663..5a3fec99785 100644 --- a/src/mpid/ch4/netmod/ofi/init_addrxchg.c +++ b/src/mpid/ch4/netmod/ofi/init_addrxchg.c @@ -222,7 +222,8 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) MPIR_CHKLMEM_MALLOC(all_num_vcis, void *, sizeof(int) * size, mpi_errno, "all_num_vcis", MPL_MEM_ADDRESS); mpi_errno = MPIR_Allgather_fallback(&MPIDI_OFI_global.num_vcis, 1, MPI_INT, - all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); + all_num_vcis, 1, MPI_INT, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); max_vcis = 0; @@ -261,7 +262,8 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) } /* Allgather */ mpi_errno = MPIR_Allgather_fallback(MPI_IN_PLACE, 0, MPI_BYTE, - all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + all_names, my_len, MPI_BYTE, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); /* Step 2: insert and store non-root nic/vci on the root context */ int root_ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); @@ -335,7 +337,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) } } } - mpi_errno = MPIR_Barrier_fallback(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_fallback(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* check */ diff --git a/src/mpid/ch4/netmod/ofi/ofi_coll.h b/src/mpid/ch4/netmod/ofi/ofi_coll.h index ddc7279272f..c9a13eeb107 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_coll.h +++ b/src/mpid/ch4/netmod/ofi/ofi_coll.h @@ -30,13 +30,14 @@ === END_MPI_T_CVAR_INFO_BLOCK === */ -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Barrier_impl(comm, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -49,11 +50,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag } static inline int MPIDI_OFI_bcast_json(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -63,7 +65,8 @@ static inline int MPIDI_OFI_bcast_json(void *buffer, MPI_Aint count, MPI_Datatyp } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; enum fi_datatype fi_dt; @@ -79,7 +82,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP "Bcast triggered_tagged cannot be applied.\n"); mpi_errno = MPIDI_OFI_Bcast_intra_triggered_tagged(buffer, count, datatype, root, comm, - MPIR_Bcast_tree_type, + coll_group, MPIR_Bcast_tree_type, MPIR_CVAR_BCAST_TREE_KVAL); break; case MPIR_CVAR_BCAST_OFI_INTRA_ALGORITHM_trigger_tree_rma: @@ -89,7 +92,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP NULL) != -1, mpi_errno, "Bcast triggered_rma cannot be applied.\n"); mpi_errno = - MPIDI_OFI_Bcast_intra_triggered_rma(buffer, count, datatype, root, comm, + MPIDI_OFI_Bcast_intra_triggered_rma(buffer, count, datatype, root, comm, coll_group, MPIR_Bcast_tree_type, MPIR_CVAR_BCAST_TREE_KVAL); break; @@ -97,7 +100,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP goto fallback; break; case MPIR_CVAR_BCAST_OFI_INTRA_ALGORITHM_auto: - mpi_errno = MPIDI_OFI_bcast_json(buffer, count, datatype, root, comm, errflag); + mpi_errno = + MPIDI_OFI_bcast_json(buffer, count, datatype, root, comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -106,7 +110,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP goto fn_exit; fallback: - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_FUNC_EXIT; @@ -119,14 +123,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -141,14 +146,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *r MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -164,14 +170,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, errflag); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -186,14 +194,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -209,7 +218,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -217,7 +226,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, errflag); + recvcounts, displs, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -233,7 +242,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -241,7 +250,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -257,14 +266,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -279,14 +289,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -303,7 +314,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -311,7 +323,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); + rdispls, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -330,7 +342,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -338,7 +351,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); + rdispls, recvtypes, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -352,13 +365,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -373,13 +388,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, errflag); + mpi_errno = + MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -394,14 +412,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, vo MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, errflag); + MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -415,13 +434,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -435,13 +454,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbu MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -635,12 +654,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ineighbor_alltoallw(const void *sendbu return mpi_errno; } -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibarrier_impl(comm, req); + mpi_errno = MPIR_Ibarrier_impl(comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -648,12 +668,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Reques MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, req); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -662,13 +682,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -679,13 +700,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, req); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -693,13 +714,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + mpi_errno = + MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -708,13 +730,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -727,13 +750,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req); + sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -746,13 +770,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, - sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req); + sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -760,12 +786,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -774,13 +801,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -790,13 +818,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, root, comm, req); + recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -805,14 +835,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm, req); + datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -821,12 +851,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *send MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + mpi_errno = + MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -834,12 +867,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, v MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, req); + mpi_errno = + MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -847,12 +882,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -861,14 +896,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -879,13 +914,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/netmod/ofi/ofi_comm.c b/src/mpid/ch4/netmod/ofi/ofi_comm.c index 57b9cb131de..984b2412a31 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_comm.c +++ b/src/mpid/ch4/netmod/ofi/ofi_comm.c @@ -106,7 +106,8 @@ static int update_nic_preferences(MPIR_Comm * comm) /* Collect the NIC IDs set for the other ranks. We always expect to receive a single * NIC id from each rank, i.e., one MPI_INT. */ mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_INT, - pref_nic_copy, 1, MPI_INT, comm, MPIR_ERR_NONE); + pref_nic_copy, 1, MPI_INT, comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (MPIDI_OFI_COMM(comm).pref_nic == NULL) { diff --git a/src/mpid/ch4/netmod/ofi/ofi_events.c b/src/mpid/ch4/netmod/ofi/ofi_events.c index 84046946d11..b2ef6c8ebdf 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_events.c +++ b/src/mpid/ch4/netmod/ofi/ofi_events.c @@ -189,6 +189,7 @@ static int pipeline_recv_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r, chunk_req->buf = host_buf; int ret = 0; if (!MPIDI_OFI_global.gpu_recv_queue && host_buf) { + /* FIXME: error handling */ ret = fi_trecv (MPIDI_OFI_global.ctx [MPIDI_OFI_REQUEST(rreq, pipeline_info.ctx_idx)].rx, diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 674bcc45ead..88132fadaff 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -895,7 +895,8 @@ static int check_num_nics(void) /* Confirm that all processes have the same number of NICs */ mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, &num_nics, 1, MPI_INT, - MPI_MIN, MPIR_Process.comm_world, MPIR_ERR_NONE); + MPI_MIN, MPIR_Process.comm_world, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIDI_OFI_global.num_vcis = tmp_num_vcis; MPIDI_OFI_global.num_nics = tmp_num_nics; MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/netmod/ofi/ofi_recv.h b/src/mpid/ch4/netmod/ofi/ofi_recv.h index ffd66c98c35..a33018e3bcb 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_recv.h +++ b/src/mpid/ch4/netmod/ofi/ofi_recv.h @@ -271,6 +271,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf, chunk_req->parent = rreq; chunk_req->buf = host_buf; int ret = 0; + /* FIXME: handle error */ ret = fi_trecv(MPIDI_OFI_global.ctx[ctx_idx].rx, host_buf, MPIR_CVAR_CH4_OFI_GPU_PIPELINE_BUFFER_SZ, diff --git a/src/mpid/ch4/netmod/ofi/ofi_win.c b/src/mpid/ch4/netmod/ofi/ofi_win.c index 7dc0016f963..83bb0c5d054 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_win.c +++ b/src/mpid/ch4/netmod/ofi/ofi_win.c @@ -137,7 +137,8 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) * available to the processes involved in the RMA window. Use the current maximum + 1 * to ensure that the key is available for all processes. */ mpi_errno = MPIR_Allreduce(&MPIDI_OFI_global.global_max_optimized_mr_key, &local_key, 1, - MPI_UNSIGNED, MPI_MAX, comm_ptr, MPIR_ERR_NONE); + MPI_UNSIGNED, MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (local_key + 1 < MPIDI_OFI_NUM_OPTIMIZED_MEMORY_REGIONS) { @@ -220,7 +221,7 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) } /* Check if any process fails to register. If so, release local MR and force AM path. */ - MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_ERR_NONE); + MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (allrc < 0) { if (rc >= 0 && MPIDI_OFI_WIN(win).mr) MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_WIN(win).mr->fid), fi_close); @@ -244,7 +245,8 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - winfo, sizeof(*winfo), MPI_BYTE, comm_ptr, MPIR_ERR_NONE); + winfo, sizeof(*winfo), MPI_BYTE, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (!MPIDI_OFI_ENABLE_MR_PROV_KEY && !MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { @@ -969,7 +971,7 @@ int MPIDI_OFI_mpi_win_attach_hook(MPIR_Win * win, void *base, MPI_Aint size) } /* Check if any process fails to register. If so, release local MR and force AM path. */ - MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_ERR_NONE); + MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (allrc < 0) { if (rc >= 0) MPIDI_OFI_CALL(fi_close(&mr->fid), fi_close); @@ -995,7 +997,7 @@ int MPIDI_OFI_mpi_win_attach_hook(MPIR_Win * win, void *base, MPI_Aint size) mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, target_mrs, sizeof(dwin_target_mr_t), MPI_BYTE, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Insert each remote MR which will be searched when issuing an RMA operation @@ -1053,7 +1055,7 @@ int MPIDI_OFI_mpi_win_detach_hook(MPIR_Win * win, const void *base) target_bases[comm_ptr->rank] = base; mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, target_bases, sizeof(const void *), MPI_BYTE, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Search and delete each remote MR */ diff --git a/src/mpid/ch4/netmod/ucx/ucx_coll.h b/src/mpid/ch4/netmod/ucx/ucx_coll.h index 6a1d4759958..dde47058a15 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_coll.h +++ b/src/mpid/ch4/netmod/ucx/ucx_coll.h @@ -11,7 +11,8 @@ #include "../../../common/hcoll/hcoll.h" #endif -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -21,7 +22,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Err if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; @@ -29,7 +30,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Err } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; @@ -40,7 +41,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; @@ -49,7 +50,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; @@ -60,7 +61,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *r if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + errflag); } MPIR_FUNC_EXIT; @@ -70,7 +73,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *r MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -82,7 +86,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Ain #endif { mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; @@ -93,13 +97,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -108,14 +113,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -125,14 +130,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -141,14 +147,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -158,13 +164,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm_ptr, errflag); + recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -173,7 +181,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -185,7 +194,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint #endif { mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; return mpi_errno; @@ -196,7 +205,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -208,7 +218,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, #endif { mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, + coll_group, errflag); } MPIR_FUNC_EXIT; return mpi_errno; @@ -221,13 +232,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -235,7 +248,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -245,8 +259,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recv if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, - errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); } MPIR_FUNC_EXIT; return mpi_errno; @@ -255,14 +270,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -271,14 +286,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, vo MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -286,12 +301,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -299,12 +316,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbu MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -501,12 +520,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ineighbor_alltoallw(const void *sendbu return mpi_errno; } -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibarrier_impl(comm_ptr, req); + mpi_errno = MPIR_Ibarrier_impl(comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -514,12 +534,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Re MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, req); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -528,13 +549,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -545,13 +567,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, req); + recvcounts, displs, recvtype, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -559,13 +581,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + mpi_errno = + MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -574,13 +597,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -593,13 +617,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -612,13 +637,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -626,12 +653,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -640,14 +668,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -657,14 +685,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, req); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -674,13 +702,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *send MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -689,13 +717,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *send MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -703,13 +732,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, v MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + mpi_errno = + MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -717,12 +747,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -731,14 +762,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -749,13 +780,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/netmod/ucx/ucx_init.c b/src/mpid/ch4/netmod/ucx/ucx_init.c index 56701043f1d..0a727e44ce7 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_init.c +++ b/src/mpid/ch4/netmod/ucx/ucx_init.c @@ -170,7 +170,8 @@ static int all_vcis_address_exchange(void) /* Allgather */ MPIR_Comm *comm = MPIR_Process.comm_world; mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_BYTE, - all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + all_names, my_len, MPI_BYTE, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* insert the addresses */ diff --git a/src/mpid/ch4/netmod/ucx/ucx_win.c b/src/mpid/ch4/netmod/ucx/ucx_win.c index 6a4a51fa85c..837a748e219 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_win.c +++ b/src/mpid/ch4/netmod/ucx/ucx_win.c @@ -83,7 +83,8 @@ static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void rkey_sizes = (MPI_Aint *) MPL_malloc(sizeof(MPI_Aint) * comm_ptr->local_size, MPL_MEM_OTHER); rkey_sizes[comm_ptr->rank] = (MPI_Aint) rkey_size; mpi_errno = - MPIR_Allgather(MPI_IN_PLACE, 1, MPI_AINT, rkey_sizes, 1, MPI_AINT, comm_ptr, MPIR_ERR_NONE); + MPIR_Allgather(MPI_IN_PLACE, 1, MPI_AINT, rkey_sizes, 1, MPI_AINT, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -100,7 +101,7 @@ static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void /* allgather */ mpi_errno = MPIR_Allgatherv(rkey_buffer, rkey_size, MPI_BYTE, rkey_recv_buff, rkey_sizes, recv_disps, MPI_BYTE, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -141,7 +142,8 @@ static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void mpi_errno = MPIR_Allgather(MPI_IN_PLACE, sizeof(struct ucx_share), MPI_BYTE, share_data, - sizeof(struct ucx_share), MPI_BYTE, comm_ptr, MPIR_ERR_NONE); + sizeof(struct ucx_share), MPI_BYTE, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < comm_ptr->local_size; i++) { diff --git a/src/mpid/ch4/shm/ipc/src/ipc_win.c b/src/mpid/ch4/shm/ipc/src/ipc_win.c index 9e5e7769a11..42731cba68a 100644 --- a/src/mpid/ch4/shm/ipc/src/ipc_win.c +++ b/src/mpid/ch4/shm/ipc/src/ipc_win.c @@ -154,7 +154,8 @@ int MPIDI_IPC_mpi_win_create_hook(MPIR_Win * win) 0, MPI_DATATYPE_NULL, ipc_shared_table, - sizeof(win_shared_info_t), MPI_BYTE, shm_comm_ptr, MPIR_ERR_NONE); + sizeof(win_shared_info_t), MPI_BYTE, shm_comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/shm/posix/posix_coll.h b/src/mpid/ch4/shm/posix/posix_coll.h index ae1c7e97eee..2131e1824db 100644 --- a/src/mpid/ch4/shm/posix/posix_coll.h +++ b/src/mpid/ch4/shm/posix/posix_coll.h @@ -148,12 +148,14 @@ */ -MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BARRIER, .comm_ptr = comm, + .coll_group = coll_group, }; MPIDI_POSIX_csel_container_s *cnt; @@ -163,7 +165,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf case MPIR_CVAR_BARRIER_POSIX_INTRA_ALGORITHM_release_gather: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, !MPIR_IS_THREADED, mpi_errno, "Barrier release_gather cannot be applied.\n"); - mpi_errno = MPIDI_POSIX_mpi_barrier_release_gather(comm, errflag); + mpi_errno = MPIDI_POSIX_mpi_barrier_release_gather(comm, coll_group, errflag); break; case MPIR_CVAR_BARRIER_POSIX_INTRA_ALGORITHM_mpir: @@ -177,7 +179,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf switch (cnt->id) { case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_barrier_release_gather: mpi_errno = - MPIDI_POSIX_mpi_barrier_release_gather(comm, errflag); + MPIDI_POSIX_mpi_barrier_release_gather(comm, coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Barrier_impl: goto fallback; @@ -194,7 +196,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf goto fn_exit; fallback: - mpi_errno = MPIR_Barrier_impl(comm, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -206,7 +208,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { @@ -226,12 +229,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, !MPIR_IS_THREADED, mpi_errno, "Bcast release_gather cannot be applied.\n"); mpi_errno = - MPIDI_POSIX_mpi_bcast_release_gather(buffer, count, datatype, root, comm, errflag); + MPIDI_POSIX_mpi_bcast_release_gather(buffer, count, datatype, root, comm, + coll_group, errflag); break; case MPIR_CVAR_BCAST_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = - MPIDI_POSIX_mpi_bcast_gpu_ipc_read(buffer, count, datatype, root, comm, errflag); + MPIDI_POSIX_mpi_bcast_gpu_ipc_read(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIR_CVAR_BCAST_POSIX_INTRA_ALGORITHM_mpir: @@ -240,15 +245,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, case MPIR_CVAR_BCAST_POSIX_INTRA_ALGORITHM_auto: if (MPIR_CVAR_COLL_HYBRID_MEMORY) { cnt = MPIR_Csel_search(MPIDI_POSIX_COMM(comm, csel_comm), coll_sig); - } - else { + } else { /* In no hybird case, local memory type can be used to select algorithm */ MPL_pointer_attr_t pointer_attr; MPIR_GPU_query_pointer_attr(buffer, &pointer_attr); if (pointer_attr.type == MPL_GPU_POINTER_DEV) { cnt = MPIR_Csel_search(MPIDI_POSIX_COMM(comm, csel_comm_gpu), coll_sig); - } - else { + } else { cnt = MPIR_Csel_search(MPIDI_POSIX_COMM(comm, csel_comm), coll_sig); } } @@ -259,12 +262,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_bcast_release_gather: mpi_errno = MPIDI_POSIX_mpi_bcast_release_gather(buffer, count, datatype, root, comm, - errflag); + coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_bcast_ipc_read: mpi_errno = MPIDI_POSIX_mpi_bcast_gpu_ipc_read(buffer, count, datatype, root, comm, - errflag); + coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Bcast_impl: goto fallback; @@ -281,7 +284,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, goto fn_exit; fallback: - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -293,7 +296,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -317,7 +320,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void "Allreduce release_gather cannot be applied.\n"); mpi_errno = MPIDI_POSIX_mpi_allreduce_release_gather(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_ALLREDUCE_POSIX_INTRA_ALGORITHM_mpir: @@ -332,7 +335,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_allreduce_release_gather: mpi_errno = MPIDI_POSIX_mpi_allreduce_release_gather(sendbuf, recvbuf, count, datatype, - op, comm, errflag); + op, comm, coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_impl: @@ -351,7 +354,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void goto fn_exit; fallback: - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -364,7 +368,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -374,7 +379,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather(const void *sendbuf, MPI_ case MPIR_CVAR_ALLGATHER_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = MPIDI_POSIX_mpi_allgather_gpu_ipc_read(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_ALLGATHER_POSIX_INTRA_ALGORITHM_mpir: @@ -389,7 +394,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather(const void *sendbuf, MPI_ fallback: mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -404,7 +409,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -414,7 +419,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI case MPIR_CVAR_ALLGATHERV_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm, errflag); + recvtype, comm, coll_group, + errflag); break; case MPIR_CVAR_ALLGATHERV_POSIX_INTRA_ALGORITHM_mpir: @@ -429,7 +435,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI fallback: mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, errflag); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -442,7 +449,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -450,7 +457,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gather(const void *sendbuf, MPI_Ain MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -466,7 +473,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gatherv(const void *sendbuf, MPI_Ai MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -474,7 +481,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gatherv(const void *sendbuf, MPI_Ai MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); + displs, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -490,7 +497,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gatherv(const void *sendbuf, MPI_Ai MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -498,7 +505,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatter(const void *sendbuf, MPI_Ai MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -515,7 +522,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -523,7 +530,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatterv(const void *sendbuf, MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, errflag); + recvbuf, recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -538,7 +545,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatterv(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -548,7 +556,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall(const void *sendbuf, MPI_A case MPIR_CVAR_ALLTOALL_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_ALLTOALL_POSIX_INTRA_ALGORITHM_mpir: @@ -563,7 +571,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall(const void *sendbuf, MPI_A fallback: mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -580,7 +588,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; @@ -588,7 +596,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallv(const void *sendbuf, mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); + rdispls, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -607,7 +615,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallw(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; @@ -615,7 +624,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallw(const void *sendbuf, mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); + rdispls, recvtypes, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -630,7 +639,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { @@ -654,7 +663,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r "Reduce release_gather cannot be applied.\n"); mpi_errno = MPIDI_POSIX_mpi_reduce_release_gather(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_REDUCE_POSIX_INTRA_ALGORITHM_mpir: @@ -669,7 +678,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_reduce_release_gather: mpi_errno = MPIDI_POSIX_mpi_reduce_release_gather(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Reduce_impl: @@ -688,7 +697,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r goto fn_exit; fallback: - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -701,14 +711,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, errflag); + mpi_errno = + MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -723,14 +735,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, errflag); + MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -744,13 +757,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter_block(const void *se MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -764,14 +778,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scan(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -975,12 +989,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ineighbor_alltoallw(const void *sen return mpi_errno; } -MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibarrier(MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibarrier_impl(comm, req); + mpi_errno = MPIR_Ibarrier_impl(comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -988,12 +1003,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibarrier(MPIR_Comm * comm, MPIR_Req MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, req); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1002,13 +1018,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibcast(void *buffer, MPI_Aint count MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1019,13 +1036,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallgatherv(const void *sendbuf, MP const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, req); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1034,13 +1051,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallgatherv(const void *sendbuf, MP MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1053,13 +1071,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req); + sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1072,13 +1091,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoallw(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, - sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req); + sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1086,13 +1107,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1101,14 +1122,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iexscan(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1119,13 +1140,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_igatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, root, comm, req); + recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1135,14 +1158,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter_block(const void *s void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = - MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, req); + MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1151,12 +1175,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter_block(const void *s MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + mpi_errno = + MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1165,12 +1192,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter(const void *sendbuf MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, req); + mpi_errno = + MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1178,13 +1206,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1192,12 +1220,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallreduce(const void *sendbuf, voi MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1206,14 +1235,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscan(const void *sendbuf, void *re MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -1224,14 +1253,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h b/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h index ee5a0f0b672..a27bb3b2d6f 100644 --- a/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h +++ b/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h @@ -101,7 +101,7 @@ static int allgather_ipc_handles(const void *buf, MPI_Aint count, MPI_Datatype d /* allgather is needed to exchange all the IPC handles */ mpi_errno = MPIR_Allgather_impl(&my_ipc_handle, sizeof(MPIDI_IPCI_ipc_handle_t), MPI_BYTE, ipc_handles, sizeof(MPIDI_IPCI_ipc_handle_t), MPI_BYTE, - comm, MPIR_ERR_NONE); + comm, coll_group, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* check the ipc_handles to make sure all the buffers are on GPU */ @@ -131,6 +131,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_gpu_ipc_read(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -186,7 +187,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_gpu_ipc_read(void *buffer, goto fn_exit; fallback: /* Fall back to other algorithms as gpu ipc bcast cannot be used */ - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -198,6 +199,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -280,7 +282,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *s fallback: /* Fall back to other algorithms as gpu ipc alltoall cannot be used */ mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -292,6 +294,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void * MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -374,7 +377,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void * fallback: /* Fall back to other algorithms as gpu ipc allgather cannot be used */ mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -387,6 +390,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -473,7 +477,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void fallback: /* Fall back to other algorithms as gpu ipc allgatherv cannot be used */ mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm_ptr, errflag); + recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -483,9 +487,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_gpu_ipc_read(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *sendbuf, @@ -495,10 +500,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void *sendbuf, @@ -508,10 +514,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void * MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void *sendbuf, @@ -522,10 +529,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); } #endif /* !MPIDI_CH4_SHM_ENABLE_GPU */ diff --git a/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h b/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h index b1cd8b83d8b..63e7e2ed786 100644 --- a/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h +++ b/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h @@ -19,7 +19,7 @@ */ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_ibcast_release_gather(void *buffer, int count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched) { MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_POSIX_MPI_IBCAST_RELEASE_GATHER); @@ -69,6 +69,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_ireduce_release_gather(const void *send MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_TSP_sched_t sched) { MPIR_FUNC_ENTER; diff --git a/src/mpid/ch4/shm/posix/posix_coll_release_gather.h b/src/mpid/ch4/shm/posix/posix_coll_release_gather.h index 72e53cc89c4..0c8b8f7fd50 100644 --- a/src/mpid/ch4/shm/posix/posix_coll_release_gather.h +++ b/src/mpid/ch4/shm/posix/posix_coll_release_gather.h @@ -39,6 +39,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_release_gather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -152,7 +153,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_release_gather(void *buffer, goto fn_exit; fallback: /* Fall back to other algo as release_gather based bcast cannot be used */ - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -168,6 +169,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_release_gather(const void *s MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int i; @@ -251,7 +253,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_release_gather(const void *s goto fn_exit; fallback: /* Fall back to other algo as release_gather algo cannot be used */ - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -266,6 +269,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int i; @@ -346,7 +350,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void goto fn_exit; fallback: - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -355,6 +360,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void * framework. */ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier_release_gather(MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -395,7 +401,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier_release_gather(MPIR_Comm * goto fn_exit; fallback: - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } diff --git a/src/mpid/ch4/shm/posix/posix_init.c b/src/mpid/ch4/shm/posix/posix_init.c index dbdff1c4c8c..19372ff3e90 100644 --- a/src/mpid/ch4/shm/posix/posix_init.c +++ b/src/mpid/ch4/shm/posix/posix_init.c @@ -128,8 +128,7 @@ static void *create_container(struct json_object *obj) cnt->id = MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_bcast_release_gather; else if (!strcmp(ckey, "algorithm=MPIDI_POSIX_mpi_bcast_ipc_read")) - cnt->id = - MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_bcast_ipc_read; + cnt->id = MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_bcast_ipc_read; else if (!strcmp(ckey, "algorithm=MPIDI_POSIX_mpi_barrier_release_gather")) cnt->id = MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_barrier_release_gather; @@ -300,7 +299,8 @@ int MPIDI_POSIX_post_init(void) memset(local_rank_topo, 0, MPIDI_POSIX_global.num_local * topo_info_size); mpi_errno = MPIR_Allgather_fallback(&MPIDI_POSIX_global.topo, topo_info_size, MPI_BYTE, local_rank_topo, topo_info_size, MPI_BYTE, - MPIR_Process.comm_world->node_comm, MPIR_ERR_NONE); + MPIR_Process.comm_world->node_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); for (int i = 0; i < MPIDI_POSIX_global.num_local; i++) { if (local_rank_topo[i].l3_cache_id == -1 || local_rank_topo[i].numa_id == -1) { @@ -397,13 +397,15 @@ int MPIDI_POSIX_coll_init(int rank, int size) } MPIR_ERR_CHECK(mpi_errno); - /* Initialize collective selection for gpu*/ + /* Initialize collective selection for gpu */ if (!strcmp(MPIR_CVAR_CH4_POSIX_COLL_SELECTION_TUNING_JSON_FILE_GPU, "")) { mpi_errno = MPIR_Csel_create_from_buf(MPIDI_POSIX_coll_generic_json, - create_container, &MPIDI_global.shm.posix.csel_root_gpu); + create_container, + &MPIDI_global.shm.posix.csel_root_gpu); } else { - mpi_errno = MPIR_Csel_create_from_file(MPIR_CVAR_CH4_POSIX_COLL_SELECTION_TUNING_JSON_FILE_GPU, - create_container, &MPIDI_global.shm.posix.csel_root_gpu); + mpi_errno = + MPIR_Csel_create_from_file(MPIR_CVAR_CH4_POSIX_COLL_SELECTION_TUNING_JSON_FILE_GPU, + create_container, &MPIDI_global.shm.posix.csel_root_gpu); } MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h b/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h index 472abbd50ac..16b245a0c6f 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h @@ -94,12 +94,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_NB_RG_root_datacopy_completion(void *v, /* Root sends data to rank 0 */ if (rank == root) { MPIC_Isend(per_call_data->local_buf, per_call_data->count, per_call_data->datatype, - 0, per_call_data->tag, comm_ptr, &(per_call_data->sreq), MPIR_ERR_NONE); + 0, per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, + &(per_call_data->sreq), MPIR_ERR_NONE); *done = 1; } else if (rank == 0) { MPIC_Irecv(MPIDI_POSIX_RELEASE_GATHER_NB_IBCAST_DATA_ADDR(segment), per_call_data->count, per_call_data->datatype, per_call_data->root, - per_call_data->tag, comm_ptr, &(per_call_data->rreq)); + per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, + &(per_call_data->rreq)); *done = 1; } } else { @@ -353,6 +355,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ibcast_impl(void *loc MPI_Aint type_size, nbytes, true_lb, true_extent; void *ori_local_buf = local_buf; MPI_Datatype ori_datatype = datatype; + int coll_group = MPIR_SUBGROUP_NONE; MPIR_CHKLMEM_DECL(1); /* Register the vertices */ @@ -423,7 +426,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ibcast_impl(void *loc MPIR_TSP_sched_malloc(sizeof(MPIDI_POSIX_per_call_ibcast_info_t), sched); MPIR_ERR_CHKANDJUMP(!data, mpi_errno, MPI_ERR_OTHER, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h b/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h index 19f85901e92..cd6b2f239b2 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h @@ -248,11 +248,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_NB_RG_reduce_start_sendrecv_completion( if (root != 0) { if (rank == root) { MPIC_Irecv(per_call_data->recv_buf, per_call_data->count, per_call_data->datatype, - 0, per_call_data->tag, comm_ptr, &(per_call_data->rreq)); + 0, per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, &(per_call_data->rreq)); } else if (rank == 0) { MPIC_Isend(MPIDI_POSIX_RELEASE_GATHER_NB_REDUCE_DATA_ADDR(rank, segment), per_call_data->count, per_call_data->datatype, per_call_data->root, - per_call_data->tag, comm_ptr, &(per_call_data->sreq), MPIR_ERR_NONE); + per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, &(per_call_data->sreq), + MPIR_ERR_NONE); } } @@ -363,6 +364,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ireduce_impl(void *se MPI_Aint num_chunks, chunk_count_floor, chunk_count_ceil; MPI_Aint true_extent, type_size, lb, extent; int offset = 0, is_contig; + int coll_group = MPIR_SUBGROUP_NONE; /* Register the vertices */ reserve_buf_type_id = MPIR_TSP_sched_new_type(sched, MPIDI_POSIX_NB_RG_rank0_hold_buf_issue, @@ -417,7 +419,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ireduce_impl(void *se data->seq_no = MPIDI_POSIX_COMM(comm_ptr, nb_reduce_seq_no); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c b/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c index c8f457a1b30..00874742d63 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c +++ b/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c @@ -121,17 +121,19 @@ int MPIDI_POSIX_nb_release_gather_comm_init(MPIR_Comm * comm_ptr, to other algorithms.\n"); } fallback = 1; - MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); } else { /* More shm can be created, update the shared counter */ MPL_atomic_fetch_add_uint64(MPIDI_POSIX_shm_limit_counter, memory_to_be_allocated); fallback = 0; - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (fallback) { MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); @@ -168,7 +170,7 @@ int MPIDI_POSIX_nb_release_gather_comm_init(MPIR_Comm * comm_ptr, topotree_fail[1] = -1; } mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, topotree_fail, 2, MPI_INT, - MPI_MAX, comm_ptr, errflag); + MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, errflag); } else { topotree_fail[0] = -1; topotree_fail[1] = -1; @@ -266,7 +268,7 @@ int MPIDI_POSIX_nb_release_gather_comm_init(MPIR_Comm * comm_ptr, if (initialize_ibcast_buf || initialize_ireduce_buf) { /* Make sure all the flags are set before ranks start reading each other's flags from shm */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch4/shm/posix/release_gather/release_gather.c b/src/mpid/ch4/shm/posix/release_gather/release_gather.c index 8bcc1301332..6bb1507c886 100644 --- a/src/mpid/ch4/shm/posix/release_gather/release_gather.c +++ b/src/mpid/ch4/shm/posix/release_gather/release_gather.c @@ -303,17 +303,19 @@ int MPIDI_POSIX_mpi_release_gather_comm_init(MPIR_Comm * comm_ptr, } fallback = 1; - MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); } else { /* More shm can be created, update the shared counter */ MPL_atomic_fetch_add_uint64(MPIDI_POSIX_shm_limit_counter, memory_to_be_allocated); fallback = 0; - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (fallback) { MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); @@ -359,7 +361,7 @@ int MPIDI_POSIX_mpi_release_gather_comm_init(MPIR_Comm * comm_ptr, topotree_fail[1] = -1; } mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, topotree_fail, 2, MPI_INT, - MPI_MAX, comm_ptr, errflag); + MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } else { topotree_fail[0] = -1; @@ -423,7 +425,7 @@ int MPIDI_POSIX_mpi_release_gather_comm_init(MPIR_Comm * comm_ptr, release_gather_info_ptr->release_state); /* Make sure all the flags are set before ranks start reading each other's flags from shm */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch4/shm/posix/release_gather/release_gather.h b/src/mpid/ch4/shm/posix/release_gather/release_gather.h index a6dd2a07541..cb427a862dc 100644 --- a/src/mpid/ch4/shm/posix/release_gather/release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/release_gather.h @@ -103,8 +103,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_release(void *local_ if (root != 0) { /* Root sends data to rank 0 */ if (rank == root) { - mpi_errno = - MPIC_Send(local_buf, count, datatype, 0, MPIR_BCAST_TAG, comm_ptr, errflag); + mpi_errno = MPIC_Send(local_buf, count, datatype, 0, MPIR_BCAST_TAG, + comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == 0) { #ifdef HAVE_ERROR_CHECKING @@ -116,8 +116,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_release(void *local_ MPI_Aint recv_bytes; mpi_errno = MPIC_Recv((char *) bcast_data_addr + 2 * MPIDU_SHM_CACHE_LINE_LEN, count, - datatype, root, MPIR_BCAST_TAG, comm_ptr, &status); - MPIR_ERR_CHECK(mpi_errno); + datatype, root, MPIR_BCAST_TAG, comm_ptr, MPIR_SUBGROUP_NONE, + &status); MPIR_Get_count_impl(&status, MPI_BYTE, &recv_bytes); MPIR_Typerep_copy(bcast_data_addr, &recv_bytes, sizeof(int), MPIR_TYPEREP_FLAG_NONE); @@ -135,7 +135,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_release(void *local_ /* When error checking is disabled, MPI_STATUS_IGNORE is used */ mpi_errno = MPIC_Recv(bcast_data_addr, count, datatype, root, MPIR_BCAST_TAG, comm_ptr, - MPI_STATUS_IGNORE); + MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); #endif } @@ -371,13 +371,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_gather(const void *i if (rank == root) { mpi_errno = MPIC_Recv(outbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, - MPI_STATUS_IGNORE); + MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } else if (rank == 0) { MPIR_ERR_CHKANDJUMP(!reduce_data_addr, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = MPIC_Send((void *) reduce_data_addr, count, datatype, root, MPIR_REDUCE_TAG, - comm_ptr, errflag); + comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpid/ch4/shm/src/shm_am_fallback_coll.h b/src/mpid/ch4/shm/src/shm_am_fallback_coll.h index b9e4d53ef40..58c9de5524c 100644 --- a/src/mpid/ch4/shm/src/shm_am_fallback_coll.h +++ b/src/mpid/ch4/shm/src/shm_am_fallback_coll.h @@ -6,33 +6,37 @@ #ifndef SHM_AM_FALLBACK_COLL_H_INCLUDED #define SHM_AM_FALLBACK_COLL_H_INCLUDED -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Barrier_impl(comm_ptr, errflag); + return MPIR_Barrier_impl(comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_Aint sendcount, @@ -40,41 +44,41 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, @@ -82,19 +86,21 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm_ptr, errflag); + recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallv(const void *sendbuf, @@ -104,10 +110,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, @@ -117,51 +124,58 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag); + return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_neighbor_allgather(const void *sendbuf, @@ -295,25 +309,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ineighbor_alltoallw(const void *sendb rdispls, recvtypes, comm_ptr, req); } -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibarrier_impl(comm_ptr, req); + return MPIR_Ibarrier_impl(comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, req); + return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_Aint sendcount, @@ -321,27 +338,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_ const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, req); + recvcounts, displs, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { - return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallv(const void *sendbuf, @@ -351,10 +369,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, @@ -364,82 +382,88 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, req); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Request ** req) { return MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { - return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, req); + return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { - return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { return MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatterv(const void *sendbuf, @@ -447,10 +471,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { return MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); } #endif /* SHM_AM_FALLBACK_COLL_H_INCLUDED */ diff --git a/src/mpid/ch4/shm/src/shm_coll.h b/src/mpid/ch4/shm/src/shm_coll.h index 737cc921de2..dbde9186e24 100644 --- a/src/mpid/ch4/shm/src/shm_coll.h +++ b/src/mpid/ch4/shm/src/shm_coll.h @@ -9,13 +9,14 @@ #include #include "../posix/shm_inline.h" -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_barrier(comm, errflag); + ret = MPIDI_POSIX_mpi_barrier(comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -23,13 +24,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm, MPIR_Errfla MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_bcast(buffer, count, datatype, root, comm, errflag); + ret = MPIDI_POSIX_mpi_bcast(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -37,14 +38,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_bcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + ret = + MPIDI_POSIX_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -53,14 +55,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allreduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -71,14 +74,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm, errflag); + displs, recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -87,7 +90,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -95,7 +98,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatter(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -106,14 +109,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -122,7 +126,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -130,7 +134,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gather(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -140,7 +144,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -148,7 +152,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); + displs, recvtype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -157,14 +161,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -177,14 +182,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm, errflag); + recvcounts, rdispls, recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -197,14 +202,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_alltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm, errflag); + recvcounts, rdispls, recvtypes, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -212,14 +218,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + ret = + MPIDI_POSIX_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return ret; @@ -228,7 +236,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -236,7 +244,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter(const void *sendbuf, v MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_reduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -246,6 +254,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *send void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -253,7 +262,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *send MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_reduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, - op, comm_ptr, errflag); + op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -261,13 +270,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *send MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm, errflag); + ret = MPIDI_POSIX_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -275,13 +285,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scan(const void *sendbuf, void *recvb MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm, errflag); + ret = MPIDI_POSIX_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -486,13 +497,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ineighbor_alltoallw(const void *sendb return ret; } -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ibarrier(comm, req); + ret = MPIDI_POSIX_mpi_ibarrier(comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -500,13 +512,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Reque MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ibcast(buffer, count, datatype, root, comm, req); + ret = MPIDI_POSIX_mpi_ibcast(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -515,14 +527,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, req); + recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -533,14 +546,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_ const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm, req); + displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -548,14 +561,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_POSIX_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -564,14 +577,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallreduce(const void *sendbuf, void MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ialltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, req); + recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -584,14 +598,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ialltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm, req); + recvcounts, rdispls, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -604,14 +618,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm, req); + recvcounts, rdispls, recvtypes, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -619,13 +634,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_POSIX_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -634,14 +650,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iexscan(const void *sendbuf, void *re MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_igather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, req); + recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -651,14 +668,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igatherv(const void *sendbuf, MPI_Ain MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, req); + displs, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -667,7 +685,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igatherv(const void *sendbuf, MPI_Ain MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; @@ -675,7 +693,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sen MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, - op, comm, req); + op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -684,13 +702,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sen MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + ret = + MPIDI_POSIX_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, + coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -698,14 +719,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + ret = + MPIDI_POSIX_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + req); MPIR_FUNC_EXIT; return ret; @@ -713,13 +736,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce(const void *sendbuf, void *re MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_POSIX_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -728,14 +752,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscan(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, req); + recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -746,14 +771,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return ret; diff --git a/src/mpid/ch4/shm/src/topotree.c b/src/mpid/ch4/shm/src/topotree.c index 2f71547962c..5ccc6fd62be 100644 --- a/src/mpid/ch4/shm/src/topotree.c +++ b/src/mpid/ch4/shm/src/topotree.c @@ -500,7 +500,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in shared_region_ptr[rank][depth++] = MPIR_hwtopo_get_lid(gid); gid = MPIR_hwtopo_get_ancestor(gid, topo_depth - depth - 1); } - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* STEP 3. Root has all the bind_map information, now build tree */ @@ -558,7 +558,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in 0 /*left_skewed */ , bcast_tree_type); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* Every rank copies their tree out from shared memory */ @@ -567,7 +567,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in MPIDI_SHM_print_topotree_file("BCAST", comm_ptr->context_id, rank, bcast_tree); /* Wait until shared memory is available */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* Generate the reduce tree */ /* For Reduce, package leaders are added after the package local ranks, and the per_package @@ -581,7 +581,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* each rank copy the reduce tree out */ @@ -590,7 +590,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in if (MPIDI_SHM_TOPOTREE_DEBUG) MPIDI_SHM_print_topotree_file("REDUCE", comm_ptr->context_id, rank, reduce_tree); /* Wait for all ranks to copy out the tree */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* Cleanup */ if (rank == root) { diff --git a/src/mpid/ch4/src/ch4_coll.h b/src/mpid/ch4/src/ch4_coll.h index 8251157856c..1979b53bd59 100644 --- a/src/mpid/ch4/src/ch4_coll.h +++ b/src/mpid/ch4/src/ch4_coll.h @@ -100,6 +100,7 @@ */ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -108,28 +109,31 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BARRIER, .comm_ptr = comm, + .coll_group = coll_group, }; cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Barrier_impl(comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Barrier_intra_composition_alpha: - mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Barrier_intra_composition_beta: - mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, coll_group, errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -137,30 +141,32 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * goto fn_exit; } -MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } switch (MPIR_CVAR_BARRIER_COMPOSITION) { case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Barrier composition alpha cannot be applied.\n"); - mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM, mpi_errno, "Barrier composition beta cannot be applied.\n"); - mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, coll_group, errflag); break; default: - mpi_errno = MPIDI_Barrier_allcomm_composition_json(comm, errflag); + mpi_errno = MPIDI_Barrier_allcomm_composition_json(comm, coll_group, errflag); break; } @@ -168,10 +174,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errfl goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Barrier_impl(comm, errflag); - else - mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -182,13 +185,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errfl MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BCAST, .comm_ptr = comm, + .coll_group = coll_group, .u.bcast.buffer = buffer, .u.bcast.count = count, .u.bcast.datatype = datatype, @@ -199,46 +203,50 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, if (MPIR_CVAR_COLL_HYBRID_MEMORY) { cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); - } - else { + } else { /* In no hybird case, local memory type can be used to select algorithm */ MPL_pointer_attr_t pointer_attr; MPIR_GPU_query_pointer_attr(buffer, &pointer_attr); if (pointer_attr.type == MPL_GPU_POINTER_DEV) { cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm_gpu), coll_sig); - } - else { + } else { cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); } } if (cnt == NULL) { - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_alpha: mpi_errno = - MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_beta: mpi_errno = - MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_gamma: mpi_errno = - MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_delta: mpi_errno = - MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -247,50 +255,57 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, } MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + switch (MPIR_CVAR_BCAST_COMPOSITION) { case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Bcast composition alpha cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, coll_group, + errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Bcast composition beta cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, coll_group, + errflag); break; case 3: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM, mpi_errno, "Bcast composition gamma cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, coll_group, + errflag); break; case 4: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Bcast composition delta cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, coll_group, + errflag); break; default: mpi_errno = - MPIDI_Bcast_allcomm_composition_json(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_allcomm_composition_json(buffer, count, datatype, root, comm, + coll_group, errflag); break; } @@ -298,11 +313,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); - else - mpi_errno = - MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -311,7 +322,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty goto fn_exit; } -MPL_STATIC_INLINE_PREFIX void MPIDI_Allreduce_fill_multi_leads_info(MPIR_Comm * comm) +MPL_STATIC_INLINE_PREFIX void MPIDI_Allreduce_fill_multi_leads_info(MPIR_Comm * comm, + int coll_group) { int node_comm_size = 0, num_nodes; bool node_balanced = false; @@ -346,6 +358,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -355,6 +368,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLREDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.allreduce.sendbuf = sendbuf, .u.allreduce.recvbuf = recvbuf, @@ -366,30 +380,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_alpha: mpi_errno = MPIDI_Allreduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_beta: mpi_errno = MPIDI_Allreduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_gamma: mpi_errno = MPIDI_Allreduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_delta: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) { - MPIDI_Allreduce_fill_multi_leads_info(comm); + MPIDI_Allreduce_fill_multi_leads_info(comm, coll_group); if (comm->node_comm) node_comm_size = MPIR_Comm_size(comm->node_comm); /* Reset number of leaders, so that (node_comm_size % num_leads) is zero. The new number of @@ -404,16 +416,22 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void count >= num_leads && MPIR_Op_is_commutative(op)) { mpi_errno = MPIDI_Allreduce_intra_composition_delta(sendbuf, recvbuf, count, datatype, op, - num_leads, comm, errflag); + num_leads, comm, coll_group, errflag); } else mpi_errno = - MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, + comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -423,7 +441,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int is_commutative = -1; @@ -431,19 +449,22 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + is_commutative = MPIR_Op_is_commutative(op); switch (MPIR_CVAR_ALLREDUCE_COMPOSITION) { case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT) && + MPIR_Comm_is_parent_comm(comm, coll_group) && is_commutative, mpi_errno, "Allreduce composition alpha cannot be applied.\n"); mpi_errno = MPIDI_Allreduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -451,7 +472,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, "Allreduce composition beta cannot be applied.\n"); mpi_errno = MPIDI_Allreduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); break; case 3: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -462,11 +483,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, "Allreduce composition gamma cannot be applied.\n"); mpi_errno = MPIDI_Allreduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); break; case 4: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) { - MPIDI_Allreduce_fill_multi_leads_info(comm); + MPIDI_Allreduce_fill_multi_leads_info(comm, coll_group); if (comm->node_comm) node_comm_size = MPIR_Comm_size(comm->node_comm); /* Reset number of leaders, so that (node_comm_size % num_leads) is zero. The new number of @@ -487,13 +508,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, mpi_errno = MPIDI_Allreduce_intra_composition_delta(sendbuf, recvbuf, count, datatype, op, - num_leads, comm, errflag); + num_leads, comm, coll_group, errflag); break; default: mpi_errno = MPIDI_Allreduce_allcomm_composition_json(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; } @@ -502,11 +523,12 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, fallback: if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); else mpi_errno = MPIDI_Allreduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -515,7 +537,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, goto fn_exit; } -MPL_STATIC_INLINE_PREFIX void MPIDI_Allgather_fill_multi_leads_info(MPIR_Comm * comm) +MPL_STATIC_INLINE_PREFIX void MPIDI_Allgather_fill_multi_leads_info(MPIR_Comm * comm, + int coll_group) { int node_comm_size = 0, num_nodes; bool node_balanced = false; @@ -550,6 +573,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -558,15 +582,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void if (sendbuf != MPI_IN_PLACE) { MPIR_Datatype_get_size_macro(sendtype, type_size); - data_size = sendcount * type_size; + data_size = sendcount * type_size; } else { MPIR_Datatype_get_size_macro(recvtype, type_size); - data_size = recvcount * type_size; + data_size = recvcount * type_size; } MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLGATHER, .comm_ptr = comm, + .coll_group = coll_group, .u.allgather.sendbuf = sendbuf, .u.allgather.sendcount = sendcount, @@ -579,26 +604,22 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allgather_intra_composition_alpha: /* make sure that the algo can be run */ if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Allgather_fill_multi_leads_info(comm); + MPIDI_Allgather_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, MPIDI_COMM_ALLGATHER(comm, use_multi_leads) == 1 && data_size <= MPIR_CVAR_ALLGATHER_SHM_PER_RANK, mpi_errno, "Allgather composition alpha cannot be applied.\n"); - mpi_errno = - MPIDI_Allgather_intra_composition_alpha(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, - comm, errflag); + mpi_errno = + MPIDI_Allgather_intra_composition_alpha(sendbuf, sendcount, sendtype, + recvbuf, recvcount, recvtype, + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allgather_intra_composition_beta: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -606,7 +627,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void "Allgather composition beta cannot be applied.\n"); mpi_errno = MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, + coll_group, errflag); break; default: MPIR_Assert(0); @@ -616,14 +638,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); - else - mpi_errno = - MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -634,25 +650,30 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Aint type_size, data_size; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + if (sendbuf != MPI_IN_PLACE) { MPIR_Datatype_get_size_macro(sendtype, type_size); - data_size = sendcount * type_size; + data_size = sendcount * type_size; } else { MPIR_Datatype_get_size_macro(recvtype, type_size); - data_size = recvcount * type_size; + data_size = recvcount * type_size; } switch (MPIR_CVAR_ALLGATHER_COMPOSITION) { case 1: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Allgather_fill_multi_leads_info(comm); + MPIDI_Allgather_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, MPIDI_COMM_ALLGATHER(comm, use_multi_leads) == 1 && @@ -666,7 +687,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco mpi_errno = MPIDI_Allgather_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm, errflag); + comm, coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -674,12 +695,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco "Allgather composition beta cannot be applied.\n"); mpi_errno = MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, + errflag); break; default: mpi_errno = MPIDI_Allgather_allcomm_composition_json(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); break; } @@ -687,14 +709,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); - else - mpi_errno = - MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_ENTER; @@ -707,14 +723,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLGATHERV, .comm_ptr = comm, + .coll_group = coll_group, .u.allgatherv.sendbuf = sendbuf, .u.allgatherv.sendcount = sendcount, @@ -730,11 +751,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -742,14 +759,18 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc mpi_errno = MPIDI_Allgatherv_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + fallback: + mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, + displs, recvtype, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; return mpi_errno; @@ -760,14 +781,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCATTER, .comm_ptr = comm, + .coll_group = coll_group, .u.scatter.sendbuf = sendbuf, .u.scatter.sendcount = sendcount, @@ -783,23 +809,26 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, - errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Scatter_intra_composition_alpha: mpi_errno = MPIDI_Scatter_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -811,14 +840,20 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCATTERV, .comm_ptr = comm, + .coll_group = coll_group, .u.scatterv.sendbuf = sendbuf, .u.scatterv.sendcounts = sendcounts, @@ -835,10 +870,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, - root, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -846,13 +878,18 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * mpi_errno = MPIDI_Scatterv_intra_composition_alpha(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, - comm, errflag); + comm, coll_group, errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, + recvtype, root, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -864,14 +901,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__GATHER, .comm_ptr = comm, + .coll_group = coll_group, .u.gather.sendbuf = sendbuf, .u.gather.sendcount = sendcount, @@ -887,24 +929,26 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, - errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Gather_intra_composition_alpha: mpi_errno = MPIDI_Gather_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -917,14 +961,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__GATHERV, .comm_ptr = comm, + .coll_group = coll_group, .u.gatherv.sendbuf = sendbuf, .u.gatherv.sendcount = sendcount, @@ -941,10 +990,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -952,13 +998,18 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun mpi_errno = MPIDI_Gatherv_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, - comm, errflag); + comm, coll_group, errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, + displs, recvtype, root, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -967,7 +1018,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun goto fn_exit; } -MPL_STATIC_INLINE_PREFIX void MPIDI_Alltoall_fill_multi_leads_info(MPIR_Comm * comm) +MPL_STATIC_INLINE_PREFIX void MPIDI_Alltoall_fill_multi_leads_info(MPIR_Comm * comm, int coll_group) { int node_comm_size = 0, num_nodes; bool node_balanced = false; @@ -1002,6 +1053,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1009,10 +1061,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void if (sendbuf != MPI_IN_PLACE) { MPIR_Datatype_get_size_macro(sendtype, type_size); - data_size = sendcount * type_size; + data_size = sendcount * type_size; } else { MPIR_Datatype_get_size_macro(recvtype, type_size); - data_size = recvcount * type_size; + data_size = recvcount * type_size; } const MPIDI_Csel_container_s *cnt = NULL; @@ -1020,6 +1072,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALL, .comm_ptr = comm, + .coll_group = coll_group, .u.alltoall.sendbuf = sendbuf, .u.alltoall.sendcount = sendcount, @@ -1032,24 +1085,21 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Alltoall_intra_composition_alpha: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Alltoall_fill_multi_leads_info(comm); + MPIDI_Alltoall_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && MPIDI_COMM_ALLTOALL(comm, use_multi_leads) == 1 && data_size <= MPIR_CVAR_ALLTOALL_SHM_PER_RANK, mpi_errno, "Alltoall composition alpha cannot be applied.\n"); - mpi_errno = - MPIDI_Alltoall_intra_composition_alpha(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, - comm, errflag); + mpi_errno = + MPIDI_Alltoall_intra_composition_alpha(sendbuf, sendcount, sendtype, + recvbuf, recvcount, recvtype, + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Alltoall_intra_composition_beta: @@ -1058,7 +1108,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void "Alltoall composition beta cannot be applied.\n"); mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, + coll_group, errflag); break; default: @@ -1069,13 +1120,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); - else - mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -1085,7 +1131,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1093,18 +1139,22 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + if (sendbuf != MPI_IN_PLACE) { MPIR_Datatype_get_size_macro(sendtype, type_size); - data_size = sendcount * type_size; + data_size = sendcount * type_size; } else { MPIR_Datatype_get_size_macro(recvtype, type_size); - data_size = recvcount * type_size; + data_size = recvcount * type_size; } switch (MPIR_CVAR_ALLTOALL_COMPOSITION) { case 1: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Alltoall_fill_multi_leads_info(comm); + MPIDI_Alltoall_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && MPIDI_COMM_ALLTOALL(comm, use_multi_leads) == 1 && @@ -1117,7 +1167,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou mpi_errno = MPIDI_Alltoall_intra_composition_alpha(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, + coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -1125,12 +1176,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou "Alltoall composition beta cannot be applied.\n"); mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, + errflag); break; default: mpi_errno = MPIDI_Alltoall_allcomm_composition_json(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); break; } @@ -1138,13 +1190,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); - else - mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_ENTER; @@ -1157,14 +1204,20 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALLV, .comm_ptr = comm, + .coll_group = coll_group, .u.alltoallv.sendbuf = sendbuf, .u.alltoallv.sendcounts = sendcounts, @@ -1181,11 +1234,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1193,13 +1242,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint mpi_errno = MPIDI_Alltoallv_intra_composition_alpha(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); + rdispls, recvtype, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, + rdispls, recvtype, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1213,14 +1268,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALLW, .comm_ptr = comm, + .coll_group = coll_group, .u.alltoallw.sendbuf = sendbuf, .u.alltoallw.sendcounts = sendcounts, @@ -1237,11 +1297,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, - sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1249,13 +1305,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint mpi_errno = MPIDI_Alltoallw_intra_composition_alpha(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); + rdispls, recvtypes, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, + rdispls, recvtypes, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1268,6 +1330,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1276,6 +1339,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.reduce.sendbuf = sendbuf, .u.reduce.recvbuf = recvbuf, @@ -1288,7 +1352,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1297,17 +1363,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_intra_composition_alpha: mpi_errno = MPIDI_Reduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_intra_composition_beta: mpi_errno = MPIDI_Reduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_intra_composition_gamma: mpi_errno = MPIDI_Reduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -1323,32 +1389,35 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + switch (MPIR_CVAR_REDUCE_COMPOSITION) { case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM - && comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT && + && MPIR_Comm_is_parent_comm(comm, coll_group) && MPIR_Op_is_commutative(op), mpi_errno, "Reduce composition alpha cannot be applied.\n"); mpi_errno = MPIDI_Reduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM - && comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT && + && MPIR_Comm_is_parent_comm(comm, coll_group) && MPIR_Op_is_commutative(op), mpi_errno, "Reduce composition beta cannot be applied.\n"); mpi_errno = MPIDI_Reduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; case 3: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -1356,12 +1425,12 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, "Reduce composition gamma cannot be applied.\n"); mpi_errno = MPIDI_Reduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; default: mpi_errno = MPIDI_Reduce_allcomm_composition_json(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; } @@ -1370,11 +1439,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, fallback: if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, + errflag); else mpi_errno = MPIDI_Reduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, root, comm, - errflag); + coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1385,15 +1456,20 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE_SCATTER, .comm_ptr = comm, + .coll_group = coll_group, .u.reduce_scatter.sendbuf = sendbuf, .u.reduce_scatter.recvbuf = recvbuf, @@ -1407,23 +1483,26 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_scatter_intra_composition_alpha: mpi_errno = MPIDI_Reduce_scatter_intra_composition_alpha(sendbuf, recvbuf, recvcounts, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1434,15 +1513,20 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE_SCATTER_BLOCK, .comm_ptr = comm, + .coll_group = coll_group, .u.reduce_scatter_block.sendbuf = sendbuf, .u.reduce_scatter_block.recvbuf = recvbuf, @@ -1456,24 +1540,26 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, - errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_scatter_block_intra_composition_alpha: mpi_errno = MPIDI_Reduce_scatter_block_intra_composition_alpha(sendbuf, recvbuf, recvcount, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, + errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1484,14 +1570,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCAN, .comm_ptr = comm, + .coll_group = coll_group, .u.scan.sendbuf = sendbuf, .u.scan.recvbuf = recvbuf, @@ -1505,27 +1596,29 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Scan_intra_composition_alpha: mpi_errno = MPIDI_Scan_intra_composition_alpha(sendbuf, recvbuf, count, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Scan_intra_composition_beta: mpi_errno = MPIDI_Scan_intra_composition_beta(sendbuf, recvbuf, count, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1536,14 +1629,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__EXSCAN, .comm_ptr = comm, + .coll_group = coll_group, .u.exscan.sendbuf = sendbuf, .u.exscan.recvbuf = recvbuf, @@ -1557,22 +1655,24 @@ MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag);; - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Exscan_intra_composition_alpha: mpi_errno = MPIDI_Exscan_intra_composition_alpha(sendbuf, recvbuf, count, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, errflag); break; default: MPIR_Assert(0); } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1761,109 +1861,163 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ineighbor_alltoallw(const void *sendbuf, return ret; } -MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ibarrier(comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = MPIDI_NM_mpi_ibarrier(comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ibarrier_impl(comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ibcast(buffer, count, datatype, root, comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = MPIDI_NM_mpi_ibcast(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iallgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iallgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iallgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm, req); + recvcounts, displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcounts, displs, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = MPIDI_NM_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ialltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ialltoallv(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm, req); + recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, + recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint * sendcounts, @@ -1871,155 +2025,237 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint const MPI_Datatype * sendtypes, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype * recvtypes, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm, req); + recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, + recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = MPIDI_NM_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_igather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Igatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_igatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, req); + recvcounts, displs, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Igatherv(sendbuf, sendcount, sendtype, recvbuf, + recvcounts, displs, recvtype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = + MPIDI_NM_mpi_ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm, + coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, + comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = + MPIDI_NM_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, + comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = MPIDI_NM_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, req); + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + + ret = MPIDI_NM_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iscatter(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iscatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iscatterv(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, req); + recvbuf, recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, + recvbuf, recvcount, recvtype, root, comm, coll_group, req); } #endif /* CH4_COLL_H_INCLUDED */ diff --git a/src/mpid/ch4/src/ch4_coll_impl.h b/src/mpid/ch4/src/ch4_coll_impl.h index ca626076491..5b9e6bf08b5 100644 --- a/src/mpid/ch4/src/ch4_coll_impl.h +++ b/src/mpid/ch4/src/ch4_coll_impl.h @@ -162,7 +162,7 @@ static void MPIDI_Coll_calculate_size_shift(MPI_Aint count, MPI_Datatype datatyp } } -MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * comm, +MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -170,17 +170,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * c /* do the intranode barrier on all nodes */ if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } /* do the barrier across roots of all nodes */ if (comm->node_roots_comm != NULL) { - mpi_errno = MPIDI_NM_mpi_barrier(comm->node_roots_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -190,10 +190,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * c if (comm->node_comm != NULL) { int i = 0; #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -204,12 +204,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * c goto fn_exit; } -MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_beta(MPIR_Comm * comm, +MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_beta(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_barrier(comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -221,6 +221,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_beta(MPIR_Comm * co MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -240,7 +241,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M /* root sends message to local leader (node_comm rank 0) */ if (comm->rank == root) { mpi_errno = MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* local leader receives message from root */ @@ -248,12 +249,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M #ifndef HAVE_ERROR_CHECKING mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - &status); + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Datatype_get_size_macro(datatype, type_size); @@ -281,17 +282,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M } if (comm->node_roots_comm != NULL) { - mpi_errno = - MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -315,6 +317,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_beta(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -340,27 +343,29 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_beta(void *buffer, MP #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, MPIR_Get_intranode_rank(comm, root), - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_intranode_rank(comm, root), - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } if (comm->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (comm->node_comm != NULL && MPIR_Get_intranode_rank(comm, root) <= 0) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -384,6 +389,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_beta(void *buffer, MP MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_gamma(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -405,7 +411,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_gamma(void *buffer, M } } - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (host_buffer != NULL && comm->rank != root) { @@ -435,6 +441,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_gamma(void *buffer, M MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -454,7 +461,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M /* root sends message to local leader (node_comm rank 0) */ if (comm->rank == root) { mpi_errno = MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* local leader receives message from root */ @@ -462,12 +469,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M #ifndef HAVE_ERROR_CHECKING mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - &status); + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Datatype_get_size_macro(datatype, type_size); @@ -499,7 +506,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M if (comm->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* Node leaders copy data to GPU */ @@ -517,10 +524,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M /* intra-node Bcast */ if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -536,6 +545,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_alpha(const void MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -566,24 +576,24 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_alpha(const void #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(recvbuf, NULL, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIDI_NM_mpi_reduce(recvbuf, NULL, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } else { #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -597,16 +607,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_alpha(const void if (comm->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_allreduce(MPI_IN_PLACE, recvbuf, count, datatype, op, - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif } @@ -632,6 +644,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_beta(const void * MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -657,7 +670,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_beta(const void * recvbuf = host_recvbuf; } - mpi_errno = MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (host_recvbuf != NULL) { @@ -681,6 +695,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_gamma(const void MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -706,9 +721,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_gamma(const void recvbuf = host_recvbuf; } #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIDI_SHM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -741,6 +758,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void MPI_Op op, int num_leads, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -817,9 +835,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void /* Step 0: Barrier to make sure the shm_buffer can be reused after the previous call */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -830,12 +848,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void MPIDI_SHM_mpi_reduce((char *) sendbuf + offset * extent, (char *) shm_addr + my_leader_rank * shm_size_per_lead, chunk_count, datatype, op, 0, MPIDI_COMM(comm_ptr, sub_node_comm), - errflag); + MPIR_SUBGROUP_NONE, errflag); #else mpi_errno = MPIDI_NM_mpi_reduce((char *) sendbuf + offset * extent, (char *) shm_addr + my_leader_rank * shm_size_per_lead, chunk_count, - datatype, op, 0, MPIDI_COMM(comm_ptr, sub_node_comm), errflag); + datatype, op, 0, MPIDI_COMM(comm_ptr, sub_node_comm), + MPIR_SUBGROUP_NONE, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -843,9 +862,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void * buffers. */ if (MPIDI_COMM(comm_ptr, intra_node_leads_comm) != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), + MPIR_SUBGROUP_NONE, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), errflag); + mpi_errno = MPIDI_NM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), + MPIR_SUBGROUP_NONE, errflag); #endif MPIR_ERR_CHECK(mpi_errno); } @@ -892,16 +913,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void extent), per_leader_count, datatype, op, MPIDI_COMM(comm_ptr, inter_node_leads_comm), - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* Step 5: Barrier to make sure non-leaders wait for leaders to finish reducing the data * from other nodes */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -941,7 +962,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_alpha(const void *se void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -978,11 +999,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_alpha(const void *se #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(intra_sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_reduce(intra_sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); @@ -1002,16 +1023,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_alpha(const void *se } mpi_errno = MPIDI_NM_mpi_reduce(inter_sendbuf, recvbuf, count, datatype, op, 0, - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* Send data to root via point-to-point message if root is not rank 0 in comm */ if (root != 0) { if (comm->rank == 0) { - MPIC_Send(recvbuf, count, datatype, root, MPIR_REDUCE_TAG, comm, errflag); + MPIC_Send(recvbuf, count, datatype, root, MPIR_REDUCE_TAG, comm, coll_group, errflag); } else if (comm->rank == root) { - MPIC_Recv(ori_recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm, MPI_STATUS_IGNORE); + MPIC_Recv(ori_recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm, coll_group, + MPI_STATUS_IGNORE); } } @@ -1026,7 +1048,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1055,10 +1077,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); #else - mpi_errno = - MPIDI_NM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, 0, comm->node_comm, + coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } @@ -1072,7 +1094,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen mpi_errno = MPIDI_NM_mpi_reduce(buf, NULL, count, datatype, op, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* I am on root's node. I have not participated in the earlier reduce. */ if (comm->rank != root) { @@ -1081,7 +1103,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -1092,7 +1114,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* set sendbuf to MPI_IN_PLACE to make final intranode reduce easy. */ @@ -1107,11 +1129,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, errflag); + op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, + coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, errflag); + op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, + coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } @@ -1129,12 +1153,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_gamma(const void *se void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1153,6 +1178,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * int recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1196,9 +1222,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * /* Barrier to make sure that the shm buffer can be reused after the previous call to Alltoall */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -1226,9 +1252,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * /* Barrier to make sure each rank has copied the data to the shm buf */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -1244,7 +1270,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * my_node_comm_rank * num_nodes * node_comm_size * type_size * sendcount), node_comm_size * sendcount, sendtype, recvbuf, sendcount * node_comm_size, sendtype, - MPIDI_COMM(comm_ptr, multi_leads_comm), errflag); + MPIDI_COMM(comm_ptr, multi_leads_comm), + MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1260,6 +1287,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_beta(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1273,17 +1301,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_beta(const void *s #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIDI_NM_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -1302,13 +1330,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoallv_intra_composition_alpha(const void const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_alltoallv(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + sendtype, recvbuf, recvcounts, rdispls, recvtype, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1328,6 +1358,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoallw_intra_composition_alpha(const void const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1335,7 +1366,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoallw_intra_composition_alpha(const void mpi_errno = MPIDI_NM_mpi_alltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, errflag); + rdispls, recvtypes, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1351,6 +1382,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void int recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1408,9 +1440,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void /* Barrier to make sure that the shm buffer can be reused after the previous call to Allgather */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -1424,9 +1456,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void /* Barrier to make sure all the ranks in a node_comm copied data to shm buffer */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif /* Perform inter-node allgather on the multi leader comms */ @@ -1434,7 +1466,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void MPIDI_NM_mpi_allgather((char *) MPIDI_COMM_ALLGATHER(comm_ptr, shm_addr), sendcount * node_comm_size, sendtype, recvbuf, recvcount * node_comm_size, recvtype, - MPIDI_COMM(comm_ptr, multi_leads_comm), errflag); + MPIDI_COMM(comm_ptr, multi_leads_comm), MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1450,6 +1482,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_beta(const void * MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1463,17 +1496,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_beta(const void * #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIDI_NM_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -1491,6 +1524,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgatherv_intra_composition_alpha(const void const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1504,17 +1538,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgatherv_intra_composition_alpha(const void #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIDI_NM_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -1530,13 +1564,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Gather_intra_composition_alpha(const void *se void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1553,13 +1588,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Gatherv_intra_composition_alpha(const void *s const MPI_Aint * displs, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); + displs, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1575,13 +1611,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scatter_intra_composition_alpha(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1598,13 +1635,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scatterv_intra_composition_alpha(const void * MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1620,12 +1658,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_scatter_intra_composition_alpha(const MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = - MPIDI_NM_mpi_reduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag); + MPIDI_NM_mpi_reduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1643,6 +1683,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_scatter_block_intra_composition_alpha( MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { @@ -1650,7 +1691,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_scatter_block_intra_composition_alpha( mpi_errno = MPIDI_NM_mpi_reduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, - op, comm_ptr, errflag); + op, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1665,6 +1706,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1706,12 +1748,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send * one process, just copy the raw data. */ if (comm_ptr->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = - MPIDI_SHM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = - MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } else if (sendbuf != MPI_IN_PLACE) { @@ -1725,13 +1767,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (comm_ptr->node_roots_comm != NULL && comm_ptr->node_comm != NULL) { mpi_errno = MPIC_Recv(localfulldata, count, datatype, comm_ptr->node_comm->local_size - 1, MPIR_SCAN_TAG, - comm_ptr->node_comm, &status); + comm_ptr->node_comm, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else if (comm_ptr->node_roots_comm == NULL && comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, rank) == comm_ptr->node_comm->local_size - 1) { mpi_errno = MPIC_Send(recvbuf, count, datatype, - 0, MPIR_SCAN_TAG, comm_ptr->node_comm, errflag); + 0, MPIR_SCAN_TAG, comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (comm_ptr->node_roots_comm != NULL) { localfulldata = recvbuf; @@ -1743,19 +1785,19 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (comm_ptr->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_scan(localfulldata, prefulldata, count, datatype, - op, comm_ptr->node_roots_comm, errflag); + op, comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (MPIR_Get_internode_rank(comm_ptr, rank) != comm_ptr->node_roots_comm->local_size - 1) { mpi_errno = MPIC_Send(prefulldata, count, datatype, MPIR_Get_internode_rank(comm_ptr, rank) + 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, errflag); + MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (MPIR_Get_internode_rank(comm_ptr, rank) != 0) { mpi_errno = MPIC_Recv(tempbuf, count, datatype, MPIR_Get_internode_rank(comm_ptr, rank) - 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, &status); + MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, &status); noneed = 0; MPIR_ERR_CHECK(mpi_errno); } @@ -1769,10 +1811,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (comm_ptr->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -1780,12 +1824,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (noneed == 0) { if (comm_ptr->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = - MPIDI_SHM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = - MPIDI_NM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -1806,12 +1850,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_beta(const void *sendb MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1826,11 +1871,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Exscan_intra_composition_alpha(const void *se MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIDI_NM_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpid/ch4/src/ch4_comm.c b/src/mpid/ch4/src/ch4_comm.c index 808d6f6e21b..a925231244f 100644 --- a/src/mpid/ch4/src/ch4_comm.c +++ b/src/mpid/ch4/src/ch4_comm.c @@ -450,8 +450,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C mpi_errno = MPIC_Sendrecv(&local_size_send, 1, MPI_INT, remote_leader, cts_tag, &remote_size_recv, 1, MPI_INT, - remote_leader, cts_tag, peer_comm, MPI_STATUS_IGNORE, - MPIR_ERR_NONE); + remote_leader, cts_tag, peer_comm, MPIR_SUBGROUP_NONE, + MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (remote_size_recv & MPIDI_DYNPROC_MASK) @@ -488,7 +488,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C remote_leader, cts_tag, remote_upid_size, *remote_size, MPI_INT, remote_leader, cts_tag, - peer_comm, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); upid_send_size = 0; for (i = 0; i < local_size; i++) @@ -502,7 +503,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C remote_leader, cts_tag, remote_upids, upid_recv_size, MPI_BYTE, remote_leader, cts_tag, - peer_comm, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Stage 1.2 convert remote UPID to GPID and get GPID for local group */ @@ -513,7 +515,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C remote_leader, cts_tag, *remote_gpids, *remote_size, MPI_UINT64_T, remote_leader, cts_tag, - peer_comm, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } /* Stage 1.3 check if local/remote groups are disjoint */ @@ -622,23 +625,28 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i map_info[2] = *is_low_group; map_info[3] = pure_intracomm; mpi_errno = - MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, MPIR_ERR_NONE); + MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (!pure_intracomm) { mpi_errno = MPIR_Bcast_allcomm_auto(remote_upid_size, *remote_size, MPI_INT, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast_allcomm_auto(remote_upids, upid_recv_size, MPI_BYTE, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Bcast_allcomm_auto(*remote_gpids, *remote_size, MPI_UINT64_T, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } } else { mpi_errno = - MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, MPIR_ERR_NONE); + MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); *remote_size = map_info[0]; upid_recv_size = map_info[1]; @@ -651,18 +659,21 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i MPIR_CHKLMEM_MALLOC(_remote_upid_size, int *, (*remote_size) * sizeof(int), mpi_errno, "_remote_upid_size", MPL_MEM_COMM); mpi_errno = MPIR_Bcast_allcomm_auto(_remote_upid_size, *remote_size, MPI_INT, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPIR_CHKLMEM_MALLOC(_remote_upids, char *, upid_recv_size * sizeof(char), mpi_errno, "_remote_upids", MPL_MEM_COMM); mpi_errno = MPIR_Bcast_allcomm_auto(_remote_upids, upid_recv_size, MPI_BYTE, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPIDIU_upids_to_gpids(*remote_size, _remote_upid_size, _remote_upids, *remote_gpids); } else { mpi_errno = MPIR_Bcast_allcomm_auto(*remote_gpids, *remote_size, MPI_UINT64_T, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } } @@ -780,8 +791,6 @@ int MPIDI_Comm_create_multi_leaders(MPIR_Comm * comm) MPIDI_COMM(comm, multi_leads_comm)); MPIDI_COMM(comm, multi_leads_comm)->local_size = num_external; - MPIDI_COMM(comm, multi_leads_comm)->coll.pof2 = - MPL_pof2(MPIDI_COMM(comm, multi_leads_comm)->local_size); MPIDI_COMM(comm, multi_leads_comm)->remote_size = num_external; MPIR_Comm_map_irregular(MPIDI_COMM(comm, multi_leads_comm), comm, diff --git a/src/mpid/ch4/src/ch4_init.c b/src/mpid/ch4/src/ch4_init.c index 365a12b37ad..45fb824be43 100644 --- a/src/mpid/ch4/src/ch4_init.c +++ b/src/mpid/ch4/src/ch4_init.c @@ -704,7 +704,7 @@ int MPIDI_world_post_init(void) mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT, MPIDI_global.all_num_vcis, 1, MPI_INT, - MPIR_Process.comm_world, MPIR_ERR_NONE); + MPIR_Process.comm_world, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); #endif diff --git a/src/mpid/ch4/src/ch4_persist.c b/src/mpid/ch4/src/ch4_persist.c index 20ffa647d43..457dd6827cb 100644 --- a/src/mpid/ch4/src/ch4_persist.c +++ b/src/mpid/ch4/src/ch4_persist.c @@ -168,12 +168,15 @@ int MPID_Recv_init(void *buf, } int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, info_ptr, request); + mpi_errno = + MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, coll_group, info_ptr, + request); MPIR_FUNC_EXIT; return mpi_errno; @@ -181,14 +184,16 @@ int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, info_ptr, - request); + mpi_errno = + MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -196,14 +201,16 @@ int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, - info_ptr, request); + mpi_errno = + MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -212,14 +219,15 @@ int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoall_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -229,14 +237,15 @@ int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallv_init_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm_ptr, info_ptr, - request); + recvcounts, rdispls, recvtype, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -248,14 +257,15 @@ int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallw_init_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm_ptr, info_ptr, - request); + recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -264,13 +274,14 @@ int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -280,13 +291,15 @@ int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, info_ptr, request); + displs, recvtype, comm_ptr, coll_group, info_ptr, + request); return mpi_errno; } @@ -294,13 +307,15 @@ int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, - info, request); + mpi_errno = + MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, + coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -309,26 +324,29 @@ int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, - info, request); + mpi_errno = + MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + info, request); MPIR_FUNC_EXIT; return mpi_errno; } int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = + MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -337,13 +355,13 @@ int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -352,14 +370,14 @@ int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -368,13 +386,13 @@ int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -383,38 +401,40 @@ int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_init_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; } -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Barrier_init_impl(comm, info, request); + mpi_errno = MPIR_Barrier_init_impl(comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; } int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = + MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, + request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/src/ch4_spawn.c b/src/mpid/ch4/src/ch4_spawn.c index 7928fe6feec..dade6274d62 100644 --- a/src/mpid/ch4/src/ch4_spawn.c +++ b/src/mpid/ch4/src/ch4_spawn.c @@ -76,7 +76,8 @@ int MPID_Comm_spawn_multiple(int count, char *commands[], char **argvs[], const bcast_ints[0] = total_num_processes; bcast_ints[1] = spawn_error; } - mpi_errno = MPIR_Bcast(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (comm_ptr->rank != root) { total_num_processes = bcast_ints[0]; @@ -90,7 +91,8 @@ int MPID_Comm_spawn_multiple(int count, char *commands[], char **argvs[], const int should_accept = 1; if (errcodes != MPI_ERRCODES_IGNORE) { mpi_errno = - MPIR_Bcast(pmi_errcodes, total_num_processes, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + MPIR_Bcast(pmi_errcodes, total_num_processes, MPI_INT, root, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); for (int i = 0; i < total_num_processes; i++) { @@ -391,12 +393,16 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int bcast_tag_and_errno: bcast_ints[0] = tag; bcast_ints[1] = mpi_errno; - mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = bcast_ints[1]; MPIR_ERR_CHECK(mpi_errno); } else { - mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (bcast_ints[1]) { /* errno from root cannot be directly returned */ diff --git a/src/mpid/ch4/src/init_comm.c b/src/mpid/ch4/src/init_comm.c index e546337bd6f..1ff8135e2c8 100644 --- a/src/mpid/ch4/src/init_comm.c +++ b/src/mpid/ch4/src/init_comm.c @@ -32,7 +32,6 @@ int MPIDI_create_init_comm(MPIR_Comm ** comm) init_comm->rank = node_roots_comm_rank; init_comm->remote_size = node_roots_comm_size; init_comm->local_size = node_roots_comm_size; - init_comm->coll.pof2 = MPL_pof2(node_roots_comm_size); MPIDI_COMM(init_comm, map).mode = MPIDI_RANK_MAP_LUT_INTRA; mpi_errno = MPIDIU_alloc_lut(&lut, node_roots_comm_size); MPIR_ERR_CHECK(mpi_errno); @@ -47,8 +46,8 @@ int MPIDI_create_init_comm(MPIR_Comm ** comm) mpi_errno = MPIDIG_init_comm(init_comm); MPIR_ERR_CHECK(mpi_errno); /* hacky, consider a separate MPIDI_{NM,SHM}_init_comm_hook - * to initialize the init_comm, e.g. to eliminate potential - * runtime features for stability during init */ + * to initialize the init_comm, e.g. to eliminate potential + * runtime features for stability during init */ mpi_errno = MPIDI_NM_mpi_comm_commit_pre_hook(init_comm); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/src/mpidig_win.c b/src/mpid/ch4/src/mpidig_win.c index 1438c198de0..2b3a5c0f32b 100644 --- a/src/mpid/ch4/src/mpidig_win.c +++ b/src/mpid/ch4/src/mpidig_win.c @@ -393,7 +393,7 @@ static int win_init(MPI_Aint length, int disp_unit, MPIR_Win ** win_ptr, MPIR_In no_local = true; mpi_errno = MPIR_Allreduce(&no_local, &all_no_local, 1, MPI_C_BOOL, - MPI_LAND, comm_ptr, MPIR_ERR_NONE); + MPI_LAND, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (all_no_local) MPIDI_WIN(win, winattr) |= MPIDI_WINATTR_ACCU_NO_SHM; @@ -524,7 +524,7 @@ static int win_shm_alloc_impl(MPI_Aint size, int disp_unit, MPIR_Comm * comm_ptr MPI_DATATYPE_NULL, shared_table, sizeof(MPIDIG_win_shared_info_t), MPI_BYTE, shm_comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -561,7 +561,7 @@ static int win_shm_alloc_impl(MPI_Aint size, int disp_unit, MPIR_Comm * comm_ptr * - user sets alloc_shared_noncontig=true, thus we can internally make * the size aligned on each process. */ mpi_errno = MPIR_Allreduce(&symheap_flag, &global_symheap_flag, 1, MPI_C_BOOL, - MPI_LAND, comm_ptr, MPIR_ERR_NONE); + MPI_LAND, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else global_symheap_flag = false; @@ -692,7 +692,7 @@ int MPIDIG_mpi_win_set_info(MPIR_Win * win, MPIR_Info * info) /* Do not update winattr except for info set at window creation. * Because it will change RMA's behavior which requires collective synchronization. */ - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: MPIR_FUNC_EXIT; return mpi_errno; @@ -857,7 +857,7 @@ int MPIDIG_mpi_win_free(MPIR_Win ** win_ptr) MPIDIG_ACCESS_EPOCH_CHECK_NONE(win, mpi_errno, return mpi_errno); MPIDIG_EXPOSURE_EPOCH_CHECK_NONE(win, mpi_errno, return mpi_errno); - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -894,7 +894,7 @@ int MPIDIG_mpi_win_create(void *base, MPI_Aint length, int disp_unit, MPIR_Info MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -962,7 +962,7 @@ int MPIDIG_mpi_win_allocate_shared(MPI_Aint size, int disp_unit, MPIR_Info * inf MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: MPIR_FUNC_EXIT; @@ -1026,7 +1026,7 @@ int MPIDIG_mpi_win_allocate(MPI_Aint size, int disp_unit, MPIR_Info * info, MPIR MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -1065,7 +1065,7 @@ int MPIDIG_mpi_win_create_dynamic(MPIR_Info * info, MPIR_Comm * comm, MPIR_Win * MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: MPIR_FUNC_EXIT; diff --git a/src/mpid/ch4/src/mpidig_win.h b/src/mpid/ch4/src/mpidig_win.h index 6353bf7def3..4e5936a5262 100644 --- a/src/mpid/ch4/src/mpidig_win.h +++ b/src/mpid/ch4/src/mpidig_win.h @@ -522,7 +522,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_fence(int massert, MPIR_Win * win) * the VCI lock internally. */ MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); need_unlock = 0; - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: if (need_unlock) { diff --git a/src/mpid/common/bc/mpidu_bc.c b/src/mpid/common/bc/mpidu_bc.c index 587cd12b2fb..c84d95ee730 100644 --- a/src/mpid/common/bc/mpidu_bc.c +++ b/src/mpid/common/bc/mpidu_bc.c @@ -107,7 +107,8 @@ int MPIDU_bc_allgather(MPIR_Comm * allgather_comm, void *bc, int bc_len, int sam void *recv_buf = segment + local_size * recv_bc_len; if (rank == node_root) { MPIR_Allgatherv_fallback(segment, local_size * recv_bc_len, MPI_BYTE, recv_buf, - recv_cnts, recv_offs, MPI_BYTE, allgather_comm, MPIR_ERR_NONE); + recv_cnts, recv_offs, MPI_BYTE, allgather_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } diff --git a/src/mpid/common/sched/mpidu_sched.c b/src/mpid/common/sched/mpidu_sched.c index 891ced6b46c..0298b928c41 100644 --- a/src/mpid/common/sched/mpidu_sched.c +++ b/src/mpid/common/sched/mpidu_sched.c @@ -148,7 +148,7 @@ int MPIDU_Sched_are_pending(void) return (all_schedules.head != NULL); } -int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag) +int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int coll_group, int *tag) { int mpi_errno = MPI_SUCCESS; /* TODO there should be an internal accessor/utility macro for getting the @@ -162,6 +162,10 @@ int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag) MPIR_FUNC_ENTER; *tag = comm_ptr->next_sched_tag; + if (coll_group != MPIR_SUBGROUP_NONE) { + /* subgroup collectives use the same tag within a parent collective */ + goto fn_exit; + } ++comm_ptr->next_sched_tag; #if defined(HAVE_ERROR_CHECKING) @@ -191,11 +195,13 @@ int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag) if (comm_ptr->next_sched_tag == tag_ub) { comm_ptr->next_sched_tag = MPIR_FIRST_NBC_TAG; } + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; #if defined(HAVE_ERROR_CHECKING) fn_fail: + goto fn_exit; #endif - MPIR_FUNC_EXIT; - return mpi_errno; } void MPIDU_Sched_set_tag(struct MPIDU_Sched *s, int tag) @@ -226,12 +232,12 @@ static int MPIDU_Sched_start_entry(struct MPIDU_Sched *s, size_t idx, struct MPI * &send.count, but this requires patching up the pointers * during realloc of entries, so this is easier */ ret_errno = MPIC_Isend(e->u.send.buf, *e->u.send.count_p, e->u.send.datatype, - e->u.send.dest, s->tag, comm, &e->u.send.sreq, - r->u.nbc.errflag); + e->u.send.dest, s->tag, comm, e->u.send.coll_group, + &e->u.send.sreq, r->u.nbc.errflag); } else { ret_errno = MPIC_Isend(e->u.send.buf, e->u.send.count, e->u.send.datatype, - e->u.send.dest, s->tag, comm, &e->u.send.sreq, - r->u.nbc.errflag); + e->u.send.dest, s->tag, comm, e->u.send.coll_group, + &e->u.send.sreq, r->u.nbc.errflag); } /* Check if the error is actually fatal to the NBC or we can continue. */ if (unlikely(ret_errno)) { @@ -256,7 +262,8 @@ static int MPIDU_Sched_start_entry(struct MPIDU_Sched *s, size_t idx, struct MPI MPL_DBG_MSG_D(MPIR_DBG_COMM, VERBOSE, "starting RECV entry %d\n", (int) idx); comm = e->u.recv.comm; ret_errno = MPIC_Irecv(e->u.recv.buf, e->u.recv.count, e->u.recv.datatype, - e->u.recv.src, s->tag, comm, &e->u.recv.rreq); + e->u.recv.src, s->tag, comm, e->u.recv.coll_group, + &e->u.recv.rreq); /* Check if the error is actually fatal to the NBC or we can continue. */ if (unlikely(ret_errno)) { if (MPIR_ERR_NONE == r->u.nbc.errflag) { @@ -655,7 +662,7 @@ static int MPIDU_Sched_add_entry(struct MPIDU_Sched *s, int *idx, struct MPIDU_S /* do these ops need an entry handle returned? */ int MPIDU_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s) + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -674,6 +681,7 @@ int MPIDU_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int e->u.send.dest = dest; e->u.send.sreq = NULL; /* will be populated by _start_entry */ e->u.send.comm = comm; + e->u.send.coll_group = coll_group; /* the user may free the comm & type after initiating but before the * underlying send is actually posted, so we must add a reference here and @@ -711,6 +719,7 @@ int MPIDU_Sched_pt2pt_send(const void *buf, MPI_Aint count, MPI_Datatype datatyp e->u.send.dest = dest; e->u.send.sreq = NULL; /* will be populated by _start_entry */ e->u.send.comm = comm; + e->u.send.coll_group = MPIR_SUBGROUP_NONE; e->u.send.tag = tag; /* the user may free the comm & type after initiating but before the @@ -730,7 +739,7 @@ int MPIDU_Sched_pt2pt_send(const void *buf, MPI_Aint count, MPI_Datatype datatyp } int MPIDU_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s) + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -749,6 +758,7 @@ int MPIDU_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype e->u.send.dest = dest; e->u.send.sreq = NULL; /* will be populated by _start_entry */ e->u.send.comm = comm; + e->u.send.coll_group = coll_group; /* the user may free the comm & type after initiating but before the * underlying send is actually posted, so we must add a reference here and @@ -767,7 +777,7 @@ int MPIDU_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype } int MPIDU_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, - MPIR_Comm * comm, MPI_Status * status, MPIR_Sched_t s) + MPIR_Comm * comm, int coll_group, MPI_Status * status, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -785,6 +795,7 @@ int MPIDU_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, in e->u.recv.src = src; e->u.recv.rreq = NULL; /* will be populated by _start_entry */ e->u.recv.comm = comm; + e->u.recv.coll_group = coll_group; e->u.recv.status = status; status->MPI_ERROR = MPI_SUCCESS; MPIR_Comm_add_ref(comm); @@ -801,7 +812,7 @@ int MPIDU_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, in } int MPIDU_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, MPIR_Comm * comm, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -819,6 +830,7 @@ int MPIDU_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, e->u.recv.src = src; e->u.recv.rreq = NULL; /* will be populated by _start_entry */ e->u.recv.comm = comm; + e->u.recv.coll_group = coll_group; e->u.recv.status = MPI_STATUS_IGNORE; MPIR_Comm_add_ref(comm); @@ -853,6 +865,7 @@ int MPIDU_Sched_pt2pt_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, e->u.recv.src = src; e->u.recv.rreq = NULL; /* will be populated by _start_entry */ e->u.recv.comm = comm; + e->u.send.coll_group = MPIR_SUBGROUP_NONE; e->u.recv.status = MPI_STATUS_IGNORE; e->u.recv.tag = tag; diff --git a/src/mpid/common/sched/mpidu_sched.h b/src/mpid/common/sched/mpidu_sched.h index 2454ad60c84..90d43b75392 100644 --- a/src/mpid/common/sched/mpidu_sched.h +++ b/src/mpid/common/sched/mpidu_sched.h @@ -43,6 +43,7 @@ struct MPIDU_Sched_send { int tag; /* only used for _PT2PT_SEND */ int dest; struct MPIR_Comm *comm; + int coll_group; struct MPIR_Request *sreq; }; @@ -53,6 +54,7 @@ struct MPIDU_Sched_recv { int tag; /* only used for _PT2PT_RECV */ int src; struct MPIR_Comm *comm; + int coll_group; struct MPIR_Request *rreq; MPI_Status *status; }; @@ -132,7 +134,7 @@ struct MPIDU_Sched { /* prototypes */ int MPIDU_Sched_progress(int vci, int *made_progress); int MPIDU_Sched_are_pending(void); -int MPIDU_Sched_next_tag(struct MPIR_Comm *comm_ptr, int *tag); +int MPIDU_Sched_next_tag(struct MPIR_Comm *comm_ptr, int coll_group, int *tag); void MPIDU_Sched_set_tag(MPIR_Sched_t s, int tag); int MPIDU_Sched_create(MPIR_Sched_t * sp, enum MPIR_Sched_kind kind); int MPIDU_Sched_clone(MPIR_Sched_t orig, MPIR_Sched_t * cloned); @@ -141,9 +143,9 @@ int MPIDU_Sched_reset(MPIR_Sched_t s); void *MPIDU_Sched_alloc_state(MPIR_Sched_t s, MPI_Aint size); int MPIDU_Sched_start(MPIR_Sched_t sp, struct MPIR_Comm *comm, struct MPIR_Request **req); int MPIDU_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - struct MPIR_Comm *comm, MPIR_Sched_t s); + struct MPIR_Comm *comm, int coll_group, MPIR_Sched_t s); int MPIDU_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, - struct MPIR_Comm *comm, MPIR_Sched_t s); + struct MPIR_Comm *comm, int coll_group, MPIR_Sched_t s); int MPIDU_Sched_pt2pt_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int tag, int dest, struct MPIR_Comm *comm, MPIR_Sched_t s); int MPIDU_Sched_pt2pt_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, diff --git a/src/mpid/common/shm/mpidu_shm_alloc.c b/src/mpid/common/shm/mpidu_shm_alloc.c index 72fdaa0fa55..7720a1134cd 100644 --- a/src/mpid/common/shm/mpidu_shm_alloc.c +++ b/src/mpid/common/shm/mpidu_shm_alloc.c @@ -227,7 +227,7 @@ static int allreduce_maxloc(size_t mysz, int myloc, MPIR_Comm * comm, size_t * m mpi_errno = MPIR_Allreduce(&maxloc, &maxloc_result, 1, maxloc_type, maxloc_op->handle, comm, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); *maxsz_loc = maxloc_result.loc; @@ -282,21 +282,25 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int root_sync: /* broadcast the mapping result on rank 0 */ - mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*map_result_ptr != SYMSHM_SUCCESS) goto map_fail; mpi_errno = MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, - shm_comm_ptr, MPIR_ERR_NONE); + shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else { char serialized_hnd[MPL_SHM_GHND_SZ] = { 0 }; /* receive the mapping result of rank 0 */ - mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*map_result_ptr != SYMSHM_SUCCESS) @@ -306,7 +310,7 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int /* get serialized handle from rank 0 and deserialize it */ mpi_errno = MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, - shm_comm_ptr, MPIR_ERR_NONE); + shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpl_err = @@ -331,7 +335,7 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int * return SYMSHM_OTHER_FAIL if anyone reports it (max result == 2). * Otherwise return SYMSHM_MAP_FAIL (max result == 1). */ mpi_errno = MPIR_Allreduce(map_result_ptr, &all_map_result, 1, MPI_INT, - MPI_MAX, shm_comm_ptr, MPIR_ERR_NONE); + MPI_MAX, shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (all_map_result != SYMSHM_SUCCESS) @@ -423,8 +427,9 @@ static int shm_alloc_symm_all(MPIR_Comm * comm_ptr, size_t offset, MPIDU_shm_seg map_pointer = generate_random_addr(shm_seg->segment_len); /* broadcast fixed address to the other processes in comm */ - mpi_errno = MPIR_Bcast(&map_pointer, sizeof(char *), MPI_CHAR, maxsz_loc, comm_ptr, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(&map_pointer, sizeof(char *), MPI_CHAR, maxsz_loc, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* optimization: make sure every process memory in the shared segment is mapped @@ -442,7 +447,7 @@ static int shm_alloc_symm_all(MPIR_Comm * comm_ptr, size_t offset, MPIDU_shm_seg /* check if any mapping failure occurs */ mpi_errno = MPIR_Allreduce(&map_result, &all_map_result, 1, MPI_INT, - MPI_MAX, comm_ptr, MPIR_ERR_NONE); + MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* cleanup local shm segment if mapping failed on other process */ @@ -492,8 +497,9 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) if (shm_fail_flag) serialized_hnd = &mpl_err_hnd[0]; - mpi_errno = MPIR_Bcast_impl(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, shm_comm_ptr, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_impl(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, shm_comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (shm_fail_flag) @@ -501,7 +507,7 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) /* ensure all other processes have mapped successfully */ mpi_errno = MPIR_Allreduce_impl(&shm_fail_flag, &any_shm_fail_flag, 1, MPI_C_BOOL, - MPI_LOR, shm_comm_ptr, MPIR_ERR_NONE); + MPI_LOR, shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -516,7 +522,7 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) /* get serialized handle from rank 0 and deserialize it */ mpi_errno = MPIR_Bcast_impl(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, - shm_comm_ptr, MPIR_ERR_NONE); + shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* empty handler means root fails */ @@ -539,7 +545,7 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) result_sync: mpi_errno = MPIR_Allreduce_impl(&shm_fail_flag, &any_shm_fail_flag, 1, MPI_C_BOOL, - MPI_LOR, shm_comm_ptr, MPIR_ERR_NONE); + MPI_LOR, shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (any_shm_fail_flag) diff --git a/src/util/mpir_nodemap.c b/src/util/mpir_nodemap.c index a3d5fea409a..e8f6d1ee6ba 100644 --- a/src/util/mpir_nodemap.c +++ b/src/util/mpir_nodemap.c @@ -436,14 +436,14 @@ int MPIR_nodeid_init(void) utarray_resize(MPIR_Process.node_hostnames, MPIR_Process.num_nodes, MPL_MEM_OTHER); char *allhostnames = (char *) utarray_eltptr(MPIR_Process.node_hostnames, 0); - if (MPIR_Process.local_rank == 0) { - MPIR_Comm *node_roots_comm = MPIR_Process.comm_world->node_roots_comm; - if (node_roots_comm == NULL) { - /* num_external == comm->remote_size */ - node_roots_comm = MPIR_Process.comm_world; - } + MPIR_Comm *world_comm = MPIR_Process.comm_world; + int local_rank = world_comm->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = world_comm->subgroups[MPIR_SUBGROUP_NODE].size; + + if (local_rank == 0) { + int inter_rank = world_comm->subgroups[MPIR_SUBGROUP_NODE_CROSS].rank; - char *my_hostname = allhostnames + MAX_HOSTNAME_LEN * node_roots_comm->rank; + char *my_hostname = allhostnames + MAX_HOSTNAME_LEN * inter_rank; int ret = gethostname(my_hostname, MAX_HOSTNAME_LEN); char strerrbuf[MPIR_STRERROR_BUF_SIZE] ATTRIBUTE((unused)); MPIR_ERR_CHKANDJUMP2(ret == -1, mpi_errno, MPI_ERR_OTHER, @@ -453,14 +453,13 @@ int MPIR_nodeid_init(void) mpi_errno = MPIR_Allgather_impl(MPI_IN_PLACE, MAX_HOSTNAME_LEN, MPI_CHAR, allhostnames, MAX_HOSTNAME_LEN, MPI_CHAR, - node_roots_comm, MPIR_ERR_NONE); + world_comm, MPIR_SUBGROUP_NODE_CROSS, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } - MPIR_Comm *node_comm = MPIR_Process.comm_world->node_comm; - if (node_comm) { + if (local_size > 1) { mpi_errno = MPIR_Bcast_impl(allhostnames, MAX_HOSTNAME_LEN * MPIR_Process.num_nodes, - MPI_CHAR, 0, node_comm, MPIR_ERR_NONE); + MPI_CHAR, 0, world_comm, MPIR_SUBGROUP_NODE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); }