Skip to content
This repository was archived by the owner on Sep 30, 2022. It is now read-only.

coll/hcoll mpi datatypes support #1333

Merged
merged 2 commits into from
Aug 23, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ompi/mca/coll/hcoll/coll_hcoll.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ typedef struct mca_coll_hcoll_ops_t {
int (*hcoll_barrier)(void *);
} mca_coll_hcoll_ops_t;

typedef struct {
opal_free_list_item_t super;
dte_data_representation_t type;
} mca_coll_hcoll_dtype_t;
OBJ_CLASS_DECLARATION(mca_coll_hcoll_dtype_t);

struct mca_coll_hcoll_component_t {
/** Base coll component */
Expand Down Expand Up @@ -89,6 +94,8 @@ struct mca_coll_hcoll_component_t {
/* FCA global stuff */
mca_coll_hcoll_ops_t hcoll_ops;
opal_free_list_t requests;
opal_free_list_t dtypes;
int derived_types_support_enabled;
};
typedef struct mca_coll_hcoll_component_t mca_coll_hcoll_component_t;

Expand Down
13 changes: 11 additions & 2 deletions ompi/mca/coll/hcoll/coll_hcoll_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "coll_hcoll.h"
#include "opal/mca/installdirs/installdirs.h"
#include "coll_hcoll_dtypes.h"

/*
* Public string showing the coll ompi_hcol component version number
Expand Down Expand Up @@ -207,7 +208,15 @@ static int hcoll_register(void)
1,
&mca_coll_hcoll_component.hcoll_datatype_fallback,
0));

#if HCOLL_API >= HCOLL_VERSION(3,6)
CHECK(reg_int("dts",NULL,
"[1|0|] Enable/Disable derived types support",
1,
&mca_coll_hcoll_component.derived_types_support_enabled,
0));
#else
mca_coll_hcoll_component.derived_types_support_enabled = 0;
#endif
mca_coll_hcoll_component.compiletime_version = HCOLL_VERNO_STRING;
mca_base_component_var_register(&mca_coll_hcoll_component.super.collm_version,
MCA_COMPILETIME_VER,
Expand Down Expand Up @@ -278,7 +287,7 @@ static int hcoll_close(void)

HCOL_VERBOSE(5,"HCOLL FINALIZE");
rc = hcoll_finalize();

OBJ_DESTRUCT(&cm->dtypes);
opal_progress_unregister(mca_coll_hcoll_progress);
if (HCOLL_SUCCESS != rc){
HCOL_VERBOSE(1,"Hcol library finalize failed");
Expand Down
135 changes: 119 additions & 16 deletions ompi/mca/coll/hcoll/coll_hcoll_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/

#include "ompi/datatype/ompi_datatype.h"
#include "ompi/datatype/ompi_datatype_internal.h"
#include "ompi/mca/op/op.h"
#include "hcoll/api/hcoll_dte.h"
extern int hcoll_type_attr_keyval;

/*to keep this at hand: Ids of the basic opal_datatypes:
#define OPAL_DATATYPE_INT1 4
Expand All @@ -31,9 +33,7 @@
total 15 types
*/



static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX_PREDEFINED] = {
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OMPI_DATATYPE_MAX_PREDEFINED] = {
&DTE_ZERO, /*OPAL_DATATYPE_LOOP 0 */
&DTE_ZERO, /*OPAL_DATATYPE_END_LOOP 1 */
&DTE_ZERO, /*OPAL_DATATYPE_LB 2 */
Expand All @@ -53,34 +53,114 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX
&DTE_FLOAT64, /*OPAL_DATATYPE_FLOAT8 16 */
&DTE_FLOAT96, /*OPAL_DATATYPE_FLOAT12 17 */
&DTE_FLOAT128, /*OPAL_DATATYPE_FLOAT16 18 */
#if defined(DTE_FLOAT32_COMPLEX) && defined(DTE_FLOAT64_COMPLEX)
#if defined(DTE_FLOAT32_COMPLEX)
&DTE_FLOAT32_COMPLEX, /*OPAL_DATATYPE_COMPLEX8 19 */
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX16 20 */
#else
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX8 19 */
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX16 20 */
&DTE_ZERO,
#endif
#if defined(DTE_FLOAT64_COMPLEX)
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX32 20 */
#else
&DTE_ZERO,
#endif
#if defined(DTE_FLOAT128_COMPLEX)
&DTE_FLOAT128_COMPLEX, /*OPAL_DATATYPE_COMPLEX64 21 */
#else
&DTE_ZERO,
#endif
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX32 21 */
&DTE_ZERO, /*OPAL_DATATYPE_BOOL 22 */
&DTE_ZERO, /*OPAL_DATATYPE_WCHAR 23 */
&DTE_ZERO /*OPAL_DATATYPE_UNAVAILABLE 24 */
};

