Skip to content

Commit

Permalink
all: use internal datatypes internally
Browse files Browse the repository at this point in the history
The external builtin datatypes, e.g. MPI_INT, may be reconfigured at
runtime. This won't be the case practically, but it is possibility by
design, so that all MPI builtin datatypes, MPI_INT or MPI_INTEGER, are
treated the same.
  • Loading branch information
hzhou committed Jan 15, 2025
1 parent caac7ab commit 87da97d
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/mpid/ch4/netmod/ofi/init_addrxchg.c
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ int MPIDI_OFI_addr_exchange_all_ctx(void)
/* Allgather num_vcis */
MPIR_CHKLMEM_MALLOC(all_num_vcis, void *, sizeof(int) * size,
mpi_errno, "all_num_vcis", MPL_MEM_ADDRESS);
mpi_errno = MPIR_Allgather_fallback(&MPIDI_OFI_global.num_vcis, 1, MPI_INT,
all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE);
mpi_errno = MPIR_Allgather_fallback(&MPIDI_OFI_global.num_vcis, 1, MPI_INT_INTERNAL,
all_num_vcis, 1, MPI_INT_INTERNAL, comm, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

max_vcis = 0;
Expand Down
5 changes: 3 additions & 2 deletions src/mpid/ch4/netmod/ofi/ofi_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ static int update_nic_preferences(MPIR_Comm * comm)

/* Collect the NIC IDs set for the other ranks. We always expect to receive a single
* NIC id from each rank, i.e., one MPI_INT. */
mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_INT,
pref_nic_copy, 1, MPI_INT, comm, MPIR_ERR_NONE);
mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_INT_INTERNAL,
pref_nic_copy, 1, MPI_INT_INTERNAL, comm,
MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

if (MPIDI_OFI_COMM(comm).pref_nic == NULL) {
Expand Down
2 changes: 1 addition & 1 deletion src/mpid/ch4/netmod/ofi/ofi_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ static int check_num_nics(void)
MPIDI_OFI_global.num_vcis = MPIDI_OFI_global.num_nics = 1;

/* Confirm that all processes have the same number of NICs */
mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, &num_nics, 1, MPI_INT,
mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, &num_nics, 1, MPI_INT_INTERNAL,
MPI_MIN, MPIR_Process.comm_world, MPIR_ERR_NONE);
MPIDI_OFI_global.num_vcis = tmp_num_vcis;
MPIDI_OFI_global.num_nics = tmp_num_nics;
Expand Down
4 changes: 2 additions & 2 deletions src/mpid/ch4/netmod/ofi/ofi_win.c
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit)
}

/* Check if any process fails to register. If so, release local MR and force AM path. */
MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_ERR_NONE);
MPIR_Allreduce(&rc, &allrc, 1, MPI_INT_INTERNAL, MPI_MIN, comm_ptr, MPIR_ERR_NONE);
if (allrc < 0) {
if (rc >= 0 && MPIDI_OFI_WIN(win).mr)
MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_WIN(win).mr->fid), fi_close);
Expand Down Expand Up @@ -969,7 +969,7 @@ int MPIDI_OFI_mpi_win_attach_hook(MPIR_Win * win, void *base, MPI_Aint size)
}

