Skip to content

Commit

Permalink
Implements dpctl.tensor.logsumexp and ``dpctl.tensor.reduce_hypot…
Browse files Browse the repository at this point in the history
…`` (#1446)

* Implements logsumexp and reduce_hypot

* Implements dedicated kernels for temp reductions over axes 1 and 0 in contiguous arrays

* logsumexp and reduce_hypot no longer use atomics

This change was made to improve the accuracy of these functions

* Adds tests for reduce_hypot and logsumexp

* Arithmetic reductions no longer use atomics for inexact types

This change is intended to improve the numerical stability of sum and prod

* Removed support of atomic reduction for max and min

* Adds new tests for reductions

* Split reductions into multiple source files

* Remove unneccessary imports of reduction init functions

* Added functions for querying reduction atomic support per type per function

* Corrected ``min`` contig variant typo

These variants were using ``sycl::maximum`` rather than ``sycl::minimum``

* Removes _tree_reduction_over_axis

Use lambdas to ignore atomic-specific arguments to hypot and logsumexp dtype_supported functions

* Always use atomic implementation for min/max if available

For add/multiplies reductions, use tree reduction for FP types, real and complex,
to get better round-off accumulation properties.

* ``logaddexp`` implementation moved to math_utils

Reduces code repetition between logsumexp and logaddexp

---------

Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
  • Loading branch information
ndgrigorian and oleksandr-pavlyk authored Oct 26, 2023
1 parent 2eba93e commit 03fd737
Show file tree
Hide file tree
Showing 32 changed files with 6,735 additions and 2,402 deletions.
16 changes: 14 additions & 2 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
)
set(_reduction_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/prod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
)
set(_tensor_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
Expand All @@ -120,11 +131,11 @@ set(_tensor_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
)
list(APPEND _tensor_impl_sources
${_elementwise_sources}
${_reduction_sources}
)

set(python_module_name _tensor_impl)
Expand All @@ -138,12 +149,13 @@ endif()
set(_no_fast_math_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
)
list(APPEND _no_fast_math_sources
${_elementwise_sources}
${_reduction_sources}
)

foreach(_src_fn ${_no_fast_math_sources})
get_source_file_property(_cmpl_options_prop ${_src_fn} COMPILE_OPTIONS)
set(_combined_options_prop ${_cmpl_options_prop} "${_clang_prefix}-fno-fast-math")
Expand Down
13 changes: 12 additions & 1 deletion dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,16 @@
tanh,
trunc,
)
from ._reduction import argmax, argmin, max, min, prod, sum
from ._reduction import (
argmax,
argmin,
logsumexp,
max,
min,
prod,
reduce_hypot,
sum,
)
from ._testing import allclose

__all__ = [
Expand Down Expand Up @@ -324,4 +333,6 @@
"copysign",
"rsqrt",
"clip",
"logsumexp",
"reduce_hypot",
]
159 changes: 148 additions & 11 deletions dpctl/tensor/_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,28 @@ def _default_reduction_dtype(inp_dt, q):
return res_dt