static dte_data_representation_t ompi_dtype_2_dte_dtype(ompi_datatype_t *dtype){
enum {
TRY_FIND_DERIVED,
NO_DERIVED
};


#if HCOLL_API >= HCOLL_VERSION(3,6)
static inline
int hcoll_map_derived_type(ompi_datatype_t *dtype, dte_data_representation_t *new_dte)
{
int rc;
if (NULL == dtype->args) {
/* predefined type, shouldn't call this */
return OMPI_SUCCESS;
}
rc = hcoll_create_mpi_type((void*)dtype, new_dte);
return rc == HCOLL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR;
}

static dte_data_representation_t find_derived_mapping(ompi_datatype_t *dtype){
dte_data_representation_t dte = DTE_ZERO;
mca_coll_hcoll_dtype_t *hcoll_dtype;
if (mca_coll_hcoll_component.derived_types_support_enabled) {
int map_found = 0;
ompi_attr_get_c(dtype->d_keyhash, hcoll_type_attr_keyval,
(void**)&hcoll_dtype, &map_found);
if (!map_found)
hcoll_map_derived_type(dtype, &dte);
else
dte = hcoll_dtype->type;
}

return dte;
}



static inline dte_data_representation_t
ompi_predefined_derived_2_hcoll(int ompi_id) {
switch(ompi_id) {
case OMPI_DATATYPE_MPI_FLOAT_INT:
return DTE_FLOAT_INT;
case OMPI_DATATYPE_MPI_DOUBLE_INT:
return DTE_DOUBLE_INT;
case OMPI_DATATYPE_MPI_LONG_INT:
return DTE_LONG_INT;
case OMPI_DATATYPE_MPI_SHORT_INT:
return DTE_SHORT_INT;
case OMPI_DATATYPE_MPI_LONG_DOUBLE_INT:
return DTE_LONG_DOUBLE_INT;
case OMPI_DATATYPE_MPI_2INT:
return DTE_2INT;
default:
break;
}
return DTE_ZERO;
}
#endif

static dte_data_representation_t
ompi_dtype_2_hcoll_dtype( ompi_datatype_t *dtype,
const int mode)
{
int ompi_type_id = dtype->id;
int opal_type_id = dtype->super.id;
dte_data_representation_t dte_data_rep;
if (!(dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS)) {
ompi_type_id = -1;
dte_data_representation_t dte_data_rep = DTE_ZERO;

if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED &&
dtype->super.flags & OMPI_DATATYPE_FLAG_PREDEFINED) {
if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED) {
dte_data_rep = *ompi_datatype_2_dte_data_rep[opal_type_id];
}
#if HCOLL_API >= HCOLL_VERSION(3,6)
else if (TRY_FIND_DERIVED == mode){
dte_data_rep = ompi_predefined_derived_2_hcoll(ompi_type_id);
}
} else {
if (TRY_FIND_DERIVED == mode)
dte_data_rep = find_derived_mapping(dtype);
#endif
}
if (OPAL_UNLIKELY( ompi_type_id < 0 ||
ompi_type_id >= OPAL_DATATYPE_MAX_PREDEFINED)){
if (HCOL_DTE_IS_ZERO(dte_data_rep) && TRY_FIND_DERIVED == mode &&
!mca_coll_hcoll_component.hcoll_datatype_fallback) {
dte_data_rep = DTE_ZERO;
dte_data_rep.rep.in_line_rep.data_handle.in_line.in_line = 0;
dte_data_rep.rep.in_line_rep.data_handle.pointer_to_handle = (uint64_t ) &dtype->super;
return dte_data_rep;
}
return *ompi_datatype_2_dte_data_rep[opal_type_id];
return dte_data_rep;
}

static hcoll_dte_op_t* ompi_op_2_hcoll_op[OMPI_OP_BASE_FORTRAN_OP_MAX + 1] = {
Expand Down Expand Up @@ -108,4 +188,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) {
return ompi_op_2_hcoll_op[op->o_f_to_c_index];
}


#if HCOLL_API >= HCOLL_VERSION(3,6)
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
int ret = OMPI_SUCCESS;
mca_coll_hcoll_dtype_t *dtype =
(mca_coll_hcoll_dtype_t*) attr_val;

assert(dtype);
if (HCOLL_SUCCESS != (ret = hcoll_dt_destroy(dtype->type))) {
HCOL_ERROR("failed to delete type attr: hcoll_dte_destroy returned %d",ret);
return OMPI_ERROR;
}
opal_free_list_return(&mca_coll_hcoll_component.dtypes,
&dtype->super);

return OMPI_SUCCESS;
}
#else
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
/*Do nothing - it's an old version of hcoll w/o dtypes support */
return OMPI_SUCCESS;
}
#endif
#endif /* COLL_HCOLL_DTYPES_H */
24 changes: 24 additions & 0 deletions ompi/mca/coll/hcoll/coll_hcoll_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

#include "ompi_config.h"
#include "coll_hcoll.h"
#include "coll_hcoll_dtypes.h"

int hcoll_comm_attr_keyval;
int hcoll_type_attr_keyval;

/*
* Initial query function that is invoked during MPI_INIT, allowing
Expand Down Expand Up @@ -240,6 +242,10 @@ int mca_coll_hcoll_progress(void)
}


OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t,
opal_free_list_item_t,
NULL,NULL);

/*
* Invoked when there's a new communicator that has been created.
* Look at the communicator and decide which set of functions and
Expand Down Expand Up @@ -317,6 +323,24 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
HCOL_ERROR("Hcol comm keyval create failed");
return NULL;
}

if (mca_coll_hcoll_component.derived_types_support_enabled) {
copy_fn.attr_datatype_copy_fn = (MPI_Type_internal_copy_attr_function *) MPI_TYPE_NULL_COPY_FN;
del_fn.attr_datatype_delete_fn = hcoll_type_attr_del_fn;
err = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, &hcoll_type_attr_keyval, NULL ,0, NULL);
if (OMPI_SUCCESS != err) {
cm->hcoll_enable = 0;
hcoll_finalize();
opal_progress_unregister(mca_coll_hcoll_progress);
HCOL_ERROR("Hcol type keyval create failed");
return NULL;
}
}
OBJ_CONSTRUCT(&cm->dtypes, opal_free_list_t);
opal_free_list_init(&cm->dtypes, sizeof(mca_coll_hcoll_dtype_t),
8, OBJ_CLASS(mca_coll_hcoll_dtype_t), 0, 0,
32, -1, 32, NULL, 0, NULL, NULL, NULL);

}

hcoll_module = OBJ_NEW(mca_coll_hcoll_module_t);
Expand Down
Loading