diff --git a/prov/efa/src/efa_shm.c b/prov/efa/src/efa_shm.c index c6fd79a429f..3b0f5dab9dd 100644 --- a/prov/efa/src/efa_shm.c +++ b/prov/efa/src/efa_shm.c @@ -118,6 +118,7 @@ void efa_shm_info_create(const struct fi_info *app_info, struct fi_info **shm_in shm_hints->domain_attr->mr_mode |= FI_MR_HMEM; } + shm_hints->domain_attr->threading = app_info->domain_attr->threading; shm_hints->domain_attr->av_type = FI_AV_TABLE; shm_hints->domain_attr->caps |= FI_LOCAL_COMM; shm_hints->tx_attr->msg_order = FI_ORDER_SAS; diff --git a/prov/efa/test/efa_unit_test_info.c b/prov/efa/test/efa_unit_test_info.c index 92defe115a8..bec039a6462 100644 --- a/prov/efa/test/efa_unit_test_info.c +++ b/prov/efa/test/efa_unit_test_info.c @@ -130,6 +130,11 @@ static void test_info_check_shm_info_from_hints(struct fi_info *hints) assert_true(efa_domain->shm_info->tx_attr->op_flags == info->tx_attr->op_flags); assert_true(efa_domain->shm_info->rx_attr->op_flags == info->rx_attr->op_flags); + + if (hints->domain_attr->threading) { + assert_true(hints->domain_attr->threading == info->domain_attr->threading); + assert_true(hints->domain_attr->threading == efa_domain->shm_info->domain_attr->threading); + } } fi_close(&domain->fid); @@ -143,7 +148,7 @@ static void test_info_check_shm_info_from_hints(struct fi_info *hints) * @brief Check shm info created by efa_domain() has correct caps. * */ -void test_info_check_shm_info() +void test_info_check_shm_info_hmem() { struct fi_info *hints; @@ -154,6 +159,13 @@ void test_info_check_shm_info() hints->caps &= ~FI_HMEM; test_info_check_shm_info_from_hints(hints); +} + +void test_info_check_shm_info_op_flags() +{ + struct fi_info *hints; + + hints = efa_unit_test_alloc_hints(FI_EP_RDM); hints->tx_attr->op_flags |= FI_COMPLETION; hints->rx_attr->op_flags |= FI_COMPLETION; @@ -162,8 +174,16 @@ void test_info_check_shm_info() hints->tx_attr->op_flags |= FI_DELIVERY_COMPLETE; hints->rx_attr->op_flags |= FI_MULTI_RECV; test_info_check_shm_info_from_hints(hints); +} +void test_info_check_shm_info_threading() +{ + struct fi_info *hints; + hints = efa_unit_test_alloc_hints(FI_EP_RDM); + + hints->domain_attr->threading = FI_THREAD_DOMAIN; + test_info_check_shm_info_from_hints(hints); } /** diff --git a/prov/efa/test/efa_unit_tests.c b/prov/efa/test/efa_unit_tests.c index 6b7a54ad8a2..1091d4540ff 100644 --- a/prov/efa/test/efa_unit_tests.c +++ b/prov/efa/test/efa_unit_tests.c @@ -103,7 +103,9 @@ int main(void) cmocka_unit_test_setup_teardown(test_rdm_fallback_to_ibv_create_cq_ex_cq_read_ignore_forgotton_peer, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_open_ep_with_wrong_info, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_open_ep_with_api_1_1_info, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), - cmocka_unit_test_setup_teardown(test_info_check_shm_info, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), + cmocka_unit_test_setup_teardown(test_info_check_shm_info_hmem, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), + cmocka_unit_test_setup_teardown(test_info_check_shm_info_op_flags, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), + cmocka_unit_test_setup_teardown(test_info_check_shm_info_threading, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_check_hmem_cuda_support_on_api_lt_1_18, NULL, NULL), cmocka_unit_test_setup_teardown(test_info_check_hmem_cuda_support_on_api_ge_1_18, NULL, NULL), cmocka_unit_test_setup_teardown(test_info_check_no_hmem_support_when_not_requested, NULL, NULL), diff --git a/prov/efa/test/efa_unit_tests.h b/prov/efa/test/efa_unit_tests.h index b757f328cb7..115fea3fb2f 100644 --- a/prov/efa/test/efa_unit_tests.h +++ b/prov/efa/test/efa_unit_tests.h @@ -111,7 +111,9 @@ void test_rdm_fallback_to_ibv_create_cq_ex_cq_read_ignore_forgotton_peer(); void test_ibv_cq_ex_read_ignore_removed_peer(); void test_info_open_ep_with_wrong_info(); void test_info_open_ep_with_api_1_1_info(); -void test_info_check_shm_info(); +void test_info_check_shm_info_hmem(); +void test_info_check_shm_info_op_flags(); +void test_info_check_shm_info_threading(); void test_info_check_hmem_cuda_support_on_api_lt_1_18(); void test_info_check_hmem_cuda_support_on_api_ge_1_18(); void test_info_check_no_hmem_support_when_not_requested();