From 0749914220c15d32f3ab2d134b6ce1fda71dd74d Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 12 Aug 2024 17:27:49 -0500 Subject: [PATCH] coll: MPIR_THREADCOMM_RANK_SIZE to check coll_attr Enhance the macro MPIR_THREADCOMM_RANK_SIZE to check coll_attr for rank and size. --- src/include/mpir_threadcomm.h | 11 +++++------ src/mpi/coll/allgather/allgather_intra_brucks.c | 2 +- src/mpi/coll/allgatherv/allgatherv_intra_brucks.c | 2 +- .../allreduce/allreduce_intra_recursive_doubling.c | 2 +- src/mpi/coll/alltoall/alltoall_intra_brucks.c | 2 +- src/mpi/coll/alltoallv/alltoallv_intra_scattered.c | 2 +- src/mpi/coll/alltoallw/alltoallw_intra_scattered.c | 2 +- src/mpi/coll/barrier/barrier_intra_k_dissemination.c | 2 +- src/mpi/coll/bcast/bcast_intra_binomial.c | 2 +- src/mpi/coll/exscan/exscan_intra_recursive_doubling.c | 2 +- src/mpi/coll/gather/gather_intra_binomial.c | 2 +- src/mpi/coll/gatherv/gatherv_allcomm_linear.c | 2 +- src/mpi/coll/reduce/reduce_intra_binomial.c | 2 +- .../reduce_scatter_intra_recursive_halving.c | 2 +- .../reduce_scatter_block_intra_recursive_halving.c | 2 +- src/mpi/coll/scan/scan_intra_recursive_doubling.c | 2 +- src/mpi/coll/scatter/scatter_intra_binomial.c | 2 +- src/mpi/coll/scatterv/scatterv_allcomm_linear.c | 2 +- 18 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/include/mpir_threadcomm.h b/src/include/mpir_threadcomm.h index cda298f1f9e..02c974795ef 100644 --- a/src/include/mpir_threadcomm.h +++ b/src/include/mpir_threadcomm.h @@ -111,23 +111,22 @@ MPL_STATIC_INLINE_PREFIX } #ifdef ENABLE_THREADCOMM -#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ +#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \ MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \ if (threadcomm) { \ + MPIR_Assert(MPIR_COLL_ATTR_GET_SUBGROUP(coll_attr) == 0); /* for now */ \ int intracomm_size = (comm)->local_size; \ size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \ rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \ } else { \ - rank_ = (comm)->rank; \ - size_ = (comm)->local_size; \ + MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \ } \ } while (0) #else -#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ +#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \ MPIR_Assert((comm)->threadcomm == NULL); \ - rank_ = (comm)->rank; \ - size_ = (comm)->local_size; \ + MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \ } while (0) #endif diff --git a/src/mpi/coll/allgather/allgather_intra_brucks.c b/src/mpi/coll/allgather/allgather_intra_brucks.c index e4f4bc85537..d96647a72fc 100644 --- a/src/mpi/coll/allgather/allgather_intra_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_brucks.c @@ -34,7 +34,7 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, if (((sendcount == 0) && (sendbuf != MPI_IN_PLACE)) || (recvcount == 0)) goto fn_exit; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); MPIR_Datatype_get_size_macro(recvtype, recvtype_sz); diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c index f9a1f85de57..c61b2d54905 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c @@ -34,7 +34,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); total_count = 0; for (i = 0; i < comm_size; i++) diff --git a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c index 7cb1d6461e5..5679ee1155d 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c @@ -32,7 +32,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, MPI_Aint true_extent, true_lb, extent; void *tmp_buf; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/alltoall/alltoall_intra_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_brucks.c index 815ae7ced9c..e5eeeaba9ac 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_brucks.c @@ -37,7 +37,7 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(6); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf != MPI_IN_PLACE); diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c index 475caf0bb5a..c7b2d0d4cbf 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c @@ -37,7 +37,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); /* Get extent of recv type, but send type is only valid if (sendbuf!=MPI_IN_PLACE) */ MPIR_Datatype_get_extent_macro(recvtype, recv_extent); diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c index 77443fb8a8b..4233ccc6c42 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c @@ -36,7 +36,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount MPI_Aint type_size; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* When MPI_IN_PLACE, we use pair-wise sendrecv_replace in order to conserve memory usage, diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c index 859c3c961de..2ef13ce7894 100644 --- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c +++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c @@ -21,7 +21,7 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_attr) int size, rank, src, dst, mask, mpi_errno = MPI_SUCCESS; int mpi_errno_ret = MPI_SUCCESS; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, size); mask = 0x1; while (mask < size) { diff --git a/src/mpi/coll/bcast/bcast_intra_binomial.c b/src/mpi/coll/bcast/bcast_intra_binomial.c index 01e780dc0e7..ba46680963f 100644 --- a/src/mpi/coll/bcast/bcast_intra_binomial.c +++ b/src/mpi/coll/bcast/bcast_intra_binomial.c @@ -32,7 +32,7 @@ int MPIR_Bcast_intra_binomial(void *buffer, void *tmp_buf = NULL; MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); if (HANDLE_IS_BUILTIN(datatype)) is_contig = 1; diff --git a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c index 4aa11821e2a..517146c7e25 100644 --- a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c +++ b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c @@ -59,7 +59,7 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf, void *partial_scan, *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/gather/gather_intra_binomial.c b/src/mpi/coll/gather/gather_intra_binomial.c index fd9bbab5297..aa9b8a07a10 100644 --- a/src/mpi/coll/gather/gather_intra_binomial.c +++ b/src/mpi/coll/gather/gather_intra_binomial.c @@ -58,7 +58,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); /* Use binomial tree algorithm. */ diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c index 0a3b3fdd1ab..7fe5cc13455 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c @@ -54,7 +54,7 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || diff --git a/src/mpi/coll/reduce/reduce_intra_binomial.c b/src/mpi/coll/reduce/reduce_intra_binomial.c index 9c10e483115..d29310d95b6 100644 --- a/src/mpi/coll/reduce/reduce_intra_binomial.c +++ b/src/mpi/coll/reduce/reduce_intra_binomial.c @@ -24,7 +24,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); /* Create a temporary buffer */ diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c index 956ab8cfdfd..fdf7711939e 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c @@ -50,7 +50,7 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb int pof2, old_i, newrank; MPIR_CHKLMEM_DECL(5); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); #ifdef HAVE_ERROR_CHECKING { diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c index 1669f1b5657..b7a1918528b 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c @@ -53,7 +53,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, int pof2, old_i, newrank; MPIR_CHKLMEM_DECL(5); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); #ifdef HAVE_ERROR_CHECKING { diff --git a/src/mpi/coll/scan/scan_intra_recursive_doubling.c b/src/mpi/coll/scan/scan_intra_recursive_doubling.c index 55064face49..5d4c9cbf111 100644 --- a/src/mpi/coll/scan/scan_intra_recursive_doubling.c +++ b/src/mpi/coll/scan/scan_intra_recursive_doubling.c @@ -55,7 +55,7 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf, void *partial_scan, *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/scatter/scatter_intra_binomial.c b/src/mpi/coll/scatter/scatter_intra_binomial.c index de95ab55dd6..940df5de1f7 100644 --- a/src/mpi/coll/scatter/scatter_intra_binomial.c +++ b/src/mpi/coll/scatter/scatter_intra_binomial.c @@ -42,7 +42,7 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat int mpi_errno_ret = MPI_SUCCESS; MPIR_CHKLMEM_DECL(4); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); if (rank == root) MPIR_Datatype_get_extent_macro(sendtype, extent); diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c index b0ca7f3184e..6031f21b87d 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c @@ -30,7 +30,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size); /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) ||