/* Check if any process fails to register. If so, release local MR and force AM path. */
MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_ERR_NONE);
MPIR_Allreduce(&rc, &allrc, 1, MPI_INT_INTERNAL, MPI_MIN, comm_ptr, MPIR_ERR_NONE);
if (allrc < 0) {
if (rc >= 0)
MPIDI_OFI_CALL(fi_close(&mr->fid), fi_close);
Expand Down
6 changes: 4 additions & 2 deletions src/mpid/ch4/src/ch4_coll_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1769,10 +1769,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send

if (comm_ptr->node_comm != NULL) {
#ifndef MPIDI_CH4_DIRECT_NETMOD
mpi_errno = MPIDI_SHM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag);
mpi_errno = MPIDI_SHM_mpi_bcast(&noneed, 1, MPI_INT_INTERNAL, 0, comm_ptr->node_comm,
errflag);
MPIR_ERR_CHECK(mpi_errno);
#else
mpi_errno = MPIDI_NM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag);
mpi_errno = MPIDI_NM_mpi_bcast(&noneed, 1, MPI_INT_INTERNAL, 0, comm_ptr->node_comm,
errflag);
MPIR_ERR_CHECK(mpi_errno);
#endif /* MPIDI_CH4_DIRECT_NETMOD */
}
Expand Down
20 changes: 10 additions & 10 deletions src/mpid/ch4/src/ch4_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C
MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_COMM, VERBOSE,
(MPL_DBG_FDEST, "rank %d sendrecv to rank %d",
peer_comm->rank, remote_leader));
mpi_errno = MPIC_Sendrecv(&local_size_send, 1, MPI_INT,
mpi_errno = MPIC_Sendrecv(&local_size_send, 1, MPI_INT_INTERNAL,
remote_leader, cts_tag,
&remote_size_recv, 1, MPI_INT,
&remote_size_recv, 1, MPI_INT_INTERNAL,
remote_leader, cts_tag, peer_comm, MPI_STATUS_IGNORE,
MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
Expand Down Expand Up @@ -484,9 +484,9 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C

mpi_errno = MPIDI_NM_get_local_upids(local_comm, &local_upid_size, &local_upids);
MPIR_ERR_CHECK(mpi_errno);
mpi_errno = MPIC_Sendrecv(local_upid_size, local_size, MPI_INT,
mpi_errno = MPIC_Sendrecv(local_upid_size, local_size, MPI_INT_INTERNAL,
remote_leader, cts_tag,
remote_upid_size, *remote_size, MPI_INT,
remote_upid_size, *remote_size, MPI_INT_INTERNAL,
remote_leader, cts_tag,
peer_comm, MPI_STATUS_IGNORE, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
Expand Down Expand Up @@ -621,12 +621,12 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i
map_info[1] = upid_recv_size;
map_info[2] = *is_low_group;
map_info[3] = pure_intracomm;
mpi_errno =
MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT_INTERNAL,
local_leader, local_comm, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

if (!pure_intracomm) {
mpi_errno = MPIR_Bcast_allcomm_auto(remote_upid_size, *remote_size, MPI_INT,
mpi_errno = MPIR_Bcast_allcomm_auto(remote_upid_size, *remote_size, MPI_INT_INTERNAL,
local_leader, local_comm, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
mpi_errno = MPIR_Bcast_allcomm_auto(remote_upids, upid_recv_size, MPI_BYTE,
Expand All @@ -637,8 +637,8 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i
local_leader, local_comm, MPIR_ERR_NONE);
}
} else {
mpi_errno =
MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT_INTERNAL,
local_leader, local_comm, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
*remote_size = map_info[0];
upid_recv_size = map_info[1];
Expand All @@ -650,7 +650,7 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i
if (!pure_intracomm) {
MPIR_CHKLMEM_MALLOC(_remote_upid_size, int *, (*remote_size) * sizeof(int),
mpi_errno, "_remote_upid_size", MPL_MEM_COMM);
mpi_errno = MPIR_Bcast_allcomm_auto(_remote_upid_size, *remote_size, MPI_INT,
mpi_errno = MPIR_Bcast_allcomm_auto(_remote_upid_size, *remote_size, MPI_INT_INTERNAL,
local_leader, local_comm, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
MPIR_CHKLMEM_MALLOC(_remote_upids, char *, upid_recv_size * sizeof(char),
Expand Down
4 changes: 2 additions & 2 deletions src/mpid/ch4/src/ch4_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,8 @@ int MPIDI_world_post_init(void)
MPIDI_global.n_reserved_vcis -= diff;
}

mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT,
MPIDI_global.all_num_vcis, 1, MPI_INT,
mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT_INTERNAL,
MPIDI_global.all_num_vcis, 1, MPI_INT_INTERNAL,
MPIR_Process.comm_world, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
#endif
Expand Down
12 changes: 7 additions & 5 deletions src/mpid/ch4/src/ch4_spawn.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ int MPID_Comm_spawn_multiple(int count, char *commands[], char **argvs[], const
bcast_ints[0] = total_num_processes;
bcast_ints[1] = mpi_errno;
}
mpi_errno = MPIR_Bcast(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast(bcast_ints, 2, MPI_INT_INTERNAL, root, comm_ptr, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
if (comm_ptr->rank != root) {
total_num_processes = bcast_ints[0];
Expand All @@ -89,8 +89,8 @@ int MPID_Comm_spawn_multiple(int count, char *commands[], char **argvs[], const

int should_accept = 1;
if (errcodes != MPI_ERRCODES_IGNORE) {
mpi_errno =
MPIR_Bcast(pmi_errcodes, total_num_processes, MPI_INT, root, comm_ptr, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast(pmi_errcodes, total_num_processes, MPI_INT_INTERNAL,
root, comm_ptr, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

for (int i = 0; i < total_num_processes; i++) {
Expand Down Expand Up @@ -391,12 +391,14 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int
bcast_tag_and_errno:
bcast_ints[0] = tag;
bcast_ints[1] = mpi_errno;
mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT_INTERNAL, root, comm_ptr,
MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
mpi_errno = bcast_ints[1];
MPIR_ERR_CHECK(mpi_errno);
} else {
mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT_INTERNAL, root, comm_ptr,
MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
if (bcast_ints[1]) {
/* errno from root cannot be directly returned */
Expand Down
10 changes: 6 additions & 4 deletions src/mpid/common/shm/mpidu_shm_alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int

root_sync:
/* broadcast the mapping result on rank 0 */
mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT_INTERNAL, 0, shm_comm_ptr,
MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

if (*map_result_ptr != SYMSHM_SUCCESS)
Expand All @@ -296,7 +297,8 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int
char serialized_hnd[MPL_SHM_GHND_SZ] = { 0 };

/* receive the mapping result of rank 0 */
mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_ERR_NONE);
mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT_INTERNAL, 0, shm_comm_ptr,
MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

if (*map_result_ptr != SYMSHM_SUCCESS)
Expand Down Expand Up @@ -330,7 +332,7 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int
/* check results of all processes. If any failure happens (max result > 0),
* return SYMSHM_OTHER_FAIL if anyone reports it (max result == 2).
* Otherwise return SYMSHM_MAP_FAIL (max result == 1). */
mpi_errno = MPIR_Allreduce(map_result_ptr, &all_map_result, 1, MPI_INT,
mpi_errno = MPIR_Allreduce(map_result_ptr, &all_map_result, 1, MPI_INT_INTERNAL,
MPI_MAX, shm_comm_ptr, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

Expand Down Expand Up @@ -441,7 +443,7 @@ static int shm_alloc_symm_all(MPIR_Comm * comm_ptr, size_t offset, MPIDU_shm_seg
MPIR_ERR_CHECK(mpi_errno);

/* check if any mapping failure occurs */
mpi_errno = MPIR_Allreduce(&map_result, &all_map_result, 1, MPI_INT,
mpi_errno = MPIR_Allreduce(&map_result, &all_map_result, 1, MPI_INT_INTERNAL,
MPI_MAX, comm_ptr, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);

Expand Down

0 comments on commit 87da97d

Please sign in to comment.