Skip to content

Commit

Permalink
Merge pull request pmodels#6674 from raffenet/misc-pmi
Browse files Browse the repository at this point in the history
misc: PMI fixes and improvements

Approved-by: Hui Zhou <hzhou321@anl.gov>
  • Loading branch information
raffenet authored Sep 19, 2023
2 parents ef5f196 + 555c9f8 commit d1a428e
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 121 deletions.
62 changes: 2 additions & 60 deletions src/mpi/info/info_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,14 @@ int MPIR_Info_create_env_impl(int argc, char **argv, MPIR_Info ** new_info_ptr)
goto fn_exit;
}

static int hex_encode(char *str, const void *value, int len);
static int hex_decode(const char *str, void *buf, int len);

int MPIR_Info_set_hex_impl(MPIR_Info * info_ptr, const char *key, const void *value, int value_size)
{
int mpi_errno = MPI_SUCCESS;

char value_buf[1024];
MPIR_Assertp(value_size * 2 + 1 < 1024);

hex_encode(value_buf, value, value_size);
MPL_hex_encode(value_size, value, value_buf);

mpi_errno = MPIR_Info_set_impl(info_ptr, key, value_buf);

Expand All @@ -231,66 +228,11 @@ int MPIR_Info_decode_hex(const char *str, void *buf, int len)
{
int mpi_errno = MPI_SUCCESS;

int rc = hex_decode(str, buf, len);
int rc = MPL_hex_decode(len, str, buf);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_OTHER, "**infohexinvalid");

fn_exit:
return mpi_errno;
fn_fail:
goto fn_exit;
}

/* ---- internal utility ---- */

/* Simple hex encoding binary as hexadecimal string. For example,
* a binary with 4 bytes, 0x12, 0x34, 0x56, 0x78, will be encoded
* as ascii string "12345678". The encoded string will have string
* length of exactly double the binary size plus a terminating "NUL".
*/

static int hex_val(char c)
{
/* translate a hex char [0-9a-fA-F] to its value (0-15) */
if (c >= '0' && c <= '9') {
return c - '0';
} else if (c >= 'a' && c <= 'f') {
return c - 'a' + 10;
} else if (c >= 'A' && c <= 'F') {
return c - 'A' + 10;
} else {
return -1;
}
}

static int hex_encode(char *str, const void *value, int len)
{
/* assume the size of str is already validated */

const unsigned char *s = value;

for (int i = 0; i < len; i++) {
sprintf(str + i * 2, "%02x", s[i]);
}

return 0;
}

static int hex_decode(const char *str, void *buf, int len)
{
int n = strlen(str);
if (n != len * 2) {
return 1;
}

unsigned char *s = buf;
for (int i = 0; i < len; i++) {
int a = hex_val(str[i * 2]);
int b = hex_val(str[i * 2 + 1]);
if (a < 0 || b < 0) {
return 1;
}
s[i] = (unsigned char) ((a << 4) + b);
}

return 0;
}
3 changes: 3 additions & 0 deletions src/mpl/include/mpl_misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ int mkstemp(char *template);
int MPL_mkstemp(char *template);
#endif

int MPL_hex_encode(int size, const char *src, char *dest);
int MPL_hex_decode(int size, const char *src, char *dest);

#endif /* MPL_MISC_H_INCLUDED */
52 changes: 52 additions & 0 deletions src/mpl/src/misc/mpl_misc.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,55 @@ int MPL_get_nprocs(void)
return 1;
#endif
}

/* Simple hex encoding binary as hexadecimal string. For example,
* a binary with 4 bytes, 0x12, 0x34, 0x56, 0x78, will be encoded
* as ascii string "12345678". The encoded string will have string
* length of exactly double the binary size plus a terminating "NUL".
*/

static int hex(unsigned char c)
{
if (c >= '0' && c <= '9') {
return c - '0';
} else if (c >= 'a' && c <= 'f') {
return 10 + c - 'a';
} else if (c >= 'A' && c <= 'F') {
return 10 + c - 'A';
} else {
assert(0);
return -1;
}
}

/* encodes src data into hex characters */
int MPL_hex_encode(int size, const char *src, char *dest)
{
for (int i = 0; i < size; i++) {
snprintf(dest, 3, "%02X", (unsigned char) *src);
src++;
dest += 2;
}

return 0;
}

