From 3ab33719b30888460b7cfae90c08f3ae2f73dade Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Thu, 25 Apr 2024 10:53:07 -0400 Subject: [PATCH] Add CUDA/HIP implementations of reduction operators The operators are generated from macros. Function pointers to kernel launch functions are stored inside the ompi_op_t as a pointer to a struct that is filled if accelerator support is available. The ompi_op* API is extended to include versions taking streams and device IDs to allow enqueuing operators on streams. The old functions map to the stream versions with a NULL stream. Signed-off-by: Joseph Schuchart --- config/opal_check_cudart.m4 | 120 ++ ompi/mca/op/base/op_base_frame.c | 4 +- ompi/mca/op/base/op_base_op_select.c | 60 +- ompi/mca/op/cuda/Makefile.am | 84 + ompi/mca/op/cuda/configure.m4 | 41 + ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt | 15 + ompi/mca/op/cuda/op_cuda.h | 80 + ompi/mca/op/cuda/op_cuda_component.c | 195 ++ ompi/mca/op/cuda/op_cuda_functions.c | 1897 +++++++++++++++++++ ompi/mca/op/cuda/op_cuda_impl.cu | 1080 +++++++++++ ompi/mca/op/cuda/op_cuda_impl.h | 695 +++++++ ompi/mca/op/op.h | 66 +- ompi/mca/op/rocm/Makefile.am | 82 + ompi/mca/op/rocm/configure.m4 | 36 + ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt | 15 + ompi/mca/op/rocm/op_rocm.h | 86 + ompi/mca/op/rocm/op_rocm_component.c | 219 +++ ompi/mca/op/rocm/op_rocm_functions.c | 1908 ++++++++++++++++++++ ompi/mca/op/rocm/op_rocm_impl.h | 706 ++++++++ ompi/mca/op/rocm/op_rocm_impl.hip | 1085 +++++++++++ ompi/op/Makefile.am | 2 + ompi/op/help-ompi-op.txt | 15 + ompi/op/op.c | 16 + ompi/op/op.h | 249 ++- 24 files changed, 8708 insertions(+), 48 deletions(-) create mode 100644 config/opal_check_cudart.m4 create mode 100644 ompi/mca/op/cuda/Makefile.am create mode 100644 ompi/mca/op/cuda/configure.m4 create mode 100644 ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt create mode 100644 ompi/mca/op/cuda/op_cuda.h create mode 100644 ompi/mca/op/cuda/op_cuda_component.c create mode 100644 ompi/mca/op/cuda/op_cuda_functions.c create mode 100644 ompi/mca/op/cuda/op_cuda_impl.cu create mode 100644 ompi/mca/op/cuda/op_cuda_impl.h create mode 100644 ompi/mca/op/rocm/Makefile.am create mode 100644 ompi/mca/op/rocm/configure.m4 create mode 100644 ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt create mode 100644 ompi/mca/op/rocm/op_rocm.h create mode 100644 ompi/mca/op/rocm/op_rocm_component.c create mode 100644 ompi/mca/op/rocm/op_rocm_functions.c create mode 100644 ompi/mca/op/rocm/op_rocm_impl.h create mode 100644 ompi/mca/op/rocm/op_rocm_impl.hip create mode 100644 ompi/op/help-ompi-op.txt diff --git a/config/opal_check_cudart.m4 b/config/opal_check_cudart.m4 new file mode 100644 index 00000000000..0e3fced8065 --- /dev/null +++ b/config/opal_check_cudart.m4 @@ -0,0 +1,120 @@ +dnl -*- autoconf -*- +dnl +dnl Copyright (c) 2004-2010 The Trustees of Indiana University and Indiana +dnl University Research and Technology +dnl Corporation. All rights reserved. +dnl Copyright (c) 2004-2005 The University of Tennessee and The University +dnl of Tennessee Research Foundation. All rights +dnl reserved. +dnl Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, +dnl University of Stuttgart. All rights reserved. +dnl Copyright (c) 2004-2005 The Regents of the University of California. +dnl All rights reserved. +dnl Copyright (c) 2006-2016 Cisco Systems, Inc. All rights reserved. +dnl Copyright (c) 2007 Sun Microsystems, Inc. All rights reserved. +dnl Copyright (c) 2009 IBM Corporation. All rights reserved. +dnl Copyright (c) 2009 Los Alamos National Security, LLC. All rights +dnl reserved. +dnl Copyright (c) 2009-2011 Oak Ridge National Labs. All rights reserved. +dnl Copyright (c) 2011-2015 NVIDIA Corporation. All rights reserved. +dnl Copyright (c) 2015 Research Organization for Information Science +dnl and Technology (RIST). All rights reserved. +dnl Copyright (c) 2022 Amazon.com, Inc. or its affiliates. All Rights reserved. +dnl $COPYRIGHT$ +dnl +dnl Additional copyrights may follow +dnl +dnl $HEADER$ +dnl + + +# OPAL_CHECK_CUDART(prefix, [action-if-found], [action-if-not-found]) +# -------------------------------------------------------- +# check if CUDA runtime library support can be found. sets prefix_{CPPFLAGS, +# LDFLAGS, LIBS} as needed and runs action-if-found if there is +# support, otherwise executes action-if-not-found + +# +# Check for CUDA support +# +AC_DEFUN([OPAL_CHECK_CUDART],[ +OPAL_VAR_SCOPE_PUSH([cudart_save_CPPFLAGS cudart_save_LDFLAGS cudart_save_LIBS]) + +cudart_save_CPPFLAGS="$CPPFLAGS" +cudart_save_LDFLAGS="$LDFLAGS" +cudart_save_LIBS="$LIBS" + +# +# Check to see if the user provided paths for CUDART +# +AC_ARG_WITH([cudart], + [AS_HELP_STRING([--with-cudart=DIR], + [Path to the CUDA runtime library and header files])]) +AC_MSG_CHECKING([if --with-cudart is set]) +AC_ARG_WITH([cudart-libdir], + [AS_HELP_STRING([--with-cudart-libdir=DIR], + [Search for CUDA runtime libraries in DIR])]) + +#################################### +#### Check for CUDA runtime library +#################################### +AS_IF([test "x$with_cudart" != "xno" || test "x$with_cudart" = "x"], + [opal_check_cudart_happy=no + AC_MSG_RESULT([not set (--with-cudart=$with_cudart)])], + [AS_IF([test ! -d "$with_cudart"], + [AC_MSG_RESULT([not found]) + AC_MSG_WARN([Directory $with_cudart not found])] + [AS_IF([test "x`ls $with_cudart/include/cuda_runtime.h 2> /dev/null`" = "x"] + [AC_MSG_RESULT([not found]) + AC_MSG_WARN([Could not find cuda_runtime.h in $with_cudart/include])] + [opal_check_cudart_happy=yes + opal_cudart_incdir="$with_cudart/include"])])]) + +AS_IF([test "$opal_check_cudart_happy" = "no" && test "$with_cudart" != "no"], + [AC_PATH_PROG([nvcc_bin], [nvcc], ["not-found"]) + AS_IF([test "$nvcc_bin" = "not-found"], + [AC_MSG_WARN([Could not find nvcc binary])], + [nvcc_dirname=`AS_DIRNAME([$nvcc_bin])` + with_cudart=$nvcc_dirname/../ + opal_cudart_incdir=$nvcc_dirname/../include + opal_check_cudart_happy=yes]) + ] + []) + +AS_IF([test x"$with_cudart_libdir" = "x"], + [with_cudart_libdir=$with_cudart/lib64/] + []) + +AS_IF([test "$opal_check_cudart_happy" = "yes"], + [OAC_CHECK_PACKAGE([cudart], + [$1], + [cuda_runtime.h], + [cudart], + [cudaMalloc], + [opal_check_cudart_happy="yes"], + [opal_check_cudart_happy="no"])], + []) + + +AC_MSG_CHECKING([if have cuda runtime library support]) +if test "$opal_check_cudart_happy" = "yes"; then + AC_MSG_RESULT([yes (-I$opal_cudart_incdir)]) + CUDART_SUPPORT=1 + common_cudart_CPPFLAGS="-I$opal_cudart_incdir" + AC_SUBST([common_cudart_CPPFLAGS]) +else + AC_MSG_RESULT([no]) + CUDART_SUPPORT=0 +fi + + +OPAL_SUMMARY_ADD([Accelerators], [CUDART support], [], [$opal_check_cudart_happy]) +AM_CONDITIONAL([OPAL_cudart_support], [test "x$CUDART_SUPPORT" = "x1"]) +AC_DEFINE_UNQUOTED([OPAL_CUDART_SUPPORT],$CUDART_SUPPORT, + [Whether we have cuda runtime library support]) + +CPPFLAGS=${cudart_save_CPPFLAGS} +LDFLAGS=${cudart_save_LDFLAGS} +LIBS=${cudart_save_LIBS} +OPAL_VAR_SCOPE_POP +])dnl diff --git a/ompi/mca/op/base/op_base_frame.c b/ompi/mca/op/base/op_base_frame.c index 90167300851..1a7d6dc1320 100644 --- a/ompi/mca/op/base/op_base_frame.c +++ b/ompi/mca/op/base/op_base_frame.c @@ -2,7 +2,7 @@ * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2005 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, @@ -42,6 +42,7 @@ static void module_constructor(ompi_op_base_module_t *m) { m->opm_enable = NULL; m->opm_op = NULL; + m->opm_device_enabled = false; memset(&(m->opm_fns), 0, sizeof(m->opm_fns)); memset(&(m->opm_3buff_fns), 0, sizeof(m->opm_3buff_fns)); } @@ -50,6 +51,7 @@ static void module_constructor_1_0_0(ompi_op_base_module_1_0_0_t *m) { m->opm_enable = NULL; m->opm_op = NULL; + m->opm_device_enabled = false; memset(&(m->opm_fns), 0, sizeof(m->opm_fns)); memset(&(m->opm_3buff_fns), 0, sizeof(m->opm_3buff_fns)); } diff --git a/ompi/mca/op/base/op_base_op_select.c b/ompi/mca/op/base/op_base_op_select.c index 53754ce5668..534a1d63267 100644 --- a/ompi/mca/op/base/op_base_op_select.c +++ b/ompi/mca/op/base/op_base_op_select.c @@ -3,7 +3,7 @@ * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2009 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, @@ -152,22 +152,50 @@ int ompi_op_base_op_select(ompi_op_t *op) } /* Copy over the non-NULL pointers */ - for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { - /* 2-buffer variants */ - if (NULL != avail->ao_module->opm_fns[i]) { - OBJ_RELEASE(op->o_func.intrinsic.modules[i]); - op->o_func.intrinsic.fns[i] = avail->ao_module->opm_fns[i]; - op->o_func.intrinsic.modules[i] = avail->ao_module; - OBJ_RETAIN(avail->ao_module); + if (avail->ao_module->opm_device_enabled) { + if (NULL == op->o_device_op) { + op->o_device_op = calloc(1, sizeof(*op->o_device_op)); } - - /* 3-buffer variants */ - if (NULL != avail->ao_module->opm_3buff_fns[i]) { - OBJ_RELEASE(op->o_func.intrinsic.modules[i]); - op->o_3buff_intrinsic.fns[i] = - avail->ao_module->opm_3buff_fns[i]; - op->o_3buff_intrinsic.modules[i] = avail->ao_module; - OBJ_RETAIN(avail->ao_module); + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + /* 2-buffer variants */ + if (NULL != avail->ao_module->opm_stream_fns[i]) { + if (NULL != op->o_device_op->do_intrinsic.modules[i]) { + OBJ_RELEASE(op->o_device_op->do_intrinsic.modules[i]); + } + op->o_device_op->do_intrinsic.fns[i] = avail->ao_module->opm_stream_fns[i]; + op->o_device_op->do_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + + /* 3-buffer variants */ + if (NULL != avail->ao_module->opm_3buff_stream_fns[i]) { + if (NULL != op->o_device_op->do_3buff_intrinsic.modules[i]) { + OBJ_RELEASE(op->o_device_op->do_3buff_intrinsic.modules[i]); + } + op->o_device_op->do_3buff_intrinsic.fns[i] = + avail->ao_module->opm_3buff_stream_fns[i]; + op->o_device_op->do_3buff_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + } + } else { + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + /* 2-buffer variants */ + if (NULL != avail->ao_module->opm_fns[i]) { + OBJ_RELEASE(op->o_func.intrinsic.modules[i]); + op->o_func.intrinsic.fns[i] = avail->ao_module->opm_fns[i]; + op->o_func.intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } + + /* 3-buffer variants */ + if (NULL != avail->ao_module->opm_3buff_fns[i]) { + OBJ_RELEASE(op->o_3buff_intrinsic.modules[i]); + op->o_3buff_intrinsic.fns[i] = + avail->ao_module->opm_3buff_fns[i]; + op->o_3buff_intrinsic.modules[i] = avail->ao_module; + OBJ_RETAIN(avail->ao_module); + } } } diff --git a/ompi/mca/op/cuda/Makefile.am b/ompi/mca/op/cuda/Makefile.am new file mode 100644 index 00000000000..7075d26301c --- /dev/null +++ b/ompi/mca/op/cuda/Makefile.am @@ -0,0 +1,84 @@ +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# This component provides support for offloading reduce ops to CUDA devices. +# +# See https://github.com/open-mpi/ompi/wiki/devel-CreateComponent +# for more details on how to make Open MPI components. + +# First, list all .h and .c sources. It is necessary to list all .h +# files so that they will be picked up in the distribution tarball. + +AM_CPPFLAGS = $(op_cuda_CPPFLAGS) $(op_cudart_CPPFLAGS) + +dist_ompidata_DATA = help-ompi-mca-op-cuda.txt + +sources = op_cuda_component.c op_cuda.h op_cuda_functions.c op_cuda_impl.h +#sources_extended = op_cuda_functions.cu +cu_sources = op_cuda_impl.cu + +NVCC = nvcc -g +NVCCFLAGS= --std c++17 --gpu-architecture=compute_52 + +.cu.l$(OBJEXT): + $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ + $(LIBTOOLFLAGS) --mode=compile $(NVCC) -prefer-non-pic $(NVCCFLAGS) -Wc,-Xcompiler,-fPIC,-g -c $< + +# -o $($@.o:.lo) + +# Open MPI components can be compiled two ways: +# +# 1. As a standalone dynamic shared object (DSO), sometimes called a +# dynamically loadable library (DLL). +# +# 2. As a static library that is slurped up into the upper-level +# libmpi library (regardless of whether libmpi is a static or dynamic +# library). This is called a "Libtool convenience library". +# +# The component needs to create an output library in this top-level +# component directory, and named either mca__.la (for DSO +# builds) or libmca__.la (for static builds). The OMPI +# build system will have set the +# MCA_BUILD_ompi___DSO AM_CONDITIONAL to indicate +# which way this component should be built. + +if MCA_BUILD_ompi_op_cuda_DSO +component_install = mca_op_cuda.la +else +component_install = +component_noinst = libmca_op_cuda.la +endif + +# Specific information for DSO builds. +# +# The DSO should install itself in $(ompilibdir) (by default, +# $prefix/lib/openmpi). + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_op_cuda_la_SOURCES = $(sources) +mca_op_cuda_la_LIBADD = $(cu_sources:.cu=.lo) +mca_op_cuda_la_LDFLAGS = -module -avoid-version $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ + $(op_cuda_LIBS) $(op_cudart_LDFLAGS) $(op_cudart_LIBS) +EXTRA_mca_op_cuda_la_SOURCES = $(cu_sources) + +# Specific information for static builds. +# +# Note that we *must* "noinst"; the upper-layer Makefile.am's will +# slurp in the resulting .la library into libmpi. + +noinst_LTLIBRARIES = $(component_noinst) +libmca_op_cuda_la_SOURCES = $(sources) +libmca_op_cuda_la_LIBADD = $(cu_sources:.cu=.lo) +libmca_op_cuda_la_LDFLAGS = -module -avoid-version\ + $(op_cuda_LIBS) $(op_cudart_LDFLAGS) $(op_cudart_LIBS) +EXTRA_libmca_op_cuda_la_SOURCES = $(cu_sources) + diff --git a/ompi/mca/op/cuda/configure.m4 b/ompi/mca/op/cuda/configure.m4 new file mode 100644 index 00000000000..0974e3aaf31 --- /dev/null +++ b/ompi/mca/op/cuda/configure.m4 @@ -0,0 +1,41 @@ +# -*- shell-script -*- +# +# Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# Copyright (c) 2022 Amazon.com, Inc. or its affiliates. +# All Rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# +# If CUDA support was requested, then build the CUDA support library. +# This code checks makes sure the check was done earlier by the +# opal_check_cuda.m4 code. It also copies the flags and libs under +# opal_cuda_CPPFLAGS, opal_cuda_LDFLAGS, and opal_cuda_LIBS + +AC_DEFUN([MCA_ompi_op_cuda_CONFIG],[ + + AC_CONFIG_FILES([ompi/mca/op/cuda/Makefile]) + + OPAL_CHECK_CUDA([op_cuda]) + OPAL_CHECK_CUDART([op_cudart]) + + AS_IF([test "x$CUDA_SUPPORT" = "x1"], + [$1], + [$2]) + + AC_SUBST([op_cuda_CPPFLAGS]) + AC_SUBST([op_cuda_LDFLAGS]) + AC_SUBST([op_cuda_LIBS]) + + AC_SUBST([op_cudart_CPPFLAGS]) + AC_SUBST([op_cudart_LDFLAGS]) + AC_SUBST([op_cudart_LIBS]) + +])dnl diff --git a/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt b/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt new file mode 100644 index 00000000000..f999ebc939c --- /dev/null +++ b/ompi/mca/op/cuda/help-ompi-mca-op-cuda.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's CUDA operator component +# +[CUDA call failed] +"CUDA call %s failed: %s: %s\n" diff --git a/ompi/mca/op/cuda/op_cuda.h b/ompi/mca/op/cuda/op_cuda.h new file mode 100644 index 00000000000..ab349d48ee4 --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_OP_CUDA_EXPORT_H +#define MCA_OP_CUDA_EXPORT_H + +#include "ompi_config.h" + +#include "ompi/mca/mca.h" +#include "opal/class/opal_object.h" + +#include "ompi/mca/op/op.h" +#include "ompi/runtime/mpiruntime.h" + +#include +#include + +BEGIN_C_DECLS + + +#define xstr(x) #x +#define str(x) xstr(x) + +#define CHECK(fn, args) \ + do { \ + cudaError_t err = fn args; \ + if (err != cudaSuccess) { \ + opal_show_help("help-ompi-mca-op-cuda.txt", \ + "CUDA call failed", true, \ + str(fn), cudaGetErrorName(err), \ + cudaGetErrorString(err)); \ + ompi_mpi_abort(MPI_COMM_WORLD, 1); \ + } \ + } while (0) + + +/** + * Derive a struct from the base op component struct, allowing us to + * cache some component-specific information on our well-known + * component struct. + */ +typedef struct { + /** The base op component struct */ + ompi_op_base_component_1_0_0_t super; + int cu_max_num_blocks; + int cu_max_num_threads; + int *cu_max_threads_per_block; + int *cu_max_blocks; + CUdevice *cu_devices; + int cu_num_devices; +} ompi_op_cuda_component_t; + +/** + * Globally exported variable. Note that it is a *cuda* component + * (defined above), which has the ompi_op_base_component_t as its + * first member. Hence, the MCA/op framework will find the data that + * it expects in the first memory locations, but then the component + * itself can cache additional information after that that can be used + * by both the component and modules. + */ +OMPI_DECLSPEC extern ompi_op_cuda_component_t + mca_op_cuda_component; + +OMPI_DECLSPEC extern +ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +OMPI_DECLSPEC extern +ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +END_C_DECLS + +#endif /* MCA_OP_CUDA_EXPORT_H */ diff --git a/ompi/mca/op/cuda/op_cuda_component.c b/ompi/mca/op/cuda/op_cuda_component.c new file mode 100644 index 00000000000..3ead710bd1d --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_component.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2021 Cisco Systems, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** @file + * + * This is the "cuda" op component source code. + * + */ + +#include "ompi_config.h" + +#include "opal/util/printf.h" + +#include "ompi/constants.h" +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/cuda/op_cuda.h" + +#include + +static int cuda_component_open(void); +static int cuda_component_close(void); +static int cuda_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple); +static struct ompi_op_base_module_1_0_0_t * + cuda_component_op_query(struct ompi_op_t *op, int *priority); +static int cuda_component_register(void); + +ompi_op_cuda_component_t mca_op_cuda_component = { + { + .opc_version = { + OMPI_OP_BASE_VERSION_1_0_0, + + .mca_component_name = "cuda", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + .mca_open_component = cuda_component_open, + .mca_close_component = cuda_component_close, + .mca_register_component_params = cuda_component_register, + }, + .opc_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + .opc_init_query = cuda_component_init_query, + .opc_op_query = cuda_component_op_query, + }, + .cu_max_num_blocks = -1, + .cu_max_num_threads = -1, + .cu_max_threads_per_block = NULL, + .cu_max_blocks = NULL, + .cu_devices = NULL, + .cu_num_devices = 0, +}; + +/* + * Component open + */ +static int cuda_component_open(void) +{ + return OMPI_SUCCESS; +} + +/* + * Component close + */ +static int cuda_component_close(void) +{ + if (mca_op_cuda_component.cu_num_devices > 0) { + free(mca_op_cuda_component.cu_max_threads_per_block); + mca_op_cuda_component.cu_max_threads_per_block = NULL; + free(mca_op_cuda_component.cu_max_blocks); + mca_op_cuda_component.cu_max_blocks = NULL; + free(mca_op_cuda_component.cu_devices); + mca_op_cuda_component.cu_devices = NULL; + mca_op_cuda_component.cu_num_devices = 0; + } + + return OMPI_SUCCESS; +} + +/* + * Register MCA params. + */ +static int +cuda_component_register(void) +{ + mca_base_var_enum_flag_t *new_enum_flag = NULL; + (void) mca_base_component_var_register(&mca_op_cuda_component.super.opc_version, + "max_num_blocks", + "Maximum number of thread blocks in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_cuda_component.cu_max_num_blocks); + + (void) mca_base_component_var_register(&mca_op_cuda_component.super.opc_version, + "max_num_threads", + "Maximum number of threads per block in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_cuda_component.cu_max_num_threads); + + return OMPI_SUCCESS; +} + + +/* + * Query whether this component wants to be used in this process. + */ +static int +cuda_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple) +{ + int num_devices; + int rc; + // TODO: is this init needed here? + cuInit(0); + CHECK(cuDeviceGetCount, (&num_devices)); + mca_op_cuda_component.cu_num_devices = num_devices; + mca_op_cuda_component.cu_devices = (CUdevice*)malloc(num_devices*sizeof(CUdevice)); + mca_op_cuda_component.cu_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); + mca_op_cuda_component.cu_max_blocks = (int*)malloc(num_devices*sizeof(int)); + for (int i = 0; i < num_devices; ++i) { + CHECK(cuDeviceGet, (&mca_op_cuda_component.cu_devices[i], i)); + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_threads_per_block[i], + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_threads_per_block[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_threads) { + if (mca_op_cuda_component.cu_max_threads_per_block[i] >= mca_op_cuda_component.cu_max_num_threads) { + mca_op_cuda_component.cu_max_threads_per_block[i] = mca_op_cuda_component.cu_max_num_threads; + } + } + + rc = cuDeviceGetAttribute(&mca_op_cuda_component.cu_max_blocks[i], + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, + mca_op_cuda_component.cu_devices[i]); + if (CUDA_SUCCESS != rc) { + /* fall-back to value that should work on every device */ + mca_op_cuda_component.cu_max_blocks[i] = 512; + } + if (-1 < mca_op_cuda_component.cu_max_num_blocks) { + if (mca_op_cuda_component.cu_max_blocks[i] >= mca_op_cuda_component.cu_max_num_blocks) { + mca_op_cuda_component.cu_max_blocks[i] = mca_op_cuda_component.cu_max_num_blocks; + } + } + } + + return OMPI_SUCCESS; +} + +/* + * Query whether this component can be used for a specific op + */ +static struct ompi_op_base_module_1_0_0_t* +cuda_component_op_query(struct ompi_op_t *op, int *priority) +{ + ompi_op_base_module_t *module = NULL; + + module = OBJ_NEW(ompi_op_base_module_t); + module->opm_device_enabled = true; + for (int i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + module->opm_stream_fns[i] = ompi_op_cuda_functions[op->o_f_to_c_index][i]; + module->opm_3buff_stream_fns[i] = ompi_op_cuda_3buff_functions[op->o_f_to_c_index][i]; + + if( NULL != module->opm_fns[i] ) { + OBJ_RETAIN(module); + } + if( NULL != module->opm_3buff_fns[i] ) { + OBJ_RETAIN(module); + } + } + *priority = 50; + return (ompi_op_base_module_1_0_0_t *) module; +} diff --git a/ompi/mca/op/cuda/op_cuda_functions.c b/ompi/mca/op/cuda/op_cuda_functions.c new file mode 100644 index 00000000000..904595147cb --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_functions.c @@ -0,0 +1,1897 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#ifdef HAVE_SYS_TYPES_H +#include +#endif +#include "opal/util/output.h" + + +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/cuda/op_cuda.h" +#include "opal/mca/accelerator/accelerator.h" + +#include "ompi/mca/op/cuda/op_cuda.h" +#include "ompi/mca/op/cuda/op_cuda_impl.h" + +/** + * Disable warning about empty macro var-args. + * We use varargs to suppress expansion of typenames + * (e.g., int32_t -> int) which could lead to collisions + * for similar base types. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" + +static inline void device_op_pre(const void *orig_source1, + void **source1, + int *source1_device, + const void *orig_source2, + void **source2, + int *source2_device, + void *orig_target, + void **target, + int *target_device, + int count, + struct ompi_datatype_t *dtype, + int *threads_per_block, + int *max_blocks, + int *device, + opal_accelerator_stream_t *stream) +{ + uint64_t target_flags = -1, source1_flags = -1, source2_flags = -1; + int target_rc, source1_rc, source2_rc = -1; + + *target = orig_target; + *source1 = (void*)orig_source1; + if (NULL != orig_source2) { + *source2 = (void*)orig_source2; + } + + if (*device != MCA_ACCELERATOR_NO_DEVICE_ID) { + /* we got the device from the caller, just adjust the output parameters */ + *target_device = *device; + *source1_device = *device; + if (NULL != source2_device) { + *source2_device = *device; + } + } else { + + target_rc = opal_accelerator.check_addr(*target, target_device, &target_flags); + source1_rc = opal_accelerator.check_addr(*source1, source1_device, &source1_flags); + *device = *target_device; + + if (NULL != orig_source2) { + source2_rc = opal_accelerator.check_addr(*source2, source2_device, &source2_flags); + } + + if (0 == target_rc && 0 == source1_rc && 0 == source2_rc) { + /* no buffers are on any device, select device 0 */ + *device = 0; + } else if (*target_device == -1) { + if (*source1_device == -1 && NULL != orig_source2) { + *device = *source2_device; + } else { + *device = *source1_device; + } + } + + if (0 == target_rc || 0 == source1_rc || *target_device != *source1_device) { + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + if (0 == target_rc) { + // allocate memory on the device for the target buffer + opal_accelerator.mem_alloc_stream(*device, target, nbytes, stream); + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*target, orig_target, nbytes, *(CUstream*)stream->stream)); + *target_device = -1; // mark target device as host + } + + if (0 == source1_rc || *device != *source1_device) { + // allocate memory on the device for the source buffer + opal_accelerator.mem_alloc_stream(*device, source1, nbytes, stream); + if (0 == source1_rc) { + /* copy from host to device */ + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*source1, orig_source1, nbytes, *(CUstream*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(cuMemcpyDtoDAsync, ((CUdeviceptr)*source1, (CUdeviceptr)orig_source1, nbytes, *(CUstream*)stream->stream)); + } + } + + } + if (NULL != source2_device && *target_device != *source2_device) { + // allocate memory on the device for the source buffer + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + opal_accelerator.mem_alloc_stream(*device, source2, nbytes, stream); + if (0 == source2_rc) { + /* copy from host to device */ + //printf("copying source from host to device %d\n", *device); + CHECK(cuMemcpyHtoDAsync, ((CUdeviceptr)*source2, orig_source2, nbytes, *(CUstream*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(cuMemcpyDtoDAsync, ((CUdeviceptr)*source2, (CUdeviceptr)orig_source2, nbytes, *(CUstream*)stream->stream)); + } + } + } + *threads_per_block = mca_op_cuda_component.cu_max_threads_per_block[*device]; + *max_blocks = mca_op_cuda_component.cu_max_blocks[*device]; +} + +static inline void device_op_post(void *source1, + int source1_device, + void *source2, + int source2_device, + void *orig_target, + void *target, + int target_device, + int count, + struct ompi_datatype_t *dtype, + int device, + opal_accelerator_stream_t *stream) +{ + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + CHECK(cuMemcpyDtoHAsync, (orig_target, (CUdeviceptr)target, nbytes, *(CUstream *)stream->stream)); + } + + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + opal_accelerator.mem_release_stream(device, target, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)target, mca_op_cuda_component.cu_stream)); + } + if (source1_device != device) { + opal_accelerator.mem_release_stream(device, source1, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)source, mca_op_cuda_component.cu_stream)); + } + if (NULL != source2 && source2_device != device) { + opal_accelerator.mem_release_stream(device, source2, stream); + //CHECK(cuMemFreeAsync, ((CUdeviceptr)source, mca_op_cuda_component.cu_stream)); + } +} + +#define FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) __opal_attribute_unused__; \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source_device, target_device; \ + type *source, *target; \ + int n = *count; \ + device_op_pre(in, (void**)&source, &source_device, NULL, NULL, NULL, \ + inout, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + CUstream *custream = (CUstream*)stream->stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(source, target, n, threads_per_block, max_blocks, *custream); \ + device_op_post(source, source_device, NULL, -1, inout, target, target_device, n, *dtype, device, stream); \ + } + +#define OP_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC(name, type_name) FUNC(name, type_name, ompi_op_predefined_##type_name##_t) + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_2buff_##name##_int8_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_2buff_##name##_int16_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_2buff_##name##_int32_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_2buff_##name##_int64_t(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_2buff_##name##_float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_2buff_##name##_double(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_cuda_2buff_##name##_long_double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer type"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_2buff_##name##_2int8(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_2buff_##name##_2int16(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_2buff_##name##_2int32(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_2buff_##name##_2int64(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_cuda_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_2buff_##name##_2float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_2buff_##name##_2double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(max, int8_t, int8_t) +FUNC_FUNC(max, uint8_t, uint8_t) +FUNC_FUNC(max, int16_t, int16_t) +FUNC_FUNC(max, uint16_t, uint16_t) +FUNC_FUNC(max, int32_t, int32_t) +FUNC_FUNC(max, uint32_t, uint32_t) +FUNC_FUNC(max, int64_t, int64_t) +FUNC_FUNC(max, uint64_t, uint64_t) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(max, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC(max, float, float) +FUNC_FUNC(max, double, double) +FUNC_FUNC(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(min, int8_t, int8_t) +FUNC_FUNC(min, uint8_t, uint8_t) +FUNC_FUNC(min, int16_t, int16_t) +FUNC_FUNC(min, uint16_t, uint16_t) +FUNC_FUNC(min, int32_t, int32_t) +FUNC_FUNC(min, uint32_t, uint32_t) +FUNC_FUNC(min, int64_t, int64_t) +FUNC_FUNC(min, uint64_t, uint64_t) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(min, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC(min, float, float) +FUNC_FUNC(min, double, double) +FUNC_FUNC(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC(sum, int8_t, int8_t) +OP_FUNC(sum, uint8_t, uint8_t) +OP_FUNC(sum, int16_t, int16_t) +OP_FUNC(sum, uint16_t, uint16_t) +OP_FUNC(sum, int32_t, int32_t) +OP_FUNC(sum, uint32_t, uint32_t) +OP_FUNC(sum, int64_t, int64_t) +OP_FUNC(sum, uint64_t, uint64_t) +OP_FUNC(sum, long, long) +OP_FUNC(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC(sum, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC(sum, float, float) +OP_FUNC(sum, double, double) +OP_FUNC(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC(prod, int8_t, int8_t) +OP_FUNC(prod, uint8_t, uint8_t) +OP_FUNC(prod, int16_t, int16_t) +OP_FUNC(prod, uint16_t, uint16_t) +OP_FUNC(prod, int32_t, int32_t) +OP_FUNC(prod, uint32_t, uint32_t) +OP_FUNC(prod, int64_t, int64_t) +OP_FUNC(prod, uint64_t, uint64_t) +OP_FUNC(prod, long, long) +OP_FUNC(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ + +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC(prod, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC(prod, float, float) +OP_FUNC(prod, double, double) +OP_FUNC(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int) +LOC_FUNC(maxloc, double_int) +LOC_FUNC(maxloc, long_int) +LOC_FUNC(maxloc, 2int) +LOC_FUNC(maxloc, short_int) +LOC_FUNC(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float) +LOC_FUNC(maxloc, 2double) +LOC_FUNC(maxloc, 2int8) +LOC_FUNC(maxloc, 2int16) +LOC_FUNC(maxloc, 2int32) +LOC_FUNC(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int) +LOC_FUNC(minloc, double_int) +LOC_FUNC(minloc, long_int) +LOC_FUNC(minloc, 2int) +LOC_FUNC(minloc, short_int) +LOC_FUNC(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float) +LOC_FUNC(minloc, 2double) +LOC_FUNC(minloc, 2int8) +LOC_FUNC(minloc, 2int16) +LOC_FUNC(minloc, 2int32) +LOC_FUNC(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source1_device, source2_device, target_device; \ + type *source1, *source2, *target; \ + int n = *count; \ + device_op_pre(in1, (void**)&source1, &source1_device, \ + in2, (void**)&source2, &source2_device, \ + out, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + CUstream *custream = (CUstream*)stream->stream; \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(source1, source2, target, n, threads_per_block, max_blocks, *custream);\ + device_op_post(source1, source1_device, source2, source2_device, out, target, target_device, n, *dtype, device, stream);\ + } + + +#define OP_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name) FUNC_3BUF(name, type_name, ompi_op_predefined_##type_name##_t) + + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_3buff_##name##_int8_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_3buff_##name##_int16_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_3buff_##name##_int32_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_3buff_##name##_int64_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_3buff_##name##_float(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_3buff_##name##_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_cuda_3buff_##name##_long_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_cuda_3buff_##name##_2int8(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_cuda_3buff_##name##_2int16(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_cuda_3buff_##name##_2int32(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_cuda_3buff_##name##_2int64(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_cuda_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "IUnsuported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_cuda_3buff_##name##_2float(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_cuda_3buff_##name##_2double(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(min, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t) +OP_FUNC_3BUF(sum, uint8_t, uint8_t) +OP_FUNC_3BUF(sum, int16_t, int16_t) +OP_FUNC_3BUF(sum, uint16_t, uint16_t) +OP_FUNC_3BUF(sum, int32_t, int32_t) +OP_FUNC_3BUF(sum, uint32_t, uint32_t) +OP_FUNC_3BUF(sum, int64_t, int64_t) +OP_FUNC_3BUF(sum, uint64_t, uint64_t) +OP_FUNC_3BUF(sum, long, long) +OP_FUNC_3BUF(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(sum, float, float) +OP_FUNC_3BUF(sum, double, double) +OP_FUNC_3BUF(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_float_complex, float _Complex) +OP_FUNC_3BUF(sum, c_double_complex, double _Complex) +OP_FUNC_3BUF(sum, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t) +OP_FUNC_3BUF(prod, uint8_t, uint8_t) +OP_FUNC_3BUF(prod, int16_t, int16_t) +OP_FUNC_3BUF(prod, uint16_t, uint16_t) +OP_FUNC_3BUF(prod, int32_t, int32_t) +OP_FUNC_3BUF(prod, uint32_t, uint32_t) +OP_FUNC_3BUF(prod, int64_t, int64_t) +OP_FUNC_3BUF(prod, uint64_t, uint64_t) +OP_FUNC_3BUF(prod, long, long) +OP_FUNC_3BUF(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FORT_FLOAT_FUNC_3BUF(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FORT_FLOAT_FUNC_3BUF(prod, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(prod, float, float) +OP_FUNC_3BUF(prod, double, double) +OP_FUNC_3BUF(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int) +LOC_FUNC_3BUF(maxloc, double_int) +LOC_FUNC_3BUF(maxloc, long_int) +LOC_FUNC_3BUF(maxloc, 2int) +LOC_FUNC_3BUF(maxloc, short_int) +LOC_FUNC_3BUF(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float) +LOC_FUNC_3BUF(maxloc, 2double) +LOC_FUNC_3BUF(maxloc, 2int8) +LOC_FUNC_3BUF(maxloc, 2int16) +LOC_FUNC_3BUF(maxloc, 2int32) +LOC_FUNC_3BUF(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int) +LOC_FUNC_3BUF(minloc, double_int) +LOC_FUNC_3BUF(minloc, long_int) +LOC_FUNC_3BUF(minloc, 2int) +LOC_FUNC_3BUF(minloc, short_int) +LOC_FUNC_3BUF(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float) +LOC_FUNC_3BUF(minloc, 2double) +LOC_FUNC_3BUF(minloc, 2int8) +LOC_FUNC_3BUF(minloc, 2int16) +LOC_FUNC_3BUF(minloc, 2int32) +LOC_FUNC_3BUF(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/* + * Helpful defines, because there's soooo many names! + * + * **NOTE** These #define's used to be strictly ordered but the use of + * designated initializers removed this restrictions. When adding new + * operators ALWAYS use a designated initializer! + */ + +/** C integer ***********************************************************/ +#define C_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INT8_T] = ompi_op_cuda_##ftype##_##name##_int8_t, \ + [OMPI_OP_BASE_TYPE_UINT8_T] = ompi_op_cuda_##ftype##_##name##_uint8_t, \ + [OMPI_OP_BASE_TYPE_INT16_T] = ompi_op_cuda_##ftype##_##name##_int16_t, \ + [OMPI_OP_BASE_TYPE_UINT16_T] = ompi_op_cuda_##ftype##_##name##_uint16_t, \ + [OMPI_OP_BASE_TYPE_INT32_T] = ompi_op_cuda_##ftype##_##name##_int32_t, \ + [OMPI_OP_BASE_TYPE_UINT32_T] = ompi_op_cuda_##ftype##_##name##_uint32_t, \ + [OMPI_OP_BASE_TYPE_INT64_T] = ompi_op_cuda_##ftype##_##name##_int64_t, \ + [OMPI_OP_BASE_TYPE_LONG] = ompi_op_cuda_##ftype##_##name##_long, \ + [OMPI_OP_BASE_TYPE_UNSIGNED_LONG] = ompi_op_cuda_##ftype##_##name##_ulong, \ + [OMPI_OP_BASE_TYPE_UINT64_T] = ompi_op_cuda_##ftype##_##name##_uint64_t + +/** All the Fortran integers ********************************************/ + +#if OMPI_HAVE_FORTRAN_INTEGER +#define FORTRAN_INTEGER_PLAIN(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer +#else +#define FORTRAN_INTEGER_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +#define FORTRAN_INTEGER1(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer1 +#else +#define FORTRAN_INTEGER1(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +#define FORTRAN_INTEGER2(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer2 +#else +#define FORTRAN_INTEGER2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +#define FORTRAN_INTEGER4(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer4 +#else +#define FORTRAN_INTEGER4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +#define FORTRAN_INTEGER8(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer8 +#else +#define FORTRAN_INTEGER8(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +#define FORTRAN_INTEGER16(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_integer16 +#else +#define FORTRAN_INTEGER16(name, ftype) NULL +#endif + +#define FORTRAN_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INTEGER] = FORTRAN_INTEGER_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER1] = FORTRAN_INTEGER1(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER2] = FORTRAN_INTEGER2(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER4] = FORTRAN_INTEGER4(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER8] = FORTRAN_INTEGER8(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER16] = FORTRAN_INTEGER16(name, ftype) + +/** All the Fortran reals ***********************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real +#else +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real2 +#else +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real4 +#else +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real8 +#else +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) NULL +#endif +/* If: + - we have fortran REAL*16and* + - fortran REAL*16 matches the bit representation of the + corresponding C type + Only then do we put in function pointers for REAL*16 reductions. + Otherwise, just put in NULL. */ +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) ompi_op_cuda_##ftype##_##name##_fortran_real16 +#else +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) NULL +#endif + +#define FLOATING_POINT_FORTRAN_REAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_REAL] = FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL2] = FLOATING_POINT_FORTRAN_REAL2(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL4] = FLOATING_POINT_FORTRAN_REAL4(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL8] = FLOATING_POINT_FORTRAN_REAL8(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL16] = FLOATING_POINT_FORTRAN_REAL16(name, ftype) + +/** Fortran double precision ********************************************/ + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) \ + ompi_op_cuda_##ftype##_##name##_fortran_double_precision +#else +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) NULL +#endif + +/** Floating point, including all the Fortran reals *********************/ + +//#if defined(HAVE_SHORT_FLOAT) || defined(HAVE_OPAL_SHORT_FLOAT_T) +//#define SHORT_FLOAT(name, ftype) ompi_op_cuda_##ftype##_##name##_short_float +//#else +#define SHORT_FLOAT(name, ftype) NULL +//#endif +#define FLOAT(name, ftype) ompi_op_cuda_##ftype##_##name##_float +#define DOUBLE(name, ftype) ompi_op_cuda_##ftype##_##name##_double +#define LONG_DOUBLE(name, ftype) ompi_op_cuda_##ftype##_##name##_long_double + +#define FLOATING_POINT(name, ftype) \ + [OMPI_OP_BASE_TYPE_SHORT_FLOAT] = SHORT_FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT] = FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE] = DOUBLE(name, ftype), \ + FLOATING_POINT_FORTRAN_REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE_PRECISION] = FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE] = LONG_DOUBLE(name, ftype) + +/** Fortran logical *****************************************************/ + +#if OMPI_HAVE_FORTRAN_LOGICAL +#define FORTRAN_LOGICAL(name, ftype) \ + ompi_op_cuda_##ftype##_##name##_fortran_logical /* OMPI_OP_CUDA_TYPE_LOGICAL */ +#else +#define FORTRAN_LOGICAL(name, ftype) NULL +#endif + +#define LOGICAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_LOGICAL] = FORTRAN_LOGICAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_BOOL] = ompi_op_cuda_##ftype##_##name##_bool + +/** Complex *****************************************************/ +#if 0 + +#if defined(HAVE_SHORT_FLOAT__COMPLEX) || defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +#define SHORT_FLOAT_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_short_float_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#endif +#define LONG_DOUBLE_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_long_double_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#define LONG_DOUBLE_COMPLEX(name, ftype) NULL +#endif // 0 +#define FLOAT_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_float_complex +#define DOUBLE_COMPLEX(name, ftype) ompi_op_cuda_##ftype##_##name##_c_double_complex + +#define COMPLEX(name, ftype) \ + [OMPI_OP_BASE_TYPE_C_SHORT_FLOAT_COMPLEX] = SHORT_FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_FLOAT_COMPLEX] = FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_DOUBLE_COMPLEX] = DOUBLE_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_LONG_DOUBLE_COMPLEX] = LONG_DOUBLE_COMPLEX(name, ftype) + +/** Byte ****************************************************************/ + +#define BYTE(name, ftype) \ + [OMPI_OP_BASE_TYPE_BYTE] = ompi_op_cuda_##ftype##_##name##_byte + +/** Fortran complex *****************************************************/ +/** Fortran "2" types ***************************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define TWOLOC_FORTRAN_2REAL(name, ftype) ompi_op_cuda_##ftype##_##name##_2real +#else +#define TWOLOC_FORTRAN_2REAL(name, ftype) NULL +#endif + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) ompi_op_cuda_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) ompi_op_cuda_##ftype##_##name##_2integer +#else +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) NULL +#endif + +/** All "2" types *******************************************************/ + +#define TWOLOC(name, ftype) \ + [OMPI_OP_BASE_TYPE_2REAL] = TWOLOC_FORTRAN_2REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_2DOUBLE_PRECISION] = TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_2INTEGER] = TWOLOC_FORTRAN_2INTEGER(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT_INT] = ompi_op_cuda_##ftype##_##name##_float_int, \ + [OMPI_OP_BASE_TYPE_DOUBLE_INT] = ompi_op_cuda_##ftype##_##name##_double_int, \ + [OMPI_OP_BASE_TYPE_LONG_INT] = ompi_op_cuda_##ftype##_##name##_long_int, \ + [OMPI_OP_BASE_TYPE_2INT] = ompi_op_cuda_##ftype##_##name##_2int, \ + [OMPI_OP_BASE_TYPE_SHORT_INT] = ompi_op_cuda_##ftype##_##name##_short_int, \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE_INT] = ompi_op_cuda_##ftype##_##name##_long_double_int + +/* + * MPI_OP_NULL + * All types + */ +#define FLAGS_NO_FLOAT \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | OMPI_OP_FLAGS_COMMUTE) +#define FLAGS \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | \ + OMPI_OP_FLAGS_FLOAT_ASSOC | OMPI_OP_FLAGS_COMMUTE) + +ompi_op_base_stream_handler_fn_t ompi_op_cuda_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 2buff), + FORTRAN_INTEGER(max, 2buff), + FLOATING_POINT(max, 2buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 2buff), + FORTRAN_INTEGER(min, 2buff), + FLOATING_POINT(min, 2buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 2buff), + FORTRAN_INTEGER(sum, 2buff), + FLOATING_POINT(sum, 2buff), + COMPLEX(sum, 2buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 2buff), + FORTRAN_INTEGER(prod, 2buff), + FLOATING_POINT(prod, 2buff), + COMPLEX(prod, 2buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] = { + C_INTEGER(land, 2buff), + LOGICAL(land, 2buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 2buff), + FORTRAN_INTEGER(band, 2buff), + BYTE(band, 2buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 2buff), + LOGICAL(lor, 2buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 2buff), + FORTRAN_INTEGER(bor, 2buff), + BYTE(bor, 2buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 2buff), + LOGICAL(lxor, 2buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 2buff), + FORTRAN_INTEGER(bxor, 2buff), + BYTE(bxor, 2buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 2buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 2buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* (MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE) */ + NULL, + }, + + }; + +ompi_op_base_3buff_stream_handler_fn_t ompi_op_cuda_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 3buff), + FORTRAN_INTEGER(max, 3buff), + FLOATING_POINT(max, 3buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 3buff), + FORTRAN_INTEGER(min, 3buff), + FLOATING_POINT(min, 3buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 3buff), + FORTRAN_INTEGER(sum, 3buff), + FLOATING_POINT(sum, 3buff), + COMPLEX(sum, 3buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 3buff), + FORTRAN_INTEGER(prod, 3buff), + FLOATING_POINT(prod, 3buff), + COMPLEX(prod, 3buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] ={ + C_INTEGER(land, 3buff), + LOGICAL(land, 3buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 3buff), + FORTRAN_INTEGER(band, 3buff), + BYTE(band, 3buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 3buff), + LOGICAL(lor, 3buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 3buff), + FORTRAN_INTEGER(bor, 3buff), + BYTE(bor, 3buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 3buff), + LOGICAL(lxor, 3buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 3buff), + FORTRAN_INTEGER(bxor, 3buff), + BYTE(bxor, 3buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 3buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 3buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE */ + NULL, + }, + }; diff --git a/ompi/mca/op/cuda/op_cuda_impl.cu b/ompi/mca/op/cuda/op_cuda_impl.cu new file mode 100644 index 00000000000..3daf7f56fbb --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_impl.cu @@ -0,0 +1,1080 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "op_cuda_impl.h" + +#include + +#include + +#define ISSIGNED(x) std::is_signed_v + +template +static inline __device__ constexpr T tmax(T a, T b) { + return (a > b) ? a : b; +} + +template +static inline __device__ constexpr T tmin(T a, T b) { + return (a < b) ? a : b; +} + +template +static inline __device__ constexpr T tsum(T a, T b) { + return a+b; +} + +template +static inline __device__ constexpr T tprod(T a, T b) { + return a*b; +} + +template +static inline __device__ T vmax(const T& a, const T& b) { + return T{tmax(a.x, b.x), tmax(a.y, b.y), tmax(a.z, b.z), tmax(a.w, b.w)}; +} + +template +static inline __device__ T vmin(const T& a, const T& b) { + return T{tmin(a.x, b.x), tmin(a.y, b.y), tmin(a.z, b.z), tmin(a.w, b.w)}; +} + +template +static inline __device__ T vsum(const T& a, const T& b) { + return T{tsum(a.x, b.x), tsum(a.y, b.y), tsum(a.z, b.z), tsum(a.w, b.w)}; +} + +template +static inline __device__ T vprod(const T& a, const T& b) { + return T{(a.x * b.x), (a.y * b.y), (a.z * b.z), (a.w * b.w)}; +} + + +/* TODO: missing support for + * - short float (conditional on whether short float is available) + * - some Fortran types + * - some complex types + */ + +#define USE_VECTORS 1 + +#define OP_FUNC(name, type_name, type, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + /*if (index < n) { int i = index;*/ \ + inout[i] = inout[i] op in[i]; \ + } \ + } \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + + +#if defined(USE_VECTORS) +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + vtype vin = ((vtype*)in)[i]; \ + vtype vinout = ((vtype*)inout)[i]; \ + vinout.x = vinout.x op vin.x; \ + vinout.y = vinout.y op vin.y; \ + vinout.z = vinout.z op vin.z; \ + vinout.w = vinout.w op vin.w; \ + ((vtype*)inout)[i] = vinout; \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = inout[idx] op in[idx]; \ + } \ + } \ + } \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else // USE_VECTORS +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) OP_FUNC(name, type_name, type, op) +#endif // USE_VECTORS + +#define FUNC_FUNC_FN(name, type_name, type, fn) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = fn(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + +#define FUNC_FUNC(name, type_name, type) FUNC_FUNC_FN(name, type_name, type, current_func) + +#if defined(USE_VECTORS) +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = vfn(((vtype*)inout)[i], ((vtype*)in)[i]); \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = fn(inout[idx], in[idx]); \ + } \ + } \ + } \ + static void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#endif // defined(USE_VECTORS) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ + +#define LOC_FUNC(name, type_name, op) \ + static __global__ void \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in, \ + ompi_op_predefined_##type_name##_t *__restrict__ inout, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a = &in[i]; \ + ompi_op_predefined_##type_name##_t *b = &inout[i]; \ + if (a->v op b->v) { \ + b->v = a->v; \ + b->k = a->k; \ + } else if (a->v == b->v) { \ + b->k = (b->k < a->k ? b->k : a->k); \ + } \ + } \ + } \ + void \ + ompi_op_cuda_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + CUstream s = stream; \ + ompi_op_cuda_2buff_##name##_##type_name##_kernel<<>>(a, b, count); \ + } + +#define OPV_DISPATCH(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + static_assert(sizeof(type_name) <= sizeof(unsigned long long), "Unknown size type"); \ + if constexpr(!ISSIGNED(type)) { \ + if constexpr(sizeof(type_name) == sizeof(unsigned char)) { \ + ompi_op_cuda_2buff_##name##_uchar_submit((const unsigned char*)in, (unsigned char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned short)) { \ + ompi_op_cuda_2buff_##name##_ushort_submit((const unsigned short*)in, (unsigned short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned int)) { \ + ompi_op_cuda_2buff_##name##_uint_submit((const unsigned int*)in, (unsigned int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long)) { \ + ompi_op_cuda_2buff_##name##_ulong_submit((const unsigned long*)in, (unsigned long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long long)) { \ + ompi_op_cuda_2buff_##name##_ulonglong_submit((const unsigned long long*)in, (unsigned long long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } \ + } else { \ + if constexpr(sizeof(type_name) == sizeof(char)) { \ + ompi_op_cuda_2buff_##name##_char_submit((const char*)in, (char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(short)) { \ + ompi_op_cuda_2buff_##name##_short_submit((const short*)in, (short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(int)) { \ + ompi_op_cuda_2buff_##name##_int_submit((const int*)in, (int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long)) { \ + ompi_op_cuda_2buff_##name##_long_submit((const long*)in, (long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long long)) { \ + ompi_op_cuda_2buff_##name##_longlong_submit((const long long*)in, (long long*)inout, count,\ + threads_per_block, \ + max_blocks, stream); \ + } \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(max, char, char, char4, 4, vmax, max) +VFUNC_FUNC(max, uchar, unsigned char, uchar4, 4, vmax, max) +VFUNC_FUNC(max, short, short, short4, 4, vmax, max) +VFUNC_FUNC(max, ushort, unsigned short, ushort4, 4, vmax, max) +VFUNC_FUNC(max, int, int, int4, 4, vmax, max) +VFUNC_FUNC(max, uint, unsigned int, uint4, 4, vmax, max) + +#undef current_func +#define current_func(a, b) max(a, b) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) +FUNC_FUNC(max, longlong, long long) +FUNC_FUNC(max, ulonglong, unsigned long long) + +/* dispatch fixed-size types */ +OPV_DISPATCH(max, int8_t, int8_t) +OPV_DISPATCH(max, uint8_t, uint8_t) +OPV_DISPATCH(max, int16_t, int16_t) +OPV_DISPATCH(max, uint16_t, uint16_t) +OPV_DISPATCH(max, int32_t, int32_t) +OPV_DISPATCH(max, uint32_t, uint32_t) +OPV_DISPATCH(max, int64_t, int64_t) +OPV_DISPATCH(max, uint64_t, uint64_t) + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +FUNC_FUNC(max, long_double, long double) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmaxf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmax(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, double, double) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmax2(a, b) +//VFUNC_FUNC(max, halfx, half, half2, 2, __hmax2, __hmax) +#endif // __CUDA_ARCH__ + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(min, char, char, char4, 4, vmin, min) +VFUNC_FUNC(min, uchar, unsigned char, uchar4, 4, vmin, min) +VFUNC_FUNC(min, short, short, short4, 4, vmin, min) +VFUNC_FUNC(min, ushort, unsigned short, ushort4, 4, vmin, min) +VFUNC_FUNC(min, int, int, int4, 4, vmin, min) +VFUNC_FUNC(min, uint, unsigned int, uint4, 4, vmin, min) + +#undef current_func +#define current_func(a, b) min(a, b) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) +FUNC_FUNC(min, longlong, long long) +FUNC_FUNC(min, ulonglong, unsigned long long) +OPV_DISPATCH(min, int8_t, int8_t) +OPV_DISPATCH(min, uint8_t, uint8_t) +OPV_DISPATCH(min, int16_t, int16_t) +OPV_DISPATCH(min, uint16_t, uint16_t) +OPV_DISPATCH(min, int32_t, int32_t) +OPV_DISPATCH(min, uint32_t, uint32_t) +OPV_DISPATCH(min, int64_t, int64_t) +OPV_DISPATCH(min, uint64_t, uint64_t) + + + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fminf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmin(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, double, double) + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +FUNC_FUNC(min, long_double, long double) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmin2(a, b) +//VFUNC_FUNC(min, half, half, half2, 2, __hmin2, __hmin) +#endif // __CUDA_ARCH__ + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(sum, char, char, char4, 4, vsum, tsum) +VFUNC_FUNC(sum, uchar, unsigned char, uchar4, 4, vsum, tsum) +VFUNC_FUNC(sum, short, short, short4, 4, vsum, tsum) +VFUNC_FUNC(sum, ushort, unsigned short, ushort4, 4, vsum, tsum) +VFUNC_FUNC(sum, int, int, int4, 4, vsum, tsum) +VFUNC_FUNC(sum, uint, unsigned int, uint4, 4, vsum, tsum) + +#undef current_func +#define current_func(a, b) tsum(a, b) +FUNC_FUNC(sum, long, long) +FUNC_FUNC(sum, ulong, unsigned long) +FUNC_FUNC(sum, longlong, long long) +FUNC_FUNC(sum, ulonglong, unsigned long long) + +OPV_DISPATCH(sum, int8_t, int8_t) +OPV_DISPATCH(sum, uint8_t, uint8_t) +OPV_DISPATCH(sum, int16_t, int16_t) +OPV_DISPATCH(sum, uint16_t, uint16_t) +OPV_DISPATCH(sum, int32_t, int32_t) +OPV_DISPATCH(sum, uint32_t, uint32_t) +OPV_DISPATCH(sum, int64_t, int64_t) +OPV_DISPATCH(sum, uint64_t, uint64_t) + +OPV_FUNC(sum, float, float, float4, 4, +) +OPV_FUNC(sum, double, double, double4, 4, +) +OP_FUNC(sum, long_double, long double, +) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hadd2(a, b) +//VFUNC_FUNC(sum, half, half, half2, 2, __hadd2, __hadd) +#endif // __CUDA_ARCH__ + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +#undef current_func +#define current_func(a, b) tprod(a, b) +FUNC_FUNC(prod, char, char) +FUNC_FUNC(prod, uchar, unsigned char) +FUNC_FUNC(prod, short, short) +FUNC_FUNC(prod, ushort, unsigned short) +FUNC_FUNC(prod, int, int) +FUNC_FUNC(prod, uint, unsigned int) +FUNC_FUNC(prod, long, long) +FUNC_FUNC(prod, ulong, unsigned long) +FUNC_FUNC(prod, longlong, long long) +FUNC_FUNC(prod, ulonglong, unsigned long long) + +OPV_DISPATCH(prod, int8_t, int8_t) +OPV_DISPATCH(prod, uint8_t, uint8_t) +OPV_DISPATCH(prod, int16_t, int16_t) +OPV_DISPATCH(prod, uint16_t, uint16_t) +OPV_DISPATCH(prod, int32_t, int32_t) +OPV_DISPATCH(prod, uint32_t, uint32_t) +OPV_DISPATCH(prod, int64_t, int64_t) +OPV_DISPATCH(prod, uint64_t, uint64_t) + + +OPV_FUNC(prod, float, float, float4, 4, *) +OPV_FUNC(prod, double, double, double4, 4, *) +OP_FUNC(prod, long_double, long double, *) + +// __CUDA_ARCH__ is only defined when compiling device code +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#undef current_func +#define current_func(a, b) __hmul2(a, b) +//VFUNC_FUNC(prod, half, half, half2, 2, __hmul2, __hmul) +#endif // __CUDA_ARCH__ + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex, *=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex, *=) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float, >) +LOC_FUNC(maxloc, 2double, >) +LOC_FUNC(maxloc, 2int8, >) +LOC_FUNC(maxloc, 2int16, >) +LOC_FUNC(maxloc, 2int32, >) +LOC_FUNC(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float, <) +LOC_FUNC(minloc, 2double, <) +LOC_FUNC(minloc, 2int8, <) +LOC_FUNC(minloc, 2int16, <) +LOC_FUNC(minloc, 2int32, <) +LOC_FUNC(minloc, 2int64, <) + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define OP_FUNC_3BUF(name, type_name, type, op) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = in1[i] op in2[i]; \ + } \ + } \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for (out = op(in1, in2)) + */ +#define FUNC_FUNC_3BUF(name, type_name, type) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = current_func(in1[i], in2[i]); \ + } \ + } \ + void \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name, op) \ + static __global__ void \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a1 = &in1[i]; \ + const ompi_op_predefined_##type_name##_t *a2 = &in2[i]; \ + ompi_op_predefined_##type_name##_t *b = &out[i]; \ + if (a1->v op a2->v) { \ + b->v = a1->v; \ + b->k = a1->k; \ + } else if (a1->v == a2->v) { \ + b->v = a1->v; \ + b->k = (a2->k < a1->k ? a2->k : a1->k); \ + } else { \ + b->v = a2->v; \ + b->k = a2->k; \ + } \ + } \ + } \ + void \ + ompi_op_cuda_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *in1, \ + const ompi_op_predefined_##type_name##_t *in2, \ + ompi_op_predefined_##type_name##_t *out, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream) \ + { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + ompi_op_cuda_3buff_##name##_##type_name##_kernel<<>>(in1, in2, out, count); \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, cuLongDoubleComplex, +) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(prod, short_float, short float, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(prod, short_float, opal_short_float_t, *) +#endif +OP_FUNC_3BUF(prod, float, float, *) +OP_FUNC_3BUF(prod, double, double, *) +OP_FUNC_3BUF(prod, long_double, long double, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +#undef current_func +#define current_func(a, b) (cuCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, cuFloatComplex) +#undef current_func +#define current_func(a, b) (cuCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float, >) +LOC_FUNC_3BUF(maxloc, 2double, >) +LOC_FUNC_3BUF(maxloc, 2int8, >) +LOC_FUNC_3BUF(maxloc, 2int16, >) +LOC_FUNC_3BUF(maxloc, 2int32, >) +LOC_FUNC_3BUF(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float, <) +LOC_FUNC_3BUF(minloc, 2double, <) +LOC_FUNC_3BUF(minloc, 2int8, <) +LOC_FUNC_3BUF(minloc, 2int16, <) +LOC_FUNC_3BUF(minloc, 2int32, <) +LOC_FUNC_3BUF(minloc, 2int64, <) diff --git a/ompi/mca/op/cuda/op_cuda_impl.h b/ompi/mca/op/cuda/op_cuda_impl.h new file mode 100644 index 00000000000..43209581bab --- /dev/null +++ b/ompi/mca/op/cuda/op_cuda_impl.h @@ -0,0 +1,695 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include + +#include +#include +#include + +#ifndef BEGIN_C_DECLS +#if defined(c_plusplus) || defined(__cplusplus) +# define BEGIN_C_DECLS extern "C" { +# define END_C_DECLS } +#else +# define BEGIN_C_DECLS /* empty */ +# define END_C_DECLS /* empty */ +#endif +#endif + +BEGIN_C_DECLS + +#define OP_FUNC_SIG(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define FUNC_FUNC_SIG(name, type_name, type) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; + +#define LOC_FUNC_SIG(name, type_name) \ + void ompi_op_cuda_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(max, int8_t, int8_t) +FUNC_FUNC_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_SIG(max, int16_t, int16_t) +FUNC_FUNC_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_SIG(max, int32_t, int32_t) +FUNC_FUNC_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_SIG(max, int64_t, int64_t) +FUNC_FUNC_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_SIG(max, long, long) +FUNC_FUNC_SIG(max, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(max, float, float) +FUNC_FUNC_SIG(max, double, double) +FUNC_FUNC_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(min, int8_t, int8_t) +FUNC_FUNC_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_SIG(min, int16_t, int16_t) +FUNC_FUNC_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_SIG(min, int32_t, int32_t) +FUNC_FUNC_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_SIG(min, int64_t, int64_t) +FUNC_FUNC_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_SIG(min, long, long) +FUNC_FUNC_SIG(min, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(min, float, float) +FUNC_FUNC_SIG(min, double, double) +FUNC_FUNC_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(sum, int8_t, int8_t) +OP_FUNC_SIG(sum, uint8_t, uint8_t) +OP_FUNC_SIG(sum, int16_t, int16_t) +OP_FUNC_SIG(sum, uint16_t, uint16_t) +OP_FUNC_SIG(sum, int32_t, int32_t) +OP_FUNC_SIG(sum, uint32_t, uint32_t) +OP_FUNC_SIG(sum, int64_t, int64_t) +OP_FUNC_SIG(sum, uint64_t, uint64_t) +OP_FUNC_SIG(sum, long, long) +OP_FUNC_SIG(sum, ulong, unsigned long) + +//#if __CUDA_ARCH__ >= 530 +//OP_FUNC_SIG(sum, half, half) +//#endif // __CUDA_ARCH__ + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(sum, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(sum, float, float) +OP_FUNC_SIG(sum, double, double) +OP_FUNC_SIG(sum, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +OP_FUNC_SIG(sum, c_long_double_complex, long double _Complex) +#endif +#endif // 0 +FUNC_FUNC_SIG(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_SIG(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(prod, int8_t, int8_t) +OP_FUNC_SIG(prod, uint8_t, uint8_t) +OP_FUNC_SIG(prod, int16_t, int16_t) +OP_FUNC_SIG(prod, uint16_t, uint16_t) +OP_FUNC_SIG(prod, int32_t, int32_t) +OP_FUNC_SIG(prod, uint32_t, uint32_t) +OP_FUNC_SIG(prod, int64_t, int64_t) +OP_FUNC_SIG(prod, uint64_t, uint64_t) +OP_FUNC_SIG(prod, long, long) +OP_FUNC_SIG(prod, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(prod, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(prod, float, float) +OP_FUNC_SIG(prod, float, float) +OP_FUNC_SIG(prod, double, double) +OP_FUNC_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_SIG(prod, c_long_double_complex, long double _Complex) +#endif // 0 + +FUNC_FUNC_SIG(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_SIG(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(land, int8_t, int8_t) +FUNC_FUNC_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_SIG(land, int16_t, int16_t) +FUNC_FUNC_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_SIG(land, int32_t, int32_t) +FUNC_FUNC_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_SIG(land, int64_t, int64_t) +FUNC_FUNC_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_SIG(land, long, long) +FUNC_FUNC_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(lor, int8_t, int8_t) +FUNC_FUNC_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lor, int16_t, int16_t) +FUNC_FUNC_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lor, int32_t, int32_t) +FUNC_FUNC_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lor, int64_t, int64_t) +FUNC_FUNC_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lor, long, long) +FUNC_FUNC_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lxor, long, long) +FUNC_FUNC_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(band, int8_t, int8_t) +FUNC_FUNC_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_SIG(band, int16_t, int16_t) +FUNC_FUNC_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_SIG(band, int32_t, int32_t) +FUNC_FUNC_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_SIG(band, int64_t, int64_t) +FUNC_FUNC_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_SIG(band, long, long) +FUNC_FUNC_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(bor, int8_t, int8_t) +FUNC_FUNC_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bor, int16_t, int16_t) +FUNC_FUNC_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bor, int32_t, int32_t) +FUNC_FUNC_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bor, int64_t, int64_t) +FUNC_FUNC_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bor, long, long) +FUNC_FUNC_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bxor, long, long) +FUNC_FUNC_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bxor, byte, char) + +/************************************************************************* + * Min and max location "pair" datatypes + *************************************************************************/ + +LOC_STRUCT(float_int, float, int) +LOC_STRUCT(double_int, double, int) +LOC_STRUCT(long_int, long, int) +LOC_STRUCT(2int, int, int) +LOC_STRUCT(short_int, short, int) +LOC_STRUCT(long_double_int, long double, int) +LOC_STRUCT(ulong, unsigned long, int) +/* compat types for Fortran */ +LOC_STRUCT(2float, float, float) +LOC_STRUCT(2double, double, double) +LOC_STRUCT(2int8, int8_t, int8_t) +LOC_STRUCT(2int16, int16_t, int16_t) +LOC_STRUCT(2int32, int32_t, int32_t) +LOC_STRUCT(2int64, int64_t, int64_t) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_SIG(maxloc, 2float) +LOC_FUNC_SIG(maxloc, 2double) +LOC_FUNC_SIG(maxloc, 2int8) +LOC_FUNC_SIG(maxloc, 2int16) +LOC_FUNC_SIG(maxloc, 2int32) +LOC_FUNC_SIG(maxloc, 2int64) + +LOC_FUNC_SIG(maxloc, float_int) +LOC_FUNC_SIG(maxloc, double_int) +LOC_FUNC_SIG(maxloc, long_int) +LOC_FUNC_SIG(maxloc, 2int) +LOC_FUNC_SIG(maxloc, short_int) +LOC_FUNC_SIG(maxloc, long_double_int) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_SIG(minloc, 2float) +LOC_FUNC_SIG(minloc, 2double) +LOC_FUNC_SIG(minloc, 2int8) +LOC_FUNC_SIG(minloc, 2int16) +LOC_FUNC_SIG(minloc, 2int32) +LOC_FUNC_SIG(minloc, 2int64) + +LOC_FUNC_SIG(minloc, float_int) +LOC_FUNC_SIG(minloc, double_int) +LOC_FUNC_SIG(minloc, long_int) +LOC_FUNC_SIG(minloc, 2int) +LOC_FUNC_SIG(minloc, short_int) +LOC_FUNC_SIG(minloc, long_double_int) + + + +#define OP_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define FUNC_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + +#define LOC_FUNC_3BUF_SIG(name, type_name) \ + void ompi_op_cuda_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a1, \ + const ompi_op_predefined_##type_name##_t *a2, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + CUstream stream); + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(max, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(max, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(max, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(max, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(max, long, long) +FUNC_FUNC_3BUF_SIG(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(max, float, float) +FUNC_FUNC_3BUF_SIG(max, double, double) +FUNC_FUNC_3BUF_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(min, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(min, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(min, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(min, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(min, long, long) +FUNC_FUNC_3BUF_SIG(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF_SIG(min, float, float) +FUNC_FUNC_3BUF_SIG(min, double, double) +FUNC_FUNC_3BUF_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(sum, int8_t, int8_t) +OP_FUNC_3BUF_SIG(sum, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(sum, int16_t, int16_t) +OP_FUNC_3BUF_SIG(sum, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(sum, int32_t, int32_t) +OP_FUNC_3BUF_SIG(sum, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(sum, int64_t, int64_t) +OP_FUNC_3BUF_SIG(sum, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(sum, long, long) +OP_FUNC_3BUF_SIG(sum, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(sum, short_float, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(sum, float, float) +OP_FUNC_3BUF_SIG(sum, double, double) +OP_FUNC_3BUF_SIG(sum, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(sum, c_long_double_complex, long double _Complex) +#endif // 0 +FUNC_FUNC_3BUF_SIG(sum, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF_SIG(sum, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(prod, int8_t, int8_t) +OP_FUNC_3BUF_SIG(prod, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(prod, int16_t, int16_t) +OP_FUNC_3BUF_SIG(prod, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(prod, int32_t, int32_t) +OP_FUNC_3BUF_SIG(prod, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(prod, int64_t, int64_t) +OP_FUNC_3BUF_SIG(prod, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(prod, long, long) +OP_FUNC_3BUF_SIG(prod, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(prod, short_float, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, float, float) +OP_FUNC_3BUF_SIG(prod, double, double) +OP_FUNC_3BUF_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, c_float_complex, float _Complex) +OP_FUNC_3BUF_SIG(prod, c_double_complex, double _Complex) +OP_FUNC_3BUF_SIG(prod, c_long_double_complex, long double _Complex) +#endif // 0 +FUNC_FUNC_3BUF_SIG(prod, c_float_complex, cuFloatComplex) +FUNC_FUNC_3BUF_SIG(prod, c_double_complex, cuDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(land, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(land, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(land, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(land, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(land, long, long) +FUNC_FUNC_3BUF_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lor, long, long) +FUNC_FUNC_3BUF_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lxor, long, long) +FUNC_FUNC_3BUF_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(band, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(band, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(band, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(band, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(band, long, long) +FUNC_FUNC_3BUF_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bor, long, long) +FUNC_FUNC_3BUF_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bxor, long, long) +FUNC_FUNC_3BUF_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(maxloc, float_int) +LOC_FUNC_3BUF_SIG(maxloc, double_int) +LOC_FUNC_3BUF_SIG(maxloc, long_int) +LOC_FUNC_3BUF_SIG(maxloc, 2int) +LOC_FUNC_3BUF_SIG(maxloc, short_int) +LOC_FUNC_3BUF_SIG(maxloc, long_double_int) + +LOC_FUNC_3BUF_SIG(maxloc, 2float) +LOC_FUNC_3BUF_SIG(maxloc, 2double) +LOC_FUNC_3BUF_SIG(maxloc, 2int8) +LOC_FUNC_3BUF_SIG(maxloc, 2int16) +LOC_FUNC_3BUF_SIG(maxloc, 2int32) +LOC_FUNC_3BUF_SIG(maxloc, 2int64) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(minloc, float_int) +LOC_FUNC_3BUF_SIG(minloc, double_int) +LOC_FUNC_3BUF_SIG(minloc, long_int) +LOC_FUNC_3BUF_SIG(minloc, 2int) +LOC_FUNC_3BUF_SIG(minloc, short_int) +LOC_FUNC_3BUF_SIG(minloc, long_double_int) + +LOC_FUNC_3BUF_SIG(minloc, 2float) +LOC_FUNC_3BUF_SIG(minloc, 2double) +LOC_FUNC_3BUF_SIG(minloc, 2int8) +LOC_FUNC_3BUF_SIG(minloc, 2int16) +LOC_FUNC_3BUF_SIG(minloc, 2int32) +LOC_FUNC_3BUF_SIG(minloc, 2int64) + + +END_C_DECLS diff --git a/ompi/mca/op/op.h b/ompi/mca/op/op.h index 34d26376ab9..097c2a109b4 100644 --- a/ompi/mca/op/op.h +++ b/ompi/mca/op/op.h @@ -85,6 +85,7 @@ #include "ompi_config.h" #include "opal/class/opal_object.h" +#include "opal/mca/accelerator/accelerator.h" #include "ompi/mca/mca.h" /* @@ -266,6 +267,22 @@ typedef void (*ompi_op_base_handler_fn_1_0_0_t)(const void *, void *, int *, typedef ompi_op_base_handler_fn_1_0_0_t ompi_op_base_handler_fn_t; +/** + * Typedef for 2-buffer op functions on streams/devices. + * + * We don't use MPI_User_function because this would create a + * confusing dependency loop between this file and mpi.h. So this is + * repeated code, but it's better this way (and this typedef will + * never change, so there's not much of a maintenance worry). + */ +typedef void (*ompi_op_base_stream_handler_fn_1_0_0_t)(const void *, void *, int *, + struct ompi_datatype_t **, + int device, + opal_accelerator_stream_t *stream, + struct ompi_op_base_module_1_0_0_t *); + +typedef ompi_op_base_stream_handler_fn_1_0_0_t ompi_op_base_stream_handler_fn_t; + /* * Typedef for 3-buffer (two input and one output) op functions. */ @@ -277,6 +294,19 @@ typedef void (*ompi_op_base_3buff_handler_fn_1_0_0_t)(const void *, typedef ompi_op_base_3buff_handler_fn_1_0_0_t ompi_op_base_3buff_handler_fn_t; +/* + * Typedef for 3-buffer (two input and one output) op functions on streams. + */ +typedef void (*ompi_op_base_3buff_stream_handler_fn_1_0_0_t)(const void *, + const void *, + void *, int *, + struct ompi_datatype_t **, + int device, + opal_accelerator_stream_t*, + struct ompi_op_base_module_1_0_0_t *); + +typedef ompi_op_base_3buff_stream_handler_fn_1_0_0_t ompi_op_base_3buff_stream_handler_fn_t; + /** * Op component initialization * @@ -376,10 +406,18 @@ typedef struct ompi_op_base_module_1_0_0_t { is being used for */ struct ompi_op_t *opm_op; + bool opm_device_enabled; + /** Function pointers for all the different datatypes to be used with the MPI_Op that this module is used with */ - ompi_op_base_handler_fn_1_0_0_t opm_fns[OMPI_OP_BASE_TYPE_MAX]; - ompi_op_base_3buff_handler_fn_1_0_0_t opm_3buff_fns[OMPI_OP_BASE_TYPE_MAX]; + union { + ompi_op_base_handler_fn_1_0_0_t opm_fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_stream_handler_fn_1_0_0_t opm_stream_fns[OMPI_OP_BASE_TYPE_MAX]; + }; + union { + ompi_op_base_3buff_handler_fn_1_0_0_t opm_3buff_fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_3buff_stream_handler_fn_1_0_0_t opm_3buff_stream_fns[OMPI_OP_BASE_TYPE_MAX]; + }; } ompi_op_base_module_1_0_0_t; /** @@ -404,6 +442,18 @@ typedef struct ompi_op_base_op_fns_1_0_0_t { typedef ompi_op_base_op_fns_1_0_0_t ompi_op_base_op_fns_t; +/** + * Struct that is used in op.h to hold all the function pointers and + * pointers to the corresopnding modules (so that we can properly + * RETAIN/RELEASE them) + */ +typedef struct ompi_op_base_op_stream_fns_1_0_0_t { + ompi_op_base_stream_handler_fn_1_0_0_t fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_module_t *modules[OMPI_OP_BASE_TYPE_MAX]; +} ompi_op_base_op_stream_fns_1_0_0_t; + +typedef ompi_op_base_op_stream_fns_1_0_0_t ompi_op_base_op_stream_fns_t; + /** * Struct that is used in op.h to hold all the function pointers and * pointers to the corresopnding modules (so that we can properly @@ -416,6 +466,18 @@ typedef struct ompi_op_base_op_3buff_fns_1_0_0_t { typedef ompi_op_base_op_3buff_fns_1_0_0_t ompi_op_base_op_3buff_fns_t; +/** + * Struct that is used in op.h to hold all the function pointers and + * pointers to the corresopnding modules (so that we can properly + * RETAIN/RELEASE them) + */ +typedef struct ompi_op_base_op_3buff_stream_fns_1_0_0_t { + ompi_op_base_3buff_stream_handler_fn_1_0_0_t fns[OMPI_OP_BASE_TYPE_MAX]; + ompi_op_base_module_t *modules[OMPI_OP_BASE_TYPE_MAX]; +} ompi_op_base_op_3buff_stream_fns_1_0_0_t; + +typedef ompi_op_base_op_3buff_stream_fns_1_0_0_t ompi_op_base_op_3buff_stream_fns_t; + /* * Macro for use in modules that are of type op v2.0.0 */ diff --git a/ompi/mca/op/rocm/Makefile.am b/ompi/mca/op/rocm/Makefile.am new file mode 100644 index 00000000000..a4d999e25f9 --- /dev/null +++ b/ompi/mca/op/rocm/Makefile.am @@ -0,0 +1,82 @@ +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# This component provides support for offloading reduce ops to ROCM devices. +# +# See https://github.com/open-mpi/ompi/wiki/devel-CreateComponent +# for more details on how to make Open MPI components. + +# First, list all .h and .c sources. It is necessary to list all .h +# files so that they will be picked up in the distribution tarball. + +AM_CPPFLAGS = $(op_rocm_CPPFLAGS) + +dist_ompidata_DATA = help-ompi-mca-op-rocm.txt + +sources = op_rocm_component.c op_rocm.h op_rocm_functions.c op_rocm_impl.h +rocm_sources = op_rocm_impl.hip + +HIPCC = hipcc + +.cpp.l$(OBJEXT): + $(LIBTOOL) $(AM_V_lt) --tag=CC $(AM_LIBTOOLFLAGS) \ + $(LIBTOOLFLAGS) --mode=compile $(HIPCC) -O2 -std=c++17 -fvectorize -prefer-non-pic -Wc,-fPIC,-g -c $< + +# -o $($@.o:.lo) + +# Open MPI components can be compiled two ways: +# +# 1. As a standalone dynamic shared object (DSO), sometimes called a +# dynamically loadable library (DLL). +# +# 2. As a static library that is slurped up into the upper-level +# libmpi library (regardless of whether libmpi is a static or dynamic +# library). This is called a "Libtool convenience library". +# +# The component needs to create an output library in this top-level +# component directory, and named either mca__.la (for DSO +# builds) or libmca__.la (for static builds). The OMPI +# build system will have set the +# MCA_BUILD_ompi___DSO AM_CONDITIONAL to indicate +# which way this component should be built. + +if MCA_BUILD_ompi_op_rocm_DSO +component_install = mca_op_rocm.la +else +component_install = +component_noinst = libmca_op_rocm.la +endif + +# Specific information for DSO builds. +# +# The DSO should install itself in $(ompilibdir) (by default, +# $prefix/lib/openmpi). + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_op_rocm_la_SOURCES = $(sources) +mca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) +mca_op_rocm_la_LDFLAGS = -module -avoid-version $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ + $(op_rocm_LIBS) +EXTRA_mca_op_rocm_la_SOURCES = $(rocm_sources) + +# Specific information for static builds. +# +# Note that we *must* "noinst"; the upper-layer Makefile.am's will +# slurp in the resulting .la library into libmpi. + +noinst_LTLIBRARIES = $(component_noinst) +libmca_op_rocm_la_SOURCES = $(sources) +libmca_op_rocm_la_LIBADD = $(rocm_sources:.cpp=.lo) +libmca_op_rocm_la_LDFLAGS = -module -avoid-version\ + $(op_rocm_LIBS) +EXTRA_libmca_op_rocm_la_SOURCES = $(rocm_sources) + diff --git a/ompi/mca/op/rocm/configure.m4 b/ompi/mca/op/rocm/configure.m4 new file mode 100644 index 00000000000..ffd88698be0 --- /dev/null +++ b/ompi/mca/op/rocm/configure.m4 @@ -0,0 +1,36 @@ +# -*- shell-script -*- +# +# Copyright (c) 2011-2013 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# Copyright (c) 2022 Amazon.com, Inc. or its affiliates. +# All Rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +# +# If ROCm support was requested, then build the ROCm support library. +# This code checks makes sure the check was done earlier by the +# opal_check_rocm.m4 code. It also copies the flags and libs under +# opal_rocm_CPPFLAGS, opal_rocm_LDFLAGS, and opal_rocm_LIBS + +AC_DEFUN([MCA_ompi_op_rocm_CONFIG],[ + + AC_CONFIG_FILES([ompi/mca/op/rocm/Makefile]) + + OPAL_CHECK_ROCM([op_rocm]) + + AS_IF([test "x$ROCM_SUPPORT" = "x1"], + [$1], + [$2]) + + AC_SUBST([op_rocm_CPPFLAGS]) + AC_SUBST([op_rocm_LDFLAGS]) + AC_SUBST([op_rocm_LIBS]) + +])dnl diff --git a/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt b/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt new file mode 100644 index 00000000000..848afbb663d --- /dev/null +++ b/ompi/mca/op/rocm/help-ompi-mca-op-rocm.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's HIP operator component +# +[HIP call failed] +"HIP call %s failed: %s: %s\n" diff --git a/ompi/mca/op/rocm/op_rocm.h b/ompi/mca/op/rocm/op_rocm.h new file mode 100644 index 00000000000..0dfeabf689b --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_OP_CUDA_EXPORT_H +#define MCA_OP_CUDA_EXPORT_H + +#include "ompi_config.h" + +#include "ompi/mca/mca.h" +#include "opal/class/opal_object.h" + +#include "ompi/mca/op/op.h" +#include "ompi/runtime/mpiruntime.h" + +#include +#include + +BEGIN_C_DECLS + + +#define xstr(x) #x +#define str(x) xstr(x) + +#define CHECK(fn, args) \ + do { \ + hipError_t err = fn args; \ + if (err != hipSuccess) { \ + opal_show_help("help-ompi-mca-op-rocm.txt", \ + "HIP call failed", true, \ + str(fn), hipGetErrorName(err), \ + hipGetErrorString(err)); \ + ompi_mpi_abort(MPI_COMM_WORLD, 1); \ + } \ + } while (0) + + +/** + * Derive a struct from the base op component struct, allowing us to + * cache some component-specific information on our well-known + * component struct. + */ +typedef struct { + /** The base op component struct */ + ompi_op_base_component_1_0_0_t super; + +#if 0 + /* a stream on which to schedule kernel calls */ + hipStream_t ro_stream; + hipCtx_t *ro_ctx; +#endif // 0 + int ro_max_num_blocks; + int ro_max_num_threads; + int *ro_max_threads_per_block; + int *ro_max_blocks; + hipDevice_t *ro_devices; + int ro_num_devices; +} ompi_op_rocm_component_t; + +/** + * Globally exported variable. Note that it is a *rocm* component + * (defined above), which has the ompi_op_base_component_t as its + * first member. Hence, the MCA/op framework will find the data that + * it expects in the first memory locations, but then the component + * itself can cache additional information after that that can be used + * by both the component and modules. + */ +OMPI_DECLSPEC extern ompi_op_rocm_component_t + mca_op_rocm_component; + +OMPI_DECLSPEC extern +ompi_op_base_stream_handler_fn_t ompi_op_rocm_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +OMPI_DECLSPEC extern +ompi_op_base_3buff_stream_handler_fn_t ompi_op_rocm_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX]; + +END_C_DECLS + +#endif /* MCA_OP_CUDA_EXPORT_H */ diff --git a/ompi/mca/op/rocm/op_rocm_component.c b/ompi/mca/op/rocm/op_rocm_component.c new file mode 100644 index 00000000000..e363bf94385 --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_component.c @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2021 Cisco Systems, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** @file + * + * This is the "rocm" op component source code. + * + */ + +#include "ompi_config.h" + +#include "opal/util/printf.h" + +#include "ompi/constants.h" +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/rocm/op_rocm.h" + +#include + +static int rocm_component_open(void); +static int rocm_component_close(void); +static int rocm_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple); +static struct ompi_op_base_module_1_0_0_t * + rocm_component_op_query(struct ompi_op_t *op, int *priority); +static int rocm_component_register(void); + +ompi_op_rocm_component_t mca_op_rocm_component = { + { + .opc_version = { + OMPI_OP_BASE_VERSION_1_0_0, + + .mca_component_name = "rocm", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + .mca_open_component = rocm_component_open, + .mca_close_component = rocm_component_close, + .mca_register_component_params = rocm_component_register, + }, + .opc_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + .opc_init_query = rocm_component_init_query, + .opc_op_query = rocm_component_op_query, + }, + .ro_max_num_blocks = -1, + .ro_max_num_threads = -1, + .ro_max_threads_per_block = NULL, + .ro_max_blocks = NULL, + .ro_devices = NULL, + .ro_num_devices = 0, +}; + +/* + * Component open + */ +static int rocm_component_open(void) +{ + /* We checked the flags during register, so if they are set to + * zero either the architecture is not suitable or the user disabled + * AVX support. + * + * A first level check to see what level of AVX is available on the + * hardware. + * + * Note that if this function returns non-OMPI_SUCCESS, then this + * component won't even be shown in ompi_info output (which is + * probably not what you want). + */ + return OMPI_SUCCESS; +} + +/* + * Component close + */ +static int rocm_component_close(void) +{ + if (mca_op_rocm_component.ro_num_devices > 0) { + //hipStreamDestroy(mca_op_rocm_component.ro_stream); + free(mca_op_rocm_component.ro_max_threads_per_block); + mca_op_rocm_component.ro_max_threads_per_block = NULL; + free(mca_op_rocm_component.ro_max_blocks); + mca_op_rocm_component.ro_max_blocks = NULL; + free(mca_op_rocm_component.ro_devices); + mca_op_rocm_component.ro_devices = NULL; + mca_op_rocm_component.ro_num_devices = 0; + } + + return OMPI_SUCCESS; +} + +/* + * Register MCA params. + */ +static int +rocm_component_register(void) +{ + /* TODO: add mca paramters */ + + mca_base_var_enum_flag_t *new_enum_flag = NULL; + (void) mca_base_component_var_register(&mca_op_rocm_component.super.opc_version, + "max_num_blocks", + "Maximum number of thread blocks in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_rocm_component.ro_max_num_blocks); + + (void) mca_base_component_var_register(&mca_op_rocm_component.super.opc_version, + "max_num_threads", + "Maximum number of threads per block in kernels (-1: device limit)", + MCA_BASE_VAR_TYPE_INT, + &(new_enum_flag->super), 0, 0, + OPAL_INFO_LVL_4, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_op_rocm_component.ro_max_num_threads); + + return OMPI_SUCCESS; +} + + +/* + * Query whether this component wants to be used in this process. + */ +static int +rocm_component_init_query(bool enable_progress_threads, + bool enable_mpi_thread_multiple) +{ + int num_devices; + int rc; + int prio_lo, prio_hi; + //memset(&mca_op_rocm_component, 0, sizeof(mca_op_rocm_component)); + hipInit(0); + CHECK(hipGetDeviceCount, (&num_devices)); + mca_op_rocm_component.ro_num_devices = num_devices; + mca_op_rocm_component.ro_devices = (hipDevice_t*)malloc(num_devices*sizeof(hipDevice_t)); + mca_op_rocm_component.ro_max_threads_per_block = (int*)malloc(num_devices*sizeof(int)); + mca_op_rocm_component.ro_max_blocks = (int*)malloc(num_devices*sizeof(int)); + for (int i = 0; i < num_devices; ++i) { + CHECK(hipDeviceGet, (&mca_op_rocm_component.ro_devices[i], i)); + rc = hipDeviceGetAttribute(&mca_op_rocm_component.ro_max_threads_per_block[i], + hipDeviceAttributeMaxBlockDimX, + mca_op_rocm_component.ro_devices[i]); + if (hipSuccess != rc) { + /* fall-back to value that should work on every device */ + mca_op_rocm_component.ro_max_threads_per_block[i] = 512; + } + if (-1 < mca_op_rocm_component.ro_max_num_threads) { + if (mca_op_rocm_component.ro_max_threads_per_block[i] > mca_op_rocm_component.ro_max_num_threads) { + mca_op_rocm_component.ro_max_threads_per_block[i] = mca_op_rocm_component.ro_max_num_threads; + } + } + + rc = hipDeviceGetAttribute(&mca_op_rocm_component.ro_max_blocks[i], + hipDeviceAttributeMaxGridDimX, + mca_op_rocm_component.ro_devices[i]); + if (hipSuccess != rc) { + /* we'll try to max out the blocks */ + mca_op_rocm_component.ro_max_blocks[i] = 512; + } + if (-1 < mca_op_rocm_component.ro_max_num_blocks) { + if (mca_op_rocm_component.ro_max_blocks[i] > mca_op_rocm_component.ro_max_num_blocks) { + mca_op_rocm_component.ro_max_blocks[i] = mca_op_rocm_component.ro_max_num_blocks; + } + } + } + +#if 0 + /* try to create a high-priority stream */ + rc = hipDeviceGetStreamPriorityRange(&prio_lo, &prio_hi); + if (hipSuccess != rc) { + hipStreamCreateWithPriority(&mca_op_rocm_component.ro_stream, hipStreamNonBlocking, prio_hi); + } else { + mca_op_rocm_component.ro_stream = 0; + } +#endif // 0 + return OMPI_SUCCESS; +} + +/* + * Query whether this component can be used for a specific op + */ +static struct ompi_op_base_module_1_0_0_t* +rocm_component_op_query(struct ompi_op_t *op, int *priority) +{ + ompi_op_base_module_t *module = NULL; + + module = OBJ_NEW(ompi_op_base_module_t); + module->opm_device_enabled = true; + for (int i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + module->opm_stream_fns[i] = ompi_op_rocm_functions[op->o_f_to_c_index][i]; + module->opm_3buff_stream_fns[i] = ompi_op_rocm_3buff_functions[op->o_f_to_c_index][i]; + + if( NULL != module->opm_fns[i] ) { + OBJ_RETAIN(module); + } + if( NULL != module->opm_3buff_fns[i] ) { + OBJ_RETAIN(module); + } + } + *priority = 50; + return (ompi_op_base_module_1_0_0_t *) module; +} diff --git a/ompi/mca/op/rocm/op_rocm_functions.c b/ompi/mca/op/rocm/op_rocm_functions.c new file mode 100644 index 00000000000..dc9a08b35db --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_functions.c @@ -0,0 +1,1908 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#ifdef HAVE_SYS_TYPES_H +#include +#endif +#include "opal/util/output.h" + + +#include "ompi/op/op.h" +#include "ompi/mca/op/op.h" +#include "ompi/mca/op/base/base.h" +#include "ompi/mca/op/rocm/op_rocm.h" +#include "opal/mca/accelerator/accelerator.h" + +#include "ompi/mca/op/rocm/op_rocm.h" +#include "ompi/mca/op/rocm/op_rocm_impl.h" + +/** + * Disable warning about empty macro var-args. + * We use varargs to suppress expansion of typenames + * (e.g., int32_t -> int) which could lead to collisions + * for similar base types. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" + +static inline void device_op_pre(const void *orig_source1, + void **source1, + int *source1_device, + const void *orig_source2, + void **source2, + int *source2_device, + void *orig_target, + void **target, + int *target_device, + int count, + struct ompi_datatype_t *dtype, + int *threads_per_block, + int *max_blocks, + int *device, + opal_accelerator_stream_t *stream) +{ + uint64_t target_flags = -1, source1_flags = -1, source2_flags = -1; + int target_rc, source1_rc, source2_rc = -1; + + *target = orig_target; + *source1 = (void*)orig_source1; + if (NULL != orig_source2) { + *source2 = (void*)orig_source2; + } + + if (*device != MCA_ACCELERATOR_NO_DEVICE_ID) { + /* we got the device from the caller, just adjust the output parameters */ + *target_device = *device; + *source1_device = *device; + if (NULL != source2_device) { + *source2_device = *device; + } + } else { + + target_rc = opal_accelerator.check_addr(*target, target_device, &target_flags); + source1_rc = opal_accelerator.check_addr(*source1, source1_device, &source1_flags); + *device = *target_device; + + if (NULL != orig_source2) { + source2_rc = opal_accelerator.check_addr(*source2, source2_device, &source2_flags); + } + + if (0 == target_rc && 0 == source1_rc && 0 == source2_rc) { + /* no buffers are on any device, select device 0 */ + *device = 0; + } else if (*target_device == -1) { + if (*source1_device == -1 && NULL != orig_source2) { + *device = *source2_device; + } else { + *device = *source1_device; + } + } + + if (0 == target_rc || 0 == source1_rc || *target_device != *source1_device) { + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + if (0 == target_rc) { + // allocate memory on the device for the target buffer + opal_accelerator.mem_alloc_stream(*device, target, nbytes, stream); + CHECK(hipMemcpyHtoDAsync, ((hipDeviceptr_t)*target, orig_target, nbytes, *(hipStream_t*)stream->stream)); + *target_device = -1; // mark target device as host + } + + if (0 == source1_rc || *device != *source1_device) { + // allocate memory on the device for the source buffer + opal_accelerator.mem_alloc_stream(*device, source1, nbytes, stream); + if (0 == source1_rc) { + /* copy from host to device */ + CHECK(hipMemcpyHtoDAsync, ((hipDeviceptr_t)*source1, (void*)orig_source1, nbytes, *(hipStream_t*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + CHECK(hipMemcpyDtoDAsync, ((hipDeviceptr_t)*source1, (hipDeviceptr_t)orig_source1, nbytes, *(hipStream_t*)stream->stream)); + } + } + + } + if (NULL != source2_device && *target_device != *source2_device) { + // allocate memory on the device for the source buffer + //printf("allocating source on device %d\n", *device); + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + opal_accelerator.mem_alloc_stream(*device, source2, nbytes, stream); + if (0 == source2_rc) { + /* copy from host to device */ + //printf("copying source from host to device %d\n", *device); + CHECK(hipMemcpyHtoDAsync, ((hipDeviceptr_t)*source2, (void*)orig_source2, nbytes, *(hipStream_t*)stream->stream)); + } else { + /* copy from one device to another device */ + /* TODO: does this actually work? Can we enable P2P? */ + //printf("attempting cross-device copy for source\n"); + CHECK(hipMemcpyDtoDAsync, ((hipDeviceptr_t)*source2, (hipDeviceptr_t)orig_source2, nbytes, *(hipStream_t*)stream->stream)); + } + } + } + + *threads_per_block = mca_op_rocm_component.ro_max_threads_per_block[*device]; + *max_blocks = mca_op_rocm_component.ro_max_blocks[*device]; + +} + +static inline void device_op_post(void *source1, + int source1_device, + void *source2, + int source2_device, + void *orig_target, + void *target, + int target_device, + int count, + struct ompi_datatype_t *dtype, + int device, + opal_accelerator_stream_t *stream) +{ + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + + size_t nbytes; + ompi_datatype_type_size(dtype, &nbytes); + nbytes *= count; + + CHECK(hipMemcpyDtoHAsync, (orig_target, (hipDeviceptr_t)target, nbytes, *(hipStream_t *)stream->stream)); + } + + if (MCA_ACCELERATOR_NO_DEVICE_ID == target_device) { + opal_accelerator.mem_release_stream(device, target, stream); + } + if (source1_device != device) { + opal_accelerator.mem_release_stream(device, source1, stream); + } + if (NULL != source2 && source2_device != device) { + opal_accelerator.mem_release_stream(device, source2, stream); + } +} + +#define FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) __opal_attribute_unused__; \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source_device, target_device; \ + type *source, *target; \ + int n = *count; \ + device_op_pre(in, (void**)&source, &source_device, NULL, NULL, NULL, \ + inout, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + hipStream_t *custream = (hipStream_t*)stream->stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(source, target, n, threads_per_block, max_blocks, *custream);\ + device_op_post(source, source_device, NULL, -1, inout, target, target_device, n, *dtype, device, stream); \ + } + +/* concatenate type_name and type to avoid expansion (e.g., int32_t -> int) */ +#define OP_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above */ +#define FUNC_FUNC(name, type_name, type, ...) FUNC(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC(name, type_name) FUNC(name, type_name, ompi_op_predefined_##type_name##_t) + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), "Unsupported integer size (<1B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_2buff_##name##_int8_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_2buff_##name##_int16_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_2buff_##name##_int32_t(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_2buff_##name##_int64_t(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), "Unsupported float size (<4B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_2buff_##name##_float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_2buff_##name##_double(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_rocm_2buff_##name##_long_double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +#define FORT_LOC_INT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_2buff_##name##_2int8(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_2buff_##name##_2int16(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_2buff_##name##_2int32(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_2buff_##name##_2int64(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC(name, type_name, type) \ + static \ + void ompi_op_rocm_2buff_##name##_##type_name(const void *in, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_2buff_##name##_2float(in, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_2buff_##name##_2double(in, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(max, int8_t, int8_t) +FUNC_FUNC(max, uint8_t, uint8_t) +FUNC_FUNC(max, int16_t, int16_t) +FUNC_FUNC(max, uint16_t, uint16_t) +FUNC_FUNC(max, int32_t, int32_t) +FUNC_FUNC(max, uint32_t, uint32_t) +FUNC_FUNC(max, int64_t, int64_t) +FUNC_FUNC(max, uint64_t, uint64_t) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(max, fortran_integer16, ompi_fortran_integer16_t) +#endif + +FUNC_FUNC(max, float, float) +FUNC_FUNC(max, double, double) +FUNC_FUNC(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC(min, int8_t, int8_t) +FUNC_FUNC(min, uint8_t, uint8_t) +FUNC_FUNC(min, int16_t, int16_t) +FUNC_FUNC(min, uint16_t, uint16_t) +FUNC_FUNC(min, int32_t, int32_t) +FUNC_FUNC(min, uint32_t, uint32_t) +FUNC_FUNC(min, int64_t, int64_t) +FUNC_FUNC(min, uint64_t, uint64_t) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(min, fortran_integer16, ompi_fortran_integer16_t) +#endif + +FUNC_FUNC(min, float, float) +FUNC_FUNC(min, double, double) +FUNC_FUNC(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC(sum, int8_t, int8_t) +OP_FUNC(sum, uint8_t, uint8_t) +OP_FUNC(sum, int16_t, int16_t) +OP_FUNC(sum, uint16_t, uint16_t) +OP_FUNC(sum, int32_t, int32_t) +OP_FUNC(sum, uint32_t, uint32_t) +OP_FUNC(sum, int64_t, int64_t) +OP_FUNC(sum, uint64_t, uint64_t) +OP_FUNC(sum, long, long) +OP_FUNC(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif + +OP_FUNC(sum, float, float) +OP_FUNC(sum, double, double) +OP_FUNC(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC(prod, int8_t, int8_t) +OP_FUNC(prod, uint8_t, uint8_t) +OP_FUNC(prod, int16_t, int16_t) +OP_FUNC(prod, uint16_t, uint16_t) +OP_FUNC(prod, int32_t, int32_t) +OP_FUNC(prod, uint32_t, uint32_t) +OP_FUNC(prod, int64_t, int64_t) +OP_FUNC(prod, uint64_t, uint64_t) +OP_FUNC(prod, long, long) +OP_FUNC(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ + +OP_FUNC(prod, float, float) +OP_FUNC(prod, double, double) +OP_FUNC(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC(prod, fortran_real16, ompi_fortran_real16_t) +#endif + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(prod, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int) +LOC_FUNC(maxloc, double_int) +LOC_FUNC(maxloc, long_int) +LOC_FUNC(maxloc, 2int) +LOC_FUNC(maxloc, short_int) +LOC_FUNC(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float) +LOC_FUNC(maxloc, 2double) +LOC_FUNC(maxloc, 2int8) +LOC_FUNC(maxloc, 2int16) +LOC_FUNC(maxloc, 2int32) +LOC_FUNC(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int) +LOC_FUNC(minloc, double_int) +LOC_FUNC(minloc, long_int) +LOC_FUNC(minloc, 2int) +LOC_FUNC(minloc, short_int) +LOC_FUNC(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float) +LOC_FUNC(minloc, 2double) +LOC_FUNC(minloc, 2int8) +LOC_FUNC(minloc, 2int16) +LOC_FUNC(minloc, 2int32) +LOC_FUNC(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + int threads_per_block, max_blocks; \ + int source1_device, source2_device, target_device; \ + type *source1, *source2, *target; \ + int n = *count; \ + device_op_pre(in1, (void**)&source1, &source1_device, \ + in2, (void**)&source2, &source2_device, \ + out, (void**)&target, &target_device, \ + n, *dtype, \ + &threads_per_block, &max_blocks, &device, stream); \ + hipStream_t *hipstream = (hipStream_t*)stream->stream; \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(source1, source2, target, n, threads_per_block, max_blocks, *hipstream);\ + device_op_post(source1, source1_device, source2, source2_device, out, target, target_device, n, *dtype, device, stream);\ + } + + +#define OP_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* reuse the macro above, no work is actually done so we don't care about the func */ +#define FUNC_FUNC_3BUF(name, type_name, type, ...) FUNC_3BUF(name, __VA_ARGS__##type_name, __VA_ARGS__##type) + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_FUNC_3BUF(name, type_name) FUNC_3BUF(name, type_name, ompi_op_predefined_##type_name##_t) + + +/* Dispatch Fortran types to C types */ +#define FORT_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), "Unsupported integer size (<1B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_3buff_##name##_int8_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_3buff_##name##_int16_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_3buff_##name##_int32_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_3buff_##name##_int64_t(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +/* Dispatch Fortran types to C types */ +#define FORT_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *out, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(long double), "Unsupported float size (<4B, >8B)"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_3buff_##name##_float(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_3buff_##name##_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + case sizeof(long double): \ + ompi_op_rocm_3buff_##name##_long_double(in1, in2, out, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_INT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + \ + _Static_assert(sizeof(type) >= sizeof(int8_t) && sizeof(type) <= sizeof(int64_t), \ + "Unsupported Fortran integer size"); \ + switch(sizeof(type)) { \ + case sizeof(int8_t): \ + ompi_op_rocm_3buff_##name##_2int8(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int16_t): \ + ompi_op_rocm_3buff_##name##_2int16(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int32_t): \ + ompi_op_rocm_3buff_##name##_2int32(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(int64_t): \ + ompi_op_rocm_3buff_##name##_2int64(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + +#define FORT_LOC_FLOAT_FUNC_3BUF(name, type_name, type) \ + static \ + void ompi_op_rocm_3buff_##name##_##type_name(const void *in1, const void *in2, void *inout, int *count, \ + struct ompi_datatype_t **dtype, \ + int device, \ + opal_accelerator_stream_t *stream, \ + struct ompi_op_base_module_1_0_0_t *module) { \ + _Static_assert(sizeof(type) >= sizeof(float) && sizeof(type) <= sizeof(double), \ + "Unsupported Fortran float size"); \ + switch(sizeof(type)) { \ + case sizeof(float): \ + ompi_op_rocm_3buff_##name##_2float(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + case sizeof(double): \ + ompi_op_rocm_3buff_##name##_2double(in1, in2, inout, count, dtype, device, stream, module); \ + break; \ + } \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(max, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(max, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(max, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(max, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(max, fortran_integer8, ompi_fortran_integer8_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(max, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(max, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(max, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(max, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(max, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(max, fortran_real16, ompi_fortran_real16_t) +#endif + + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(min, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(min, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(min, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(min, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(min, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(min, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(min, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(min, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(min, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(min, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(min, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(min, fortran_real16, ompi_fortran_real16_t) +#endif + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t) +OP_FUNC_3BUF(sum, uint8_t, uint8_t) +OP_FUNC_3BUF(sum, int16_t, int16_t) +OP_FUNC_3BUF(sum, uint16_t, uint16_t) +OP_FUNC_3BUF(sum, int32_t, int32_t) +OP_FUNC_3BUF(sum, uint32_t, uint32_t) +OP_FUNC_3BUF(sum, int64_t, int64_t) +OP_FUNC_3BUF(sum, uint64_t, uint64_t) +OP_FUNC_3BUF(sum, long, long) +OP_FUNC_3BUF(sum, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(sum, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(sum, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(sum, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(sum, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(sum, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(sum, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(sum, float, float) +OP_FUNC_3BUF(sum, double, double) +OP_FUNC_3BUF(sum, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(sum, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(sum, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(sum, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(sum, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t) +OP_FUNC_3BUF(prod, uint8_t, uint8_t) +OP_FUNC_3BUF(prod, int16_t, int16_t) +OP_FUNC_3BUF(prod, uint16_t, uint16_t) +OP_FUNC_3BUF(prod, int32_t, int32_t) +OP_FUNC_3BUF(prod, uint32_t, uint32_t) +OP_FUNC_3BUF(prod, int64_t, int64_t) +OP_FUNC_3BUF(prod, uint64_t, uint64_t) +OP_FUNC_3BUF(prod, long, long) +OP_FUNC_3BUF(prod, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(prod, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(prod, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(prod, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(prod, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(prod, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(prod, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FORT_FLOAT_FUNC_3BUF(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FORT_FLOAT_FUNC_3BUF(prod, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF(prod, float, float) +OP_FUNC_3BUF(prod, double, double) +OP_FUNC_3BUF(prod, long_double, long double) +#if OMPI_HAVE_FORTRAN_REAL +FORT_FLOAT_FUNC_3BUF(prod, fortran_real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_FLOAT_FUNC_3BUF(prod, fortran_double_precision, ompi_fortran_double_precision_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real2, ompi_fortran_real2_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real4, ompi_fortran_real4_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +FORT_FLOAT_FUNC_3BUF(prod, fortran_real8, ompi_fortran_real8_t) +#endif +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +FORT_FLOAT_FUNC_3BUF(prod, fortran_real16, ompi_fortran_real16_t) +#endif +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(land, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* Logical */ +#if OMPI_HAVE_FORTRAN_LOGICAL +FORT_INT_FUNC_3BUF(lxor, fortran_logical, ompi_fortran_logical_t) +#endif +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(band, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(band, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(band, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(band, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(band, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(band, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_INT_FUNC_3BUF(bxor, fortran_integer, ompi_fortran_integer_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +FORT_INT_FUNC_3BUF(bxor, fortran_integer1, ompi_fortran_integer1_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +FORT_INT_FUNC_3BUF(bxor, fortran_integer2, ompi_fortran_integer2_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +FORT_INT_FUNC_3BUF(bxor, fortran_integer4, ompi_fortran_integer4_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +FORT_INT_FUNC_3BUF(bxor, fortran_integer8, ompi_fortran_integer8_t) +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +FORT_INT_FUNC_3BUF(bxor, fortran_integer16, ompi_fortran_integer16_t) +#endif +/* Byte */ +FORT_INT_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int) +LOC_FUNC_3BUF(maxloc, double_int) +LOC_FUNC_3BUF(maxloc, long_int) +LOC_FUNC_3BUF(maxloc, 2int) +LOC_FUNC_3BUF(maxloc, short_int) +LOC_FUNC_3BUF(maxloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float) +LOC_FUNC_3BUF(maxloc, 2double) +LOC_FUNC_3BUF(maxloc, 2int8) +LOC_FUNC_3BUF(maxloc, 2int16) +LOC_FUNC_3BUF(maxloc, 2int32) +LOC_FUNC_3BUF(maxloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(maxloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(maxloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int) +LOC_FUNC_3BUF(minloc, double_int) +LOC_FUNC_3BUF(minloc, long_int) +LOC_FUNC_3BUF(minloc, 2int) +LOC_FUNC_3BUF(minloc, short_int) +LOC_FUNC_3BUF(minloc, long_double_int) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float) +LOC_FUNC_3BUF(minloc, 2double) +LOC_FUNC_3BUF(minloc, 2int8) +LOC_FUNC_3BUF(minloc, 2int16) +LOC_FUNC_3BUF(minloc, 2int32) +LOC_FUNC_3BUF(minloc, 2int64) + +/* Fortran integer */ +#if OMPI_HAVE_FORTRAN_INTEGER +FORT_LOC_INT_FUNC_3BUF(minloc, 2integer, ompi_fortran_integer_t) +#endif +/* Fortran float */ +#if OMPI_HAVE_FORTRAN_REAL +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2real, ompi_fortran_real_t) +#endif +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +FORT_LOC_FLOAT_FUNC_3BUF(minloc, 2double_precision, ompi_fortran_double_precision_t) +#endif + + +/* + * Helpful defines, because there's soooo many names! + * + * **NOTE** These #define's used to be strictly ordered but the use of + * designated initializers removed this restrictions. When adding new + * operators ALWAYS use a designated initializer! + */ + +/** C integer ***********************************************************/ +#define C_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INT8_T] = ompi_op_rocm_##ftype##_##name##_int8_t, \ + [OMPI_OP_BASE_TYPE_UINT8_T] = ompi_op_rocm_##ftype##_##name##_uint8_t, \ + [OMPI_OP_BASE_TYPE_INT16_T] = ompi_op_rocm_##ftype##_##name##_int16_t, \ + [OMPI_OP_BASE_TYPE_UINT16_T] = ompi_op_rocm_##ftype##_##name##_uint16_t, \ + [OMPI_OP_BASE_TYPE_INT32_T] = ompi_op_rocm_##ftype##_##name##_int32_t, \ + [OMPI_OP_BASE_TYPE_UINT32_T] = ompi_op_rocm_##ftype##_##name##_uint32_t, \ + [OMPI_OP_BASE_TYPE_INT64_T] = ompi_op_rocm_##ftype##_##name##_int64_t, \ + [OMPI_OP_BASE_TYPE_LONG] = ompi_op_rocm_##ftype##_##name##_long, \ + [OMPI_OP_BASE_TYPE_UNSIGNED_LONG] = ompi_op_rocm_##ftype##_##name##_ulong, \ + [OMPI_OP_BASE_TYPE_UINT64_T] = ompi_op_rocm_##ftype##_##name##_uint64_t + +/** All the Fortran integers ********************************************/ + +#if OMPI_HAVE_FORTRAN_INTEGER +#define FORTRAN_INTEGER_PLAIN(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer +#else +#define FORTRAN_INTEGER_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER1 +#define FORTRAN_INTEGER1(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer1 +#else +#define FORTRAN_INTEGER1(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER2 +#define FORTRAN_INTEGER2(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer2 +#else +#define FORTRAN_INTEGER2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER4 +#define FORTRAN_INTEGER4(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer4 +#else +#define FORTRAN_INTEGER4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER8 +#define FORTRAN_INTEGER8(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer8 +#else +#define FORTRAN_INTEGER8(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER16 +#define FORTRAN_INTEGER16(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_integer16 +#else +#define FORTRAN_INTEGER16(name, ftype) NULL +#endif + +#define FORTRAN_INTEGER(name, ftype) \ + [OMPI_OP_BASE_TYPE_INTEGER] = FORTRAN_INTEGER_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER1] = FORTRAN_INTEGER1(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER2] = FORTRAN_INTEGER2(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER4] = FORTRAN_INTEGER4(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER8] = FORTRAN_INTEGER8(name, ftype), \ + [OMPI_OP_BASE_TYPE_INTEGER16] = FORTRAN_INTEGER16(name, ftype) + +/** All the Fortran reals ***********************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real +#else +#define FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL2 +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real2 +#else +#define FLOATING_POINT_FORTRAN_REAL2(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL4 +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real4 +#else +#define FLOATING_POINT_FORTRAN_REAL4(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_REAL8 +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real8 +#else +#define FLOATING_POINT_FORTRAN_REAL8(name, ftype) NULL +#endif +/* If: + - we have fortran REAL*16and* + - fortran REAL*16 matches the bit representation of the + corresponding C type + Only then do we put in function pointers for REAL*16 reductions. + Otherwise, just put in NULL. */ +#if OMPI_HAVE_FORTRAN_REAL16 && OMPI_REAL16_MATCHES_C +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) ompi_op_rocm_##ftype##_##name##_fortran_real16 +#else +#define FLOATING_POINT_FORTRAN_REAL16(name, ftype) NULL +#endif + +#define FLOATING_POINT_FORTRAN_REAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_REAL] = FLOATING_POINT_FORTRAN_REAL_PLAIN(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL2] = FLOATING_POINT_FORTRAN_REAL2(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL4] = FLOATING_POINT_FORTRAN_REAL4(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL8] = FLOATING_POINT_FORTRAN_REAL8(name, ftype), \ + [OMPI_OP_BASE_TYPE_REAL16] = FLOATING_POINT_FORTRAN_REAL16(name, ftype) + +/** Fortran double precision ********************************************/ + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) \ + ompi_op_rocm_##ftype##_##name##_fortran_double_precision +#else +#define FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype) NULL +#endif + +/** Floating point, including all the Fortran reals *********************/ + +//#if defined(HAVE_SHORT_FLOAT) || defined(HAVE_OPAL_SHORT_FLOAT_T) +//#define SHORT_FLOAT(name, ftype) ompi_op_rocm_##ftype##_##name##_short_float +//#else +#define SHORT_FLOAT(name, ftype) NULL +//#endif +#define FLOAT(name, ftype) ompi_op_rocm_##ftype##_##name##_float +#define DOUBLE(name, ftype) ompi_op_rocm_##ftype##_##name##_double +#define LONG_DOUBLE(name, ftype) ompi_op_rocm_##ftype##_##name##_long_double + +#define FLOATING_POINT(name, ftype) \ + [OMPI_OP_BASE_TYPE_SHORT_FLOAT] = SHORT_FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT] = FLOAT(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE] = DOUBLE(name, ftype), \ + FLOATING_POINT_FORTRAN_REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_DOUBLE_PRECISION] = FLOATING_POINT_FORTRAN_DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE] = LONG_DOUBLE(name, ftype) + +/** Fortran logical *****************************************************/ + +#if OMPI_HAVE_FORTRAN_LOGICAL +#define FORTRAN_LOGICAL(name, ftype) \ + ompi_op_rocm_##ftype##_##name##_fortran_logical /* OMPI_OP_ROCM_TYPE_LOGICAL */ +#else +#define FORTRAN_LOGICAL(name, ftype) NULL +#endif + +#define LOGICAL(name, ftype) \ + [OMPI_OP_BASE_TYPE_LOGICAL] = FORTRAN_LOGICAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_BOOL] = ompi_op_rocm_##ftype##_##name##_bool + +/** Complex *****************************************************/ +#if 0 + +#if defined(HAVE_SHORT_FLOAT__COMPLEX) || defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +#define SHORT_FLOAT_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_short_float_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#endif +#define LONG_DOUBLE_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_long_double_complex +#else +#define SHORT_FLOAT_COMPLEX(name, ftype) NULL +#define LONG_DOUBLE_COMPLEX(name, ftype) NULL +#endif // 0 +#define FLOAT_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_float_complex +#define DOUBLE_COMPLEX(name, ftype) ompi_op_rocm_##ftype##_##name##_c_double_complex + +#define COMPLEX(name, ftype) \ + [OMPI_OP_BASE_TYPE_C_SHORT_FLOAT_COMPLEX] = SHORT_FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_FLOAT_COMPLEX] = FLOAT_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_DOUBLE_COMPLEX] = DOUBLE_COMPLEX(name, ftype), \ + [OMPI_OP_BASE_TYPE_C_LONG_DOUBLE_COMPLEX] = LONG_DOUBLE_COMPLEX(name, ftype) + +/** Byte ****************************************************************/ + +#define BYTE(name, ftype) \ + [OMPI_OP_BASE_TYPE_BYTE] = ompi_op_rocm_##ftype##_##name##_byte + +/** Fortran complex *****************************************************/ +/** Fortran "2" types ***************************************************/ + +#if OMPI_HAVE_FORTRAN_REAL +#define TWOLOC_FORTRAN_2REAL(name, ftype) ompi_op_rocm_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2REAL(name, ftype) NULL +#endif + +#if OMPI_HAVE_FORTRAN_DOUBLE_PRECISION +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) ompi_op_rocm_##ftype##_##name##_2double_precision +#else +#define TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype) NULL +#endif +#if OMPI_HAVE_FORTRAN_INTEGER +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) ompi_op_rocm_##ftype##_##name##_2int +#else +#define TWOLOC_FORTRAN_2INTEGER(name, ftype) NULL +#endif + +/** All "2" types *******************************************************/ + +#define TWOLOC(name, ftype) \ + [OMPI_OP_BASE_TYPE_2REAL] = TWOLOC_FORTRAN_2REAL(name, ftype), \ + [OMPI_OP_BASE_TYPE_2DOUBLE_PRECISION] = TWOLOC_FORTRAN_2DOUBLE_PRECISION(name, ftype), \ + [OMPI_OP_BASE_TYPE_2INTEGER] = TWOLOC_FORTRAN_2INTEGER(name, ftype), \ + [OMPI_OP_BASE_TYPE_FLOAT_INT] = ompi_op_rocm_##ftype##_##name##_float_int, \ + [OMPI_OP_BASE_TYPE_DOUBLE_INT] = ompi_op_rocm_##ftype##_##name##_double_int, \ + [OMPI_OP_BASE_TYPE_LONG_INT] = ompi_op_rocm_##ftype##_##name##_long_int, \ + [OMPI_OP_BASE_TYPE_2INT] = ompi_op_rocm_##ftype##_##name##_2int, \ + [OMPI_OP_BASE_TYPE_SHORT_INT] = ompi_op_rocm_##ftype##_##name##_short_int, \ + [OMPI_OP_BASE_TYPE_LONG_DOUBLE_INT] = ompi_op_rocm_##ftype##_##name##_long_double_int + +/* + * MPI_OP_NULL + * All types + */ +#define FLAGS_NO_FLOAT \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | OMPI_OP_FLAGS_COMMUTE) +#define FLAGS \ + (OMPI_OP_FLAGS_INTRINSIC | OMPI_OP_FLAGS_ASSOC | \ + OMPI_OP_FLAGS_FLOAT_ASSOC | OMPI_OP_FLAGS_COMMUTE) + +ompi_op_base_stream_handler_fn_t ompi_op_rocm_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 2buff), + FORTRAN_INTEGER(max, 2buff), + FLOATING_POINT(max, 2buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 2buff), + FORTRAN_INTEGER(min, 2buff), + FLOATING_POINT(min, 2buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 2buff), + FORTRAN_INTEGER(sum, 2buff), + FLOATING_POINT(sum, 2buff), + COMPLEX(sum, 2buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 2buff), + FORTRAN_INTEGER(prod, 2buff), + FLOATING_POINT(prod, 2buff), + COMPLEX(prod, 2buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] = { + C_INTEGER(land, 2buff), + LOGICAL(land, 2buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 2buff), + FORTRAN_INTEGER(band, 2buff), + BYTE(band, 2buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 2buff), + LOGICAL(lor, 2buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 2buff), + FORTRAN_INTEGER(bor, 2buff), + BYTE(bor, 2buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 2buff), + LOGICAL(lxor, 2buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 2buff), + FORTRAN_INTEGER(bxor, 2buff), + BYTE(bxor, 2buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 2buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 2buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* (MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE) */ + NULL, + }, + + }; + +ompi_op_base_3buff_stream_handler_fn_t ompi_op_rocm_3buff_functions[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX] = + { + /* Corresponds to MPI_OP_NULL */ + [OMPI_OP_BASE_FORTRAN_NULL] = { + /* Leaving this empty puts in NULL for all entries */ + NULL, + }, + /* Corresponds to MPI_MAX */ + [OMPI_OP_BASE_FORTRAN_MAX] = { + C_INTEGER(max, 3buff), + FORTRAN_INTEGER(max, 3buff), + FLOATING_POINT(max, 3buff), + }, + /* Corresponds to MPI_MIN */ + [OMPI_OP_BASE_FORTRAN_MIN] = { + C_INTEGER(min, 3buff), + FORTRAN_INTEGER(min, 3buff), + FLOATING_POINT(min, 3buff), + }, + /* Corresponds to MPI_SUM */ + [OMPI_OP_BASE_FORTRAN_SUM] = { + C_INTEGER(sum, 3buff), + FORTRAN_INTEGER(sum, 3buff), + FLOATING_POINT(sum, 3buff), + COMPLEX(sum, 3buff), + }, + /* Corresponds to MPI_PROD */ + [OMPI_OP_BASE_FORTRAN_PROD] = { + C_INTEGER(prod, 3buff), + FORTRAN_INTEGER(prod, 3buff), + FLOATING_POINT(prod, 3buff), + COMPLEX(prod, 3buff), + }, + /* Corresponds to MPI_LAND */ + [OMPI_OP_BASE_FORTRAN_LAND] ={ + C_INTEGER(land, 3buff), + LOGICAL(land, 3buff), + }, + /* Corresponds to MPI_BAND */ + [OMPI_OP_BASE_FORTRAN_BAND] = { + C_INTEGER(band, 3buff), + FORTRAN_INTEGER(band, 3buff), + BYTE(band, 3buff), + }, + /* Corresponds to MPI_LOR */ + [OMPI_OP_BASE_FORTRAN_LOR] = { + C_INTEGER(lor, 3buff), + LOGICAL(lor, 3buff), + }, + /* Corresponds to MPI_BOR */ + [OMPI_OP_BASE_FORTRAN_BOR] = { + C_INTEGER(bor, 3buff), + FORTRAN_INTEGER(bor, 3buff), + BYTE(bor, 3buff), + }, + /* Corresponds to MPI_LXOR */ + [OMPI_OP_BASE_FORTRAN_LXOR] = { + C_INTEGER(lxor, 3buff), + LOGICAL(lxor, 3buff), + }, + /* Corresponds to MPI_BXOR */ + [OMPI_OP_BASE_FORTRAN_BXOR] = { + C_INTEGER(bxor, 3buff), + FORTRAN_INTEGER(bxor, 3buff), + BYTE(bxor, 3buff), + }, + /* Corresponds to MPI_MAXLOC */ + [OMPI_OP_BASE_FORTRAN_MAXLOC] = { + TWOLOC(maxloc, 3buff), + }, + /* Corresponds to MPI_MINLOC */ + [OMPI_OP_BASE_FORTRAN_MINLOC] = { + TWOLOC(minloc, 3buff), + }, + /* Corresponds to MPI_REPLACE */ + [OMPI_OP_BASE_FORTRAN_REPLACE] = { + /* MPI_ACCUMULATE is handled differently than the other + reductions, so just zero out its function + implementations here to ensure that users don't invoke + MPI_REPLACE with any reduction operations other than + ACCUMULATE */ + NULL, + }, + }; diff --git a/ompi/mca/op/rocm/op_rocm_impl.h b/ompi/mca/op/rocm/op_rocm_impl.h new file mode 100644 index 00000000000..907a19fd4fa --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_impl.h @@ -0,0 +1,706 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include + +#include +#include + +#ifndef BEGIN_C_DECLS +#if defined(c_plusplus) || defined(__cplusplus) +# define BEGIN_C_DECLS extern "C" { +# define END_C_DECLS } +#else +# define BEGIN_C_DECLS /* empty */ +# define END_C_DECLS /* empty */ +#endif +#endif + +BEGIN_C_DECLS + +#define OP_FUNC_SIG(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define FUNC_FUNC_SIG(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; + +#define LOC_FUNC_SIG(name, type_name) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(max, int8_t, int8_t) +FUNC_FUNC_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_SIG(max, int16_t, int16_t) +FUNC_FUNC_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_SIG(max, int32_t, int32_t) +FUNC_FUNC_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_SIG(max, int64_t, int64_t) +FUNC_FUNC_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_SIG(max, long, long) +FUNC_FUNC_SIG(max, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(max, float, float) +FUNC_FUNC_SIG(max, double, double) +FUNC_FUNC_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_SIG(min, int8_t, int8_t) +FUNC_FUNC_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_SIG(min, int16_t, int16_t) +FUNC_FUNC_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_SIG(min, int32_t, int32_t) +FUNC_FUNC_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_SIG(min, int64_t, int64_t) +FUNC_FUNC_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_SIG(min, long, long) +FUNC_FUNC_SIG(min, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 + +FUNC_FUNC_SIG(min, float, float) +FUNC_FUNC_SIG(min, double, double) +FUNC_FUNC_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(sum, int8_t, int8_t) +OP_FUNC_SIG(sum, uint8_t, uint8_t) +OP_FUNC_SIG(sum, int16_t, int16_t) +OP_FUNC_SIG(sum, uint16_t, uint16_t) +OP_FUNC_SIG(sum, int32_t, int32_t) +OP_FUNC_SIG(sum, uint32_t, uint32_t) +OP_FUNC_SIG(sum, int64_t, int64_t) +OP_FUNC_SIG(sum, uint64_t, uint64_t) +OP_FUNC_SIG(sum, long, long) +OP_FUNC_SIG(sum, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(sum, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(sum, float, float) +OP_FUNC_SIG(sum, double, double) +OP_FUNC_SIG(sum, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(sum, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_SIG(sum, c_float_complex, hipFloatComplex) +FUNC_FUNC_SIG(sum, c_double_complex, hipDoubleComplex) +//OP_FUNC_SIG(sum, c_float_complex, float _Complex) +//OP_FUNC_SIG(sum, c_double_complex, double _Complex) +//OP_FUNC_SIG(sum, c_long_double_complex, long double _Complex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_SIG(prod, int8_t, int8_t) +OP_FUNC_SIG(prod, uint8_t, uint8_t) +OP_FUNC_SIG(prod, int16_t, int16_t) +OP_FUNC_SIG(prod, uint16_t, uint16_t) +OP_FUNC_SIG(prod, int32_t, int32_t) +OP_FUNC_SIG(prod, uint32_t, uint32_t) +OP_FUNC_SIG(prod, int64_t, int64_t) +OP_FUNC_SIG(prod, uint64_t, uint64_t) +OP_FUNC_SIG(prod, long, long) +OP_FUNC_SIG(prod, ulong, unsigned long) + +#if 0 +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_SIG(prod, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_SIG(prod, short_float, opal_short_float_t) +#endif +#endif // 0 + +OP_FUNC_SIG(prod, float, float) +OP_FUNC_SIG(prod, double, double) +OP_FUNC_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_SIG(prod, c_long_double_complex, long double _Complex) +#endif // 0 +FUNC_FUNC_SIG(prod, c_float_complex, hipFloatComplex) +FUNC_FUNC_SIG(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_SIG(land, int8_t, int8_t) +FUNC_FUNC_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_SIG(land, int16_t, int16_t) +FUNC_FUNC_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_SIG(land, int32_t, int32_t) +FUNC_FUNC_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_SIG(land, int64_t, int64_t) +FUNC_FUNC_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_SIG(land, long, long) +FUNC_FUNC_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_SIG(lor, int8_t, int8_t) +FUNC_FUNC_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lor, int16_t, int16_t) +FUNC_FUNC_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lor, int32_t, int32_t) +FUNC_FUNC_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lor, int64_t, int64_t) +FUNC_FUNC_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lor, long, long) +FUNC_FUNC_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(lxor, long, long) +FUNC_FUNC_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_SIG(band, int8_t, int8_t) +FUNC_FUNC_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_SIG(band, int16_t, int16_t) +FUNC_FUNC_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_SIG(band, int32_t, int32_t) +FUNC_FUNC_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_SIG(band, int64_t, int64_t) +FUNC_FUNC_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_SIG(band, long, long) +FUNC_FUNC_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_SIG(bor, int8_t, int8_t) +FUNC_FUNC_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bor, int16_t, int16_t) +FUNC_FUNC_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bor, int32_t, int32_t) +FUNC_FUNC_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bor, int64_t, int64_t) +FUNC_FUNC_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bor, long, long) +FUNC_FUNC_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_SIG(bxor, long, long) +FUNC_FUNC_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_SIG(bxor, byte, char) + +/************************************************************************* + * Min and max location "pair" datatypes + *************************************************************************/ + +LOC_STRUCT(float_int, float, int) +LOC_STRUCT(double_int, double, int) +LOC_STRUCT(long_int, long, int) +LOC_STRUCT(2int, int, int) +LOC_STRUCT(short_int, short, int) +LOC_STRUCT(long_double_int, long double, int) +LOC_STRUCT(ulong, unsigned long, int) +/* compat types for Fortran */ +LOC_STRUCT(2float, float, float) +LOC_STRUCT(2double, double, double) +LOC_STRUCT(2int8, int8_t, int8_t) +LOC_STRUCT(2int16, int16_t, int16_t) +LOC_STRUCT(2int32, int32_t, int32_t) +LOC_STRUCT(2int64, int64_t, int64_t) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_SIG(maxloc, 2float) +LOC_FUNC_SIG(maxloc, 2double) +LOC_FUNC_SIG(maxloc, 2int8) +LOC_FUNC_SIG(maxloc, 2int16) +LOC_FUNC_SIG(maxloc, 2int32) +LOC_FUNC_SIG(maxloc, 2int64) + +LOC_FUNC_SIG(maxloc, float_int) +LOC_FUNC_SIG(maxloc, double_int) +LOC_FUNC_SIG(maxloc, long_int) +LOC_FUNC_SIG(maxloc, 2int) +LOC_FUNC_SIG(maxloc, short_int) +LOC_FUNC_SIG(maxloc, long_double_int) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_SIG(minloc, 2float) +LOC_FUNC_SIG(minloc, 2double) +LOC_FUNC_SIG(minloc, 2int8) +LOC_FUNC_SIG(minloc, 2int16) +LOC_FUNC_SIG(minloc, 2int32) +LOC_FUNC_SIG(minloc, 2int64) + +LOC_FUNC_SIG(minloc, float_int) +LOC_FUNC_SIG(minloc, double_int) +LOC_FUNC_SIG(minloc, long_int) +LOC_FUNC_SIG(minloc, 2int) +LOC_FUNC_SIG(minloc, short_int) +LOC_FUNC_SIG(minloc, long_double_int) + + + +#define OP_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define FUNC_FUNC_3BUF_SIG(name, type_name, type) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, \ + const type *in2, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + +#define LOC_FUNC_3BUF_SIG(name, type_name) \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a1, \ + const ompi_op_predefined_##type_name##_t *a2, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream); + + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(max, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(max, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(max, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(max, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(max, long, long) +FUNC_FUNC_3BUF_SIG(max, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(max, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(max, float, float) +FUNC_FUNC_3BUF_SIG(max, double, double) +FUNC_FUNC_3BUF_SIG(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(min, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(min, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(min, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(min, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(min, long, long) +FUNC_FUNC_3BUF_SIG(min, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF_SIG(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF_SIG(min, short_float, opal_short_float_t) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(min, float, float) +FUNC_FUNC_3BUF_SIG(min, double, double) +FUNC_FUNC_3BUF_SIG(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(sum, int8_t, int8_t) +OP_FUNC_3BUF_SIG(sum, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(sum, int16_t, int16_t) +OP_FUNC_3BUF_SIG(sum, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(sum, int32_t, int32_t) +OP_FUNC_3BUF_SIG(sum, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(sum, int64_t, int64_t) +OP_FUNC_3BUF_SIG(sum, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(sum, long, long) +OP_FUNC_3BUF_SIG(sum, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(sum, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF_SIG(sum, short_float, opal_short_float_t) +#endif +#endif // 0 +OP_FUNC_3BUF_SIG(sum, float, float) +OP_FUNC_3BUF_SIG(sum, double, double) +OP_FUNC_3BUF_SIG(sum, long_double, long double) +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(sum, c_short_float_complex, short float _Complex) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(sum, c_float_complex, hipFloatComplex) +FUNC_FUNC_3BUF_SIG(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF_SIG(prod, int8_t, int8_t) +OP_FUNC_3BUF_SIG(prod, uint8_t, uint8_t) +OP_FUNC_3BUF_SIG(prod, int16_t, int16_t) +OP_FUNC_3BUF_SIG(prod, uint16_t, uint16_t) +OP_FUNC_3BUF_SIG(prod, int32_t, int32_t) +OP_FUNC_3BUF_SIG(prod, uint32_t, uint32_t) +OP_FUNC_3BUF_SIG(prod, int64_t, int64_t) +OP_FUNC_3BUF_SIG(prod, uint64_t, uint64_t) +OP_FUNC_3BUF_SIG(prod, long, long) +OP_FUNC_3BUF_SIG(prod, ulong, unsigned long) + +/* Floating point */ +#if 0 +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF_SIG(prod, short_float, short float) +#endif +#endif // 0 +OP_FUNC_3BUF_SIG(prod, float, float) +OP_FUNC_3BUF_SIG(prod, double, double) +OP_FUNC_3BUF_SIG(prod, long_double, long double) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF_SIG(prod, c_float_complex, float _Complex) +OP_FUNC_3BUF_SIG(prod, c_double_complex, double _Complex) +OP_FUNC_3BUF_SIG(prod, c_long_double_complex, long double _Complex) +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF_SIG(prod, c_short_float_complex, short float _Complex) +#endif +#endif // 0 +FUNC_FUNC_3BUF_SIG(prod, c_float_complex, hipFloatComplex) +FUNC_FUNC_3BUF_SIG(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(land, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(land, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(land, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(land, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(land, long, long) +FUNC_FUNC_3BUF_SIG(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lor, long, long) +FUNC_FUNC_3BUF_SIG(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(lxor, long, long) +FUNC_FUNC_3BUF_SIG(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF_SIG(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(band, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(band, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(band, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(band, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(band, long, long) +FUNC_FUNC_3BUF_SIG(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bor, long, long) +FUNC_FUNC_3BUF_SIG(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +/* C integer */ +FUNC_FUNC_3BUF_SIG(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF_SIG(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF_SIG(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF_SIG(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF_SIG(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF_SIG(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF_SIG(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF_SIG(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF_SIG(bxor, long, long) +FUNC_FUNC_3BUF_SIG(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF_SIG(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(maxloc, float_int) +LOC_FUNC_3BUF_SIG(maxloc, double_int) +LOC_FUNC_3BUF_SIG(maxloc, long_int) +LOC_FUNC_3BUF_SIG(maxloc, 2int) +LOC_FUNC_3BUF_SIG(maxloc, short_int) +LOC_FUNC_3BUF_SIG(maxloc, long_double_int) + +LOC_FUNC_3BUF_SIG(maxloc, 2float) +LOC_FUNC_3BUF_SIG(maxloc, 2double) +LOC_FUNC_3BUF_SIG(maxloc, 2int8) +LOC_FUNC_3BUF_SIG(maxloc, 2int16) +LOC_FUNC_3BUF_SIG(maxloc, 2int32) +LOC_FUNC_3BUF_SIG(maxloc, 2int64) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF_SIG(minloc, float_int) +LOC_FUNC_3BUF_SIG(minloc, double_int) +LOC_FUNC_3BUF_SIG(minloc, long_int) +LOC_FUNC_3BUF_SIG(minloc, 2int) +LOC_FUNC_3BUF_SIG(minloc, short_int) +LOC_FUNC_3BUF_SIG(minloc, long_double_int) + +LOC_FUNC_3BUF_SIG(minloc, 2float) +LOC_FUNC_3BUF_SIG(minloc, 2double) +LOC_FUNC_3BUF_SIG(minloc, 2int8) +LOC_FUNC_3BUF_SIG(minloc, 2int16) +LOC_FUNC_3BUF_SIG(minloc, 2int32) +LOC_FUNC_3BUF_SIG(minloc, 2int64) + +END_C_DECLS diff --git a/ompi/mca/op/rocm/op_rocm_impl.hip b/ompi/mca/op/rocm/op_rocm_impl.hip new file mode 100644 index 00000000000..45a6eee4349 --- /dev/null +++ b/ompi/mca/op/rocm/op_rocm_impl.hip @@ -0,0 +1,1085 @@ +/* + * Copyright (c) 2019-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "hip/hip_runtime.h" +#include +#include + +#include + +#include "op_rocm_impl.h" + +//#define DO_NOT_USE_INTRINSICS 1 +#define USE_VECTORS 1 + +#include + +#define ISSIGNED(x) std::is_signed_v + +template +static inline __device__ constexpr T tmax(T a, T b) { + return (a > b) ? a : b; +} + +template +static inline __device__ constexpr T tmin(T a, T b) { + return (a < b) ? a : b; +} + +template +static inline __device__ constexpr T tsum(T a, T b) { + return a+b; +} + +template +static inline __device__ constexpr T tprod(T a, T b) { + return a*b; +} + +template +static inline __device__ T vmax(const T& a, const T& b) { + return T{tmax(a.x, b.x), tmax(a.y, b.y), tmax(a.z, b.z), tmax(a.w, b.w)}; +} + +template +static inline __device__ T vmin(const T& a, const T& b) { + return T{tmin(a.x, b.x), tmin(a.y, b.y), tmin(a.z, b.z), tmin(a.w, b.w)}; +} + +template +static inline __device__ T vsum(const T& a, const T& b) { + return T{tsum(a.x, b.x), tsum(a.y, b.y), tsum(a.z, b.z), tsum(a.w, b.w)}; +} + +template +static inline __device__ T vprod(const T& a, const T& b) { + return T{(a.x * b.x), (a.y * b.y), (a.z * b.z), (a.w * b.w)}; +} + + +/* TODO: missing support for + * - short float (conditional on whether short float is available) + * - complex + */ + +#define VECLEN 2 +#define VECTYPE(t) t##VECLEN + +#define OP_FUNC(name, type_name, type, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = inout[i] op in[i]; \ + } \ + } \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } + +#if defined(USE_VECTORS) +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = ((vtype*)inout)[i] op ((vtype*)in)[i]; \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = inout[idx] op in[idx]; \ + } \ + } \ + } \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } +#else // USE_VECTORS +#define OPV_FUNC(name, type_name, type, vtype, vlen, op) OP_FUNC(name, type_name, type, op) +#endif // USE_VECTORS + + +#define FUNC_FUNC(name, type_name, type) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = current_func(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + in, inout, n); \ + } + + +#if defined(USE_VECTORS) +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n/vlen; i += stride) { \ + ((vtype*)inout)[i] = vfn(((vtype*)inout)[i], ((vtype*)in)[i]); \ + } \ + int remainder = n%vlen; \ + if (index == (n/vlen) && remainder != 0) { \ + while(remainder) { \ + int idx = n - remainder--; \ + inout[idx] = fn(inout[idx], in[idx]); \ + } \ + } \ + } \ + static void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int vcount = (count + vlen-1)/vlen; \ + int threads = min(threads_per_block, vcount); \ + int blocks = min((vcount + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } +#else +#define VFUNC_FUNC(name, type_name, type, vtype, vlen, vfn, fn) FUNC_FUNC_FN(name, type_name, type, fn) +#endif // defined(USE_VECTORS) + +#define FUNC_FUNC_FN(name, type_name, type, fn) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const type *__restrict__ in, \ + type *__restrict__ inout, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + inout[i] = fn(inout[i], in[i]); \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(count, threads_per_block); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + int n = count; \ + hipStream_t s = stream; \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel<<>>(in, inout, n); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ + +#define LOC_FUNC(name, type_name, op) \ + static __global__ void \ + ompi_op_rocm_2buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in, \ + ompi_op_predefined_##type_name##_t *__restrict__ inout, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a = &in[i]; \ + ompi_op_predefined_##type_name##_t *b = &inout[i]; \ + if (a->v op b->v) { \ + b->v = a->v; \ + b->k = a->k; \ + } else if (a->v == b->v) { \ + b->k = (b->k < a->k ? b->k : a->k); \ + } \ + } \ + } \ + void \ + ompi_op_rocm_2buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *a, \ + ompi_op_predefined_##type_name##_t *b, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipStream_t s = stream; \ + hipLaunchKernelGGL(ompi_op_rocm_2buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, s, \ + a, b, count); \ + } + + +#define OPV_DISPATCH(name, type_name, type) \ + void ompi_op_rocm_2buff_##name##_##type_name##_submit(const type *in, \ + type *inout, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + static_assert(sizeof(type_name) <= sizeof(unsigned long long), "Unknown size type"); \ + if constexpr(!ISSIGNED(type)) { \ + if constexpr(sizeof(type_name) == sizeof(unsigned char)) { \ + ompi_op_rocm_2buff_##name##_uchar_submit((const unsigned char*)in, (unsigned char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned short)) { \ + ompi_op_rocm_2buff_##name##_ushort_submit((const unsigned short*)in, (unsigned short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned int)) { \ + ompi_op_rocm_2buff_##name##_uint_submit((const unsigned int*)in, (unsigned int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long)) { \ + ompi_op_rocm_2buff_##name##_ulong_submit((const unsigned long*)in, (unsigned long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(unsigned long long)) { \ + ompi_op_rocm_2buff_##name##_ulonglong_submit((const unsigned long long*)in, (unsigned long long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } \ + } else { \ + if constexpr(sizeof(type_name) == sizeof(char)) { \ + ompi_op_rocm_2buff_##name##_char_submit((const char*)in, (char*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(short)) { \ + ompi_op_rocm_2buff_##name##_short_submit((const short*)in, (short*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(int)) { \ + ompi_op_rocm_2buff_##name##_int_submit((const int*)in, (int*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long)) { \ + ompi_op_rocm_2buff_##name##_long_submit((const long*)in, (long*)inout, count, \ + threads_per_block, \ + max_blocks, stream); \ + } else if constexpr(sizeof(type_name) == sizeof(long long)) { \ + ompi_op_rocm_2buff_##name##_longlong_submit((const long long*)in, (long long*)inout, count,\ + threads_per_block, \ + max_blocks, stream); \ + } \ + } \ + } + +/************************************************************************* + * Max + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(max, char, char, char4, 4, vmax, max) +VFUNC_FUNC(max, uchar, unsigned char, uchar4, 4, vmax, max) +VFUNC_FUNC(max, short, short, short4, 4, vmax, max) +VFUNC_FUNC(max, ushort, unsigned short, ushort4, 4, vmax, max) +VFUNC_FUNC(max, int, int, int4, 4, vmax, max) +VFUNC_FUNC(max, uint, unsigned int, uint4, 4, vmax, max) + +#undef current_func +#define current_func(a, b) max(a, b) +FUNC_FUNC(max, long, long) +FUNC_FUNC(max, ulong, unsigned long) +FUNC_FUNC(max, longlong, long long) +FUNC_FUNC(max, ulonglong, unsigned long long) + + +/* dispatch fixed-size types */ +OPV_DISPATCH(max, int8_t, int8_t) +OPV_DISPATCH(max, uint8_t, uint8_t) +OPV_DISPATCH(max, int16_t, int16_t) +OPV_DISPATCH(max, uint16_t, uint16_t) +OPV_DISPATCH(max, int32_t, int32_t) +OPV_DISPATCH(max, uint32_t, uint32_t) +OPV_DISPATCH(max, int64_t, int64_t) +OPV_DISPATCH(max, uint64_t, uint64_t) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmaxf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmax(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(max, double, double) + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +FUNC_FUNC(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(min, char, char, char4, 4, vmin, min) +VFUNC_FUNC(min, uchar, unsigned char, uchar4, 4, vmin, min) +VFUNC_FUNC(min, short, short, short4, 4, vmin, min) +VFUNC_FUNC(min, ushort, unsigned short, ushort4, 4, vmin, min) +VFUNC_FUNC(min, int, int, int4, 4, vmin, min) +VFUNC_FUNC(min, uint, unsigned int, uint4, 4, vmin, min) + +#undef current_func +#define current_func(a, b) min(a, b) +FUNC_FUNC(min, long, long) +FUNC_FUNC(min, ulong, unsigned long) +FUNC_FUNC(min, longlong, long long) +FUNC_FUNC(min, ulonglong, unsigned long long) +OPV_DISPATCH(min, int8_t, int8_t) +OPV_DISPATCH(min, uint8_t, uint8_t) +OPV_DISPATCH(min, int16_t, int16_t) +OPV_DISPATCH(min, uint16_t, uint16_t) +OPV_DISPATCH(min, int32_t, int32_t) +OPV_DISPATCH(min, uint32_t, uint32_t) +OPV_DISPATCH(min, int64_t, int64_t) +OPV_DISPATCH(min, uint64_t, uint64_t) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fminf(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, float, float) + +#if !defined(DO_NOT_USE_INTRINSICS) +#undef current_func +#define current_func(a, b) fmin(a, b) +#endif // DO_NOT_USE_INTRINSICS +FUNC_FUNC(min, double, double) + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +FUNC_FUNC(min, long_double, long double) + + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +VFUNC_FUNC(sum, char, char, char4, 4, vsum, tsum) +VFUNC_FUNC(sum, uchar, unsigned char, uchar4, 4, vsum, tsum) +VFUNC_FUNC(sum, short, short, short4, 4, vsum, tsum) +VFUNC_FUNC(sum, ushort, unsigned short, ushort4, 4, vsum, tsum) +VFUNC_FUNC(sum, int, int, int4, 4, vsum, tsum) +VFUNC_FUNC(sum, uint, unsigned int, uint4, 4, vsum, tsum) + +#undef current_func +#define current_func(a, b) tsum(a, b) +FUNC_FUNC(sum, long, long) +FUNC_FUNC(sum, ulong, unsigned long) +FUNC_FUNC(sum, longlong, long long) +FUNC_FUNC(sum, ulonglong, unsigned long long) + +OPV_DISPATCH(sum, int8_t, int8_t) +OPV_DISPATCH(sum, uint8_t, uint8_t) +OPV_DISPATCH(sum, int16_t, int16_t) +OPV_DISPATCH(sum, uint16_t, uint16_t) +OPV_DISPATCH(sum, int32_t, int32_t) +OPV_DISPATCH(sum, uint32_t, uint32_t) +OPV_DISPATCH(sum, int64_t, int64_t) +OPV_DISPATCH(sum, uint64_t, uint64_t) + +OPV_FUNC(sum, float, float, float4, 4, +) +OPV_FUNC(sum, double, double, double4, 4, +) +OP_FUNC(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +#undef current_func +#define current_func(a, b) tprod(a, b) +FUNC_FUNC(prod, char, char) +FUNC_FUNC(prod, uchar, unsigned char) +FUNC_FUNC(prod, short, short) +FUNC_FUNC(prod, ushort, unsigned short) +FUNC_FUNC(prod, int, int) +FUNC_FUNC(prod, uint, unsigned int) +FUNC_FUNC(prod, long, long) +FUNC_FUNC(prod, ulong, unsigned long) +FUNC_FUNC(prod, longlong, long long) +FUNC_FUNC(prod, ulonglong, unsigned long long) + +OPV_DISPATCH(prod, int8_t, int8_t) +OPV_DISPATCH(prod, uint8_t, uint8_t) +OPV_DISPATCH(prod, int16_t, int16_t) +OPV_DISPATCH(prod, uint16_t, uint16_t) +OPV_DISPATCH(prod, int32_t, int32_t) +OPV_DISPATCH(prod, uint32_t, uint32_t) +OPV_DISPATCH(prod, int64_t, int64_t) +OPV_DISPATCH(prod, uint64_t, uint64_t) + + +OPV_FUNC(prod, float, float, float4, 4, *) +OPV_FUNC(prod, double, double, double4, 4, *) +OP_FUNC(prod, long_double, long double, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC(sum, c_short_float_complex, short float _Complex, +=) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC(sum, c_long_double_complex, cuLongDoubleComplex, +=) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC(land, int8_t, int8_t) +FUNC_FUNC(land, uint8_t, uint8_t) +FUNC_FUNC(land, int16_t, int16_t) +FUNC_FUNC(land, uint16_t, uint16_t) +FUNC_FUNC(land, int32_t, int32_t) +FUNC_FUNC(land, uint32_t, uint32_t) +FUNC_FUNC(land, int64_t, int64_t) +FUNC_FUNC(land, uint64_t, uint64_t) +FUNC_FUNC(land, long, long) +FUNC_FUNC(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC(lor, int8_t, int8_t) +FUNC_FUNC(lor, uint8_t, uint8_t) +FUNC_FUNC(lor, int16_t, int16_t) +FUNC_FUNC(lor, uint16_t, uint16_t) +FUNC_FUNC(lor, int32_t, int32_t) +FUNC_FUNC(lor, uint32_t, uint32_t) +FUNC_FUNC(lor, int64_t, int64_t) +FUNC_FUNC(lor, uint64_t, uint64_t) +FUNC_FUNC(lor, long, long) +FUNC_FUNC(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC(lxor, int8_t, int8_t) +FUNC_FUNC(lxor, uint8_t, uint8_t) +FUNC_FUNC(lxor, int16_t, int16_t) +FUNC_FUNC(lxor, uint16_t, uint16_t) +FUNC_FUNC(lxor, int32_t, int32_t) +FUNC_FUNC(lxor, uint32_t, uint32_t) +FUNC_FUNC(lxor, int64_t, int64_t) +FUNC_FUNC(lxor, uint64_t, uint64_t) +FUNC_FUNC(lxor, long, long) +FUNC_FUNC(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC(band, int8_t, int8_t) +FUNC_FUNC(band, uint8_t, uint8_t) +FUNC_FUNC(band, int16_t, int16_t) +FUNC_FUNC(band, uint16_t, uint16_t) +FUNC_FUNC(band, int32_t, int32_t) +FUNC_FUNC(band, uint32_t, uint32_t) +FUNC_FUNC(band, int64_t, int64_t) +FUNC_FUNC(band, uint64_t, uint64_t) +FUNC_FUNC(band, long, long) +FUNC_FUNC(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC(bor, int8_t, int8_t) +FUNC_FUNC(bor, uint8_t, uint8_t) +FUNC_FUNC(bor, int16_t, int16_t) +FUNC_FUNC(bor, uint16_t, uint16_t) +FUNC_FUNC(bor, int32_t, int32_t) +FUNC_FUNC(bor, uint32_t, uint32_t) +FUNC_FUNC(bor, int64_t, int64_t) +FUNC_FUNC(bor, uint64_t, uint64_t) +FUNC_FUNC(bor, long, long) +FUNC_FUNC(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC(bxor, int8_t, int8_t) +FUNC_FUNC(bxor, uint8_t, uint8_t) +FUNC_FUNC(bxor, int16_t, int16_t) +FUNC_FUNC(bxor, uint16_t, uint16_t) +FUNC_FUNC(bxor, int32_t, int32_t) +FUNC_FUNC(bxor, uint32_t, uint32_t) +FUNC_FUNC(bxor, int64_t, int64_t) +FUNC_FUNC(bxor, uint64_t, uint64_t) +FUNC_FUNC(bxor, long, long) +FUNC_FUNC(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC(maxloc, float_int, >) +LOC_FUNC(maxloc, double_int, >) +LOC_FUNC(maxloc, long_int, >) +LOC_FUNC(maxloc, 2int, >) +LOC_FUNC(maxloc, short_int, >) +LOC_FUNC(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC(maxloc, 2float, >) +LOC_FUNC(maxloc, 2double, >) +LOC_FUNC(maxloc, 2int8, >) +LOC_FUNC(maxloc, 2int16, >) +LOC_FUNC(maxloc, 2int32, >) +LOC_FUNC(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC(minloc, float_int, <) +LOC_FUNC(minloc, double_int, <) +LOC_FUNC(minloc, long_int, <) +LOC_FUNC(minloc, 2int, <) +LOC_FUNC(minloc, short_int, <) +LOC_FUNC(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC(minloc, 2float, <) +LOC_FUNC(minloc, 2double, <) +LOC_FUNC(minloc, 2int8, <) +LOC_FUNC(minloc, 2int16, <) +LOC_FUNC(minloc, 2int32, <) +LOC_FUNC(minloc, 2int64, <) + + +/* + * This is a three buffer (2 input and 1 output) version of the reduction + * routines, needed for some optimizations. + */ +#define OP_FUNC_3BUF(name, type_name, type, op) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = in1[i] op in2[i]; \ + } \ + } \ + void ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for (out = op(in1, in2)) + */ +#define FUNC_FUNC_3BUF(name, type_name, type) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const type *__restrict__ in1, \ + const type *__restrict__ in2, \ + type *__restrict__ out, int n) { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + out[i] = current_func(in1[i], in2[i]); \ + } \ + } \ + void \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(const type *in1, const type *in2, \ + type *out, int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + +/* + * Since all the functions in this file are essentially identical, we + * use a macro to substitute in names and types. The core operation + * in all functions that use this macro is the same. + * + * This macro is for minloc and maxloc + */ +/* +#define LOC_STRUCT(type_name, type1, type2) \ + typedef struct { \ + type1 v; \ + type2 k; \ + } ompi_op_predefined_##type_name##_t; +*/ + +#define LOC_FUNC_3BUF(name, type_name, op) \ + static __global__ void \ + ompi_op_rocm_3buff_##name##_##type_name##_kernel(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int n) \ + { \ + const int index = blockIdx.x * blockDim.x + threadIdx.x; \ + const int stride = blockDim.x * gridDim.x; \ + for (int i = index; i < n; i += stride) { \ + const ompi_op_predefined_##type_name##_t *a1 = &in1[i]; \ + const ompi_op_predefined_##type_name##_t *a2 = &in2[i]; \ + ompi_op_predefined_##type_name##_t *b = &out[i]; \ + if (a1->v op a2->v) { \ + b->v = a1->v; \ + b->k = a1->k; \ + } else if (a1->v == a2->v) { \ + b->v = a1->v; \ + b->k = (a2->k < a1->k ? a2->k : a1->k); \ + } else { \ + b->v = a2->v; \ + b->k = a2->k; \ + } \ + } \ + } \ + void \ + ompi_op_rocm_3buff_##name##_##type_name##_submit(const ompi_op_predefined_##type_name##_t *__restrict__ in1, \ + const ompi_op_predefined_##type_name##_t *__restrict__ in2, \ + ompi_op_predefined_##type_name##_t *__restrict__ out, \ + int count, \ + int threads_per_block, \ + int max_blocks, \ + hipStream_t stream) \ + { \ + int threads = min(threads_per_block, count); \ + int blocks = min((count + threads-1) / threads, max_blocks); \ + hipLaunchKernelGGL(ompi_op_rocm_3buff_##name##_##type_name##_kernel, \ + dim3(blocks), dim3(threads), 0, stream, \ + in1, in2, out, count); \ + } + + +/************************************************************************* + * Max + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) > (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(max, int8_t, int8_t) +FUNC_FUNC_3BUF(max, uint8_t, uint8_t) +FUNC_FUNC_3BUF(max, int16_t, int16_t) +FUNC_FUNC_3BUF(max, uint16_t, uint16_t) +FUNC_FUNC_3BUF(max, int32_t, int32_t) +FUNC_FUNC_3BUF(max, uint32_t, uint32_t) +FUNC_FUNC_3BUF(max, int64_t, int64_t) +FUNC_FUNC_3BUF(max, uint64_t, uint64_t) +FUNC_FUNC_3BUF(max, long, long) +FUNC_FUNC_3BUF(max, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(max, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(max, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(max, float, float) +FUNC_FUNC_3BUF(max, double, double) +FUNC_FUNC_3BUF(max, long_double, long double) + +/************************************************************************* + * Min + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) < (b) ? (a) : (b)) +/* C integer */ +FUNC_FUNC_3BUF(min, int8_t, int8_t) +FUNC_FUNC_3BUF(min, uint8_t, uint8_t) +FUNC_FUNC_3BUF(min, int16_t, int16_t) +FUNC_FUNC_3BUF(min, uint16_t, uint16_t) +FUNC_FUNC_3BUF(min, int32_t, int32_t) +FUNC_FUNC_3BUF(min, uint32_t, uint32_t) +FUNC_FUNC_3BUF(min, int64_t, int64_t) +FUNC_FUNC_3BUF(min, uint64_t, uint64_t) +FUNC_FUNC_3BUF(min, long, long) +FUNC_FUNC_3BUF(min, ulong, unsigned long) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +FUNC_FUNC_3BUF(min, short_float, short float) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +FUNC_FUNC_3BUF(min, short_float, opal_short_float_t) +#endif +FUNC_FUNC_3BUF(min, float, float) +FUNC_FUNC_3BUF(min, double, double) +FUNC_FUNC_3BUF(min, long_double, long double) + +/************************************************************************* + * Sum + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(sum, int8_t, int8_t, +) +OP_FUNC_3BUF(sum, uint8_t, uint8_t, +) +OP_FUNC_3BUF(sum, int16_t, int16_t, +) +OP_FUNC_3BUF(sum, uint16_t, uint16_t, +) +OP_FUNC_3BUF(sum, int32_t, int32_t, +) +OP_FUNC_3BUF(sum, uint32_t, uint32_t, +) +OP_FUNC_3BUF(sum, int64_t, int64_t, +) +OP_FUNC_3BUF(sum, uint64_t, uint64_t, +) +OP_FUNC_3BUF(sum, long, long, +) +OP_FUNC_3BUF(sum, ulong, unsigned long, +) + +/* Floating point */ +#if defined(HAVE_SHORT_FLOAT) +OP_FUNC_3BUF(sum, short_float, short float, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_T) +OP_FUNC_3BUF(sum, short_float, opal_short_float_t, +) +#endif +OP_FUNC_3BUF(sum, float, float, +) +OP_FUNC_3BUF(sum, double, double, +) +OP_FUNC_3BUF(sum, long_double, long double, +) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(sum, c_short_float_complex, short float _Complex, +) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_SUM_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(sum, c_long_double_complex, cuLongDoubleComplex, +) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCaddf(a,b)) +FUNC_FUNC_3BUF(sum, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCadd(a,b)) +FUNC_FUNC_3BUF(sum, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Product + *************************************************************************/ + +/* C integer */ +OP_FUNC_3BUF(prod, int8_t, int8_t, *) +OP_FUNC_3BUF(prod, uint8_t, uint8_t, *) +OP_FUNC_3BUF(prod, int16_t, int16_t, *) +OP_FUNC_3BUF(prod, uint16_t, uint16_t, *) +OP_FUNC_3BUF(prod, int32_t, int32_t, *) +OP_FUNC_3BUF(prod, uint32_t, uint32_t, *) +OP_FUNC_3BUF(prod, int64_t, int64_t, *) +OP_FUNC_3BUF(prod, uint64_t, uint64_t, *) +OP_FUNC_3BUF(prod, long, long, *) +OP_FUNC_3BUF(prod, ulong, unsigned long, *) + +/* Complex */ +#if 0 +#if defined(HAVE_SHORT_FLOAT__COMPLEX) +OP_FUNC_3BUF(prod, c_short_float_complex, short float _Complex, *) +#elif defined(HAVE_OPAL_SHORT_FLOAT_COMPLEX_T) +COMPLEX_PROD_FUNC_3BUF(c_short_float_complex, opal_short_float_t) +#endif +OP_FUNC_3BUF(prod, c_long_double_complex, long double _Complex, *) +#endif // 0 +#undef current_func +#define current_func(a, b) (hipCmulf(a,b)) +FUNC_FUNC_3BUF(prod, c_float_complex, hipFloatComplex) +#undef current_func +#define current_func(a, b) (hipCmul(a,b)) +FUNC_FUNC_3BUF(prod, c_double_complex, hipDoubleComplex) + +/************************************************************************* + * Logical AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) && (b)) +/* C integer */ +FUNC_FUNC_3BUF(land, int8_t, int8_t) +FUNC_FUNC_3BUF(land, uint8_t, uint8_t) +FUNC_FUNC_3BUF(land, int16_t, int16_t) +FUNC_FUNC_3BUF(land, uint16_t, uint16_t) +FUNC_FUNC_3BUF(land, int32_t, int32_t) +FUNC_FUNC_3BUF(land, uint32_t, uint32_t) +FUNC_FUNC_3BUF(land, int64_t, int64_t) +FUNC_FUNC_3BUF(land, uint64_t, uint64_t) +FUNC_FUNC_3BUF(land, long, long) +FUNC_FUNC_3BUF(land, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(land, bool, bool) + +/************************************************************************* + * Logical OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) || (b)) +/* C integer */ +FUNC_FUNC_3BUF(lor, int8_t, int8_t) +FUNC_FUNC_3BUF(lor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lor, int16_t, int16_t) +FUNC_FUNC_3BUF(lor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lor, int32_t, int32_t) +FUNC_FUNC_3BUF(lor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lor, int64_t, int64_t) +FUNC_FUNC_3BUF(lor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lor, long, long) +FUNC_FUNC_3BUF(lor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lor, bool, bool) + +/************************************************************************* + * Logical XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a ? 1 : 0) ^ (b ? 1: 0)) +/* C integer */ +FUNC_FUNC_3BUF(lxor, int8_t, int8_t) +FUNC_FUNC_3BUF(lxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(lxor, int16_t, int16_t) +FUNC_FUNC_3BUF(lxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(lxor, int32_t, int32_t) +FUNC_FUNC_3BUF(lxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(lxor, int64_t, int64_t) +FUNC_FUNC_3BUF(lxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(lxor, long, long) +FUNC_FUNC_3BUF(lxor, ulong, unsigned long) + +/* C++ bool */ +FUNC_FUNC_3BUF(lxor, bool, bool) + +/************************************************************************* + * Bitwise AND + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) & (b)) +/* C integer */ +FUNC_FUNC_3BUF(band, int8_t, int8_t) +FUNC_FUNC_3BUF(band, uint8_t, uint8_t) +FUNC_FUNC_3BUF(band, int16_t, int16_t) +FUNC_FUNC_3BUF(band, uint16_t, uint16_t) +FUNC_FUNC_3BUF(band, int32_t, int32_t) +FUNC_FUNC_3BUF(band, uint32_t, uint32_t) +FUNC_FUNC_3BUF(band, int64_t, int64_t) +FUNC_FUNC_3BUF(band, uint64_t, uint64_t) +FUNC_FUNC_3BUF(band, long, long) +FUNC_FUNC_3BUF(band, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(band, byte, char) + +/************************************************************************* + * Bitwise OR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) | (b)) +/* C integer */ +FUNC_FUNC_3BUF(bor, int8_t, int8_t) +FUNC_FUNC_3BUF(bor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bor, int16_t, int16_t) +FUNC_FUNC_3BUF(bor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bor, int32_t, int32_t) +FUNC_FUNC_3BUF(bor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bor, int64_t, int64_t) +FUNC_FUNC_3BUF(bor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bor, long, long) +FUNC_FUNC_3BUF(bor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bor, byte, char) + +/************************************************************************* + * Bitwise XOR + *************************************************************************/ + +#undef current_func +#define current_func(a, b) ((a) ^ (b)) +/* C integer */ +FUNC_FUNC_3BUF(bxor, int8_t, int8_t) +FUNC_FUNC_3BUF(bxor, uint8_t, uint8_t) +FUNC_FUNC_3BUF(bxor, int16_t, int16_t) +FUNC_FUNC_3BUF(bxor, uint16_t, uint16_t) +FUNC_FUNC_3BUF(bxor, int32_t, int32_t) +FUNC_FUNC_3BUF(bxor, uint32_t, uint32_t) +FUNC_FUNC_3BUF(bxor, int64_t, int64_t) +FUNC_FUNC_3BUF(bxor, uint64_t, uint64_t) +FUNC_FUNC_3BUF(bxor, long, long) +FUNC_FUNC_3BUF(bxor, ulong, unsigned long) + +/* Byte */ +FUNC_FUNC_3BUF(bxor, byte, char) + +/************************************************************************* + * Max location + *************************************************************************/ + +LOC_FUNC_3BUF(maxloc, float_int, >) +LOC_FUNC_3BUF(maxloc, double_int, >) +LOC_FUNC_3BUF(maxloc, long_int, >) +LOC_FUNC_3BUF(maxloc, 2int, >) +LOC_FUNC_3BUF(maxloc, short_int, >) +LOC_FUNC_3BUF(maxloc, long_double_int, >) + +/* Fortran compat types */ +LOC_FUNC_3BUF(maxloc, 2float, >) +LOC_FUNC_3BUF(maxloc, 2double, >) +LOC_FUNC_3BUF(maxloc, 2int8, >) +LOC_FUNC_3BUF(maxloc, 2int16, >) +LOC_FUNC_3BUF(maxloc, 2int32, >) +LOC_FUNC_3BUF(maxloc, 2int64, >) + +/************************************************************************* + * Min location + *************************************************************************/ + +LOC_FUNC_3BUF(minloc, float_int, <) +LOC_FUNC_3BUF(minloc, double_int, <) +LOC_FUNC_3BUF(minloc, long_int, <) +LOC_FUNC_3BUF(minloc, 2int, <) +LOC_FUNC_3BUF(minloc, short_int, <) +LOC_FUNC_3BUF(minloc, long_double_int, <) + +/* Fortran compat types */ +LOC_FUNC_3BUF(minloc, 2float, <) +LOC_FUNC_3BUF(minloc, 2double, <) +LOC_FUNC_3BUF(minloc, 2int8, <) +LOC_FUNC_3BUF(minloc, 2int16, <) +LOC_FUNC_3BUF(minloc, 2int32, <) +LOC_FUNC_3BUF(minloc, 2int64, <) diff --git a/ompi/op/Makefile.am b/ompi/op/Makefile.am index 5599c31311b..f0ba89c5618 100644 --- a/ompi/op/Makefile.am +++ b/ompi/op/Makefile.am @@ -22,6 +22,8 @@ # This makefile.am does not stand on its own - it is included from # ompi/Makefile.am +dist_ompidata_DATA += op/help-ompi-op.txt + headers += op/op.h lib@OMPI_LIBMPI_NAME@_la_SOURCES += op/op.c diff --git a/ompi/op/help-ompi-op.txt b/ompi/op/help-ompi-op.txt new file mode 100644 index 00000000000..5cfb60b8f9f --- /dev/null +++ b/ompi/op/help-ompi-op.txt @@ -0,0 +1,15 @@ +# -*- text -*- +# +# Copyright (c) 2004-2023 The University of Tennessee and The University +# of Tennessee Research Foundation. All rights +# reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# This is the US/English help file for Open MPI's allocator bucket support +# +[missing implementation] +ERROR: No suitable module for op %s on type %s found for device memory! diff --git a/ompi/op/op.c b/ompi/op/op.c index 3977fa8b97b..a75d6b33d5b 100644 --- a/ompi/op/op.c +++ b/ompi/op/op.c @@ -475,6 +475,7 @@ static void ompi_op_construct(ompi_op_t *new_op) new_op->o_3buff_intrinsic.fns[i] = NULL; new_op->o_3buff_intrinsic.modules[i] = NULL; } + new_op->o_device_op = NULL; } @@ -506,4 +507,19 @@ static void ompi_op_destruct(ompi_op_t *op) op->o_3buff_intrinsic.modules[i] = NULL; } } + + if (op->o_device_op != NULL) { + for (i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) { + if( NULL != op->o_device_op->do_intrinsic.modules[i] ) { + OBJ_RELEASE(op->o_device_op->do_intrinsic.modules[i]); + op->o_device_op->do_intrinsic.modules[i] = NULL; + } + if( NULL != op->o_device_op->do_3buff_intrinsic.modules[i] ) { + OBJ_RELEASE(op->o_device_op->do_3buff_intrinsic.modules[i]); + op->o_device_op->do_3buff_intrinsic.modules[i] = NULL; + } + } + free(op->o_device_op); + op->o_device_op = NULL; + } } diff --git a/ompi/op/op.h b/ompi/op/op.h index f3cf5b53636..05a4c0c89e3 100644 --- a/ompi/op/op.h +++ b/ompi/op/op.h @@ -3,7 +3,7 @@ * Copyright (c) 2004-2006 The Trustees of Indiana University and Indiana * University Research and Technology * Corporation. All rights reserved. - * Copyright (c) 2004-2007 The University of Tennessee and The University + * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. * Copyright (c) 2004-2007 High Performance Computing Center Stuttgart, @@ -44,6 +44,7 @@ #include "opal/class/opal_object.h" #include "opal/util/printf.h" +#include "opal/util/show_help.h" #include "ompi/datatype/ompi_datatype.h" #include "ompi/mpi/fortran/base/fint_2_int.h" @@ -122,6 +123,15 @@ enum ompi_op_type { OMPI_OP_REPLACE, OMPI_OP_NUM_OF_TYPES }; + +/* device op information */ +struct ompi_device_op_t { + opal_accelerator_stream_t *do_stream; + ompi_op_base_op_stream_fns_t do_intrinsic; + ompi_op_base_op_3buff_stream_fns_t do_3buff_intrinsic; +}; +typedef struct ompi_device_op_t ompi_device_op_t; + /** * Back-end type of MPI_Op */ @@ -167,6 +177,10 @@ struct ompi_op_t { /** 3-buffer functions, which is only for intrinsic ops. No need for the C/C++/Fortran user-defined functions. */ ompi_op_base_op_3buff_fns_t o_3buff_intrinsic; + + /** device functions, only for intrinsic ops. + Provided if device support is detected. */ + ompi_device_op_t *o_device_op; }; /** @@ -376,7 +390,7 @@ OMPI_DECLSPEC void ompi_op_set_java_callback(ompi_op_t *op, void *jnienv, * this function is provided to hide the internal structure field * names. */ -static inline bool ompi_op_is_intrinsic(ompi_op_t * op) +static inline bool ompi_op_is_intrinsic(const ompi_op_t * op) { return (bool) (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)); } @@ -500,9 +514,11 @@ static inline bool ompi_op_is_valid(ompi_op_t * op, ompi_datatype_t * ddt, * optimization). If you give it an intrinsic op with a datatype that * is not defined to have that operation, it is likely to seg fault. */ -static inline void ompi_op_reduce(ompi_op_t * op, const void *source, - void *target, size_t full_count, - ompi_datatype_t * dtype) +static inline void ompi_op_reduce_stream(ompi_op_t * op, const void *source, + void *target, size_t full_count, + ompi_datatype_t * dtype, + int device, + opal_accelerator_stream_t *stream) { MPI_Fint f_dtype, f_count; int count = full_count; @@ -531,7 +547,7 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, } shift = done_count * ext; // Recurse one level in iterations of 'int' - ompi_op_reduce(op, (const char*)source + shift, (char*)target + shift, iter_count, dtype); + ompi_op_reduce_stream(op, (char*)source + shift, (char*)target + shift, iter_count, dtype, device, stream); done_count += iter_count; } return; @@ -560,6 +576,44 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, * :-) */ + bool use_device_op = false; + /* check if either of the buffers is on a device and if so make sure we can + * access handle it properly */ + if (device != MCA_ACCELERATOR_NO_DEVICE_ID && + ompi_datatype_is_predefined(dtype) && + 0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC) && + NULL != op->o_device_op) { + use_device_op = true; + } + + if (!use_device_op) { + /* query the accelerator for whether we can still execute */ + int source_dev_id, target_dev_id; + uint64_t source_flags, target_flags; + int target_check_addr = opal_accelerator.check_addr(target, &target_dev_id, &target_flags); + int source_check_addr = opal_accelerator.check_addr(source, &source_dev_id, &source_flags); + if (target_check_addr > 0 && + source_check_addr > 0 && + ompi_datatype_is_predefined(dtype) && + 0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC) && + NULL != op->o_device_op) { + use_device_op = true; + if (target_dev_id == source_dev_id) { + /* both inputs are on the same device; if not the op will take of that */ + device = target_dev_id; + } + } else { + /* check whether we can access the memory from the host */ + if ((target_check_addr == 0 || (target_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source_check_addr == 0 || (source_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + /* nothing to be done, we won't need device-capable ops */ + } else { + opal_show_help("help-ompi-op.txt", "missing implementation", true, op->o_name, dtype->name); + abort(); + } + } + } + /* For intrinsics, we also pass the corresponding op module */ if (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)) { int dtype_id; @@ -569,9 +623,28 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, } else { dtype_id = ompi_op_ddt_map[dtype->id]; } - op->o_func.intrinsic.fns[dtype_id](source, target, - &count, &dtype, - op->o_func.intrinsic.modules[dtype_id]); + if (use_device_op) { + if (NULL == op->o_device_op) { + fprintf(stderr, "no suitable device op module found!"); + abort(); // TODO: be more graceful! + } + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + actual_stream = MCA_ACCELERATOR_STREAM_DEFAULT; + flush_stream = true; + } + op->o_device_op->do_intrinsic.fns[dtype_id]((void*)source, target, + &count, &dtype, device, actual_stream, + op->o_device_op->do_intrinsic.modules[dtype_id]); + if (flush_stream) { + opal_accelerator.sync_stream(actual_stream); + } + } else { + op->o_func.intrinsic.fns[dtype_id]((void*)source, target, + &count, &dtype, + op->o_func.intrinsic.modules[dtype_id]); + } return; } @@ -579,24 +652,31 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) { f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index); f_count = OMPI_INT_2_FINT(count); - op->o_func.fort_fn(source, target, &f_count, &f_dtype); + op->o_func.fort_fn((void*)source, target, &f_count, &f_dtype); return; } else if (0 != (op->o_flags & OMPI_OP_FLAGS_JAVA_FUNC)) { - op->o_func.java_data.intercept_fn(source, target, &count, &dtype, + op->o_func.java_data.intercept_fn((void*)source, target, &count, &dtype, op->o_func.java_data.baseType, op->o_func.java_data.jnienv, op->o_func.java_data.object); return; } - op->o_func.c_fn(source, target, &count, &dtype); + op->o_func.c_fn((void*)source, target, &count, &dtype); return; } -static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, void * restrict source2, +static inline void ompi_op_reduce(ompi_op_t * op, const void *source, + void *target, size_t full_count, + ompi_datatype_t * dtype) +{ + ompi_op_reduce_stream(op, source, target, full_count, dtype, MCA_ACCELERATOR_NO_DEVICE_ID, NULL); +} + +static inline void ompi_3buff_op_user (ompi_op_t *op, const void * source1, const void * source2, void * restrict result, int count, struct ompi_datatype_t *dtype) { - ompi_datatype_copy_content_same_ddt (dtype, count, (char*)result, (char*)source1); - op->o_func.c_fn (source2, result, &count, &dtype); + ompi_datatype_copy_content_same_ddt (dtype, count, result, (void*)source1); + op->o_func.c_fn ((void*)source2, result, &count, &dtype); } /** @@ -622,24 +702,135 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v * * Otherwise, this function is the same as ompi_op_reduce. */ -static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1, - void *source2, void *target, - int count, ompi_datatype_t * dtype) +static inline void ompi_3buff_op_reduce_stream(ompi_op_t * op, const void *source1, + const void *source2, void *target, + int count, ompi_datatype_t * dtype, + int device, + opal_accelerator_stream_t *stream) { - void *restrict src1; - void *restrict src2; - void *restrict tgt; - src1 = source1; - src2 = source2; - tgt = target; + bool use_device_op = false; + if (OPAL_UNLIKELY(!ompi_op_is_intrinsic (op))) { + /* no 3buff variants for user-defined ops */ + ompi_3buff_op_user (op, source1, source2, target, count, dtype); + return; + } + + if (device != MCA_ACCELERATOR_NO_DEVICE_ID && + ompi_datatype_is_predefined(dtype) && + op->o_flags & OMPI_OP_FLAGS_INTRINSIC && + NULL != op->o_device_op) { + use_device_op = true; + } + if (!use_device_op) { + int source1_dev_id, source2_dev_id, target_dev_id; + uint64_t source1_flags, source2_flags, target_flags; + int target_check_addr = opal_accelerator.check_addr(target, &target_dev_id, &target_flags); + int source1_check_addr = opal_accelerator.check_addr(source1, &source1_dev_id, &source1_flags); + int source2_check_addr = opal_accelerator.check_addr(source2, &source2_dev_id, &source2_flags); + /* check if either of the buffers is on a device and if so make sure we can + * access handle it properly */ + if (target_check_addr > 0 || source1_check_addr > 0 || source2_check_addr > 0) { + if (ompi_datatype_is_predefined(dtype) && + op->o_flags & OMPI_OP_FLAGS_INTRINSIC && + NULL != op->o_device_op) { + use_device_op = true; + device = target_dev_id; + } else { + /* check whether we can access the memory from the host */ + if ((target_check_addr == 0 || (target_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source1_check_addr == 0 || (source1_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) && + (source2_check_addr == 0 || (source2_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + /* nothing to be done, we won't need device-capable ops */ + } else { + fprintf(stderr, "3buff op: no suitable op module found for device memory!\n"); + abort(); + } + } + } + } + + /* For intrinsics, we also pass the corresponding op module */ + if (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC)) { + int dtype_id; + if (!ompi_datatype_is_predefined(dtype)) { + ompi_datatype_t *dt = ompi_datatype_get_single_predefined_type_from_args(dtype); + dtype_id = ompi_op_ddt_map[dt->id]; + } else { + dtype_id = ompi_op_ddt_map[dtype->id]; + } + if (use_device_op) { + opal_accelerator_stream_t *actual_stream = stream; + bool flush_stream = false; + if (NULL == stream) { + actual_stream = MCA_ACCELERATOR_STREAM_DEFAULT; + flush_stream = true; + } + op->o_device_op->do_3buff_intrinsic.fns[dtype_id]((void*)source1, (void*)source2, target, + &count, &dtype, device, actual_stream, + op->o_device_op->do_3buff_intrinsic.modules[dtype_id]); + if (flush_stream) { + opal_accelerator.sync_stream(actual_stream); + } + } else { + op->o_3buff_intrinsic.fns[dtype_id]((void*)source1, (void*)source2, target, + &count, &dtype, + op->o_func.intrinsic.modules[dtype_id]); + } + } +} + + +static inline void ompi_3buff_op_reduce(ompi_op_t * op, const void *source1, + const void *source2, void *target, + int count, ompi_datatype_t * dtype) +{ if (OPAL_LIKELY(ompi_op_is_intrinsic (op))) { - op->o_3buff_intrinsic.fns[ompi_op_ddt_map[dtype->id]](src1, src2, - tgt, &count, - &dtype, - op->o_3buff_intrinsic.modules[ompi_op_ddt_map[dtype->id]]); + ompi_3buff_op_reduce_stream(op, source1, source2, target, count, dtype, MCA_ACCELERATOR_NO_DEVICE_ID, NULL); } else { - ompi_3buff_op_user (op, src1, src2, tgt, count, dtype); + ompi_3buff_op_user (op, source1, source2, target, count, dtype); + } +} + +static inline void ompi_op_preferred_device(ompi_op_t *op, int source_dev, + int target_dev, size_t count, + ompi_datatype_t *dtype, int *op_device) +{ + /* default to host */ + *op_device = -1; + if (!ompi_op_is_intrinsic (op)) { + return; + } + /* quick check: can we execute on the device? */ + int dtype_id = ompi_op_ddt_map[dtype->id]; + if (NULL == op->o_device_op || NULL == op->o_device_op->do_intrinsic.fns[dtype_id]) { + /* not available on the gpu, must select host */ + return; + } + + size_t size_type; + ompi_datatype_type_size(dtype, &size_type); + + float device_bw; + if (target_dev >= 0) { + opal_accelerator.get_mem_bw(target_dev, &device_bw); + } else if (source_dev >= 0) { + opal_accelerator.get_mem_bw(source_dev, &device_bw); + } + + // assume we reach 50% of theoretical peak on the device + device_bw /= 2.0; + + // TODO: determine at runtime (?) + const float host_bw = 10.0; // 10GB/s + + float host_startup_cost = 0.0; // host has no startup cost + float host_compute_cost = (count*size_type) / (host_bw*1024); // assume 10GB/s memory bandwidth on host + float device_startup_cost = 10.0; // 10us startup cost on device + float device_compute_cost = (count*size_type) / (device_bw*1024); + + if ((host_startup_cost + host_compute_cost) > (device_startup_cost + device_compute_cost)) { + *op_device = (target_dev >= 0) ? target_dev : source_dev; } }