Skip to content

Commit

Permalink
Implementing histogramdd
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Nov 15, 2024
1 parent afa6980 commit a108da7
Show file tree
Hide file tree
Showing 12 changed files with 1,120 additions and 119 deletions.
1 change: 1 addition & 0 deletions dpnp/backend/extensions/statistics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
)
Expand Down
94 changes: 53 additions & 41 deletions dpnp/backend/extensions/statistics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,30 @@
// so sycl.hpp must be included before math_utils.hpp
#include <sycl/sycl.hpp>
#include "utils/math_utils.hpp"
#include "utils/type_utils.hpp"
// clang-format on

namespace dpctl
{
namespace tensor
{
namespace type_utils
{
// Upstream to dpctl
template <class T>
struct is_complex<const std::complex<T>> : std::true_type
{
};

template <typename T>
constexpr bool is_complex_v = is_complex<T>::value;

} // namespace type_utils
} // namespace tensor
} // namespace dpctl

namespace type_utils = dpctl::tensor::type_utils;

namespace statistics
{
namespace common
Expand All @@ -56,24 +78,20 @@ constexpr auto Align(N n, D d)
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
struct AtomicOp
{
static void add(T &lhs, const T value)
static void add(T &lhs, const T &value)
{
sycl::atomic_ref<T, Order, Scope> lh(lhs);
lh += value;
}
};
if constexpr (type_utils::is_complex_v<T>) {
using vT = typename T::value_type;
vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
const vT *_val = reinterpret_cast<const vT(&)[2]>(value);

template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
struct AtomicOp<std::complex<T>, Order, Scope>
{
static void add(std::complex<T> &lhs, const std::complex<T> value)
{
T *_lhs = reinterpret_cast<T(&)[2]>(lhs);
const T *_val = reinterpret_cast<const T(&)[2]>(value);
sycl::atomic_ref<T, Order, Scope> lh0(_lhs[0]);
lh0 += _val[0];
sycl::atomic_ref<T, Order, Scope> lh1(_lhs[1]);
lh1 += _val[1];
AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
}
else {
sycl::atomic_ref<T, Order, Scope> lh(lhs);
lh += value;
}
}
};

Expand All @@ -82,17 +100,12 @@ struct Less
{
bool operator()(const T &lhs, const T &rhs) const
{
return std::less{}(lhs, rhs);
}
};

template <typename T>
struct Less<std::complex<T>>
{
bool operator()(const std::complex<T> &lhs,
const std::complex<T> &rhs) const
{
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
if constexpr (type_utils::is_complex_v<T>) {
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
}
else {
return std::less{}(lhs, rhs);
}
}
};

Expand All @@ -101,26 +114,25 @@ struct IsNan
{
static bool isnan(const T &v)
{
if constexpr (std::is_floating_point_v<T> ||
std::is_same_v<T, sycl::half>) {
return sycl::isnan(v);
if constexpr (type_utils::is_complex_v<T>) {
const auto real1 = std::real(v);
const auto imag1 = std::imag(v);

using vT = typename T::value_type;

return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
}
else {
if constexpr (std::is_floating_point_v<T> ||
std::is_same_v<T, sycl::half>) {
return sycl::isnan(v);
}
}

return false;
}
};

template <typename T>
struct IsNan<std::complex<T>>
{
static bool isnan(const std::complex<T> &v)
{
T real1 = std::real(v);
T imag1 = std::imag(v);
return sycl::isnan(real1) || sycl::isnan(imag1);
}
};

size_t get_max_local_size(const sycl::device &device);
size_t get_max_local_size(const sycl::device &device,
int cpu_local_size_limit,
Expand Down
55 changes: 22 additions & 33 deletions dpnp/backend/extensions/statistics/histogram_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,16 @@ void validate(const usm_ndarray &sample,
" parameter must have at least 1 element");
}

if (histogram.get_ndim() != 1) {
throw py::value_error(get_name(&histogram) +
" parameter must be 1d. Actual " +
std::to_string(histogram.get_ndim()) + "d");
}

if (weights_ptr) {
if (weights_ptr->get_ndim() != 1) {
throw py::value_error(
get_name(weights_ptr) + " parameter must be 1d. Actual " +
std::to_string(weights_ptr->get_ndim()) + "d");
}

auto sample_size = sample.get_size();
auto sample_size = sample.get_shape(0);
auto weights_size = weights_ptr->get_size();
if (sample.get_size() != weights_ptr->get_size()) {
if (sample_size != weights_ptr->get_size()) {
throw py::value_error(
get_name(&sample) + " size (" + std::to_string(sample_size) +
") and " + get_name(weights_ptr) + " size (" +
Expand All @@ -168,42 +162,37 @@ void validate(const usm_ndarray &sample,
}

if (sample.get_ndim() == 1) {
if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) {
if (histogram.get_ndim() != 1) {
throw py::value_error(get_name(&sample) + " parameter is 1d, but " +
get_name(bins_ptr) + " is " +
std::to_string(bins_ptr->get_ndim()) + "d");
get_name(&histogram) + " is " +
std::to_string(histogram.get_ndim()) + "d");
}

if (bins_ptr && histogram.get_size() != bins_ptr->get_size() - 1) {
auto hist_size = histogram.get_size();
auto bins_size = bins_ptr->get_size();
throw py::value_error(
get_name(&histogram) + " parameter and " + get_name(bins_ptr) +
" parameters shape mismatch. " + get_name(&histogram) +
" size is " + std::to_string(hist_size) + get_name(bins_ptr) +
" must have size " + std::to_string(hist_size + 1) +
" but have " + std::to_string(bins_size));
}
}
else if (sample.get_ndim() == 2) {
auto sample_count = sample.get_shape(0);
auto expected_dims = sample.get_shape(1);

if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) {
throw py::value_error(get_name(&sample) + " parameter has shape {" +
std::to_string(sample_count) + "x" +
std::to_string(expected_dims) + "}" +
", so " + get_name(bins_ptr) +
if (histogram.get_ndim() != expected_dims) {
throw py::value_error(get_name(&sample) + " parameter has shape (" +
std::to_string(sample_count) + ", " +
std::to_string(expected_dims) + ")" +
", so " + get_name(&histogram) +
" parameter expected to be " +
std::to_string(expected_dims) +
"d. "
"Actual " +
std::to_string(bins->get_ndim()) + "d");
}
}

if (bins_ptr != nullptr) {
py::ssize_t expected_hist_size = 1;
for (int i = 0; i < bins_ptr->get_ndim(); ++i) {
expected_hist_size *= (bins_ptr->get_shape(i) - 1);
}

if (histogram.get_size() != expected_hist_size) {
throw py::value_error(
get_name(&histogram) + " and " + get_name(bins_ptr) +
" shape mismatch. " + get_name(&histogram) +
" expected to have size = " +
std::to_string(expected_hist_size) + ". Actual " +
std::to_string(histogram.get_size()));
std::to_string(histogram.get_ndim()) + "d");
}
}

Expand Down
43 changes: 35 additions & 8 deletions dpnp/backend/extensions/statistics/histogram_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ template <typename T, int Dims>
struct CachedData
{
static constexpr bool const sync_after_init = true;
using pointer_type = T *;
using Shape = sycl::range<Dims>;
using value_type = T;
using pointer_type = value_type *;
static constexpr auto dims = Dims;

using ncT = typename std::remove_const<T>::type;
using ncT = typename std::remove_const<value_type>::type;
using LocalData = sycl::local_accessor<ncT, Dims>;

CachedData(T *global_data, sycl::range<Dims> shape, sycl::handler &cgh)
CachedData(T *global_data, Shape shape, sycl::handler &cgh)
{
this->global_data = global_data;
local_data = LocalData(shape, cgh);
Expand Down Expand Up @@ -87,17 +90,30 @@ struct CachedData
return local_data.size();
}

T &operator[](const sycl::id<Dims> &id) const
{
return local_data[id];
}

template <typename = std::enable_if_t<Dims == 1>>
T &operator[](const size_t id) const
{
return local_data[id];
}

private:
LocalData local_data;
T *global_data = nullptr;
value_type *global_data = nullptr;
};

template <typename T, int Dims>
struct UncachedData
{
static constexpr bool const sync_after_init = false;
using Shape = sycl::range<Dims>;
using pointer_type = T *;
using value_type = T;
using pointer_type = value_type *;
static constexpr auto dims = Dims;

UncachedData(T *global_data, const Shape &shape, sycl::handler &)
{
Expand All @@ -120,6 +136,17 @@ struct UncachedData
return _shape.size();
}

T &operator[](const sycl::id<Dims> &id) const
{
return global_data[id];
}

template <typename = std::enable_if_t<Dims == 1>>
T &operator[](const size_t id) const
{
return global_data[id];
}

private:
T *global_data = nullptr;
Shape _shape;
Expand Down Expand Up @@ -290,9 +317,9 @@ class histogram_kernel;

template <typename T, typename HistImpl, typename Edges, typename Weights>
void submit_histogram(const T *in,
size_t size,
size_t dims,
uint32_t WorkPI,
const size_t size,
const size_t dims,
const uint32_t WorkPI,
const HistImpl &hist,
const Edges &edges,
const Weights &weights,
Expand Down
Loading

0 comments on commit a108da7

Please sign in to comment.