/* decodes hex encoded string into original src data */
int MPL_hex_decode(int size, const char *src, char *dest)
{
int n = strlen(src);
if (n != size * 2) {
return 1;
}

for (int i = 0; i < size; i++) {
if (hex(src[0]) < 0 || hex(src[1]) < 0) {
return 1;
}
*dest = (char) (hex(src[0]) << 4) + hex(src[1]);
src += 2;
dest++;
}

return 0;
}
23 changes: 17 additions & 6 deletions src/pm/hydra/proxy/pmip_pmi.c
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,8 @@ HYD_status fn_get_my_kvsname(struct pmip_downstream *p, struct PMIU_cmd *pmi)
goto fn_exit;
}

HYD_status fn_get_usize(struct pmip_downstream *p, struct PMIU_cmd *pmi)
static int get_universe_size(struct pmip_downstream *p)
{
HYD_status status = HYD_SUCCESS;
int pmi_errno;
HYDU_FUNC_ENTER();

int universe_size;
if (HYD_pmcd_pmip.user_global.usize == HYD_USIZE_SYSTEM) {
universe_size = PMIP_pg_from_downstream(p)->global_core_map.global_count;
Expand All @@ -368,6 +364,16 @@ HYD_status fn_get_usize(struct pmip_downstream *p, struct PMIU_cmd *pmi)
} else {
universe_size = HYD_pmcd_pmip.user_global.usize;
}
return universe_size;
}

HYD_status fn_get_usize(struct pmip_downstream * p, struct PMIU_cmd * pmi)
{
HYD_status status = HYD_SUCCESS;
int pmi_errno;
HYDU_FUNC_ENTER();

int universe_size = get_universe_size(p);

struct PMIU_cmd pmi_response;
pmi_errno = PMIU_msg_set_response_universe(pmi, &pmi_response, is_static, universe_size);
Expand All @@ -390,6 +396,11 @@ static const char *get_jobattr(struct pmip_downstream *p, const char *key)
return PMIP_pg_from_downstream(p)->pmi_process_mapping;
} else if (!strcmp(key, "PMI_hwloc_xmlfile")) {
return HYD_pmip_get_hwloc_xmlfile();
} else if (!strcmp(key, "universeSize")) {
static char universe_str[64];
int universe_size = get_universe_size(p);
snprintf(universe_str, 64, "%d", universe_size);
return universe_str;
}
return NULL;
}
Expand All @@ -406,7 +417,7 @@ HYD_status fn_get(struct pmip_downstream * p, struct PMIU_cmd * pmi)

