Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OMPI/OSC/UCX: fix issue in impl of MPI_Win_create_dynamic/MPI_Win_attach/MPI_Win_detach #5094

Merged
merged 1 commit into from
May 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ompi/mca/osc/ucx/osc_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
#include "ompi/communicator/communicator.h"

#define OMPI_OSC_UCX_POST_PEER_MAX 32
#define OMPI_OSC_UCX_ATTACH_MAX 32
#define OMPI_OSC_UCX_RKEY_BUF_MAX 1024

typedef struct ompi_osc_ucx_win_info {
ucp_rkey_h rkey;
uint64_t addr;
bool rkey_init;
} ompi_osc_ucx_win_info_t;

typedef struct ompi_osc_ucx_component {
Expand Down Expand Up @@ -59,6 +62,18 @@ typedef struct ompi_osc_ucx_epoch_type {
#define OSC_UCX_STATE_COMPLETE_COUNT_OFFSET (sizeof(uint64_t) * 3)
#define OSC_UCX_STATE_POST_INDEX_OFFSET (sizeof(uint64_t) * 4)
#define OSC_UCX_STATE_POST_STATE_OFFSET (sizeof(uint64_t) * 5)
#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (5 + OMPI_OSC_UCX_POST_PEER_MAX))

typedef struct ompi_osc_dynamic_win_info {
uint64_t base;
size_t size;
char rkey_buffer[OMPI_OSC_UCX_RKEY_BUF_MAX];
} ompi_osc_dynamic_win_info_t;

typedef struct ompi_osc_local_dynamic_win_info {
ucp_mem_h memh;
int refcnt;
} ompi_osc_local_dynamic_win_info_t;

typedef struct ompi_osc_ucx_state {
volatile uint64_t lock;
Expand All @@ -67,12 +82,16 @@ typedef struct ompi_osc_ucx_state {
volatile uint64_t complete_count; /* # msgs received from complete processes */
volatile uint64_t post_index;
volatile uint64_t post_state[OMPI_OSC_UCX_POST_PEER_MAX];
volatile uint64_t dynamic_win_count;
volatile ompi_osc_dynamic_win_info_t dynamic_wins[OMPI_OSC_UCX_ATTACH_MAX];
} ompi_osc_ucx_state_t;

typedef struct ompi_osc_ucx_module {
ompi_osc_base_module_t super;
struct ompi_communicator_t *comm;
ucp_mem_h memh; /* remote accessible memory */
int flavor;
size_t size;
ucp_mem_h state_memh;
ompi_osc_ucx_win_info_t *win_info_array;
ompi_osc_ucx_win_info_t *state_info_array;
Expand All @@ -82,6 +101,7 @@ typedef struct ompi_osc_ucx_module {
int *disp_units;

ompi_osc_ucx_state_t state; /* remote accessible flags */
ompi_osc_local_dynamic_win_info_t local_dynamic_win_info[OMPI_OSC_UCX_ATTACH_MAX];
ompi_osc_ucx_epoch_type_t epoch_type;
ompi_group_t *start_group;
ompi_group_t *post_group;
Expand Down Expand Up @@ -184,6 +204,10 @@ int ompi_osc_ucx_flush_all(struct ompi_win_t *win);
int ompi_osc_ucx_flush_local(int target, struct ompi_win_t *win);
int ompi_osc_ucx_flush_local_all(struct ompi_win_t *win);

int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_wins,
int min_index, int max_index,
uint64_t base, size_t len, int *insert);

void req_completion(void *request, ucs_status_t status);
void internal_req_init(void *request);

Expand Down
123 changes: 117 additions & 6 deletions ompi/mca/osc/ucx/osc_ucx_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,68 @@ static inline int end_atomicity(ompi_osc_ucx_module_t *module, ucp_ep_h ep, int
return OMPI_SUCCESS;
}

static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module_t *module,
ucp_ep_h ep, int target) {
ucp_rkey_h state_rkey = (module->state_info_array)[target].rkey;
uint64_t remote_state_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET;
size_t len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX;
char *temp_buf = malloc(len);
ompi_osc_dynamic_win_info_t *temp_dynamic_wins;
int win_count, contain, insert = -1;
ucs_status_t status;

if ((module->win_info_array[target]).rkey_init == true) {
ucp_rkey_destroy((module->win_info_array[target]).rkey);
(module->win_info_array[target]).rkey_init == false;
}

status = ucp_get_nbi(ep, (void *)temp_buf, len, remote_state_addr, state_rkey);
if (status != UCS_OK && status != UCS_INPROGRESS) {
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
"%s:%d: ucp_get_nbi failed: %d\n",
__FILE__, __LINE__, status);
return OMPI_ERROR;
}

status = ucp_ep_flush(ep);
if (status != UCS_OK) {
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
"%s:%d: ucp_ep_flush failed: %d\n",
__FILE__, __LINE__, status);
return OMPI_ERROR;
}

memcpy(&win_count, temp_buf, sizeof(uint64_t));
assert(win_count > 0 && win_count <= OMPI_OSC_UCX_ATTACH_MAX);

temp_dynamic_wins = (ompi_osc_dynamic_win_info_t *)(temp_buf + sizeof(uint64_t));
contain = ompi_osc_find_attached_region_position(temp_dynamic_wins, 0, win_count,
remote_addr, 1, &insert);
assert(contain >= 0 && contain < win_count);

status = ucp_ep_rkey_unpack(ep, temp_dynamic_wins[contain].rkey_buffer,
&((module->win_info_array[target]).rkey));
if (status != UCS_OK) {
opal_output_verbose(1, ompi_osc_base_framework.framework_output,
"%s:%d: ucp_ep_rkey_unpack failed: %d\n",
__FILE__, __LINE__, status);
return OMPI_ERROR;
}

(module->win_info_array[target]).rkey_init = true;

free(temp_buf);

return status;
}

int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, int target_count,
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
ucp_rkey_h rkey;
bool is_origin_contig = false, is_target_contig = false;
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
ucs_status_t status;
Expand All @@ -342,6 +397,15 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
return OMPI_ERROR;
}
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);

