Skip to content

Commit

Permalink
Support for missing FI_MR_PROV_KEY
Browse files Browse the repository at this point in the history
  • Loading branch information
Poeschel committed Jul 30, 2024
1 parent 4b63824 commit 5e8ae3c
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions source/adios2/toolkit/sst/dp/rdma_dp.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,16 @@ int sst_fi_mr_reg(
CP_Services Svcs, void *CP_Stream,
/* regular fi_mir_reg() parameters*/
struct fid_domain *domain, const void *buf, size_t len, uint64_t acs, uint64_t offset,
uint64_t requested_key, uint64_t flags, struct fid_mr **mr, void *context,
uint64_t *requested_key, uint64_t flags, struct fid_mr **mr, void *context,
/* additional parameters for binding the mr to the endpoint*/
struct fid_ep *endpoint, int mr_mode)
{
*mr = NULL;
int res = fi_mr_reg(domain, buf, len, acs, offset, requested_key, flags, mr, context);
int res = fi_mr_reg(domain, buf, len, acs, offset, *requested_key, flags, mr, context);
if (*requested_key != 0)
{
++*requested_key;
}
int is_mr_endpoint = (mr_mode & FI_MR_ENDPOINT) != 0;
if (!is_mr_endpoint)
{
Expand Down Expand Up @@ -281,6 +285,7 @@ struct fabric_state
#endif /* SST_HAVE_CRAY_DRC */
struct cq_manual_progress *cq_manual_progress;
pthread_t pthread_id;
uint64_t mr_key;
};

// Wrapper for fi_cq_sread to be called in its stead from the main thread.
Expand Down Expand Up @@ -591,15 +596,19 @@ static void init_fabric(struct fabric_state *fabric, struct _SstParams *Params,
* (1) It does not support FI_MR_VIRT_ADDR.
* (2) It requires use of FI_MR_ENDPOINT.
*
* Some other providers again (e.g. psm3) don't support FI_MR_PROV_KEY.
*
* So we propagate the bit value currently contained in the mr_mode
* for these flags.
*/
if (info->domain_attr->mr_mode != FI_MR_BASIC)
{
info->domain_attr->mr_mode = FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_LOCAL |
info->domain_attr->mr_mode = FI_MR_ALLOCATED | FI_MR_LOCAL |
(FI_MR_PROV_KEY & info->domain_attr->mr_mode) |
(FI_MR_ENDPOINT & info->domain_attr->mr_mode) |
(FI_MR_VIRT_ADDR & info->domain_attr->mr_mode);
fabric->mr_virt_addr = info->domain_attr->mr_mode & FI_MR_VIRT_ADDR ? 1 : 0;
fabric->mr_key = info->domain_attr->mr_mode & FI_MR_PROV_KEY ? 0 : 1;
}
else
{
Expand Down Expand Up @@ -1729,8 +1738,8 @@ static DP_WSR_Stream RdmaInitWriterPerReader(CP_Services Svcs, DP_WS_Stream WS_S
ReaderRollHandle = &ContactInfo->ReaderRollHandle;
ReaderRollHandle->Block = calloc(readerCohortSize, sizeof(struct _RdmaBuffer));
sst_fi_mr_reg(Svcs, WS_Stream->CP_Stream, Fabric->domain, ReaderRollHandle->Block,
readerCohortSize * sizeof(struct _RdmaBuffer), FI_REMOTE_WRITE, 0, 0, 0,
&WSR_Stream->rrmr, Fabric->ctx, Fabric->signal,
readerCohortSize * sizeof(struct _RdmaBuffer), FI_REMOTE_WRITE, 0,
&Fabric->mr_key, 0, &WSR_Stream->rrmr, Fabric->ctx, Fabric->signal,
Fabric->info->domain_attr->mr_mode);
ReaderRollHandle->Key = fi_mr_key(WSR_Stream->rrmr);

Expand Down Expand Up @@ -1869,8 +1878,8 @@ static ssize_t PostRead(CP_Services Svcs, Rdma_RS_Stream RS_Stream, int Rank, lo
if (Fabric->local_mr_req)
{
// register dest buffer
sst_fi_mr_reg(Svcs, RS_Stream->CP_Stream, Fabric->domain, Buffer, Length, FI_READ, 0, 0, 0,
&ret->LocalMR, Fabric->ctx, Fabric->signal,
sst_fi_mr_reg(Svcs, RS_Stream->CP_Stream, Fabric->domain, Buffer, Length, FI_READ, 0,
&Fabric->mr_key, 0, &ret->LocalMR, Fabric->ctx, Fabric->signal,
Fabric->info->domain_attr->mr_mode);
LocalDesc = fi_mr_desc(ret->LocalMR);
}
Expand Down Expand Up @@ -2188,8 +2197,8 @@ static void RdmaProvideTimestep(CP_Services Svcs, DP_WS_Stream Stream_v, struct
Entry->Desc = NULL;

sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, Data->block, Data->DataSize,
FI_WRITE | FI_REMOTE_READ, 0, 0, 0, &Entry->mr, Fabric->ctx, Fabric->signal,
Fabric->info->domain_attr->mr_mode);
FI_WRITE | FI_REMOTE_READ, 0, &Fabric->mr_key, 0, &Entry->mr, Fabric->ctx,
Fabric->signal, Fabric->info->domain_attr->mr_mode);
Entry->Key = fi_mr_key(Entry->mr);
if (Fabric->local_mr_req)
{
Expand Down Expand Up @@ -2706,16 +2715,17 @@ static void PostPreload(CP_Services Svcs, Rdma_RS_Stream Stream, long Timestep)
PreloadBuffer->BufferLen = 2 * StepLog->BufferSize;
PreloadBuffer->Handle.Block = malloc(PreloadBuffer->BufferLen);
sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, PreloadBuffer->Handle.Block,
PreloadBuffer->BufferLen, FI_REMOTE_WRITE, 0, 0, 0, &Stream->pbmr, Fabric->ctx,
Fabric->signal, Fabric->info->domain_attr->mr_mode);
PreloadBuffer->BufferLen, FI_REMOTE_WRITE, 0, &Fabric->mr_key, 0, &Stream->pbmr,
Fabric->ctx, Fabric->signal, Fabric->info->domain_attr->mr_mode);
PreloadKey = fi_mr_key(Stream->pbmr);

SBSize = sizeof(*SendBuffer) * StepLog->WRanks;
SendBuffer = malloc(SBSize);
if (Fabric->local_mr_req)
{
sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, SendBuffer, SBSize, FI_WRITE, 0, 0,
0, &sbmr, Fabric->ctx, Fabric->signal, Fabric->info->domain_attr->mr_mode);
sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, SendBuffer, SBSize, FI_WRITE, 0,
&Fabric->mr_key, 0, &sbmr, Fabric->ctx, Fabric->signal,
Fabric->info->domain_attr->mr_mode);
sbdesc = fi_mr_desc(sbmr);
}

Expand All @@ -2724,7 +2734,7 @@ static void PostPreload(CP_Services Svcs, Rdma_RS_Stream Stream, long Timestep)
RBLen = 2 * StepLog->Entries * DP_DATA_RECV_SIZE;
Stream->RecvDataBuffer = malloc(RBLen);
sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, Stream->RecvDataBuffer, RBLen,
FI_RECV, 0, 0, 0, &Stream->rbmr, Fabric->ctx, Fabric->signal,
FI_RECV, 0, &Fabric->mr_key, 0, &Stream->rbmr, Fabric->ctx, Fabric->signal,
Fabric->info->domain_attr->mr_mode);
Stream->rbdesc = fi_mr_desc(Stream->rbmr);
RecvBuffer = (uint8_t *)Stream->RecvDataBuffer;
Expand All @@ -2750,8 +2760,8 @@ static void PostPreload(CP_Services Svcs, Rdma_RS_Stream Stream, long Timestep)
RankLog->Buffer = (void *)RawPLBuffer;
sst_fi_mr_reg(Svcs, Stream->CP_Stream, Fabric->domain, RankLog->ReqLog,
(sizeof(struct _RdmaBuffer) * RankLog->Entries) + sizeof(uint64_t),
FI_REMOTE_READ, 0, 0, 0, &RankLog->preqbmr, Fabric->ctx, Fabric->signal,
Fabric->info->domain_attr->mr_mode);
FI_REMOTE_READ, 0, &Fabric->mr_key, 0, &RankLog->preqbmr, Fabric->ctx,
Fabric->signal, Fabric->info->domain_attr->mr_mode);
for (j = 0; j < RankLog->Entries; j++)
{
ReqLog = &RankLog->ReqLog[j];
Expand Down Expand Up @@ -2894,8 +2904,8 @@ static void PullSelection(CP_Services Svcs, Rdma_WSR_Stream Stream)
if (Fabric->local_mr_req)
{
sst_fi_mr_reg(Svcs, WS_Stream->CP_Stream, Fabric->domain, ReqBuffer.Handle.Block,
ReqBuffer.BufferLen, FI_READ, 0, 0, 0, &rrmr, Fabric->ctx, Fabric->signal,
Fabric->info->domain_attr->mr_mode);
ReqBuffer.BufferLen, FI_READ, 0, &Fabric->mr_key, 0, &rrmr, Fabric->ctx,
Fabric->signal, Fabric->info->domain_attr->mr_mode);
rrdesc = fi_mr_desc(rrmr);
}

Expand Down

0 comments on commit 5e8ae3c

Please sign in to comment.