bool found = false;
const char *val;
if (strncmp(key, "PMI_", 4) == 0) {
if (strncmp(key, "PMI_", 4) == 0 || strcmp(key, "universeSize") == 0) {
val = get_jobattr(p, key);
if (val) {
found = true;
Expand Down
31 changes: 18 additions & 13 deletions src/pmi/src/pmi_v2.c
Original file line number Diff line number Diff line change
Expand Up @@ -519,26 +519,31 @@ PMI_API_PUBLIC int PMI2_Info_GetJobAttr(const char name[], char value[], int val
{
int pmi_errno = PMI2_SUCCESS;

struct PMIU_cmd pmicmd;
PMIU_msg_set_query_get(&pmicmd, USE_WIRE_VER, no_static, NULL, name);
if (PMI_initialized > SINGLETON_INIT_BUT_NO_PM) {
struct PMIU_cmd pmicmd;
PMIU_msg_set_query_get(&pmicmd, USE_WIRE_VER, no_static, NULL, name);

pmi_errno = PMIU_cmd_get_response(PMI_fd, &pmicmd);
pmi_errno = PMIU_cmd_get_response(PMI_fd, &pmicmd);

bool found;
const char *tmp_val;
if (pmi_errno == PMIU_SUCCESS) {
pmi_errno = PMIU_msg_get_response_get(&pmicmd, &tmp_val, &found);
}
bool found;
const char *tmp_val;
if (pmi_errno == PMIU_SUCCESS) {
pmi_errno = PMIU_msg_get_response_get(&pmicmd, &tmp_val, &found);
}

if (!pmi_errno && found) {
MPL_strncpy(value, tmp_val, valuelen);
*flag = 1;
if (!pmi_errno && found) {
MPL_strncpy(value, tmp_val, valuelen);
*flag = 1;
} else {
*flag = 0;
pmi_errno = PMIU_SUCCESS;
}

PMIU_cmd_free_buf(&pmicmd);
} else {
*flag = 0;
pmi_errno = PMIU_SUCCESS;
}

PMIU_cmd_free_buf(&pmicmd);
return pmi_errno;
}

Expand Down
45 changes: 4 additions & 41 deletions src/util/mpir_pmi.c
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,6 @@ int MPIR_pmi_barrier_local(void)
return mpi_errno;
}

/* declare static functions used in bcast/allgather */
static void encode(int size, const char *src, char *dest);
static void decode(int size, const char *src, char *dest);

/* is_local is a hint that we optimize for node local access when we can */
static int optimized_put(const char *key, const char *val, int is_local)
{
Expand Down Expand Up @@ -353,7 +349,7 @@ static int put_ex_segs(const char *key, const void *buf, int bufsize, int is_loc
* (depends on pmi implementations, and may not be sufficient) */
int segsize = (pmi_max_val_size - 2) / 2;
if (bufsize < segsize) {
encode(bufsize, buf, val);
MPL_hex_encode(bufsize, buf, val);
mpi_errno = optimized_put(key, val, is_local);
MPIR_ERR_CHECK(mpi_errno);
} else {
Expand All @@ -371,7 +367,7 @@ static int put_ex_segs(const char *key, const void *buf, int bufsize, int is_loc
if (i == num_segs - 1) {
n = bufsize - segsize * (num_segs - 1);
}
encode(n, (char *) buf + i * segsize, val);
MPL_hex_encode(n, (char *) buf + i * segsize, val);
mpi_errno = optimized_put(seg_key, val, is_local);
MPIR_ERR_CHECK(mpi_errno);
}
Expand Down Expand Up @@ -410,12 +406,12 @@ static int get_ex_segs(int src, const char *key, void *buf, int *p_size, int is_
} else {
MPIR_Assert(n <= segsize);
}
decode(n, val, (char *) buf + i * segsize);
MPL_hex_decode(n, val, (char *) buf + i * segsize);
got_size += n;
}
} else {
int n = strlen(val) / 2; /* 2-to-1 decode */
decode(n, val, (char *) buf);
MPL_hex_decode(n, val, (char *) buf);
got_size = n;
}
MPIR_Assert(got_size <= bufsize);
Expand Down Expand Up @@ -724,39 +720,6 @@ int MPIR_pmi_unpublish(const char name[])

/* ---- static functions ---- */

/* similar to functions in mpl/src/str/mpl_argstr.c, but much simpler */
static int hex(unsigned char c)
{
if (c >= '0' && c <= '9') {
return c - '0';
} else if (c >= 'a' && c <= 'f') {
return 10 + c - 'a';
} else if (c >= 'A' && c <= 'F') {
return 10 + c - 'A';
} else {
MPIR_Assert(0);
return -1;
}
}

static void encode(int size, const char *src, char *dest)
{
for (int i = 0; i < size; i++) {
snprintf(dest, 3, "%02X", (unsigned char) *src);
src++;
dest += 2;
}
}

static void decode(int size, const char *src, char *dest)
{
for (int i = 0; i < size; i++) {
*dest = (char) (hex(src[0]) << 4) + hex(src[1]);
src += 2;
dest++;
}
}

/* static functions used in MPIR_pmi_spawn_multiple */

static int mpi_to_pmi_keyvals(MPIR_Info * info_ptr, INFO_TYPE ** kv_ptr, int *nkeys_ptr)
Expand Down
6 changes: 5 additions & 1 deletion src/util/mpir_pmix.inc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ static int pmix_get_binary(int src, const char *key, char *buf, int *p_size, int
int mpi_errno = MPI_SUCCESS;
int pmi_errno;

int bufsize = *p_size;
int bufsize ATTRIBUTE((unused)) = *p_size;
pmix_value_t *pvalue;
if (src < 0) {
pmi_errno = PMIx_Get(NULL, key, NULL, 0, &pvalue);
Expand Down Expand Up @@ -338,6 +338,10 @@ static int pmix_get_universe_size(int *universe_size)
pmix_value_t *pvalue = NULL;

pmi_errno = PMIx_Get(&pmix_wcproc, PMIX_UNIV_SIZE, NULL, 0, &pvalue);
if (pmi_errno == PMIX_ERR_NOT_FOUND) {
*universe_size = MPIR_UNIVERSE_SIZE_NOT_AVAILABLE;
goto fn_exit;
}
MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
"**pmix_get", "**pmix_get %d", pmi_errno);
*universe_size = pvalue->data.uint32;
Expand Down

0 comments on commit d1a428e

Please sign in to comment.