diff --git a/include/ofi_mr.h b/include/ofi_mr.h index a353dc4935b..ca5b2702ada 100644 --- a/include/ofi_mr.h +++ b/include/ofi_mr.h @@ -311,6 +311,17 @@ void ofi_mr_get_iov_from_dmabuf(struct iovec *iov, } } +static inline +void ofi_mr_info_get_iov_from_mr_attr(struct ofi_mr_info *info, + const struct fi_mr_attr *attr, + uint64_t flags) +{ + if (flags & FI_MR_DMABUF) + ofi_mr_get_iov_from_dmabuf(&info->iov, attr->dmabuf, 1); + else + info->iov = *attr->mr_iov; +} + void ofi_mr_update_attr(uint32_t user_version, uint64_t caps, const struct fi_mr_attr *user_attr, struct fi_mr_attr *cur_abi_attr, @@ -428,9 +439,10 @@ int ofi_mr_cache_search(struct ofi_mr_cache *cache, * with the cache. */ struct ofi_mr_entry *ofi_mr_cache_find(struct ofi_mr_cache *cache, - const struct fi_mr_attr *attr); + const struct fi_mr_attr *attr, + uint64_t flags); int ofi_mr_cache_reg(struct ofi_mr_cache *cache, const struct fi_mr_attr *attr, - struct ofi_mr_entry **entry); + struct ofi_mr_entry **entry, uint64_t flags); void ofi_mr_cache_delete(struct ofi_mr_cache *cache, struct ofi_mr_entry *entry); diff --git a/prov/efa/src/efa_mr.c b/prov/efa/src/efa_mr.c index 4a6be4d054a..6068af624a7 100644 --- a/prov/efa/src/efa_mr.c +++ b/prov/efa/src/efa_mr.c @@ -358,7 +358,7 @@ static int efa_mr_cache_regattr(struct fid *fid, const struct fi_mr_attr *attr, util_domain.domain_fid.fid); assert(attr->iov_count == 1); - info.iov = *attr->mr_iov; + ofi_mr_info_get_iov_from_mr_attr(&info, attr, flags); info.iface = attr->iface; info.device = attr->device.reserved; ret = ofi_mr_cache_search(domain->cache, &info, &entry); diff --git a/prov/util/src/util_mr_cache.c b/prov/util/src/util_mr_cache.c index dd00cd537bf..f2148e56267 100644 --- a/prov/util/src/util_mr_cache.c +++ b/prov/util/src/util_mr_cache.c @@ -389,7 +389,8 @@ int ofi_mr_cache_search(struct ofi_mr_cache *cache, const struct ofi_mr_info *in } struct ofi_mr_entry *ofi_mr_cache_find(struct ofi_mr_cache *cache, - const struct fi_mr_attr *attr) + const struct fi_mr_attr *attr, + uint64_t flags) { struct ofi_mr_info info; struct ofi_mr_entry *entry; @@ -410,7 +411,7 @@ struct ofi_mr_entry *ofi_mr_cache_find(struct ofi_mr_cache *cache, cache->search_cnt++; info.peer_id = 0; - info.iov = *attr->mr_iov; + ofi_mr_info_get_iov_from_mr_attr(&info, attr, flags); entry = ofi_mr_rbt_find(&cache->tree, &info); if (!entry) { goto unlock; @@ -436,7 +437,7 @@ struct ofi_mr_entry *ofi_mr_cache_find(struct ofi_mr_cache *cache, } int ofi_mr_cache_reg(struct ofi_mr_cache *cache, const struct fi_mr_attr *attr, - struct ofi_mr_entry **entry) + struct ofi_mr_entry **entry, uint64_t flags) { int ret; @@ -453,7 +454,7 @@ int ofi_mr_cache_reg(struct ofi_mr_cache *cache, const struct fi_mr_attr *attr, cache->uncached_size += attr->mr_iov->iov_len; pthread_mutex_unlock(&mm_lock); - (*entry)->info.iov = *attr->mr_iov; + ofi_mr_info_get_iov_from_mr_attr(&(*entry)->info, attr, flags); (*entry)->use_cnt = 1; (*entry)->node = NULL; diff --git a/prov/verbs/src/verbs_mr.c b/prov/verbs/src/verbs_mr.c index d6c9771d55f..b941ab3e502 100644 --- a/prov/verbs/src/verbs_mr.c +++ b/prov/verbs/src/verbs_mr.c @@ -333,12 +333,12 @@ vrb_mr_cache_reg(struct vrb_domain *domain, const void *buf, size_t len, attr.iface = iface; attr.device.reserved = device; assert(attr.iov_count == 1); - info.iov = iov; + ofi_mr_info_get_iov_from_mr_attr(&info, &attr, flags); info.iface = iface; info.device = device; ret = (flags & OFI_MR_NOCACHE) ? - ofi_mr_cache_reg(&domain->cache, &attr, &entry) : + ofi_mr_cache_reg(&domain->cache, &attr, &entry, flags) : ofi_mr_cache_search(&domain->cache, &info, &entry); if (OFI_UNLIKELY(ret)) return ret;