Skip to content

Commit

Permalink
[SYCL] Optimize sub-group group_load via BlockRead in simple cases
Browse files Browse the repository at this point in the history
  • Loading branch information
aelovikov-intel committed May 8, 2024
1 parent f56d7d7 commit 98e088f
Show file tree
Hide file tree
Showing 4 changed files with 386 additions and 179 deletions.
7 changes: 7 additions & 0 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ template <typename T>
using make_unsinged_integer_t =
make_type_t<T, gtl::scalar_unsigned_integer_list>;

template <int Size>
using cl_unsigned = std::conditional_t<
Size == 1, opencl::cl_uchar,
std::conditional_t<
Size == 2, opencl::cl_ushort,
std::conditional_t<Size == 4, opencl::cl_uint, opencl::cl_ulong>>>;

// select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on
// sizeof(IN). expected to handle scalar types.
template <typename T, typename T8, typename T16, typename T32, typename T64>
Expand Down
1 change: 1 addition & 0 deletions sycl/include/sycl/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ void loop_impl(std::integer_sequence<size_t, Inds...>, F &&f) {
template <size_t count, class F> void loop(F &&f) {
loop_impl(std::make_index_sequence<count>{}, std::forward<F>(f));
}
inline constexpr bool is_power_of_two(int x) { return (x & (x - 1)) == 0; }
} // namespace detail

} // namespace _V1
Expand Down
169 changes: 165 additions & 4 deletions sycl/include/sycl/ext/oneapi/experimental/group_load_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <sycl/ext/oneapi/properties/properties.hpp>
#include <sycl/sycl_span.hpp>

#include <cstring>

namespace sycl {
inline namespace _V1 {
namespace ext::oneapi::experimental {
Expand Down Expand Up @@ -106,6 +108,116 @@ int get_mem_idx(GroupTy g, int vec_or_array_idx) {
return g.get_local_linear_id() +
g.get_local_linear_range() * vec_or_array_idx;
}

// SPIR-V extension:
// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroups.asciidoc,
// however it doesn't describe limitations/requirements. Those seem to be
// listed in the Intel OpenCL extensions for sub-groups:
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_char.html
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_long.html
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_short.html
// Reads require 4-byte alignment, writes 16-byte alignment. Supported
// sizes:
//
// +--------+-------------+
// | uchar | 1,2,4,8,16 |
// | ushort | 1,2,4,8 |
// | uint | 1,2,4,8 |
// | ulong | 1,2,4,8 |
// +--------+-------------+
//
// Utility type traits below are used to map user type to one of the block
// read/write types above.

template <typename IteratorT, std::size_t ElementsPerWorkItem, bool blocked>
struct BlockInfo {
using value_type =
remove_decoration_t<typename std::iterator_traits<IteratorT>::value_type>;

static constexpr int block_size =
sizeof(value_type) * (blocked ? ElementsPerWorkItem : 1);
static constexpr int num_blocks = blocked ? 1 : ElementsPerWorkItem;
static constexpr bool has_builtin =
detail::is_power_of_two(block_size) &&
detail::is_power_of_two(num_blocks) && block_size <= 8 &&
(num_blocks <= 8 || (num_blocks == 16 && block_size == 1));
};

template <typename BlockInfoTy> struct BlockTypeInfo;

template <typename IteratorT, std::size_t ElementsPerWorkItem, bool blocked>
struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, blocked>> {
using BlockInfoTy = BlockInfo<IteratorT, ElementsPerWorkItem, blocked>;
static_assert(BlockInfoTy::has_builtin);

using block_type = detail::cl_unsigned<BlockInfoTy::block_size>;

using block_pointer_elem_type = std::conditional_t<
std::is_const_v<std::remove_reference_t<
typename std::iterator_traits<IteratorT>::reference>>,
std::add_const_t<block_type>, block_type>;

using block_pointer_type = typename detail::DecoratedType<
block_pointer_elem_type, access::address_space::global_space>::type *;
using block_load_type = std::conditional_t<
BlockInfoTy::num_blocks == 1, block_type,
detail::ConvertToOpenCLType_t<vec<block_type, BlockInfoTy::num_blocks>>>;
};

// Returns either a pointer suitable to use in a block read/write builtin or
// nullptr if some legality conditions aren't satisfied.
template <int RequiredAlign, std::size_t ElementsPerWorkItem,
typename IteratorT, typename Properties>
auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
using value_type =
remove_decoration_t<typename std::iterator_traits<IteratorT>::value_type>;
using iter_no_cv = std::remove_cv_t<IteratorT>;

constexpr bool blocked = detail::isBlocked(props);
using BlkInfo = BlockInfo<IteratorT, ElementsPerWorkItem, blocked>;

#if defined(__SPIR__)
// TODO: What about non-Intel SPIR-V devices?
constexpr bool is_spir = true;
#else
constexpr bool is_spir = false;
#endif

if constexpr (!is_spir || !BlkInfo::has_builtin) {
return nullptr;
} else if constexpr (!props.template has_property<full_group_key>()) {
return nullptr;
} else if constexpr (detail::is_multi_ptr_v<IteratorT>) {
return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
iter.get_decorated(), props);
} else if constexpr (!std::is_pointer_v<iter_no_cv>) {
if constexpr (props.template has_property<contiguous_memory_key>())
return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(&*iter,
props);
else
return nullptr;
} else {
__builtin_assume(iter != nullptr);
static_assert(BlkInfo::has_builtin);
bool aligned = alignof(value_type) >= RequiredAlign ||
reinterpret_cast<uintptr_t>(iter) % RequiredAlign == 0;

constexpr auto AS = detail::deduce_AS<iter_no_cv>::value;
using block_pointer_type =
typename BlockTypeInfo<BlkInfo>::block_pointer_type;
if constexpr (AS == access::address_space::global_space) {
return aligned ? reinterpret_cast<block_pointer_type>(iter) : nullptr;
} else if constexpr (AS == access::address_space::generic_space) {
return aligned ? reinterpret_cast<block_pointer_type>(
__SYCL_GenericCastToPtrExplicit_ToGlobal<value_type>(
iter))
: nullptr;
} else {
return nullptr;
}
}
}
} // namespace detail

