diff --git a/source/adios2/toolkit/sst/dp/rdma_dp.c b/source/adios2/toolkit/sst/dp/rdma_dp.c index 5de750abeb..3ec3139503 100644 --- a/source/adios2/toolkit/sst/dp/rdma_dp.c +++ b/source/adios2/toolkit/sst/dp/rdma_dp.c @@ -74,12 +74,16 @@ int sst_fi_mr_reg( CP_Services Svcs, void *CP_Stream, /* regular fi_mir_reg() parameters*/ struct fid_domain *domain, const void *buf, size_t len, uint64_t acs, uint64_t offset, - uint64_t requested_key, uint64_t flags, struct fid_mr **mr, void *context, + uint64_t *requested_key, uint64_t flags, struct fid_mr **mr, void *context, /* additional parameters for binding the mr to the endpoint*/ struct fid_ep *endpoint, int mr_mode) { *mr = NULL; - int res = fi_mr_reg(domain, buf, len, acs, offset, requested_key, flags, mr, context); + int res = fi_mr_reg(domain, buf, len, acs, offset, *requested_key, flags, mr, context); + if (*requested_key != 0) + { + ++*requested_key; + } int is_mr_endpoint = (mr_mode & FI_MR_ENDPOINT) != 0; if (!is_mr_endpoint) { @@ -281,6 +285,7 @@ struct fabric_state #endif /* SST_HAVE_CRAY_DRC */ struct cq_manual_progress *cq_manual_progress; pthread_t pthread_id; + uint64_t mr_key; }; // Wrapper for fi_cq_sread to be called in its stead from the main thread. @@ -591,15 +596,19 @@ static void init_fabric(struct fabric_state *fabric, struct _SstParams *Params, * (1) It does not support FI_MR_VIRT_ADDR. * (2) It requires use of FI_MR_ENDPOINT. * + * Some other providers again (e.g. psm3) don't support FI_MR_PROV_KEY. + * * So we propagate the bit value currently contained in the mr_mode * for these flags. */ if (info->domain_attr->mr_mode != FI_MR_BASIC) { - info->domain_attr->mr_mode = FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_LOCAL | + info->domain_attr->mr_mode = FI_MR_ALLOCATED | FI_MR_LOCAL | + (FI_MR_PROV_KEY & info->domain_attr->mr_mode) | (FI_MR_ENDPOINT & info->domain_attr->mr_mode) | (FI_MR_VIRT_ADDR & info->domain_attr->mr_mode); fabric->mr_virt_addr = info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ? 1 : 0; + fabric->mr_key = info->domain_attr->mr_mode & FI_MR_PROV_KEY ? 0 : 1; } else { @@ -1729,8 +1738,8 @@ static DP_WSR_Stream RdmaInitWriterPerReader(CP_Services Svcs, DP_WS_Stream WS_S ReaderRollHandle = &ContactInfo->ReaderRollHandle; ReaderRollHandle->Block = calloc(readerCohortSize, sizeof(struct _RdmaBuffer)); sst_fi_mr_reg(Svcs, WS_Stream->CP_Stream, Fabric->domain, ReaderRollHandle->Block, - readerCohortSize * sizeof(struct _RdmaBuffer), FI_REMOTE_WRITE, 0, 0, 0, - &WSR_Stream->rrmr, Fabric->ctx, Fabric->signal, + readerCohortSize * sizeof(struct _RdmaBuffer), FI_REMOTE_WRITE, 0, + &Fabric->mr_key, 0, &WSR_Stream->rrmr, Fabric->ctx, Fabric->signal, Fabric->info->domain_attr->mr_mode); ReaderRollHandle->Key = fi_mr_key(WSR_Stream->rrmr); @@ -1869,8 +1878,8 @@ static ssize_t PostRead(CP_Services Svcs, Rdma_RS_Stream RS_Stream, int Rank, lo if (Fabric->local_mr_req) { // register dest buffer - sst_fi_mr_reg(Svcs, RS_Stream->CP_Stream, Fabric->domain, Buffer, Length, FI_READ, 0, 0, 0, - &ret->LocalMR, Fabric->ctx, Fabric->signal, + sst_fi_mr_reg(Svcs, RS_Stream->CP_Stream, Fabric->domain, Buffer, Length, FI_READ, 0, + &Fabric->mr_key, 0, &ret->LocalMR, Fabric->ctx, Fabric->signal, Fabric->info->domain_attr->mr_mode); LocalDesc = fi_mr_desc(ret->LocalMR); } @@ -2188,8 +2197,8 @@ static void RdmaProvideTimestep(CP_Services Svcs, DP_WS_Stream Stream_v, struct Entry->Desc = NULL; sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, Data->block, Data->DataSize, - FI_WRITE | FI_REMOTE_READ, 0, 0, 0, &Entry->mr, Fabric->ctx, Fabric->signal, - Fabric->info->domain_attr->mr_mode); + FI_WRITE | FI_REMOTE_READ, 0, &Fabric->mr_key, 0, &Entry->mr, Fabric->ctx, + Fabric->signal, Fabric->info->domain_attr->mr_mode); Entry->Key = fi_mr_key(Entry->mr); if (Fabric->local_mr_req) { @@ -2706,16 +2715,17 @@ static void PostPreload(CP_Services Svcs, Rdma_RS_Stream Stream, long Timestep) PreloadBuffer->BufferLen = 2 * StepLog->BufferSize; PreloadBuffer->Handle.Block = malloc(PreloadBuffer->BufferLen); sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, PreloadBuffer->Handle.Block, - PreloadBuffer->BufferLen, FI_REMOTE_WRITE, 0, 0, 0, &Stream->pbmr, Fabric->ctx, - Fabric->signal, Fabric->info->domain_attr->mr_mode); + PreloadBuffer->BufferLen, FI_REMOTE_WRITE, 0, &Fabric->mr_key, 0, &Stream->pbmr, + Fabric->ctx, Fabric->signal, Fabric->info->domain_attr->mr_mode); PreloadKey = fi_mr_key(Stream->pbmr); SBSize = sizeof(*SendBuffer) * StepLog->WRanks; SendBuffer = malloc(SBSize); if (Fabric->local_mr_req) { - sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, SendBuffer, SBSize, FI_WRITE, 0, 0, - 0, &sbmr, Fabric->ctx, Fabric->signal, Fabric->info->domain_attr->mr_mode); + sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, SendBuffer, SBSize, FI_WRITE, 0, + &Fabric->mr_key, 0, &sbmr, Fabric->ctx, Fabric->signal, + Fabric->info->domain_attr->mr_mode); sbdesc = fi_mr_desc(sbmr); } @@ -2724,7 +2734,7 @@ static void PostPreload(CP_Services Svcs, Rdma_RS_Stream Stream, long Timestep) RBLen = 2 * StepLog->Entries * DP_DATA_RECV_SIZE; Stream->RecvDataBuffer = malloc(RBLen); sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, Stream->RecvDataBuffer, RBLen, - FI_RECV, 0, 0, 0, &Stream->rbmr, Fabric->ctx, Fabric->signal, + FI_RECV, 0, &Fabric->mr_key, 0, &Stream->rbmr, Fabric->ctx, Fabric->signal, Fabric->info->domain_attr->mr_mode); Stream->rbdesc = fi_mr_desc(Stream->rbmr); RecvBuffer = (uint8_t *)Stream->RecvDataBuffer; @@ -2750,8 +2760,8 @@ static void PostPreload(CP_Services Svcs, Rdma_RS_Stream Stream, long Timestep) RankLog->Buffer = (void *)RawPLBuffer; sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, RankLog->ReqLog, (sizeof(struct _RdmaBuffer) * RankLog->Entries) + sizeof(uint64_t), - FI_REMOTE_READ, 0, 0, 0, &RankLog->preqbmr, Fabric->ctx, Fabric->signal, - Fabric->info->domain_attr->mr_mode); + FI_REMOTE_READ, 0, &Fabric->mr_key, 0, &RankLog->preqbmr, Fabric->ctx, + Fabric->signal, Fabric->info->domain_attr->mr_mode); for (j = 0; j < RankLog->Entries; j++) { ReqLog = &RankLog->ReqLog[j]; @@ -2894,8 +2904,8 @@ static void PullSelection(CP_Services Svcs, Rdma_WSR_Stream Stream) if (Fabric->local_mr_req) { sst_fi_mr_reg(Svcs, WS_Stream->CP_Stream, Fabric->domain, ReqBuffer.Handle.Block, - ReqBuffer.BufferLen, FI_READ, 0, 0, 0, &rrmr, Fabric->ctx, Fabric->signal, - Fabric->info->domain_attr->mr_mode); + ReqBuffer.BufferLen, FI_READ, 0, &Fabric->mr_key, 0, &rrmr, Fabric->ctx, + Fabric->signal, Fabric->info->domain_attr->mr_mode); rrdesc = fi_mr_desc(rrmr); }