Skip to content

Commit

Permalink
Added functions for querying reduction atomic support per type per fu…
Browse files Browse the repository at this point in the history
…nction
  • Loading branch information
ndgrigorian committed Oct 25, 2023
1 parent 89a10cc commit fff36a1
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 104 deletions.
30 changes: 21 additions & 9 deletions dpctl/tensor/libtensor/source/reductions/max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
#include <vector>

#include "kernels/reductions.hpp"
#include "reduction_over_axis.hpp"
#include "utils/type_dispatch.hpp"

#include "reduction_atomic_support.hpp"
#include "reduction_over_axis.hpp"

namespace py = pybind11;

namespace dpctl
Expand Down Expand Up @@ -71,8 +73,6 @@ static reduction_contig_impl_fn_ptr

void populate_max_over_axis_dispatch_tables(void)
{
using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
using td_ns::DispatchTableBuilder;

using dpctl::tensor::kernels::MaxOverAxisAtomicStridedFactory;
Expand Down Expand Up @@ -112,6 +112,20 @@ void populate_max_over_axis_dispatch_tables(void)
dtb6.populate_dispatch_table(max_over_axis0_contig_temps_dispatch_table);
}

using atomic_support::atomic_support_fn_ptr_t;
static atomic_support_fn_ptr_t max_atomic_support_vector[td_ns::num_types];

void populate_max_atomic_support_dispatch_vector(void)
{
using td_ns::DispatchVectorBuilder;

using atomic_support::MaxAtomicSupportFactory;
DispatchVectorBuilder<atomic_support_fn_ptr_t, MaxAtomicSupportFactory,
td_ns::num_types>
dvb;
dvb.populate_dispatch_vector(max_atomic_support_vector);
}

} // namespace impl

void init_max(py::module_ m)
Expand All @@ -128,11 +142,9 @@ void init_max(py::module_ m)
using impl::max_over_axis_strided_atomic_dispatch_table;
using impl::max_over_axis_strided_temps_dispatch_table;

using dpctl::tensor::py_internal::check_atomic_support;
const auto &check_atomic_support_size4 =
check_atomic_support</*require_atomic64*/ false>;
const auto &check_atomic_support_size8 =
check_atomic_support</*require_atomic64*/ true>;
using impl::populate_max_atomic_support_dispatch_vector;
populate_max_atomic_support_dispatch_vector();
using impl::max_atomic_support_vector;

auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce,
const arrayT &dst, sycl::queue &exec_q,
Expand All @@ -146,7 +158,7 @@ void init_max(py::module_ m)
max_over_axis_strided_temps_dispatch_table,
max_over_axis0_contig_temps_dispatch_table,
max_over_axis1_contig_temps_dispatch_table,
check_atomic_support_size4, check_atomic_support_size8);
max_atomic_support_vector);
};
m.def("_max_over_axis", max_pyapi, "", py::arg("src"),
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
Expand Down
28 changes: 21 additions & 7 deletions dpctl/tensor/libtensor/source/reductions/min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
#include <vector>

#include "kernels/reductions.hpp"
#include "reduction_over_axis.hpp"
#include "utils/type_dispatch.hpp"

#include "reduction_atomic_support.hpp"
#include "reduction_over_axis.hpp"

namespace py = pybind11;

namespace dpctl
Expand Down Expand Up @@ -112,6 +114,20 @@ void populate_min_over_axis_dispatch_tables(void)
dtb6.populate_dispatch_table(min_over_axis0_contig_temps_dispatch_table);
}

using atomic_support::atomic_support_fn_ptr_t;
static atomic_support_fn_ptr_t min_atomic_support_vector[td_ns::num_types];

void populate_min_atomic_support_dispatch_vector(void)
{
using td_ns::DispatchVectorBuilder;

using atomic_support::MinAtomicSupportFactory;
DispatchVectorBuilder<atomic_support_fn_ptr_t, MinAtomicSupportFactory,
td_ns::num_types>
dvb;
dvb.populate_dispatch_vector(min_atomic_support_vector);
}

} // namespace impl

void init_min(py::module_ m)
Expand All @@ -128,11 +144,9 @@ void init_min(py::module_ m)
using impl::min_over_axis_strided_atomic_dispatch_table;
using impl::min_over_axis_strided_temps_dispatch_table;

using dpctl::tensor::py_internal::check_atomic_support;
const auto &check_atomic_support_size4 =
check_atomic_support</*require_atomic64*/ false>;
const auto &check_atomic_support_size8 =
check_atomic_support</*require_atomic64*/ true>;
using impl::populate_min_atomic_support_dispatch_vector;
populate_min_atomic_support_dispatch_vector();
using impl::min_atomic_support_vector;

auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce,
const arrayT &dst, sycl::queue &exec_q,
Expand All @@ -146,7 +160,7 @@ void init_min(py::module_ m)
min_over_axis_strided_temps_dispatch_table,
min_over_axis0_contig_temps_dispatch_table,
min_over_axis1_contig_temps_dispatch_table,
check_atomic_support_size4, check_atomic_support_size8);
min_atomic_support_vector);
};
m.def("_min_over_axis", min_pyapi, "", py::arg("src"),
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
Expand Down
30 changes: 22 additions & 8 deletions dpctl/tensor/libtensor/source/reductions/prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
#include <vector>

#include "kernels/reductions.hpp"
#include "reduction_over_axis.hpp"
#include "utils/type_dispatch.hpp"

#include "reduction_atomic_support.hpp"
#include "reduction_over_axis.hpp"

namespace py = pybind11;

namespace dpctl
Expand Down Expand Up @@ -112,6 +114,20 @@ void populate_prod_over_axis_dispatch_tables(void)
dtb6.populate_dispatch_table(prod_over_axis0_contig_temps_dispatch_table);
}