// Load API span overload.
Expand All @@ -117,17 +229,66 @@ std::enable_if_t<detail::verify_load_types<InputIteratorT, OutputT> &&
group_load(Group g, InputIteratorT in_ptr,
span<OutputT, ElementsPerWorkItem> out, Properties props = {}) {
constexpr bool blocked = detail::isBlocked(props);
using use_naive =
detail::merged_properties_t<Properties,
decltype(properties(detail::naive))>;

if constexpr (props.template has_property<detail::naive_key>()) {
group_barrier(g);
for (int i = 0; i < out.size(); ++i)
out[i] = in_ptr[detail::get_mem_idx<blocked, ElementsPerWorkItem>(g, i)];
group_barrier(g);
} else {
using use_naive =
detail::merged_properties_t<Properties,
decltype(properties(detail::naive))>;
return;
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
return group_load(g, in_ptr, out, use_naive{});
} else {
auto ptr =
detail::get_block_op_ptr<4 /* load align */, ElementsPerWorkItem>(
in_ptr, props);
if (!ptr)
return group_load(g, in_ptr, out, use_naive{});

if constexpr (!std::is_same_v<std::nullptr_t, decltype(ptr)>) {
// Do optimized load.
using value_type = remove_decoration_t<
typename std::iterator_traits<InputIteratorT>::value_type>;

auto load = __spirv_SubgroupBlockReadINTEL<
typename detail::BlockTypeInfo<detail::BlockInfo<
InputIteratorT, ElementsPerWorkItem, blocked>>::block_load_type>(
ptr);

// TODO: accessor_iterator's value_type is weird, so we need
// `std::remove_const_t` below:
//
// static_assert(
// std::is_same_v<
// typename std::iterator_traits<
// sycl::detail::accessor_iterator<const int, 1>>::value_type,
// const int>);
//
// yet
//
// static_assert(
// std::is_same_v<
// typename std::iterator_traits<const int *>::value_type, int>);

if constexpr (std::is_same_v<std::remove_const_t<value_type>, OutputT>) {
static_assert(sizeof(load) == out.size_bytes());
std::memcpy(out.begin(), &load, out.size_bytes());
} else {
std::remove_const_t<value_type> values[ElementsPerWorkItem];
static_assert(sizeof(load) == sizeof(values));
std::memcpy(values, &load, sizeof(values));

// Note: can't `memcpy` directly into `out` because that might bypass
// an implicit conversion required by the specification.
for (int i = 0; i < ElementsPerWorkItem; ++i)
out[i] = values[i];
}

return;
}
}
}

Expand Down
Loading

0 comments on commit 98e088f

Please sign in to comment.