diff --git a/ompi/mca/coll/hcoll/coll_hcoll.h b/ompi/mca/coll/hcoll/coll_hcoll.h index 980f0ebb2d..e6c1698286 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll.h +++ b/ompi/mca/coll/hcoll/coll_hcoll.h @@ -49,6 +49,11 @@ typedef struct mca_coll_hcoll_ops_t { int (*hcoll_barrier)(void *); } mca_coll_hcoll_ops_t; +typedef struct { + opal_free_list_item_t super; + dte_data_representation_t type; +} mca_coll_hcoll_dtype_t; +OBJ_CLASS_DECLARATION(mca_coll_hcoll_dtype_t); struct mca_coll_hcoll_component_t { /** Base coll component */ @@ -89,6 +94,8 @@ struct mca_coll_hcoll_component_t { /* FCA global stuff */ mca_coll_hcoll_ops_t hcoll_ops; opal_free_list_t requests; + opal_free_list_t dtypes; + int derived_types_support_enabled; }; typedef struct mca_coll_hcoll_component_t mca_coll_hcoll_component_t; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_component.c b/ompi/mca/coll/hcoll/coll_hcoll_component.c index 9b457de8e3..02d390ba3b 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_component.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_component.c @@ -17,6 +17,7 @@ #include "coll_hcoll.h" #include "opal/mca/installdirs/installdirs.h" +#include "coll_hcoll_dtypes.h" /* * Public string showing the coll ompi_hcol component version number @@ -207,7 +208,15 @@ static int hcoll_register(void) 1, &mca_coll_hcoll_component.hcoll_datatype_fallback, 0)); - +#if HCOLL_API >= HCOLL_VERSION(3,6) + CHECK(reg_int("dts",NULL, + "[1|0|] Enable/Disable derived types support", + 1, + &mca_coll_hcoll_component.derived_types_support_enabled, + 0)); +#else + mca_coll_hcoll_component.derived_types_support_enabled = 0; +#endif mca_coll_hcoll_component.compiletime_version = HCOLL_VERNO_STRING; mca_base_component_var_register(&mca_coll_hcoll_component.super.collm_version, MCA_COMPILETIME_VER, @@ -278,7 +287,7 @@ static int hcoll_close(void) HCOL_VERBOSE(5,"HCOLL FINALIZE"); rc = hcoll_finalize(); - + OBJ_DESTRUCT(&cm->dtypes); opal_progress_unregister(mca_coll_hcoll_progress); if (HCOLL_SUCCESS != rc){ HCOL_VERBOSE(1,"Hcol library finalize failed"); diff --git a/ompi/mca/coll/hcoll/coll_hcoll_dtypes.h b/ompi/mca/coll/hcoll/coll_hcoll_dtypes.h index e40d9a5e63..93a5a08b4d 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_dtypes.h +++ b/ompi/mca/coll/hcoll/coll_hcoll_dtypes.h @@ -6,8 +6,10 @@ It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/ #include "ompi/datatype/ompi_datatype.h" +#include "ompi/datatype/ompi_datatype_internal.h" #include "ompi/mca/op/op.h" #include "hcoll/api/hcoll_dte.h" +extern int hcoll_type_attr_keyval; /*to keep this at hand: Ids of the basic opal_datatypes: #define OPAL_DATATYPE_INT1 4 @@ -31,9 +33,7 @@ total 15 types */ - - -static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX_PREDEFINED] = { +static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OMPI_DATATYPE_MAX_PREDEFINED] = { &DTE_ZERO, /*OPAL_DATATYPE_LOOP 0 */ &DTE_ZERO, /*OPAL_DATATYPE_END_LOOP 1 */ &DTE_ZERO, /*OPAL_DATATYPE_LB 2 */ @@ -53,34 +53,114 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX &DTE_FLOAT64, /*OPAL_DATATYPE_FLOAT8 16 */ &DTE_FLOAT96, /*OPAL_DATATYPE_FLOAT12 17 */ &DTE_FLOAT128, /*OPAL_DATATYPE_FLOAT16 18 */ -#if defined(DTE_FLOAT32_COMPLEX) && defined(DTE_FLOAT64_COMPLEX) +#if defined(DTE_FLOAT32_COMPLEX) &DTE_FLOAT32_COMPLEX, /*OPAL_DATATYPE_COMPLEX8 19 */ - &DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX16 20 */ #else - &DTE_ZERO, /*OPAL_DATATYPE_COMPLEX8 19 */ - &DTE_ZERO, /*OPAL_DATATYPE_COMPLEX16 20 */ + &DTE_ZERO, +#endif +#if defined(DTE_FLOAT64_COMPLEX) + &DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX32 20 */ +#else + &DTE_ZERO, +#endif +#if defined(DTE_FLOAT128_COMPLEX) + &DTE_FLOAT128_COMPLEX, /*OPAL_DATATYPE_COMPLEX64 21 */ +#else + &DTE_ZERO, #endif - &DTE_ZERO, /*OPAL_DATATYPE_COMPLEX32 21 */ &DTE_ZERO, /*OPAL_DATATYPE_BOOL 22 */ &DTE_ZERO, /*OPAL_DATATYPE_WCHAR 23 */ &DTE_ZERO /*OPAL_DATATYPE_UNAVAILABLE 24 */ }; -static dte_data_representation_t ompi_dtype_2_dte_dtype(ompi_datatype_t *dtype){ +enum { + TRY_FIND_DERIVED, + NO_DERIVED +}; + + +#if HCOLL_API >= HCOLL_VERSION(3,6) +static inline +int hcoll_map_derived_type(ompi_datatype_t *dtype, dte_data_representation_t *new_dte) +{ + int rc; + if (NULL == dtype->args) { + /* predefined type, shouldn't call this */ + return OMPI_SUCCESS; + } + rc = hcoll_create_mpi_type((void*)dtype, new_dte); + return rc == HCOLL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR; +} + +static dte_data_representation_t find_derived_mapping(ompi_datatype_t *dtype){ + dte_data_representation_t dte = DTE_ZERO; + mca_coll_hcoll_dtype_t *hcoll_dtype; + if (mca_coll_hcoll_component.derived_types_support_enabled) { + int map_found = 0; + ompi_attr_get_c(dtype->d_keyhash, hcoll_type_attr_keyval, + (void**)&hcoll_dtype, &map_found); + if (!map_found) + hcoll_map_derived_type(dtype, &dte); + else + dte = hcoll_dtype->type; + } + + return dte; +} + + + +static inline dte_data_representation_t +ompi_predefined_derived_2_hcoll(int ompi_id) { + switch(ompi_id) { + case OMPI_DATATYPE_MPI_FLOAT_INT: + return DTE_FLOAT_INT; + case OMPI_DATATYPE_MPI_DOUBLE_INT: + return DTE_DOUBLE_INT; + case OMPI_DATATYPE_MPI_LONG_INT: + return DTE_LONG_INT; + case OMPI_DATATYPE_MPI_SHORT_INT: + return DTE_SHORT_INT; + case OMPI_DATATYPE_MPI_LONG_DOUBLE_INT: + return DTE_LONG_DOUBLE_INT; + case OMPI_DATATYPE_MPI_2INT: + return DTE_2INT; + default: + break; + } + return DTE_ZERO; +} +#endif + +static dte_data_representation_t +ompi_dtype_2_hcoll_dtype( ompi_datatype_t *dtype, + const int mode) +{ int ompi_type_id = dtype->id; int opal_type_id = dtype->super.id; - dte_data_representation_t dte_data_rep; - if (!(dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS)) { - ompi_type_id = -1; + dte_data_representation_t dte_data_rep = DTE_ZERO; + + if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED && + dtype->super.flags & OMPI_DATATYPE_FLAG_PREDEFINED) { + if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED) { + dte_data_rep = *ompi_datatype_2_dte_data_rep[opal_type_id]; + } +#if HCOLL_API >= HCOLL_VERSION(3,6) + else if (TRY_FIND_DERIVED == mode){ + dte_data_rep = ompi_predefined_derived_2_hcoll(ompi_type_id); + } + } else { + if (TRY_FIND_DERIVED == mode) + dte_data_rep = find_derived_mapping(dtype); +#endif } - if (OPAL_UNLIKELY( ompi_type_id < 0 || - ompi_type_id >= OPAL_DATATYPE_MAX_PREDEFINED)){ + if (HCOL_DTE_IS_ZERO(dte_data_rep) && TRY_FIND_DERIVED == mode && + !mca_coll_hcoll_component.hcoll_datatype_fallback) { dte_data_rep = DTE_ZERO; dte_data_rep.rep.in_line_rep.data_handle.in_line.in_line = 0; dte_data_rep.rep.in_line_rep.data_handle.pointer_to_handle = (uint64_t ) &dtype->super; - return dte_data_rep; } - return *ompi_datatype_2_dte_data_rep[opal_type_id]; + return dte_data_rep; } static hcoll_dte_op_t* ompi_op_2_hcoll_op[OMPI_OP_BASE_FORTRAN_OP_MAX + 1] = { @@ -108,4 +188,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) { return ompi_op_2_hcoll_op[op->o_f_to_c_index]; } + +#if HCOLL_API >= HCOLL_VERSION(3,6) +static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) { + int ret = OMPI_SUCCESS; + mca_coll_hcoll_dtype_t *dtype = + (mca_coll_hcoll_dtype_t*) attr_val; + + assert(dtype); + if (HCOLL_SUCCESS != (ret = hcoll_dt_destroy(dtype->type))) { + HCOL_ERROR("failed to delete type attr: hcoll_dte_destroy returned %d",ret); + return OMPI_ERROR; + } + opal_free_list_return(&mca_coll_hcoll_component.dtypes, + &dtype->super); + + return OMPI_SUCCESS; +} +#else +static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) { + /*Do nothing - it's an old version of hcoll w/o dtypes support */ + return OMPI_SUCCESS; +} +#endif #endif /* COLL_HCOLL_DTYPES_H */ diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index cc96d230c3..097205afb7 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -10,8 +10,10 @@ #include "ompi_config.h" #include "coll_hcoll.h" +#include "coll_hcoll_dtypes.h" int hcoll_comm_attr_keyval; +int hcoll_type_attr_keyval; /* * Initial query function that is invoked during MPI_INIT, allowing @@ -240,6 +242,10 @@ int mca_coll_hcoll_progress(void) } +OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t, + opal_free_list_item_t, + NULL,NULL); + /* * Invoked when there's a new communicator that has been created. * Look at the communicator and decide which set of functions and @@ -317,6 +323,24 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) HCOL_ERROR("Hcol comm keyval create failed"); return NULL; } + + if (mca_coll_hcoll_component.derived_types_support_enabled) { + copy_fn.attr_datatype_copy_fn = (MPI_Type_internal_copy_attr_function *) MPI_TYPE_NULL_COPY_FN; + del_fn.attr_datatype_delete_fn = hcoll_type_attr_del_fn; + err = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, &hcoll_type_attr_keyval, NULL ,0, NULL); + if (OMPI_SUCCESS != err) { + cm->hcoll_enable = 0; + hcoll_finalize(); + opal_progress_unregister(mca_coll_hcoll_progress); + HCOL_ERROR("Hcol type keyval create failed"); + return NULL; + } + } + OBJ_CONSTRUCT(&cm->dtypes, opal_free_list_t); + opal_free_list_init(&cm->dtypes, sizeof(mca_coll_hcoll_dtype_t), + 8, OBJ_CLASS(mca_coll_hcoll_dtype_t), 0, 0, + 32, -1, 32, NULL, 0, NULL, NULL, NULL); + } hcoll_module = OBJ_NEW(mca_coll_hcoll_module_t); diff --git a/ompi/mca/coll/hcoll/coll_hcoll_ops.c b/ompi/mca/coll/hcoll/coll_hcoll_ops.c index f246196442..cc829f86df 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_ops.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_ops.c @@ -44,9 +44,9 @@ int mca_coll_hcoll_bcast(void *buff, int count, int rc; HCOL_VERBOSE(20,"RUNNING HCOL BCAST"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - dtype = ompi_dtype_2_dte_dtype(datatype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(dtype) || HCOL_DTE_IS_COMPLEX(dtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + dtype = ompi_dtype_2_hcoll_dtype(datatype, TRY_FIND_DERIVED); + + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(dtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -76,11 +76,12 @@ int mca_coll_hcoll_allgather(void *sbuf, int scount, int rc; HCOL_VERBOSE(20,"RUNNING HCOL ALLGATHER"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, TRY_FIND_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, TRY_FIND_DERIVED); + if (sbuf == MPI_IN_PLACE) { + stype = rtype; + } + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -117,11 +118,9 @@ int mca_coll_hcoll_allgatherv(const void *sbuf, int scount, int rc; HCOL_VERBOSE(20,"RUNNING HCOL ALLGATHERV"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -161,11 +160,9 @@ int mca_coll_hcoll_gather(const void *sbuf, int scount, int rc; HCOL_VERBOSE(20,"RUNNING HCOL GATHER"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -201,9 +198,8 @@ int mca_coll_hcoll_allreduce(void *sbuf, void *rbuf, int count, int rc; HCOL_VERBOSE(20,"RUNNING HCOL ALLREDUCE"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - Dtype = ompi_dtype_2_dte_dtype(dtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){ /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -250,9 +246,8 @@ int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count, int rc; HCOL_VERBOSE(20,"RUNNING HCOL REDUCE"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - Dtype = ompi_dtype_2_dte_dtype(dtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){ /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -302,11 +297,9 @@ int mca_coll_hcoll_alltoall(const void *sbuf, int scount, int rc; HCOL_VERBOSE(20,"RUNNING HCOL ALLTOALL"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -342,11 +335,9 @@ int mca_coll_hcoll_alltoallv(void *sbuf, int *scounts, int *sdisps, int rc; HCOL_VERBOSE(20,"RUNNING HCOL ALLTOALLV"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback alltoallv;", sdtype->super.name, rdtype->super.name); @@ -380,11 +371,9 @@ int mca_coll_hcoll_gatherv(void* sbuf, int scount, int rc; HCOL_VERBOSE(20,"RUNNING HCOL GATHERV"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -436,9 +425,8 @@ int mca_coll_hcoll_ibcast(void *buff, int count, HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING BCAST"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; rt_handle = (void**) request; - dtype = ompi_dtype_2_dte_dtype(datatype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(dtype) || HCOL_DTE_IS_COMPLEX(dtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + dtype = ompi_dtype_2_hcoll_dtype(datatype, TRY_FIND_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(dtype))){ /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -471,11 +459,9 @@ int mca_coll_hcoll_iallgather(void *sbuf, int scount, HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLGATHER"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; rt_handle = (void**) request; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, TRY_FIND_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, TRY_FIND_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -516,12 +502,10 @@ int mca_coll_hcoll_iallgatherv(const void *sbuf, int scount, int rc; HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLGATHERV"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); void **rt_handle = (void **) request; - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -566,9 +550,8 @@ int mca_coll_hcoll_iallreduce(const void *sbuf, void *rbuf, int count, HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLREDUCE"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; rt_handle = (void**) request; - Dtype = ompi_dtype_2_dte_dtype(dtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){ /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -616,10 +599,9 @@ int mca_coll_hcoll_ireduce(const void *sbuf, void *rbuf, int count, int rc; HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING REDUCE"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; - Dtype = ompi_dtype_2_dte_dtype(dtype); + Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED); void **rt_handle = (void**) request; - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){ /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ @@ -676,11 +658,9 @@ int mca_coll_hcoll_igatherv(const void* sbuf, int scount, HCOL_VERBOSE(20,"RUNNING HCOL IGATHERV"); mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; rt_handle = (void**) request; - stype = ompi_dtype_2_dte_dtype(sdtype); - rtype = ompi_dtype_2_dte_dtype(rdtype); - if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) - || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) - && mca_coll_hcoll_component.hcoll_datatype_fallback){ + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { /*If we are here then datatype is not simple predefined datatype */ /*In future we need to add more complex mapping to the dte_data_representation_t */ /* Now use fallback */ diff --git a/ompi/mca/coll/hcoll/coll_hcoll_rte.c b/ompi/mca/coll/hcoll/coll_hcoll_rte.c index 1ac0177c1e..f86846e527 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_rte.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_rte.c @@ -44,6 +44,7 @@ #include "hcoll/api/hcoll_dte.h" #include "hcoll/api/hcoll_api.h" #include "hcoll/api/hcoll_constants.h" +#include "coll_hcoll_dtypes.h" /* * Local functions */ @@ -99,6 +100,22 @@ static int group_id(rte_grp_handle_t group); static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec); /* Module Constructors */ +#if HCOLL_API >= HCOLL_VERSION(3,6) +static int get_mpi_type_envelope(void *mpi_type, int *num_integers, + int *num_addresses, int *num_datatypes, + hcoll_mpi_type_combiner_t *combiner); +static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses, + int max_datatypes, int *array_of_integers, + void *array_of_addresses, void *array_of_datatypes); +static int get_hcoll_type(void *mpi_type, dte_data_representation_t *hcoll_type); +static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type); +static int get_mpi_constants(size_t *mpi_datatype_size, + int *mpi_order_c, int *mpi_order_fortran, + int *mpi_distribute_block, + int *mpi_distribute_cyclic, + int *mpi_distribute_none, + int *mpi_distribute_dflt_darg); +#endif static void init_module_fns(void){ hcoll_rte_functions.send_fn = send_nb; @@ -118,6 +135,13 @@ static void init_module_fns(void){ hcoll_rte_functions.rte_coll_handle_complete_fn = coll_handle_complete; hcoll_rte_functions.rte_group_id_fn = group_id; hcoll_rte_functions.rte_world_rank_fn = world_rank; +#if HCOLL_API >= HCOLL_VERSION(3,6) + hcoll_rte_functions.rte_get_mpi_type_envelope_fn = get_mpi_type_envelope; + hcoll_rte_functions.rte_get_mpi_type_contents_fn = get_mpi_type_contents; + hcoll_rte_functions.rte_get_hcoll_type_fn = get_hcoll_type; + hcoll_rte_functions.rte_set_hcoll_type_fn = set_hcoll_type; + hcoll_rte_functions.rte_get_mpi_constants_fn = get_mpi_constants; +#endif } @@ -146,22 +170,6 @@ void hcoll_rte_fns_setup(void) ); } -/* This one converts dte_general_representation data into regular iovec array which is - used in rml - */ - -static inline int count_total_dte_repeat_entries(struct dte_data_representation_t *data){ - unsigned int i; - - struct dte_generalized_iovec_t * dte_iovec = - data->rep.general_rep->data_representation.data; - int total_entries_number = 0; - for (i=0; i< dte_iovec->repeat_count; i++){ - total_entries_number += dte_iovec->repeat[i].n_elements; - } - return total_entries_number; -} - static int recv_nb(struct dte_data_representation_t data, uint32_t count , void *buffer, @@ -177,56 +185,27 @@ static int recv_nb(struct dte_data_representation_t data, "ec_h.handle = %p, ec_h.rank = %d\n",ec_h.handle,ec_h.rank); return 1; } - if (HCOL_DTE_IS_INLINE(data)){ - /*do inline nb recv*/ - size_t size; - ompi_request_t *ompi_req; - opal_free_list_item_t *item; - - if (!buffer && !HCOL_DTE_IS_ZERO(data)) { - fprintf(stderr, "***Error in hcolrte_rml_recv_nb: buffer pointer is NULL" - " for non DTE_ZERO INLINE data representation\n"); - return 1; - } - size = (size_t)data.rep.in_line_rep.data_handle.in_line.packed_size*count/8; - - HCOL_VERBOSE(30,"PML_IRECV: dest = %d: buf = %p: size = %u: comm = %p", - ec_h.rank, buffer, (unsigned int)size, (void *)comm); - if (MCA_PML_CALL(irecv(buffer,size,&(ompi_mpi_unsigned_char.dt),ec_h.rank, - tag,comm,&ompi_req))) - { - return 1; - } - req->data = (void *)ompi_req; - req->status = HCOLRTE_REQUEST_ACTIVE; - }else{ - /*do iovec nb recv*/ - int total_entries_number; - int i; - unsigned int j; - void *buf; - uint64_t len; - int repeat_count; - struct dte_struct_t * repeat; - if (NULL != buffer) { - /* We have a full data description & buffer pointer simultaneously. - It is ambiguous. Throw a warning since the user might have made a - mistake with data reps*/ - fprintf(stderr,"Warning: buffer_pointer != NULL for NON-inline data representation: buffer_pointer is ignored.\n"); - } - total_entries_number = count_total_dte_repeat_entries(&data); - repeat = data.rep.general_rep->data_representation.data->repeat; - repeat_count = data.rep.general_rep->data_representation.data->repeat_count; - for (i=0; i< repeat_count; i++){ - for (j=0; jdata = (void *)ompi_req; + req->status = HCOLRTE_REQUEST_ACTIVE; return HCOLL_SUCCESS; } @@ -247,51 +226,25 @@ static int send_nb( dte_data_representation_t data, "ec_h.handle = %p, ec_h.rank = %d\n",ec_h.handle,ec_h.rank); return 1; } - if (HCOL_DTE_IS_INLINE(data)){ - /*do inline nb recv*/ - size_t size; - ompi_request_t *ompi_req; - if (!buffer && !HCOL_DTE_IS_ZERO(data)) { - fprintf(stderr, "***Error in hcolrte_rml_send_nb: buffer pointer is NULL" - " for non DTE_ZERO INLINE data representation\n"); - return 1; - } - size = (size_t)data.rep.in_line_rep.data_handle.in_line.packed_size*count/8; - HCOL_VERBOSE(30,"PML_ISEND: dest = %d: buf = %p: size = %u: comm = %p", - ec_h.rank, buffer, (unsigned int)size, (void *)comm); - if (MCA_PML_CALL(isend(buffer,size,&(ompi_mpi_unsigned_char.dt),ec_h.rank, - tag,MCA_PML_BASE_SEND_STANDARD,comm,&ompi_req))) - { - return 1; - } - req->data = (void *)ompi_req; - req->status = HCOLRTE_REQUEST_ACTIVE; - }else{ - int total_entries_number; - int i; - unsigned int j; - void *buf; - uint64_t len; - int repeat_count; - struct dte_struct_t * repeat; - if (NULL != buffer) { - /* We have a full data description & buffer pointer simultaneously. - It is ambiguous. Throw a warning since the user might have made a - mistake with data reps*/ - fprintf(stderr,"Warning: buffer_pointer != NULL for NON-inline data representation: buffer_pointer is ignored.\n"); - } - total_entries_number = count_total_dte_repeat_entries(&data); - repeat = data.rep.general_rep->data_representation.data->repeat; - repeat_count = data.rep.general_rep->data_representation.data->repeat_count; - for (i=0; i< repeat_count; i++){ - for (j=0; jdata = (void *)ompi_req; + req->status = HCOLRTE_REQUEST_ACTIVE; return HCOLL_SUCCESS; } @@ -305,7 +258,7 @@ static int test( rte_request_handle_t * request , } /*ompi_request_test(&ompi_req,completed,MPI_STATUS_IGNORE); */ - *completed = ompi_req->req_complete; + *completed = REQUEST_COMPLETE(ompi_req); if (*completed){ ompi_request_free(&ompi_req); request->status = HCOLRTE_REQUEST_DONE; @@ -413,7 +366,7 @@ static void* get_coll_handle(void) static int coll_handle_test(void* handle) { ompi_request_t *ompi_req = (ompi_request_t *)handle; - return ompi_req->req_complete; + return REQUEST_COMPLETE(ompi_req);; } static void coll_handle_free(void *handle){ @@ -433,3 +386,108 @@ static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec){ ompi_proc_t *proc = (ompi_proc_t *)ec.handle; return ((ompi_process_name_t*)&proc->super.proc_name)->vpid; } + +#if HCOLL_API >= HCOLL_VERSION(3,6) +hcoll_mpi_type_combiner_t ompi_combiner_2_hcoll_combiner(int ompi_combiner) { + switch (ompi_combiner) + { + case MPI_COMBINER_CONTIGUOUS: + return HCOLL_MPI_COMBINER_CONTIGUOUS; + case MPI_COMBINER_VECTOR: + return HCOLL_MPI_COMBINER_VECTOR; + case MPI_COMBINER_HVECTOR: + return HCOLL_MPI_COMBINER_HVECTOR; + case MPI_COMBINER_INDEXED: + return HCOLL_MPI_COMBINER_INDEXED; + case MPI_COMBINER_HINDEXED_INTEGER: + case MPI_COMBINER_HINDEXED: + return HCOLL_MPI_COMBINER_HINDEXED; + case MPI_COMBINER_DUP: + return HCOLL_MPI_COMBINER_DUP; + case MPI_COMBINER_INDEXED_BLOCK: + return HCOLL_MPI_COMBINER_INDEXED_BLOCK; + case MPI_COMBINER_HINDEXED_BLOCK: + return HCOLL_MPI_COMBINER_HINDEXED_BLOCK; + case MPI_COMBINER_SUBARRAY: + return HCOLL_MPI_COMBINER_SUBARRAY; + case MPI_COMBINER_DARRAY: + return HCOLL_MPI_COMBINER_DARRAY; + case MPI_COMBINER_F90_REAL: + return HCOLL_MPI_COMBINER_F90_REAL; + case MPI_COMBINER_F90_COMPLEX: + return HCOLL_MPI_COMBINER_F90_COMPLEX; + case MPI_COMBINER_F90_INTEGER: + return HCOLL_MPI_COMBINER_F90_INTEGER; + case MPI_COMBINER_RESIZED: + return HCOLL_MPI_COMBINER_RESIZED; + case MPI_COMBINER_STRUCT: + case MPI_COMBINER_STRUCT_INTEGER: + return HCOLL_MPI_COMBINER_STRUCT; + default: + break; + } + return HCOLL_MPI_COMBINER_LAST; +} + + +static int get_mpi_type_envelope(void *mpi_type, int *num_integers, + int *num_addresses, int *num_datatypes, + hcoll_mpi_type_combiner_t *combiner) { + int ompi_combiner, rc; + rc = ompi_datatype_get_args( (ompi_datatype_t*)mpi_type, 0, num_integers, NULL, + num_addresses, NULL, + num_datatypes, NULL, &ompi_combiner); + *combiner = ompi_combiner_2_hcoll_combiner(ompi_combiner); + return rc == OMPI_SUCCESS ? HCOLL_SUCCESS : HCOLL_ERROR; +} + +static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses, + int max_datatypes, int *array_of_integers, + void *array_of_addresses, void *array_of_datatypes) { + int rc; + rc = ompi_datatype_get_args( (ompi_datatype_t*)mpi_type, 1, &max_integers, array_of_integers, + &max_addresses, array_of_addresses, + &max_datatypes, array_of_datatypes, NULL ); + return rc == OMPI_SUCCESS ? HCOLL_SUCCESS : HCOLL_ERROR; +} + +static int get_hcoll_type(void *mpi_type, dte_data_representation_t *hcoll_type) { + *hcoll_type = ompi_dtype_2_hcoll_dtype((ompi_datatype_t*)mpi_type, TRY_FIND_DERIVED); + return HCOL_DTE_IS_ZERO((*hcoll_type)) ? HCOLL_ERR_NOT_FOUND : HCOLL_SUCCESS; +} + +static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type) { + int rc; + mca_coll_hcoll_dtype_t *hcoll_dtype = (mca_coll_hcoll_dtype_t*) + opal_free_list_get(&mca_coll_hcoll_component.dtypes); + ompi_datatype_t *dtype = (ompi_datatype_t*)mpi_type; + hcoll_dtype->type = hcoll_type; + rc = ompi_attr_set_c(TYPE_ATTR, (void*)dtype, &(dtype->d_keyhash), hcoll_type_attr_keyval, (void *)hcoll_dtype, false); + if (OMPI_SUCCESS != rc) { + HCOL_VERBOSE(1,"hcoll ompi_attr_set_c failed for derived dtype"); + goto Cleanup; + } + return HCOLL_SUCCESS; +Cleanup: + opal_free_list_return(&mca_coll_hcoll_component.dtypes, + &hcoll_dtype->super); + return rc; +} + +static int get_mpi_constants(size_t *mpi_datatype_size, + int *mpi_order_c, int *mpi_order_fortran, + int *mpi_distribute_block, + int *mpi_distribute_cyclic, + int *mpi_distribute_none, + int *mpi_distribute_dflt_darg) { + *mpi_datatype_size = sizeof(MPI_Datatype); + *mpi_order_c = MPI_ORDER_C; + *mpi_order_fortran = MPI_ORDER_FORTRAN; + *mpi_distribute_block = MPI_DISTRIBUTE_BLOCK; + *mpi_distribute_cyclic = MPI_DISTRIBUTE_CYCLIC; + *mpi_distribute_none = MPI_DISTRIBUTE_NONE; + *mpi_distribute_dflt_darg = MPI_DISTRIBUTE_DFLT_DARG; + return HCOLL_SUCCESS; +} + +#endif