Skip to content

Commit

Permalink
coll: MPIR_THREADCOMM_RANK_SIZE to check coll_attr
Browse files Browse the repository at this point in the history
Enhance the macro MPIR_THREADCOMM_RANK_SIZE to check coll_attr for rank
and size.
  • Loading branch information
hzhou committed Aug 12, 2024
1 parent 166dbb7 commit 0749914
Show file tree
Hide file tree
Showing 18 changed files with 22 additions and 23 deletions.
11 changes: 5 additions & 6 deletions src/include/mpir_threadcomm.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,22 @@ MPL_STATIC_INLINE_PREFIX
}

#ifdef ENABLE_THREADCOMM
#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \
MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \
if (threadcomm) { \
MPIR_Assert(MPIR_COLL_ATTR_GET_SUBGROUP(coll_attr) == 0); /* for now */ \
int intracomm_size = (comm)->local_size; \
size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \
rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \
} else { \
rank_ = (comm)->rank; \
size_ = (comm)->local_size; \
MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \
} \
} while (0)

#else
#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \
MPIR_Assert((comm)->threadcomm == NULL); \
rank_ = (comm)->rank; \
size_ = (comm)->local_size; \
MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \
} while (0)

#endif
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/allgather/allgather_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
if (((sendcount == 0) && (sendbuf != MPI_IN_PLACE)) || (recvcount == 0))
goto fn_exit;

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent);
MPIR_Datatype_get_size_macro(recvtype, recvtype_sz);
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/allgatherv/allgatherv_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf,
void *tmp_buf;
MPIR_CHKLMEM_DECL(1);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

total_count = 0;
for (i = 0; i < comm_size; i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf,
MPI_Aint true_extent, true_lb, extent;
void *tmp_buf;

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

is_commutative = MPIR_Op_is_commutative(op);

Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/alltoall/alltoall_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf,
void *tmp_buf;
MPIR_CHKLMEM_DECL(6);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

#ifdef HAVE_ERROR_CHECKING
MPIR_Assert(sendbuf != MPI_IN_PLACE);
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/alltoallv/alltoallv_intra_scattered.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou

MPIR_CHKLMEM_DECL(2);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

/* Get extent of recv type, but send type is only valid if (sendbuf!=MPI_IN_PLACE) */
MPIR_Datatype_get_extent_macro(recvtype, recv_extent);
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/alltoallw/alltoallw_intra_scattered.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount
MPI_Aint type_size;
MPIR_CHKLMEM_DECL(2);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

#ifdef HAVE_ERROR_CHECKING
/* When MPI_IN_PLACE, we use pair-wise sendrecv_replace in order to conserve memory usage,
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/barrier/barrier_intra_k_dissemination.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_attr)
int size, rank, src, dst, mask, mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, size);

mask = 0x1;
while (mask < size) {
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/bcast/bcast_intra_binomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int MPIR_Bcast_intra_binomial(void *buffer,
void *tmp_buf = NULL;
MPIR_CHKLMEM_DECL(1);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

if (HANDLE_IS_BUILTIN(datatype))
is_contig = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/exscan/exscan_intra_recursive_doubling.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf,
void *partial_scan, *tmp_buf;
MPIR_CHKLMEM_DECL(2);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

is_commutative = MPIR_Op_is_commutative(op);

Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/gather/gather_intra_binomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data
MPIR_CHKLMEM_DECL(1);


MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

/* Use binomial tree algorithm. */

Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/gatherv/gatherv_allcomm_linear.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf,
MPI_Status *starray;
MPIR_CHKLMEM_DECL(2);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

/* If rank == root, then I recv lots, otherwise I send */
if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) ||
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/reduce/reduce_intra_binomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf,
void *tmp_buf;
MPIR_CHKLMEM_DECL(2);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

/* Create a temporary buffer */

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb
int pof2, old_i, newrank;
MPIR_CHKLMEM_DECL(5);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

#ifdef HAVE_ERROR_CHECKING
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf,
int pof2, old_i, newrank;
MPIR_CHKLMEM_DECL(5);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

#ifdef HAVE_ERROR_CHECKING
{
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/scan/scan_intra_recursive_doubling.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf,
void *partial_scan, *tmp_buf;
MPIR_CHKLMEM_DECL(2);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

is_commutative = MPIR_Op_is_commutative(op);

Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/scatter/scatter_intra_binomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat
int mpi_errno_ret = MPI_SUCCESS;
MPIR_CHKLMEM_DECL(4);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

if (rank == root)
MPIR_Datatype_get_extent_macro(sendtype, extent);
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/scatterv/scatterv_allcomm_linear.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount
MPI_Status *starray;
MPIR_CHKLMEM_DECL(2);

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);
MPIR_THREADCOMM_RANK_SIZE(comm_ptr, coll_attr, rank, comm_size);

/* If I'm the root, then scatter */
if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) ||
Expand Down

0 comments on commit 0749914

Please sign in to comment.