def _default_reduction_dtype_fp_types(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when reduction is performed on queue `q`
and the reduction supports only floating-point data types
"""
inp_kind = inp_dt.kind
if inp_kind in "biu":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
can_cast_v = dpt.can_cast(inp_dt, res_dt)
if not can_cast_v:
_fp64 = q.sycl_device.has_aspect_fp64
res_dt = dpt.float64 if _fp64 else dpt.float32
elif inp_kind in "f":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
if res_dt.itemsize < inp_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "c":
raise TypeError("reduction not defined for complex types")

return res_dt


def _reduction_over_axis(
x,
axis,
Expand Down Expand Up @@ -91,12 +113,15 @@ def _reduction_over_axis(
res_shape = res_shape + (1,) * red_nd
inv_perm = sorted(range(nd), key=lambda d: perm[d])
res_shape = tuple(res_shape[i] for i in inv_perm)
return dpt.full(
res_shape,
_identity,
dtype=res_dt,
usm_type=res_usm_type,
sycl_queue=q,
return dpt.astype(
dpt.full(
res_shape,
_identity,
dtype=_default_reduction_type_fn(inp_dt, q),
usm_type=res_usm_type,
sycl_queue=q,
),
res_dt,
)
if red_nd == 0:
return dpt.astype(x, res_dt, copy=False)
Expand All @@ -116,7 +141,7 @@ def _reduction_over_axis(
"Automatically determined reduction data type does not "
"have direct implementation"
)
tmp_dt = _default_reduction_dtype(inp_dt, q)
tmp_dt = _default_reduction_type_fn(inp_dt, q)
tmp = dpt.empty(
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
)
Expand Down Expand Up @@ -161,13 +186,13 @@ def sum(x, axis=None, dtype=None, keepdims=False):
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If x` has signed integral data type, the returned array
* If `x` has signed integral data type, the returned array
will have the default signed integral type for the device
where input array `x` is allocated.
* If `x` has unsigned integral data type, the returned array
will have the default unsigned integral type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data typee,
* If `x` has a complex-valued floating-point data type,
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `x` is allocated.
Expand Down Expand Up @@ -222,13 +247,13 @@ def prod(x, axis=None, dtype=None, keepdims=False):
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If x` has signed integral data type, the returned array
* If `x` has signed integral data type, the returned array
will have the default signed integral type for the device
where input array `x` is allocated.
* If `x` has unsigned integral data type, the returned array
will have the default unsigned integral type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data typee,
* If `x` has a complex-valued floating-point data type,
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `x` is allocated.
Expand Down Expand Up @@ -263,6 +288,118 @@ def prod(x, axis=None, dtype=None, keepdims=False):
)


def logsumexp(x, axis=None, dtype=None, keepdims=False):
"""logsumexp(x, axis=None, dtype=None, keepdims=False)
Calculates the logarithm of the sum of exponentials of elements in the
input array `x`.
Args:
x (usm_ndarray):
input array.
axis (Optional[int, Tuple[int, ...]]):
axis or axes along which values must be computed. If a tuple
of unique integers, values are computed over multiple axes.
If `None`, the result is computed over the entire array.
Default: `None`.
dtype (Optional[dtype]):
data type of the returned array. If `None`, the default data
type is inferred from the "kind" of the input array data type.
* If `x` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If `x` has a boolean or integral data type, the returned array
will have the default floating point data type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data type,
an error is raised.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array elements are cast to the
specified data type before computing the result. Default: `None`.
keepdims (Optional[bool]):
if `True`, the reduced axes (dimensions) are included in the result
as singleton dimensions, so that the returned array remains
compatible with the input arrays according to Array Broadcasting
rules. Otherwise, if `False`, the reduced axes are not included in
the returned array. Default: `False`.
Returns:
usm_ndarray:
an array containing the results. If the result was computed over
the entire array, a zero-dimensional array is returned. The returned
array has the data type as described in the `dtype` parameter
description above.
"""
return _reduction_over_axis(
x,
axis,
dtype,
keepdims,
ti._logsumexp_over_axis,
lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported(
inp_dt, res_dt
),
_default_reduction_dtype_fp_types,
_identity=-dpt.inf,
)


def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
"""reduce_hypot(x, axis=None, dtype=None, keepdims=False)
Calculates the square root of the sum of squares of elements in the input
array `x`.
Args:
x (usm_ndarray):
input array.
axis (Optional[int, Tuple[int, ...]]):
axis or axes along which values must be computed. If a tuple
of unique integers, values are computed over multiple axes.
If `None`, the result is computed over the entire array.
Default: `None`.
dtype (Optional[dtype]):
data type of the returned array. If `None`, the default data
type is inferred from the "kind" of the input array data type.
* If `x` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If `x` has a boolean or integral data type, the returned array
will have the default floating point data type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data type,
an error is raised.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array elements are cast to the
specified data type before computing the result. Default: `None`.
keepdims (Optional[bool]):
if `True`, the reduced axes (dimensions) are included in the result
as singleton dimensions, so that the returned array remains
compatible with the input arrays according to Array Broadcasting
rules. Otherwise, if `False`, the reduced axes are not included in
the returned array. Default: `False`.
Returns:
usm_ndarray:
an array containing the results. If the result was computed over
the entire array, a zero-dimensional array is returned. The returned
array has the data type as described in the `dtype` parameter
description above.
"""
return _reduction_over_axis(
x,
axis,
dtype,
keepdims,
ti._hypot_over_axis,
lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported(
inp_dt, res_dt
),
_default_reduction_dtype_fp_types,
_identity=0,
)


def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <limits>
#include <type_traits>

#include "utils/math_utils.hpp"
#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"
Expand Down Expand Up @@ -61,7 +62,8 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor

resT operator()(const argT1 &in1, const argT2 &in2) const
{
return impl<resT>(in1, in2);
using dpctl::tensor::math_utils::logaddexp;
return logaddexp<resT>(in1, in2);
}

template <int vec_sz>
Expand All @@ -79,34 +81,15 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
impl_finite<resT>(-std::abs(diff[i]));
}
else {
res[i] = impl<resT>(in1[i], in2[i]);
using dpctl::tensor::math_utils::logaddexp;
res[i] = logaddexp<resT>(in1[i], in2[i]);
}
}

return res;
}

private:
template <typename T> T impl(T const &in1, T const &in2) const
{
if (in1 == in2) { // handle signed infinities
const T log2 = std::log(T(2));
return in1 + log2;
}
else {
const T tmp = in1 - in2;
if (tmp > 0) {
return in1 + std::log1p(std::exp(-tmp));
}
else if (tmp <= 0) {
return in2 + std::log1p(std::exp(tmp));
}
else {
return std::numeric_limits<T>::quiet_NaN();
}
}
}

template <typename T> T impl_finite(T const &in) const
{
return (in > 0) ? (in + std::log1p(std::exp(-in)))
Expand Down
Loading

0 comments on commit 03fd737

Please sign in to comment.