Expand Down Expand Up @@ -378,7 +442,7 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
ucp_rkey_h rkey;
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
bool is_origin_contig = false, is_target_contig = false;
ucs_status_t status;
Expand All @@ -389,6 +453,15 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
return OMPI_ERROR;
}
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);

Expand Down Expand Up @@ -557,10 +630,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
ucp_rkey_h rkey;
size_t dt_bytes;
ompi_osc_ucx_internal_request_t *req = NULL;
int ret = OMPI_SUCCESS;
ucs_status_t status;

ret = check_sync_state(module, target, false);
if (ret != OMPI_SUCCESS) {
Expand All @@ -572,6 +646,15 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
return OMPI_ERROR;
}
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_type_size(dt, &dt_bytes);
memcpy(result_addr, origin_addr, dt_bytes);
req = ucp_atomic_fetch_nb(ep, UCP_ATOMIC_FETCH_OP_CSWAP, *(uint64_t *)compare_addr,
Expand Down Expand Up @@ -604,17 +687,27 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
op == &ompi_mpi_op_sum.op) {
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
ucp_rkey_h rkey = (module->win_info_array[target]).rkey;
ucp_rkey_h rkey;
uint64_t value = *(uint64_t *)origin_addr;
ucp_atomic_fetch_op_t opcode;
size_t dt_bytes;
ompi_osc_ucx_internal_request_t *req = NULL;
ucs_status_t status;

ret = start_atomicity(module, ep, target);
if (ret != OMPI_SUCCESS) {
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
return OMPI_ERROR;
}
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_type_size(dt, &dt_bytes);

if (op == &ompi_mpi_op_replace.op) {
Expand Down Expand Up @@ -789,7 +882,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
uint64_t remote_addr = (module->state_info_array[target]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET;
ucp_rkey_h rkey = (module->state_info_array[target]).rkey;
ucp_rkey_h rkey;
ompi_osc_ucx_request_t *ucx_req = NULL;
ompi_osc_ucx_internal_request_t *internal_req = NULL;
ucs_status_t status;
Expand All @@ -800,6 +893,15 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
return OMPI_ERROR;
}
}

rkey = (module->win_info_array[target]).rkey;

OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
if (NULL == ucx_req) {
return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
Expand Down Expand Up @@ -843,7 +945,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
uint64_t remote_addr = (module->state_info_array[target]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET;
ucp_rkey_h rkey = (module->state_info_array[target]).rkey;
ucp_rkey_h rkey;
ompi_osc_ucx_request_t *ucx_req = NULL;
ompi_osc_ucx_internal_request_t *internal_req = NULL;
ucs_status_t status;
Expand All @@ -854,6 +956,15 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
return OMPI_ERROR;
}
}

rkey = (module->win_info_array[target]).rkey;

OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
if (NULL == ucx_req) {
return OMPI_ERR_TEMP_OUT_OF_RESOURCE;
Expand Down
Loading