Skip to content

Commit

Permalink
util: fix defects when gathering coordinates
Browse files Browse the repository at this point in the history
This commit addresses the following issues:
Resource leaks in pmix_init()
Error handling issues in parse_coord_file()
Insecure data handling in parse_coord_file()
  • Loading branch information
dycz0fx committed Dec 7, 2023
1 parent f131256 commit 00f5183
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
24 changes: 14 additions & 10 deletions src/util/mpir_pmi.c
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ static void free_pmi_keyvals(INFO_TYPE ** kv, int size, int *counts)
static int parse_coord_file(const char *filename)
{
int mpi_errno = MPI_SUCCESS;
int i, j, rank;
int i, j, rank, coords_dims;
FILE *coords_file;
int fields_scanned;

Expand All @@ -864,31 +864,36 @@ static int parse_coord_file(const char *filename)
"**filenoexist", "**filenoexist %s", filename);

/* Skip the first line */
fields_scanned = fscanf(coords_file, "%*[^\n]\n");
MPIR_Process.coords_dims = 0;
int fields_scanned = fscanf(coords_file, "%*[^\n]\n");
MPIR_ERR_CHKANDSTMT2(0 != fields_scanned, mpi_errno, MPI_ERR_FILE, goto fn_fail_read,
"**read_file", "**read_file %s %s", filename, strerror(errno));
coords_dims = 0;
fields_scanned = fscanf(coords_file, "%d:", &rank);
MPIR_ERR_CHKANDSTMT2(1 != fields_scanned, mpi_errno, MPI_ERR_FILE, goto fn_fail_read,
"**read_file", "**read_file %s %s", filename, strerror(errno));
while (!feof(coords_file)) {
int temp = 0;
if (fscanf(coords_file, "%d", &temp) == 1)
++MPIR_Process.coords_dims;
++coords_dims;
else
break;
if (fgetc(coords_file) == '\n')
break;
}

MPIR_Assert(MPIR_Process.coords_dims == 3);
MPIR_Assert(coords_dims == 3);
MPIR_Process.coords_dims = coords_dims;
rewind(coords_file);
/* Skip the first line */
fields_scanned = fscanf(coords_file, "%*[^\n]\n");
MPIR_ERR_CHKANDSTMT2(0 != fields_scanned, mpi_errno, MPI_ERR_FILE, goto fn_fail_read,
"**read_file", "**read_file %s %s", filename, strerror(errno));

if (MPIR_Process.coords == NULL) {
MPIR_Process.coords =
MPL_malloc(MPIR_Process.coords_dims * sizeof(int) * MPIR_Process.size, MPL_MEM_COLL);
MPL_malloc(coords_dims * sizeof(int) * MPIR_Process.size, MPL_MEM_COLL);
}
memset(MPIR_Process.coords, -1, MPIR_Process.coords_dims * sizeof(int) * MPIR_Process.size);
memset(MPIR_Process.coords, -1, coords_dims * sizeof(int) * MPIR_Process.size);

for (i = 0; i < MPIR_Process.size; ++i) {
fields_scanned = fscanf(coords_file, "%d:", &rank);
Expand All @@ -900,12 +905,11 @@ static int parse_coord_file(const char *filename)
rank, MPIR_Process.size);
continue;
}
for (j = 0; j < MPIR_Process.coords_dims; ++j) {
for (j = 0; j < coords_dims; ++j) {
/* MPIR_Process.coords stores the coords in this order: port number, switch_id, group_id */
fields_scanned =
fscanf(coords_file, "%d",
&MPIR_Process.coords[rank * MPIR_Process.coords_dims +
MPIR_Process.coords_dims - 1 - j]);
&MPIR_Process.coords[rank * coords_dims + coords_dims - 1 - j]);
MPIR_ERR_CHKANDSTMT2(1 != fields_scanned, mpi_errno, MPI_ERR_FILE, goto fn_fail_read,
"**read_file", "**read_file %s %s", filename, strerror(errno));
}
Expand Down
15 changes: 7 additions & 8 deletions src/util/mpir_pmix.inc
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,16 @@ static int pmix_init(int *has_parent, int *rank, int *size, int *appnum)
MPIR_Assert(pvalue->data.uint32 <= INT_MAX); /* overflow check */
*appnum = (int) pvalue->data.uint32;
PMIX_VALUE_RELEASE(pvalue);
pmix_value_t value;
pmix_value_t *val = &value;

for (int i = 0; i < *size; i++) {
pmix_proc.rank = i;
pmi_errno = PMIx_Get(&pmix_proc, PMIX_FABRIC_COORDINATES, NULL, 0, &val);
pmi_errno = PMIx_Get(&pmix_proc, PMIX_FABRIC_COORDINATES, NULL, 0, &pvalue);
if (pmi_errno != PMIX_SUCCESS) {
MPIR_Process.coords = NULL;
break;
}
MPIR_Assert(val->data.coord->dims <= INT_MAX);
MPIR_Process.coords_dims = (int) val->data.coord->dims;
MPIR_Assert(pvalue->data.coord->dims <= INT_MAX);
MPIR_Process.coords_dims = (int) pvalue->data.coord->dims;
MPIR_Assert(MPIR_Process.coords_dims == 3);

if (i == 0) {
Expand All @@ -118,13 +116,14 @@ static int pmix_init(int *has_parent, int *rank, int *size, int *appnum)
MPIR_ERR_CHKANDJUMP(!MPIR_Process.coords, mpi_errno, MPI_ERR_OTHER, "**nomem");
}

if (PMIX_COORD == val->type) {
if (PMIX_COORD == pvalue->type) {
for (int j = 0; j < MPIR_Process.coords_dims; j++) {
/* MPIR_Process.coords stores the coords in this order: port number, switch_id, group_id */
/* MPIR_Process.coords is in this order: port number, switch_id, group_id */
MPIR_Process.coords[i * MPIR_Process.coords_dims + j] =
val->data.coord->coord[MPIR_Process.coords_dims - 1 - j];
pvalue->data.coord->coord[MPIR_Process.coords_dims - 1 - j];
}
}
PMIX_VALUE_RELEASE(pvalue);
}

fn_exit:
Expand Down

0 comments on commit 00f5183

Please sign in to comment.