Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of correlate #2180

Merged
merged 9 commits into from
Dec 9, 2024
5 changes: 4 additions & 1 deletion dpnp/backend/extensions/statistics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@

set(python_module_name _statistics_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/bincount.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
{
namespace histogram
namespace statistics::histogram
{
struct Bincount
{
Expand Down Expand Up @@ -63,5 +61,4 @@ struct Bincount
};

void populate_bincount(py::module_ m);
} // namespace histogram
} // namespace statistics
} // namespace statistics::histogram
8 changes: 2 additions & 6 deletions dpnp/backend/extensions/statistics/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
namespace statistics::common
{
namespace common
{

size_t get_max_local_size(const sycl::device &device)
{
constexpr const int default_max_cpu_local_size = 256;
Expand Down Expand Up @@ -120,5 +117,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum)
}
}

} // namespace common
} // namespace statistics
} // namespace statistics::common
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@

namespace type_utils = dpctl::tensor::type_utils;

namespace statistics
{
namespace common
namespace statistics::common
{

template <typename N, typename D>
Expand Down Expand Up @@ -187,5 +185,4 @@ sycl::nd_range<1>
// headers of dpctl.
pybind11::dtype dtype_from_typenum(int dst_typenum);

} // namespace common
} // namespace statistics
} // namespace statistics::common
106 changes: 100 additions & 6 deletions dpnp/backend/extensions/statistics/dispatch_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
namespace py = pybind11;

namespace statistics
namespace statistics::common
{
namespace common
{

template <typename T, typename Rest>
struct one_of
{
Expand Down Expand Up @@ -97,6 +94,32 @@ using DTypePair = std::pair<DType, DType>;
using SupportedDTypeList = std::vector<DType>;
using SupportedDTypeList2 = std::vector<DTypePair>;

template <typename FnT,
typename SupportedTypes,
template <typename>
typename Func>
struct TableBuilder
{
template <typename _FnT, typename T>
struct impl
{
static constexpr bool is_defined = one_of_v<T, SupportedTypes>;

_FnT get()
{
if constexpr (is_defined) {
return Func<T>::impl;
}
else {
return nullptr;
}
}
};

using type =
dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
};

template <typename FnT,
typename SupportedTypes,
template <typename, typename>
Expand Down Expand Up @@ -124,6 +147,78 @@ struct TableBuilder2
dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
};

template <typename FnT>
class DispatchTable
{
public:
DispatchTable(std::string name) : name(name) {}

template <typename SupportedTypes, template <typename> typename Func>
void populate_dispatch_table()
{
using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
TBulder builder;

builder.populate_dispatch_vector(table);
populate_supported_types();
}

FnT get_unsafe(int _typenum) const
{
auto array_types = dpctl_td_ns::usm_ndarray_types();
const int type_id = array_types.typenum_to_lookup_id(_typenum);

return table[type_id];
}

FnT get(int _typenum) const
{
auto fn = get_unsafe(_typenum);

if (fn == nullptr) {
auto array_types = dpctl_td_ns::usm_ndarray_types();
const int _type_id = array_types.typenum_to_lookup_id(_typenum);

py::dtype _dtype = dtype_from_typenum(_type_id);
auto _type_pos = std::find(supported_types.begin(),
supported_types.end(), _dtype);
if (_type_pos == supported_types.end()) {
py::str types = py::str(py::cast(supported_types));
py::str dtype = py::str(_dtype);

py::str err_msg =
py::str("'" + name + "' has unsupported type '") + dtype +
py::str("'."
" Supported types are: ") +
types;

throw py::value_error(static_cast<std::string>(err_msg));
}
}

return fn;
}

const SupportedDTypeList &get_all_supported_types() const
{
return supported_types;
}

private:
void populate_supported_types()
{
for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
if (table[i] != nullptr) {
supported_types.emplace_back(dtype_from_typenum(i));
}
}
}

std::string name;
SupportedDTypeList supported_types;
Table<FnT> table;
};

template <typename FnT>
class DispatchTable2
{
Expand Down Expand Up @@ -288,5 +383,4 @@ class DispatchTable2
Table2<FnT> table;
};

} // namespace common
} // namespace statistics
} // namespace statistics::common
4 changes: 1 addition & 3 deletions dpnp/backend/extensions/statistics/histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
#include <algorithm>
#include <complex>
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <tuple>
#include <vector>

#include <pybind11/pybind11.h>
Expand Down
9 changes: 2 additions & 7 deletions dpnp/backend/extensions/statistics/histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@
#include "dispatch_table.hpp"
#include "dpctl4pybind11.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
{
namespace histogram
namespace statistics::histogram
{
struct Histogram
{
Expand All @@ -61,5 +57,4 @@ struct Histogram
};

void populate_histogram(py::module_ m);
} // namespace histogram
} // namespace statistics
} // namespace statistics::histogram
Loading
Loading