using atomic_support::atomic_support_fn_ptr_t;
static atomic_support_fn_ptr_t prod_atomic_support_vector[td_ns::num_types];

void populate_prod_atomic_support_dispatch_vector(void)
{
using td_ns::DispatchVectorBuilder;

using atomic_support::ProductAtomicSupportFactory;
DispatchVectorBuilder<atomic_support_fn_ptr_t, ProductAtomicSupportFactory,
td_ns::num_types>
dvb;
dvb.populate_dispatch_vector(prod_atomic_support_vector);
}

} // namespace impl

void init_prod(py::module_ m)
Expand All @@ -128,11 +144,9 @@ void init_prod(py::module_ m)
using impl::prod_over_axis_strided_atomic_dispatch_table;
using impl::prod_over_axis_strided_temps_dispatch_table;

using dpctl::tensor::py_internal::check_atomic_support;
const auto &check_atomic_support_size4 =
check_atomic_support</*require_atomic64*/ false>;
const auto &check_atomic_support_size8 =
check_atomic_support</*require_atomic64*/ true>;
using impl::populate_prod_atomic_support_dispatch_vector;
populate_prod_atomic_support_dispatch_vector();
using impl::prod_atomic_support_vector;

auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce,
const arrayT &dst, sycl::queue &exec_q,
Expand All @@ -146,7 +160,7 @@ void init_prod(py::module_ m)
prod_over_axis_strided_temps_dispatch_table,
prod_over_axis0_contig_temps_dispatch_table,
prod_over_axis1_contig_temps_dispatch_table,
check_atomic_support_size4, check_atomic_support_size8);
prod_atomic_support_vector);
};
m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"),
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
Expand All @@ -160,7 +174,7 @@ void init_prod(py::module_ m)
input_dtype, output_dtype, dst_usm_type, q,
prod_over_axis_strided_atomic_dispatch_table,
prod_over_axis_strided_temps_dispatch_table,
check_atomic_support_size4, check_atomic_support_size8);
prod_atomic_support_vector);
};
m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "",
py::arg("arg_dtype"), py::arg("out_dtype"),
Expand Down
143 changes: 143 additions & 0 deletions dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
//
// Data Parallel Control (dpctl)
//
// Copyright 2020-2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===--------------------------------------------------------------------===//
///
/// \file
/// This file defines functions of dpctl.tensor._tensor_impl extensions
//===--------------------------------------------------------------------===//

#pragma once
#include <CL/sycl.hpp>
#include <complex>
#include <type_traits>

#include "utils/type_utils.hpp"

namespace dpctl
{
namespace tensor
{
namespace py_internal
{
namespace atomic_support
{

typedef bool (*atomic_support_fn_ptr_t)(const sycl::queue &, sycl::usm::alloc);

/*! @brief Function which returns a constant value for atomic support */
template <bool return_value>
bool fixed_decision(const sycl::queue &, sycl::usm::alloc)
{
return return_value;
}

/*! @brief Template for querying atomic support for a type on a device */
template <typename T>
bool check_atomic_support(const sycl::queue &exec_q,
sycl::usm::alloc usm_alloc_type)
{
constexpr bool atomic32 = (sizeof(T) == 4);
constexpr bool atomic64 = (sizeof(T) == 8);
using dpctl::tensor::type_utils::is_complex;
if constexpr ((!atomic32 && !atomic64) || is_complex<T>::value) {
return fixed_decision<false>(exec_q, usm_alloc_type);
}
else {
bool supports_atomics = false;
const sycl::device &dev = exec_q.get_device();
if constexpr (atomic64) {
if (!dev.has(sycl::aspect::atomic64)) {
return false;
}
}
switch (usm_alloc_type) {
case sycl::usm::alloc::shared:
supports_atomics =
dev.has(sycl::aspect::usm_atomic_shared_allocations);
break;
case sycl::usm::alloc::host:
supports_atomics =
dev.has(sycl::aspect::usm_atomic_host_allocations);
break;
case sycl::usm::alloc::device:
supports_atomics = true;
break;
default:
supports_atomics = false;
}
return supports_atomics;
}
}

template <typename fnT, typename T> struct MaxAtomicSupportFactory
{
fnT get()
{
if constexpr (std::is_floating_point_v<T>) {
return fixed_decision<false>;
}
else {
return check_atomic_support<T>;
}
}
};

template <typename fnT, typename T> struct MinAtomicSupportFactory
{
fnT get()
{
if constexpr (std::is_floating_point_v<T>) {
return fixed_decision<false>;
}
else {
return check_atomic_support<T>;
}
}
};

template <typename fnT, typename T> struct SumAtomicSupportFactory
{
fnT get()
{
if constexpr (std::is_floating_point_v<T>) {
return fixed_decision<false>;
}
else {
return check_atomic_support<T>;
}
}
};

template <typename fnT, typename T> struct ProductAtomicSupportFactory
{
fnT get()
{
if constexpr (std::is_floating_point_v<T>) {
return fixed_decision<false>;
}
else {
return check_atomic_support<T>;
}
}
};

} // namespace atomic_support
} // namespace py_internal
} // namespace tensor
} // namespace dpctl
Loading

0 comments on commit fff36a1

Please sign in to comment.