From a108da776b6eee03dd5722086ba5af00fd08bc1a Mon Sep 17 00:00:00 2001 From: Alexander Kalistratov Date: Mon, 7 Oct 2024 17:20:10 +0200 Subject: [PATCH] Implementing histogramdd --- .../extensions/statistics/CMakeLists.txt | 1 + dpnp/backend/extensions/statistics/common.hpp | 94 +++-- .../statistics/histogram_common.cpp | 55 +-- .../statistics/histogram_common.hpp | 43 +- .../extensions/statistics/histogramdd.cpp | 357 +++++++++++++++++ .../extensions/statistics/histogramdd.hpp | 68 ++++ .../extensions/statistics/statistics_py.cpp | 2 + dpnp/dpnp_iface_histograms.py | 376 ++++++++++++++++-- tests/test_histogram.py | 182 +++++++++ tests/test_sycl_queue.py | 26 ++ tests/test_usm_type.py | 14 + .../cupy/statistics_tests/test_histogram.py | 21 +- 12 files changed, 1120 insertions(+), 119 deletions(-) create mode 100644 dpnp/backend/extensions/statistics/histogramdd.cpp create mode 100644 dpnp/backend/extensions/statistics/histogramdd.hpp diff --git a/dpnp/backend/extensions/statistics/CMakeLists.txt b/dpnp/backend/extensions/statistics/CMakeLists.txt index 2b784555630..593d6aca3e0 100644 --- a/dpnp/backend/extensions/statistics/CMakeLists.txt +++ b/dpnp/backend/extensions/statistics/CMakeLists.txt @@ -29,6 +29,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp ) diff --git a/dpnp/backend/extensions/statistics/common.hpp b/dpnp/backend/extensions/statistics/common.hpp index 99a31002ea1..2ea2983f87e 100644 --- a/dpnp/backend/extensions/statistics/common.hpp +++ b/dpnp/backend/extensions/statistics/common.hpp @@ -34,8 +34,30 @@ // so sycl.hpp must be included before math_utils.hpp #include #include "utils/math_utils.hpp" +#include "utils/type_utils.hpp" // clang-format on +namespace dpctl +{ +namespace tensor +{ +namespace type_utils +{ +// Upstream to dpctl +template +struct is_complex> : std::true_type +{ +}; + +template +constexpr bool is_complex_v = is_complex::value; + +} // namespace type_utils +} // namespace tensor +} // namespace dpctl + +namespace type_utils = dpctl::tensor::type_utils; + namespace statistics { namespace common @@ -56,24 +78,20 @@ constexpr auto Align(N n, D d) template struct AtomicOp { - static void add(T &lhs, const T value) + static void add(T &lhs, const T &value) { - sycl::atomic_ref lh(lhs); - lh += value; - } -}; + if constexpr (type_utils::is_complex_v) { + using vT = typename T::value_type; + vT *_lhs = reinterpret_cast(lhs); + const vT *_val = reinterpret_cast(value); -template -struct AtomicOp, Order, Scope> -{ - static void add(std::complex &lhs, const std::complex value) - { - T *_lhs = reinterpret_cast(lhs); - const T *_val = reinterpret_cast(value); - sycl::atomic_ref lh0(_lhs[0]); - lh0 += _val[0]; - sycl::atomic_ref lh1(_lhs[1]); - lh1 += _val[1]; + AtomicOp::add(_lhs[0], _val[0]); + AtomicOp::add(_lhs[1], _val[1]); + } + else { + sycl::atomic_ref lh(lhs); + lh += value; + } } }; @@ -82,17 +100,12 @@ struct Less { bool operator()(const T &lhs, const T &rhs) const { - return std::less{}(lhs, rhs); - } -}; - -template -struct Less> -{ - bool operator()(const std::complex &lhs, - const std::complex &rhs) const - { - return dpctl::tensor::math_utils::less_complex(lhs, rhs); + if constexpr (type_utils::is_complex_v) { + return dpctl::tensor::math_utils::less_complex(lhs, rhs); + } + else { + return std::less{}(lhs, rhs); + } } }; @@ -101,26 +114,25 @@ struct IsNan { static bool isnan(const T &v) { - if constexpr (std::is_floating_point_v || - std::is_same_v) { - return sycl::isnan(v); + if constexpr (type_utils::is_complex_v) { + const auto real1 = std::real(v); + const auto imag1 = std::imag(v); + + using vT = typename T::value_type; + + return IsNan::isnan(real1) || IsNan::isnan(imag1); + } + else { + if constexpr (std::is_floating_point_v || + std::is_same_v) { + return sycl::isnan(v); + } } return false; } }; -template -struct IsNan> -{ - static bool isnan(const std::complex &v) - { - T real1 = std::real(v); - T imag1 = std::imag(v); - return sycl::isnan(real1) || sycl::isnan(imag1); - } -}; - size_t get_max_local_size(const sycl::device &device); size_t get_max_local_size(const sycl::device &device, int cpu_local_size_limit, diff --git a/dpnp/backend/extensions/statistics/histogram_common.cpp b/dpnp/backend/extensions/statistics/histogram_common.cpp index e2445b78bb3..a9ba1dca7d2 100644 --- a/dpnp/backend/extensions/statistics/histogram_common.cpp +++ b/dpnp/backend/extensions/statistics/histogram_common.cpp @@ -137,12 +137,6 @@ void validate(const usm_ndarray &sample, " parameter must have at least 1 element"); } - if (histogram.get_ndim() != 1) { - throw py::value_error(get_name(&histogram) + - " parameter must be 1d. Actual " + - std::to_string(histogram.get_ndim()) + "d"); - } - if (weights_ptr) { if (weights_ptr->get_ndim() != 1) { throw py::value_error( @@ -150,9 +144,9 @@ void validate(const usm_ndarray &sample, std::to_string(weights_ptr->get_ndim()) + "d"); } - auto sample_size = sample.get_size(); + auto sample_size = sample.get_shape(0); auto weights_size = weights_ptr->get_size(); - if (sample.get_size() != weights_ptr->get_size()) { + if (sample_size != weights_ptr->get_size()) { throw py::value_error( get_name(&sample) + " size (" + std::to_string(sample_size) + ") and " + get_name(weights_ptr) + " size (" + @@ -168,42 +162,37 @@ void validate(const usm_ndarray &sample, } if (sample.get_ndim() == 1) { - if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) { + if (histogram.get_ndim() != 1) { throw py::value_error(get_name(&sample) + " parameter is 1d, but " + - get_name(bins_ptr) + " is " + - std::to_string(bins_ptr->get_ndim()) + "d"); + get_name(&histogram) + " is " + + std::to_string(histogram.get_ndim()) + "d"); + } + + if (bins_ptr && histogram.get_size() != bins_ptr->get_size() - 1) { + auto hist_size = histogram.get_size(); + auto bins_size = bins_ptr->get_size(); + throw py::value_error( + get_name(&histogram) + " parameter and " + get_name(bins_ptr) + + " parameters shape mismatch. " + get_name(&histogram) + + " size is " + std::to_string(hist_size) + get_name(bins_ptr) + + " must have size " + std::to_string(hist_size + 1) + + " but have " + std::to_string(bins_size)); } } else if (sample.get_ndim() == 2) { auto sample_count = sample.get_shape(0); auto expected_dims = sample.get_shape(1); - if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) { - throw py::value_error(get_name(&sample) + " parameter has shape {" + - std::to_string(sample_count) + "x" + - std::to_string(expected_dims) + "}" + - ", so " + get_name(bins_ptr) + + if (histogram.get_ndim() != expected_dims) { + throw py::value_error(get_name(&sample) + " parameter has shape (" + + std::to_string(sample_count) + ", " + + std::to_string(expected_dims) + ")" + + ", so " + get_name(&histogram) + " parameter expected to be " + std::to_string(expected_dims) + "d. " "Actual " + - std::to_string(bins->get_ndim()) + "d"); - } - } - - if (bins_ptr != nullptr) { - py::ssize_t expected_hist_size = 1; - for (int i = 0; i < bins_ptr->get_ndim(); ++i) { - expected_hist_size *= (bins_ptr->get_shape(i) - 1); - } - - if (histogram.get_size() != expected_hist_size) { - throw py::value_error( - get_name(&histogram) + " and " + get_name(bins_ptr) + - " shape mismatch. " + get_name(&histogram) + - " expected to have size = " + - std::to_string(expected_hist_size) + ". Actual " + - std::to_string(histogram.get_size())); + std::to_string(histogram.get_ndim()) + "d"); } } diff --git a/dpnp/backend/extensions/statistics/histogram_common.hpp b/dpnp/backend/extensions/statistics/histogram_common.hpp index e7503fe9877..c5aa06191cf 100644 --- a/dpnp/backend/extensions/statistics/histogram_common.hpp +++ b/dpnp/backend/extensions/statistics/histogram_common.hpp @@ -52,12 +52,15 @@ template struct CachedData { static constexpr bool const sync_after_init = true; - using pointer_type = T *; + using Shape = sycl::range; + using value_type = T; + using pointer_type = value_type *; + static constexpr auto dims = Dims; - using ncT = typename std::remove_const::type; + using ncT = typename std::remove_const::type; using LocalData = sycl::local_accessor; - CachedData(T *global_data, sycl::range shape, sycl::handler &cgh) + CachedData(T *global_data, Shape shape, sycl::handler &cgh) { this->global_data = global_data; local_data = LocalData(shape, cgh); @@ -87,9 +90,20 @@ struct CachedData return local_data.size(); } + T &operator[](const sycl::id &id) const + { + return local_data[id]; + } + + template > + T &operator[](const size_t id) const + { + return local_data[id]; + } + private: LocalData local_data; - T *global_data = nullptr; + value_type *global_data = nullptr; }; template @@ -97,7 +111,9 @@ struct UncachedData { static constexpr bool const sync_after_init = false; using Shape = sycl::range; - using pointer_type = T *; + using value_type = T; + using pointer_type = value_type *; + static constexpr auto dims = Dims; UncachedData(T *global_data, const Shape &shape, sycl::handler &) { @@ -120,6 +136,17 @@ struct UncachedData return _shape.size(); } + T &operator[](const sycl::id &id) const + { + return global_data[id]; + } + + template > + T &operator[](const size_t id) const + { + return global_data[id]; + } + private: T *global_data = nullptr; Shape _shape; @@ -290,9 +317,9 @@ class histogram_kernel; template void submit_histogram(const T *in, - size_t size, - size_t dims, - uint32_t WorkPI, + const size_t size, + const size_t dims, + const uint32_t WorkPI, const HistImpl &hist, const Edges &edges, const Weights &weights, diff --git a/dpnp/backend/extensions/statistics/histogramdd.cpp b/dpnp/backend/extensions/statistics/histogramdd.cpp new file mode 100644 index 00000000000..283dbb25b11 --- /dev/null +++ b/dpnp/backend/extensions/statistics/histogramdd.cpp @@ -0,0 +1,357 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include +#include + +#include "histogram_common.hpp" +#include "histogramdd.hpp" + +using dpctl::tensor::usm_ndarray; + +using namespace statistics::histogram; +using namespace statistics::common; + +namespace +{ + +template +struct EdgesDd +{ + static constexpr bool const sync_after_init = true; + using EdgesT = typename DataStorage::value_type; + using EdgesCountT = typename EdgesCountStorage::value_type; + using ncT = typename std::remove_const::type; + using LocalData = sycl::local_accessor; + using boundsT = std::tuple; + + EdgesDd(const EdgesT *global_edges, + const size_t total_size, + const EdgesCountT *global_edges_sizes, + const size_t dims, + sycl::handler &cgh) + : edges(global_edges, sycl::range<1>(total_size), cgh), + edges_size(global_edges_sizes, sycl::range<1>(dims), cgh) + { + min = LocalData(dims, cgh); + max = LocalData(dims, cgh); + } + + template + void init(const sycl::nd_item<_Dims> &item) const + { + auto group = item.get_group(); + edges.init(item); + edges_size.init(item); + + if constexpr (DataStorage::sync_after_init) { + sycl::group_barrier(group, sycl::memory_scope::work_group); + } + + const auto *edges_ptr = edges.get_ptr(); + auto *min_ptr = &min[0]; + auto *max_ptr = &max[0]; + + if (group.leader()) { + for (uint32_t i = 0; i < edges_size.size(); ++i) { + const auto size = edges_size[i]; + min_ptr[i] = edges_ptr[0]; + max_ptr[i] = edges_ptr[size - 1]; + edges_ptr += size; + } + } + } + + boundsT get_bounds() const + { + return {&min[0], &max[0]}; + } + + auto get_bin_for_dim(const EdgesT &val, + const EdgesT *edges_data, + const uint32_t edges_count) const + { + const uint32_t bins_count = edges_count - 1; + + uint32_t bin = std::upper_bound(edges_data, edges_data + edges_count, + val, Less{}) - + edges_data - 1; + bin = std::min(bin, bins_count - 1); + + return bin; + } + + template + auto get_bin(const sycl::nd_item<_Dims> &, + const dT *val, + const boundsT &) const + { + uint32_t resulting_bin = 0; + const auto *edges_ptr = &edges[0]; + const uint32_t dims = edges_size.size(); + + for (uint32_t i = 0; i < dims; ++i) { + const uint32_t curr_edges_count = edges_size[i]; + + const auto bin_id = + get_bin_for_dim(val[i], edges_ptr, curr_edges_count); + resulting_bin = resulting_bin * (curr_edges_count - 1) + bin_id; + edges_ptr += curr_edges_count; + } + + return resulting_bin; + } + + template + bool in_bounds(const dT *val, const boundsT &bounds) const + { + const EdgesT *min = std::get<0>(bounds); + const EdgesT *max = std::get<1>(bounds); + const uint32_t dims = edges_size.size(); + + auto in_bounds = true; + for (uint32_t i = 0; i < dims; ++i) { + in_bounds &= check_in_bounds(val[i], min[i], max[i]); + } + + return in_bounds; + } + +private: + DataStorage edges; + EdgesCountStorage edges_size; + LocalData min; + LocalData max; +}; + +template +using CachedEdgesDd = EdgesDd, CachedData>; + +template +using UncachedEdgesDd = + EdgesDd, CachedData>; + +template +struct HistogramddF +{ + static sycl::event impl(sycl::queue &exec_q, + const void *vin, + const void *vbins_edges, + const void *vbins_edges_count, + const void *vweights, + void *vout, + const size_t bins_count, + const size_t size, + const size_t total_edges, + const size_t dims, + const std::vector &depends) + { + const T *in = static_cast(vin); + const BinsT *bins_edges = static_cast(vbins_edges); + const HistType *weights = static_cast(vweights); + const EdgesCountT *bins_edges_count = + static_cast(vbins_edges_count); + HistType *out = static_cast(vout); + + auto device = exec_q.get_device(); + + const uint32_t local_size = get_max_local_size(exec_q); + constexpr uint32_t WorkPI = 128; // empirically found number + + const auto nd_range = make_ndrange(size, local_size, WorkPI); + + return exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto dispatch_edges = [&](const uint32_t local_mem, + const auto &weights, auto &hist) { + if (device.is_gpu() && (local_mem >= total_edges)) { + auto edges = CachedEdgesDd( + bins_edges, total_edges, bins_edges_count, dims, cgh); + submit_histogram(in, size, dims, WorkPI, hist, edges, + weights, nd_range, cgh); + } + else { + auto edges = UncachedEdgesDd( + bins_edges, total_edges, bins_edges_count, dims, cgh); + submit_histogram(in, size, dims, WorkPI, hist, edges, + weights, nd_range, cgh); + } + }; + + auto dispatch_bins = [&](const auto &weights) { + auto local_mem_size = get_local_mem_size_in_items(device); + local_mem_size -= 2 * dims; // for min-max values + local_mem_size -= CeilDiv(dims * sizeof(EdgesCountT), + sizeof(T)); // for edges count + + if (local_mem_size >= bins_count) { + const auto local_hist_count = get_local_hist_copies_count( + local_mem_size, local_size, bins_count); + + auto hist = HistWithLocalCopies( + out, bins_count, local_hist_count, cgh); + const uint32_t free_local_mem = + local_mem_size - hist.size(); + + dispatch_edges(free_local_mem, weights, hist); + } + else { + auto hist = HistGlobalMemory(out); + auto edges = UncachedEdgesDd( + bins_edges, total_edges, bins_edges_count, dims, cgh); + submit_histogram(in, size, dims, WorkPI, hist, edges, + weights, nd_range, cgh); + } + }; + + if (weights) { + auto _weights = Weights(weights); + dispatch_bins(_weights); + } + else { + auto _weights = NoWeights(); + dispatch_bins(_weights); + } + }); + } +}; + +template +using HistogramddF2 = HistogramddF; + +using SupportedTypes = + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple, + std::tuple, + std::tuple>, + std::tuple>, + std::tuple, float>, + std::tuple, double>, + std::tuple, std::complex>, + std::tuple, std::complex>>; +} // namespace + +Histogramdd::Histogramdd() : dispatch_table("sample", "histogram") +{ + dispatch_table.populate_dispatch_table(); +} + +std::tuple Histogramdd::call( + const dpctl::tensor::usm_ndarray &sample, + const dpctl::tensor::usm_ndarray &bins_edges, + const dpctl::tensor::usm_ndarray &bins_edges_count, + const std::optional &weights, + dpctl::tensor::usm_ndarray &histogram, + const std::vector &depends) +{ + validate(sample, bins_edges, weights, histogram); + + const int sample_typenum = sample.get_typenum(); + const int hist_typenum = histogram.get_typenum(); + + auto histogram_func = dispatch_table.get(sample_typenum, hist_typenum); + + auto exec_q = sample.get_queue(); + + void *weights_ptr = + weights.has_value() ? weights.value().get_data() : nullptr; + + if (sample.get_shape(1) != bins_edges_count.get_size()) { + throw py::value_error("'sample' parameter has shape (" + + std::to_string(sample.get_shape(0)) + ", " + + std::to_string(sample.get_shape(1)) + ")" + + " so array of bins edges must be of size " + + std::to_string(sample.get_shape(1)) + + " but actually is " + + std::to_string(bins_edges_count.get_shape(0))); + } + + auto ev = histogram_func(exec_q, sample.get_data(), bins_edges.get_data(), + bins_edges_count.get_data(), weights_ptr, + histogram.get_data(), histogram.get_size(), + sample.get_shape(0), bins_edges.get_shape(0), + sample.get_shape(1), depends); + + sycl::event args_ev; + if (weights.has_value()) { + args_ev = dpctl::utils::keep_args_alive( + exec_q, + {sample, bins_edges, bins_edges_count, weights.value(), histogram}, + {ev}); + } + else { + args_ev = dpctl::utils::keep_args_alive( + exec_q, {sample, bins_edges, bins_edges_count, histogram}, {ev}); + } + + return {args_ev, ev}; +} + +std::unique_ptr histdd; + +void statistics::histogram::populate_histogramdd(py::module_ m) +{ + using namespace std::placeholders; + + histdd.reset(new Histogramdd()); + + auto hist_func = + [histp = histdd.get()]( + const dpctl::tensor::usm_ndarray &sample, + const dpctl::tensor::usm_ndarray &bins, + const dpctl::tensor::usm_ndarray &bins_count, + const std::optional &weights, + dpctl::tensor::usm_ndarray &histogram, + const std::vector &depends) { + return histp->call(sample, bins, bins_count, weights, histogram, + depends); + }; + + m.def("histogramdd", hist_func, + "Compute the multidimensional histogram of some data.", + py::arg("sample"), py::arg("bins"), py::arg("bins_count"), + py::arg("weights"), py::arg("histogram"), + py::arg("depends") = py::list()); + + auto histogramdd_dtypes = [histp = histdd.get()]() { + return histp->dispatch_table.get_all_supported_types(); + }; + + m.def("histogramdd_dtypes", histogramdd_dtypes, + "Get the supported data types for histogramdd."); +} diff --git a/dpnp/backend/extensions/statistics/histogramdd.hpp b/dpnp/backend/extensions/statistics/histogramdd.hpp new file mode 100644 index 00000000000..7eb72c68672 --- /dev/null +++ b/dpnp/backend/extensions/statistics/histogramdd.hpp @@ -0,0 +1,68 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +// dpctl tensor headers +#include "dpctl4pybind11.hpp" + +#include "dispatch_table.hpp" + +namespace statistics +{ +namespace histogram +{ +struct Histogramdd +{ + using FnT = sycl::event (*)(sycl::queue &, + const void *, + const void *, + const void *, + const void *, + void *, + const size_t, + const size_t, + const size_t, + const size_t, + const std::vector &); + + common::DispatchTable2 dispatch_table; + + Histogramdd(); + + std::tuple + call(const dpctl::tensor::usm_ndarray &input, + const dpctl::tensor::usm_ndarray &bins_edges, + const dpctl::tensor::usm_ndarray &bins_edges_count, + const std::optional &weights, + dpctl::tensor::usm_ndarray &output, + const std::vector &depends); +}; + +void populate_histogramdd(py::module_ m); +} // namespace histogram +} // namespace statistics diff --git a/dpnp/backend/extensions/statistics/statistics_py.cpp b/dpnp/backend/extensions/statistics/statistics_py.cpp index 2f3bf6a901c..ee647c0b1fb 100644 --- a/dpnp/backend/extensions/statistics/statistics_py.cpp +++ b/dpnp/backend/extensions/statistics/statistics_py.cpp @@ -32,9 +32,11 @@ #include "bincount.hpp" #include "histogram.hpp" +#include "histogramdd.hpp" PYBIND11_MODULE(_statistics_impl, m) { statistics::histogram::populate_bincount(m); statistics::histogram::populate_histogram(m); + statistics::histogram::populate_histogramdd(m); } diff --git a/dpnp/dpnp_iface_histograms.py b/dpnp/dpnp_iface_histograms.py index 7981f350b87..f30c7f1b9cf 100644 --- a/dpnp/dpnp_iface_histograms.py +++ b/dpnp/dpnp_iface_histograms.py @@ -38,6 +38,7 @@ """ import operator +from collections.abc import Iterable import dpctl.utils as dpu import numpy @@ -56,6 +57,7 @@ "digitize", "histogram", "histogram_bin_edges", + "histogramdd", ] # range is a keyword argument to many functions, so save the builtin so they can @@ -63,8 +65,8 @@ _range = range -def _result_type_for_device(dtype1, dtype2, device): - rt = dpnp.result_type(dtype1, dtype2) +def _result_type_for_device(dtypes, device): + rt = dpnp.result_type(*dtypes) return map_dtype_to_device(rt, device) @@ -72,7 +74,7 @@ def _align_dtypes(a_dtype, bins_dtype, ntype, supported_types, device): has_fp64 = device.has_aspect_fp64 has_fp16 = device.has_aspect_fp16 - a_bin_dtype = _result_type_for_device(a_dtype, bins_dtype, device) + a_bin_dtype = _result_type_for_device([a_dtype, bins_dtype], device) # histogram implementation doesn't support uint64 as histogram type # we can use int64 instead. Result would be correct even in case of overflow @@ -129,12 +131,18 @@ def _get_outer_edges(a, range): """ + def _is_finite(a): + if dpnp.is_supported_array_type(a): + return dpnp.isfinite(a) + + return numpy.isfinite(a) + if range is not None: first_edge, last_edge = range if first_edge > last_edge: raise ValueError("max must be larger than min in range parameter.") - if not (numpy.isfinite(first_edge) and numpy.isfinite(last_edge)): + if not (_is_finite(first_edge) and _is_finite(last_edge)): raise ValueError( f"supplied range of [{first_edge}, {last_edge}] is not finite" ) @@ -145,7 +153,7 @@ def _get_outer_edges(a, range): else: first_edge, last_edge = a.min(), a.max() - if not (dpnp.isfinite(first_edge) and dpnp.isfinite(last_edge)): + if not (_is_finite(first_edge) and _is_finite(last_edge)): raise ValueError( f"autodetected range of [{first_edge}, {last_edge}] " "is not finite" @@ -227,6 +235,32 @@ def _get_bin_edges(a, bins, range, usm_type): return bin_edges, None +def _normalize_array(a, dtype, usm_type=None): + if usm_type is None: + usm_type = a.usm_type + + try: + return dpnp.asarray( + a, + dtype=dtype, + usm_type=usm_type, + sycl_queue=a.sycl_queue, + order="C", + copy=False, + ) + except ValueError: + pass + + return dpnp.asarray( + a, + dtype=dtype, + usm_type=usm_type, + sycl_queue=a.sycl_queue, + order="C", + copy=True, + ) + + def _bincount_validate(x, weights, minlength): if x.ndim > 1: raise ValueError("object too deep for desired array") @@ -392,20 +426,16 @@ def bincount(x, weights=None, minlength=None): "supported types" ) - x_casted = dpnp.astype(x, dtype=x_casted_dtype, copy=False) + x_casted = _normalize_array(x, dtype=x_casted_dtype) if weights is not None: - weights_casted = dpnp.astype(weights, dtype=ntype_casted, copy=False) + weights_casted = _normalize_array(weights, dtype=ntype_casted) n_casted = _bincount_run_native( x_casted, weights_casted, minlength, ntype_casted, usm_type ) - n_usm_type = n_casted.usm_type - if usm_type != n_usm_type: - n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type) - else: - n = dpnp.astype(n_casted, ntype, copy=False) + n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type) return n @@ -613,14 +643,10 @@ def histogram(a, bins=10, range=None, density=None, weights=None): "supported types" ) - a_casted = dpnp.astype(a, a_bin_dtype, order="C", copy=False) - bin_edges_casted = dpnp.astype( - bin_edges, a_bin_dtype, order="C", copy=False - ) + a_casted = _normalize_array(a, a_bin_dtype) + bin_edges_casted = _normalize_array(bin_edges, a_bin_dtype) weights_casted = ( - dpnp.astype(weights, hist_dtype, order="C", copy=False) - if weights is not None - else None + _normalize_array(weights, hist_dtype) if weights is not None else None ) # histogram implementation uses atomics, but atomics doesn't work with @@ -655,10 +681,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None): ) _manager.add_event_pair(mem_ev, ht_ev) - if usm_type != n_usm_type: - n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type) - else: - n = dpnp.astype(n_casted, ntype, copy=False) + n = _normalize_array(n_casted, dtype=ntype, usm_type=usm_type) if density: db = dpnp.astype( @@ -752,3 +775,310 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None): a, weights, usm_type = _ravel_check_a_and_weights(a, weights) bin_edges, _ = _get_bin_edges(a, bins, range, usm_type) return bin_edges + + +def _histdd_validate_bins(bins): + for i, b in enumerate(bins): + if numpy.ndim(b) == 0: + if b < 1: + raise ValueError( + f"'bins[{i}' must be positive, when an integer" + ) + elif numpy.ndim(b) == 1: + # will check for monotonicity later + pass + else: + raise ValueError( + f"'bins[{i}]' must be either scalar or 1d array-like," + + f" but it is {type(b)}" + ) + + +def _histdd_check_monotonicity(bins): + for i, b in enumerate(bins): + if dpnp.any(b[:-1] > b[1:]): + raise ValueError( + f"bins[{i}] must increase monotonically, when an array" + ) + + +def _histdd_normalize_bins(bins, ndims): + if not isinstance(bins, Iterable): + if not isinstance(bins, int): + raise ValueError("'bins' must be an integer, when a scalar") + + bins = [bins] * ndims + + if len(bins) != ndims: + raise ValueError( + f"The dimension of bins ({len(bins)}) must be equal" + + f" to the dimension of the sample ({ndims})." + ) + + _histdd_validate_bins(bins) + + return bins + + +def _histdd_normalize_range(range, ndims): + if range is None: + range = [None] * ndims + + if len(range) != ndims: + raise ValueError( + f"range argument length ({len(range)}) must match" + + f" number of dimensions ({ndims})" + ) + + return range + + +def _histdd_make_edges(sample, bins, range, usm_type): + bedges_list = [] + for i, (r, _bins) in enumerate(zip(range, bins)): + bedges, _ = _get_bin_edges(sample[:, i], _bins, r, usm_type) + bedges_list.append(bedges) + + return bedges_list + + +def _histdd_flatten_binedges(bedges_list, edges_count_list, dtype): + queue = bedges_list[0].sycl_queue + usm_type = bedges_list[0].usm_type + total_edges_size = numpy.sum(edges_count_list) + + bin_edges_flat = dpnp.empty( + shape=total_edges_size, dtype=dtype, sycl_queue=queue, usm_type=usm_type + ) + + offset = numpy.pad(numpy.cumsum(edges_count_list), (1, 0)) + bin_edges_view_list = [] + for start, end, bedges in zip(offset[:-1], offset[1:], bedges_list): + edges_slice = bin_edges_flat[start:end] + bin_edges_view_list.append(edges_slice) + edges_slice[:] = bedges + + return bin_edges_flat, bin_edges_view_list + + +def _histdd_run_native( + sample, weights, hist_dtype, bin_edges, edges_count_list, usm_type +): + queue = sample.sycl_queue + + hist_shape = [ec - 1 for ec in edges_count_list] + bin_edges_count = dpnp.asarray( + edges_count_list, dtype=dpnp.int64, sycl_queue=queue + ) + + n_usm_type = "device" if usm_type == "host" else usm_type + n = dpnp.zeros( + shape=hist_shape, + dtype=hist_dtype, + sycl_queue=queue, + usm_type=n_usm_type, + ) + + sample_usm = dpnp.get_usm_ndarray(sample) + weights_usm = dpnp.get_usm_ndarray(weights) if weights is not None else None + edges_usm = dpnp.get_usm_ndarray(bin_edges) + edges_count_usm = dpnp.get_usm_ndarray(bin_edges_count) + n_usm = dpnp.get_usm_ndarray(n) + + _manager = dpu.SequentialOrderManager[queue] + + mem_ev, hdd_ev = statistics_ext.histogramdd( + sample_usm, + edges_usm, + edges_count_usm, + weights_usm, + n_usm, + depends=_manager.submitted_events, + ) + + _manager.add_event_pair(mem_ev, hdd_ev) + + return n + + +def _histdd_hist_dtype(queue, weights): + hist_dtype = dpnp.default_float_type(sycl_queue=queue) + device = queue.sycl_device + + if weights is not None: + # hist_dtype is either float or complex, so it is ok + # to calculate it as result type between default_float and + # weights.dtype + hist_dtype = _result_type_for_device( + [hist_dtype, weights.dtype], device + ) + + return hist_dtype + + +def _histdd_sample_dtype(queue, sample, bin_edges_list): + device = queue.sycl_device + + dtypes_ = [bin_edges.dtype for bin_edges in bin_edges_list] + dtypes_.append(sample.dtype) + + return _result_type_for_device(dtypes_, device) + + +def _histdd_supported_dtypes(sample, bin_edges_list, weights): + queue = sample.sycl_queue + device = queue.sycl_device + + hist_dtype = _histdd_hist_dtype(queue, weights) + sample_dtype = _histdd_sample_dtype(queue, sample, bin_edges_list) + + supported_types = statistics_ext.histogramdd_dtypes() + + # passing sample_dtype twice as we already + # aligned sample_dtype and bins dtypes + sample_dtype, hist_dtype = _align_dtypes( + sample_dtype, sample_dtype, hist_dtype, supported_types, device + ) + + return sample_dtype, hist_dtype + + +def _histdd_extract_arrays(sample, weights, bins): + all_arrays = [sample] + if weights is not None: + all_arrays.append(weights) + + if isinstance(bins, Iterable): + all_arrays.extend([b for b in bins if dpnp.is_supported_array_type(b)]) + + return all_arrays + + +def histogramdd(sample, bins=10, range=None, weights=None, density=False): + """ + Compute the multidimensional histogram of some data. + + For full documentation refer to :obj:`numpy.histogramdd`. + + Parameters + ---------- + sample : {dpnp.ndarray, usm_ndarray} + Input (N, D)-shaped array to be histogrammed. + + bins : {sequence, int}, optional + The bin specification: + * A sequence of arrays describing the monotonically increasing bin + edges along each dimension. + * The number of bins for each dimension (nx, ny, ... =bins) + * The number of bins for all dimensions (nx=ny=...=bins). + Default: ``10`` + range : {None, sequence}, optional + A sequence of length D, each an optional (lower, upper) tuple giving + the outer bin edges to be used if the edges are not given explicitly in + `bins`. + An entry of None in the sequence results in the minimum and maximum + values being used for the corresponding dimension. + None is equivalent to passing a tuple of D None values. + Default: ``None`` + weights : {dpnp.ndarray, usm_ndarray}, optional + An (N,)-shaped array of values `w_i` weighing each sample + `(x_i, y_i, z_i, ...)`. + Weights are normalized to 1 if density is True. If density is False, + the values of the returned histogram are equal to the sum of the + weights belonging to the samples falling into each bin. + Default: ``None`` + density : {bool}, optional + If ``False``, the default, returns the number of samples in each bin. + If ``True``, returns the probability *density* function at the bin, + ``bin_count / sample_count / bin_volume``. + Default: ``False`` + + Returns + ------- + H : {dpnp.ndarray} + The multidimensional histogram of sample x. See density and weights + for the different possible semantics. + edges : {list of ndarrays} + A list of D arrays describing the bin edges for each dimension. + + See Also + -------- + :obj:`dpnp.histogram`: 1-D histogram + :obj:`dpnp.histogram2d`: 2-D histogram + + Examples + -------- + >>> import dpnp as np + >>> r = np.random.normal(size=(100,3)) + >>> H, edges = np.histogramdd(r, bins = (5, 8, 4)) + >>> H.shape, edges[0].size, edges[1].size, edges[2].size + ((5, 8, 4), 6, 9, 5) + + """ + + if not dpnp.is_supported_array_type(sample): + raise ValueError("sample must be dpnp.ndarray or usm_ndarray") + + if weights is not None and not dpnp.is_supported_array_type(weights): + raise ValueError("weights must be dpnp.ndarray or usm_ndarray") + + if sample.ndim == 0 and sample.size == 1: + sample = dpnp.reshape(sample, (1, 1)) + elif sample.ndim == 1: + sample = dpnp.reshape(sample, (sample.size, 1)) + elif sample.ndim > 2: + raise ValueError("sample must have no more than 2 dimensions") + + ndim = sample.shape[1] if sample.size > 0 else 1 + + _arrays = _histdd_extract_arrays(sample, weights, bins) + usm_type = dpu.get_coerced_usm_type([a.usm_type for a in _arrays]) + queue = dpu.get_execution_queue([a.sycl_queue for a in _arrays]) + + assert usm_type is not None + + if queue is None: + raise ValueError("all arrays must be allocated on the same SYCL queue") + + bins = _histdd_normalize_bins(bins, ndim) + range = _histdd_normalize_range(range, ndim) + + bin_edges_list = _histdd_make_edges(sample, bins, range, usm_type) + sample_dtype, hist_dtype = _histdd_supported_dtypes( + sample, bin_edges_list, weights + ) + + edges_count_list = [bin_edges.size for bin_edges in bin_edges_list] + bin_edges_flat, bin_edges_view_list = _histdd_flatten_binedges( + bin_edges_list, edges_count_list, sample_dtype + ) + + _histdd_check_monotonicity(bin_edges_view_list) + + sample_ = _normalize_array(sample, sample_dtype) + weights_ = ( + _normalize_array(weights, hist_dtype) if weights is not None else None + ) + n = _histdd_run_native( + sample_, + weights_, + hist_dtype, + bin_edges_flat, + edges_count_list, + usm_type, + ) + + expexted_hist_dtype = _histdd_hist_dtype(queue, weights) + n = _normalize_array(n, expexted_hist_dtype, usm_type) + + if density: + # calculate the probability density function + s = n.sum() + for i in _range(ndim): + diff = dpnp.diff(bin_edges_view_list[i]) + shape = [1] * ndim + shape[i] = diff.size + n = n / dpnp.reshape(diff, shape=shape) + n /= s + + return n, bin_edges_view_list diff --git a/tests/test_histogram.py b/tests/test_histogram.py index 92abe99526a..925196c5129 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -610,3 +610,185 @@ def test_bincount_weights(self, array, weights): expected = numpy.bincount(np_a, weights=np_weights) result = dpnp.bincount(dpnp_a, weights=dpnp_weights) assert_allclose(expected, result) + + +class TestHistogramDd: + @pytest.mark.usefixtures("suppress_complex_warning") + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_bool=True) + ) + def test_rand_data(self, dtype): + n = 100 + dims = 3 + v = numpy.random.rand(n, dims).astype(dtype=dtype) + iv = dpnp.array(v, dtype=dtype) + + expected_hist, _ = numpy.histogramdd(v) + result_hist, _ = dpnp.histogramdd(iv) + assert_array_equal(result_hist, expected_hist) + + @pytest.mark.usefixtures("suppress_complex_warning") + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_bool=True) + ) + def test_linspace_data(self, dtype): + n = 100 + dims = 3 + v = numpy.linspace(0, 10, n * dims, dtype=dtype).reshape(n, dims) + iv = dpnp.array(v) + + expected_hist, _ = numpy.histogramdd(v) + result_hist, _ = dpnp.histogramdd(iv) + assert_array_equal(result_hist, expected_hist) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_invalid_bin(self, xp): + a = xp.array([[1, 2]]) + assert_raises(ValueError, xp.histogramdd, a, bins=0.1) + + @pytest.mark.parametrize( + "bins", + [ + 11, + [11] * 3, + [[0, 20, 40, 60, 80, 100]] * 3, + [[0, 20, 40, 60, 80, 300]] * 3, + ], + ) + def test_bins(self, bins): + n = 100 + dims = 3 + v = numpy.arange(100 * 3).reshape(n, dims) + iv = dpnp.array(v) + + bins_dpnp = bins + if isinstance(bins, list): + if isinstance(bins[0], list): + bins = [numpy.array(b) for b in bins] + bins_dpnp = [dpnp.array(b) for b in bins] + + expected_hist, expected_edges = numpy.histogramdd(v, bins) + result_hist, result_edges = dpnp.histogramdd(iv, bins_dpnp) + assert_allclose(result_hist, expected_hist) + assert_allclose(result_edges, expected_edges) + + def test_no_side_effects(self): + v = dpnp.array([[1.3, 2.5, 2.3]]) + copy_v = v.copy() + + # check that ensures that values passed to ``histogramdd`` are unchanged + _, _ = dpnp.histogramdd(v) + assert (v == copy_v).all() + + @pytest.mark.parametrize("data", [[], 1, [0, 1, 1, 3, 3]]) + def test_01d(self, data): + a = numpy.array(data) + ia = dpnp.array(a) + + expected_hist, expected_edges = numpy.histogramdd(a) + result_hist, result_edges = dpnp.histogramdd(ia) + + assert_allclose(result_hist, expected_hist) + assert_allclose(result_edges, expected_edges) + + def test_3d(self): + a = dpnp.ones((10, 10, 10)) + + with assert_raises_regex(ValueError, "no more than 2 dimensions"): + dpnp.histogramdd(a) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_finite_range(self, xp): + vals = xp.linspace(0.0, 1.0, num=100) + + # normal ranges should be fine + _, _ = xp.histogramdd(vals, range=[[0.25, 0.75]]) + assert_raises(ValueError, xp.histogramdd, vals, range=[[xp.nan, 0.75]]) + assert_raises(ValueError, xp.histogramdd, vals, range=[[0.25, xp.inf]]) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_invalid_range(self, xp): + # start of range must be < end of range + vals = xp.linspace(0.0, 1.0, num=100) + with assert_raises_regex(ValueError, "max must be larger than"): + xp.histogramdd(vals, range=[[0.1, 0.01]]) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + @pytest.mark.parametrize("inf_val", [-numpy.inf, numpy.inf]) + def test_infinite_edge(self, xp, inf_val): + v = xp.array([0.5, 1.5, inf_val]) + min, max = v.min(), v.max() + + # both first and last ranges must be finite + with assert_raises_regex( + ValueError, + f"autodetected range of \\[{min}, {max}\\] is not finite", + ): + xp.histogramdd(v) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_unsigned_monotonicity_check(self, xp): + # bins must increase monotonically when bins contain unsigned values + arr = xp.array([2]) + bins = xp.array([1, 3, 1], dtype="uint64") + with assert_raises(ValueError): + xp.histogramdd(arr, bins=bins) + + def test_nan_values(self): + one_nan = numpy.array([0, 1, numpy.nan]) + all_nan = numpy.array([numpy.nan, numpy.nan]) + + ione_nan = dpnp.array(one_nan) + iall_nan = dpnp.array(all_nan) + + # NaN is not counted + expected_hist, expected_edges = numpy.histogramdd( + one_nan, bins=[[0, 1]] + ) + result_hist, result_edges = dpnp.histogramdd(ione_nan, bins=[[0, 1]]) + assert_allclose(result_hist, expected_hist) + assert_allclose(result_edges, expected_edges) + + # NaN is not counted + expected_hist, expected_edges = numpy.histogramdd( + all_nan, bins=[[0, 1]] + ) + result_hist, result_edges = dpnp.histogramdd(iall_nan, bins=[[0, 1]]) + assert_allclose(result_hist, expected_hist) + assert_allclose(result_edges, expected_edges) + + def test_bins_another_sycl_queue(self): + v = dpnp.arange(7, 12, sycl_queue=dpctl.SyclQueue()) + bins = dpnp.arange(4, sycl_queue=dpctl.SyclQueue()) + with assert_raises(ValueError): + dpnp.histogramdd(v, bins=[bins]) + + def test_sample_array_like(self): + v = [0, 1, 2, 3, 4] + with assert_raises(ValueError): + dpnp.histogramdd(v) + + def test_weights_array_like(self): + v = dpnp.arange(5) + w = [1, 2, 3, 4, 5] + with assert_raises(ValueError): + dpnp.histogramdd(v, weights=w) + + def test_weights_another_sycl_queue(self): + v = dpnp.arange(5, sycl_queue=dpctl.SyclQueue()) + w = dpnp.arange(7, 12, sycl_queue=dpctl.SyclQueue()) + with assert_raises(ValueError): + dpnp.histogramdd(v, weights=w) + + @pytest.mark.parametrize( + "bins_count", + [10, 10**2, 10**3, 10**4, 10**5, 10**6], + ) + def test_different_bins_amount(self, bins_count): + v = numpy.linspace(0, bins_count, bins_count, dtype=numpy.float32) + iv = dpnp.array(v) + + expected_hist, expected_edges = numpy.histogramdd(v, bins=[bins_count]) + result_hist, result_edges = dpnp.histogramdd(iv, bins=[bins_count]) + assert_array_equal(result_hist, expected_hist) + assert_allclose(result_edges, expected_edges) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 87e13dcb658..dafb0b58912 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -2633,6 +2633,32 @@ def test_histogram(weights, device): assert_sycl_queue_equal(edges_queue, iv.sycl_queue) +@pytest.mark.parametrize("weights", [None, numpy.arange(7, 12)]) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_histogramdd(weights, device): + v = numpy.arange(5) + w = weights + + iv = dpnp.array(v, device=device) + iw = None if weights is None else dpnp.array(w, sycl_queue=iv.sycl_queue) + + expected_hist, expected_edges = numpy.histogramdd(v, weights=w) + result_hist, result_edges = dpnp.histogramdd(iv, weights=iw) + assert_array_equal(result_hist, expected_hist) + for result_edge, expected_edge in zip(result_edges, expected_edges): + assert_dtype_allclose(result_edge, expected_edge) + + hist_queue = result_hist.sycl_queue + assert_sycl_queue_equal(hist_queue, iv.sycl_queue) + for edge in result_edges: + edges_queue = edge.sycl_queue + assert_sycl_queue_equal(edges_queue, iv.sycl_queue) + + @pytest.mark.parametrize( "func", ["tril_indices_from", "triu_indices_from", "diag_indices_from"] ) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index d6926ab16a4..f31e94adffb 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -1616,6 +1616,20 @@ def test_bincount(usm_type_v, usm_type_w): assert hist.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w]) +@pytest.mark.parametrize("usm_type_v", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_w", list_of_usm_types, ids=list_of_usm_types) +def test_histogramdd(usm_type_v, usm_type_w): + v = dp.arange(5, usm_type=usm_type_v) + w = dp.arange(7, 12, usm_type=usm_type_w) + + hist, edges = dp.histogramdd(v, weights=w) + assert v.usm_type == usm_type_v + assert w.usm_type == usm_type_w + assert hist.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w]) + for e in edges: + assert e.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w]) + + @pytest.mark.parametrize( "func", ["tril_indices_from", "triu_indices_from", "diag_indices_from"] ) diff --git a/tests/third_party/cupy/statistics_tests/test_histogram.py b/tests/third_party/cupy/statistics_tests/test_histogram.py index 29c02c2ea5d..671a1845faf 100644 --- a/tests/third_party/cupy/statistics_tests/test_histogram.py +++ b/tests/third_party/cupy/statistics_tests/test_histogram.py @@ -453,19 +453,18 @@ def test_digitize_nd_bins(self): xp.digitize(x, bins) -@pytest.mark.skip("histogramdd() is not implemented yet") @testing.parameterize( *testing.product( { "weights": [None, 1, 2], - "weights_dtype": [numpy.int32, numpy.float64], + "weights_dtype": [numpy.int32, numpy.float32], "density": [True, False], "bins": [ 10, - (8, 16, 12), - (16, 8, 12), - (16, 12, 8), - (12, 8, 16), + (9, 17, 13), + (17, 9, 13), + (17, 13, 8), + (13, 9, 17), "array_list", ], "range": [None, ((20, 50), (10, 100), (0, 40))], @@ -474,7 +473,7 @@ def test_digitize_nd_bins(self): ) class TestHistogramdd: @testing.for_all_dtypes(no_bool=True, no_complex=True) - @testing.numpy_cupy_allclose(atol=1e-7, rtol=1e-7) + @testing.numpy_cupy_allclose(atol=1e-3, rtol=1e-3, type_check=False) def test_histogramdd(self, xp, dtype): x = testing.shaped_random((100, 3), xp, dtype, scale=100) if self.bins == "array_list": @@ -485,6 +484,7 @@ def test_histogramdd(self, xp, dtype): weights = xp.ones((x.shape[0],), dtype=self.weights_dtype) else: weights = None + y, bin_edges = xp.histogramdd( x, bins=bins, @@ -497,7 +497,6 @@ def test_histogramdd(self, xp, dtype): ] + [e for e in bin_edges] -@pytest.mark.skip("histogramdd() is not implemented yet") class TestHistogramddErrors(unittest.TestCase): def test_histogramdd_invalid_bins(self): for xp in (numpy, cupy): @@ -536,12 +535,6 @@ def test_histogramdd_invalid_range(self): with pytest.raises(ValueError): y, bin_edges = xp.histogramdd(x, range=r) - def test_histogramdd_disallow_arraylike_bins(self): - x = testing.shaped_random((16, 2), cupy, scale=100) - bins = [[0, 10, 20, 50, 90]] * 2 # too many dimensions - with pytest.raises(ValueError): - y, bin_edges = cupy.histogramdd(x, bins=bins) - @pytest.mark.skip("histogram2d() is not implemented yet") @testing.parameterize(