From b26e97871e95b6ce5e960eb50c43f9f9c5deccbd Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Wed, 18 May 2022 15:58:13 +0300 Subject: [PATCH] Narrowing/widening the Func along a dimension As discussed, this implements an utility to make it easier to write code that wants to essentially operate either on smaller chunks than the whole type, or operate on several consecutive elements as a large element. The bigger picture is that it is desirable to perform load widening in such situations, and while this doesn't do that, at least having a common interface should be a step in that direction. I've took liberty to add/expose some QoL variable-bit-lenth operations in IROperator, while there. I believe, this has sufficient test coverage now. While implmeneting, i stumbled&fixed https://github.com/halide/Halide/pull/6782 Refs. #6756 --- python_bindings/src/halide/CMakeLists.txt | 1 + .../src/halide/halide_/PyFuncTypeChanging.cpp | 66 ++++ .../src/halide/halide_/PyFuncTypeChanging.h | 14 + .../src/halide/halide_/PyHalide.cpp | 2 + .../src/halide/halide_/PyIROperator.cpp | 4 + src/BoundaryConditions.h | 34 +- src/CMakeLists.txt | 3 + src/FuncExtras.h | 24 ++ src/FuncTypeChanging.cpp | 130 +++++++ src/FuncTypeChanging.h | 53 +++ src/IROperator.cpp | 108 ++++++ src/IROperator.h | 40 +++ test/correctness/CMakeLists.txt | 2 + test/correctness/func_type_changing.cpp | 307 +++++++++++++++++ test/correctness/variable_bit_length_ops.cpp | 317 ++++++++++++++++++ 15 files changed, 1084 insertions(+), 21 deletions(-) create mode 100644 python_bindings/src/halide/halide_/PyFuncTypeChanging.cpp create mode 100644 python_bindings/src/halide/halide_/PyFuncTypeChanging.h create mode 100644 src/FuncExtras.h create mode 100644 src/FuncTypeChanging.cpp create mode 100644 src/FuncTypeChanging.h create mode 100644 test/correctness/func_type_changing.cpp create mode 100644 test/correctness/variable_bit_length_ops.cpp diff --git a/python_bindings/src/halide/CMakeLists.txt b/python_bindings/src/halide/CMakeLists.txt index bea2d7b7b885..c3579cc96b3f 100644 --- a/python_bindings/src/halide/CMakeLists.txt +++ b/python_bindings/src/halide/CMakeLists.txt @@ -11,6 +11,7 @@ set(SOURCES PyExternFuncArgument.cpp PyFunc.cpp PyFuncRef.cpp + PyFuncTypeChanging.cpp PyHalide.cpp PyImageParam.cpp PyInlineReductions.cpp diff --git a/python_bindings/src/halide/halide_/PyFuncTypeChanging.cpp b/python_bindings/src/halide/halide_/PyFuncTypeChanging.cpp new file mode 100644 index 000000000000..93b1ac2063e2 --- /dev/null +++ b/python_bindings/src/halide/halide_/PyFuncTypeChanging.cpp @@ -0,0 +1,66 @@ +#include "PyFuncTypeChanging.h" + +namespace Halide { +namespace PythonBindings { + +namespace { + +inline Func to_func(const Buffer<> &b) { + return lambda(_, b(_)); +} + +} // namespace + +void define_func_type_changing(py::module &m) { + using namespace FuncTypeChanging; + + py::module bc = m.def_submodule("FuncTypeChanging"); + + py::enum_(bc, "ArgumentKind") + .value("LowestFirst", ChunkOrder::LowestFirst) + .value("HighestFirst", ChunkOrder::HighestFirst) + .value("Default", ChunkOrder::Default); + + bc.def( + "change_type", + [](const ImageParam &im, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) -> Func { + return change_type(im, dst_type, dim, name, chunk_order); + }, + py::arg("f"), py::arg("dst_type"), py::arg("dim"), py::arg("name"), + py::arg("chunk_order")); + + bc.def( + "change_type", + [](const Buffer<> &b, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) -> Func { + return change_type(b, dst_type, dim, name, chunk_order); + }, + py::arg("f"), py::arg("dst_type"), py::arg("dim"), py::arg("name"), + py::arg("chunk_order")); + + bc.def( + "change_type", + [](const py::object &target, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) -> Func { + try { + return change_type(target.cast(), dst_type, dim, name, + chunk_order); + } catch (...) { + // fall thru + } + try { + return change_type(to_func(target.cast>()), dst_type, + dim, name, chunk_order); + } catch (...) { + // fall thru + } + throw py::value_error("Invalid arguments to change_type"); + return Func(); + }, + py::arg("f"), py::arg("dst_type"), py::arg("dim"), py::arg("name"), + py::arg("chunk_order")); +} + +} // namespace PythonBindings +} // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyFuncTypeChanging.h b/python_bindings/src/halide/halide_/PyFuncTypeChanging.h new file mode 100644 index 000000000000..6ce53de1819c --- /dev/null +++ b/python_bindings/src/halide/halide_/PyFuncTypeChanging.h @@ -0,0 +1,14 @@ +#ifndef HALIDE_PYTHON_BINDINGS_PYFUNCTYPECHANGING_H +#define HALIDE_PYTHON_BINDINGS_PYFUNCTYPECHANGING_H + +#include "PyHalide.h" + +namespace Halide { +namespace PythonBindings { + +void define_func_type_changing(py::module &m); + +} // namespace PythonBindings +} // namespace Halide + +#endif // HALIDE_PYTHON_BINDINGS_PYFUNCTYPECHANGING_H diff --git a/python_bindings/src/halide/halide_/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp index d64ca3f9876b..24b46d32d20b 100644 --- a/python_bindings/src/halide/halide_/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -11,6 +11,7 @@ #include "PyExpr.h" #include "PyExternFuncArgument.h" #include "PyFunc.h" +#include "PyFuncTypeChanging.h" #include "PyIROperator.h" #include "PyImageParam.h" #include "PyInlineReductions.h" @@ -49,6 +50,7 @@ PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) { define_tuple(m); define_argument(m); define_boundary_conditions(m); + define_func_type_changing(m); define_buffer(m); define_concise_casts(m); define_error(m); diff --git a/python_bindings/src/halide/halide_/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp index 74eedd9e723e..df5b3d0feb91 100644 --- a/python_bindings/src/halide/halide_/PyIROperator.cpp +++ b/python_bindings/src/halide/halide_/PyIROperator.cpp @@ -167,6 +167,10 @@ void define_operators(py::module &m) { m.def("popcount", &popcount); m.def("count_leading_zeros", &count_leading_zeros); m.def("count_trailing_zeros", &count_trailing_zeros); + m.def("extract_high_bits", &extract_high_bits); + m.def("variable_length_extend", &variable_length_extend); + m.def("extract_bits", &extract_bits); + m.def("extract_low_bits", &extract_low_bits); m.def("div_round_to_zero", &div_round_to_zero); m.def("mod_round_to_zero", &mod_round_to_zero); m.def("random_float", (Expr(*)()) & random_float); diff --git a/src/BoundaryConditions.h b/src/BoundaryConditions.h index a01cc14d42bb..0ce50a557c49 100644 --- a/src/BoundaryConditions.h +++ b/src/BoundaryConditions.h @@ -9,6 +9,7 @@ #include "Expr.h" #include "Func.h" +#include "FuncExtras.h" #include "Lambda.h" namespace Halide { @@ -62,15 +63,6 @@ inline HALIDE_NO_USER_CODE_INLINE void collect_region(Region &collected_args, collect_region(collected_args, std::forward(args)...); } -inline const Func &func_like_to_func(const Func &func) { - return func; -} - -template -inline HALIDE_NO_USER_CODE_INLINE Func func_like_to_func(const T &func_like) { - return lambda(_, func_like(_)); -} - } // namespace Internal /** Impose a boundary condition such that a given expression is returned @@ -99,12 +91,12 @@ Func constant_exterior(const Func &source, const Expr &value, template HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tuple &value, const Region &bounds) { - return constant_exterior(Internal::func_like_to_func(func_like), value, bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, bounds); } template HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Expr &value, const Region &bounds) { - return constant_exterior(Internal::func_like_to_func(func_like), value, bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, bounds); } template @@ -114,7 +106,7 @@ HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tupl object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return constant_exterior(Internal::func_like_to_func(func_like), value, object_bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, object_bounds); } template HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Expr &value) { @@ -127,7 +119,7 @@ HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tupl Bounds &&...bounds) { Region collected_bounds; Internal::collect_region(collected_bounds, std::forward(bounds)...); - return constant_exterior(Internal::func_like_to_func(func_like), value, collected_bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, collected_bounds); } template::value>::type * = nullptr> @@ -154,7 +146,7 @@ Func repeat_edge(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func repeat_edge(const T &func_like, const Region &bounds) { - return repeat_edge(Internal::func_like_to_func(func_like), bounds); + return repeat_edge(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -164,7 +156,7 @@ HALIDE_NO_USER_CODE_INLINE Func repeat_edge(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return repeat_edge(Internal::func_like_to_func(func_like), object_bounds); + return repeat_edge(::Halide::Internal::func_like_to_func(func_like), object_bounds); } // @} @@ -185,7 +177,7 @@ Func repeat_image(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func repeat_image(const T &func_like, const Region &bounds) { - return repeat_image(Internal::func_like_to_func(func_like), bounds); + return repeat_image(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -195,7 +187,7 @@ HALIDE_NO_USER_CODE_INLINE Func repeat_image(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return repeat_image(Internal::func_like_to_func(func_like), object_bounds); + return repeat_image(::Halide::Internal::func_like_to_func(func_like), object_bounds); } /** Impose a boundary condition such that the entire coordinate space is @@ -216,7 +208,7 @@ Func mirror_image(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func mirror_image(const T &func_like, const Region &bounds) { - return mirror_image(Internal::func_like_to_func(func_like), bounds); + return mirror_image(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -226,7 +218,7 @@ HALIDE_NO_USER_CODE_INLINE Func mirror_image(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return mirror_image(Internal::func_like_to_func(func_like), object_bounds); + return mirror_image(::Halide::Internal::func_like_to_func(func_like), object_bounds); } // @} @@ -251,7 +243,7 @@ Func mirror_interior(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func mirror_interior(const T &func_like, const Region &bounds) { - return mirror_interior(Internal::func_like_to_func(func_like), bounds); + return mirror_interior(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -261,7 +253,7 @@ HALIDE_NO_USER_CODE_INLINE Func mirror_interior(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return mirror_interior(Internal::func_like_to_func(func_like), object_bounds); + return mirror_interior(::Halide::Internal::func_like_to_func(func_like), object_bounds); } // @} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c31e37c32a20..7ed887df893d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -70,6 +70,8 @@ set(HEADER_FILES FlattenNestedRamps.h Float16.h Func.h + FuncExtras.h + FuncTypeChanging.h Function.h FunctionPtr.h FuseGPUThreadLoops.h @@ -234,6 +236,7 @@ set(SOURCE_FILES FlattenNestedRamps.cpp Float16.cpp Func.cpp + FuncTypeChanging.cpp Function.cpp FuseGPUThreadLoops.cpp FuzzFloatStores.cpp diff --git a/src/FuncExtras.h b/src/FuncExtras.h new file mode 100644 index 000000000000..92f92cdd4b62 --- /dev/null +++ b/src/FuncExtras.h @@ -0,0 +1,24 @@ +#ifndef HALIDE_FUNC_EXTRAS_H +#define HALIDE_FUNC_EXTRAS_H + +#include "Func.h" +#include "Lambda.h" + +namespace Halide { + +namespace Internal { + +inline const Func &func_like_to_func(const Func &func) { + return func; +} + +template +inline HALIDE_NO_USER_CODE_INLINE Func func_like_to_func(const T &func_like) { + return lambda(_, func_like(_)); +} + +} // namespace Internal + +} // namespace Halide + +#endif diff --git a/src/FuncTypeChanging.cpp b/src/FuncTypeChanging.cpp new file mode 100644 index 000000000000..be843203dae0 --- /dev/null +++ b/src/FuncTypeChanging.cpp @@ -0,0 +1,130 @@ +#include "FuncTypeChanging.h" + +namespace Halide { + +static bool operator==(const Var &a, const Var &b) { + return a.same_as(b); +} + +namespace FuncTypeChanging { + +// NOTE: Precondition: `chunk_idx u< num_chunks`! +static Expr get_nth_chunk(const Expr &value, const Type &chunk_type, + const Expr &chunk_idx, ChunkOrder chunk_order) { + int num_bits_total = value.type().bits(); + int num_bits_per_chunk = chunk_type.bits(); + int num_chunks = num_bits_total / num_bits_per_chunk; + user_assert(num_bits_total > num_bits_per_chunk && + num_bits_total % num_bits_per_chunk == 0 && num_chunks > 1) + << "Input value must evenly partition into several chunks.\n"; + + Expr low_chunk_idx = chunk_order == ChunkOrder::LowestFirst ? + chunk_idx : + (num_chunks - 1) - chunk_idx; + Expr num_low_padding_bits = num_bits_per_chunk * low_chunk_idx; + Expr chunk_bits = extract_bits(value, num_low_padding_bits, + make_unsigned(num_bits_per_chunk)); + return cast(chunk_type, chunk_bits); +} + +static Expr concatenate_chunks(std::vector chunks, + ChunkOrder chunk_order) { + const Type chunk_type = chunks.front().type(); + const int chunk_width = chunk_type.bits(); + Type final_type = chunk_type.with_bits(chunk_width * chunks.size()); + + if (chunk_order != ChunkOrder::LowestFirst) { + std::reverse(std::begin(chunks), std::end(chunks)); + } + + Expr res = Internal::make_zero(final_type); + for (size_t chunk_idx = 0; chunk_idx != chunks.size(); ++chunk_idx) { + Expr wide_chunk = cast(final_type, chunks[chunk_idx]); // zero ext + Expr positioned_chunk = wide_chunk << (chunk_width * chunk_idx); + res = res | positioned_chunk; + } + + return res; +} + +static Func narrow(const Func &wide_input, const Type &dst_type, int num_chunks, + const Var &dim, const std::string &name, + ChunkOrder chunk_order) { + const std::vector dims = wide_input.args(); + user_assert(count(begin(dims), end(dims), dim) == 1) + << "Expected dimension " << dim << " to represent " + << "exactly one function argument!\n"; + + Expr wide_elt_idx = dim / num_chunks; + Expr chunk_idx = make_unsigned(dim % num_chunks); + + std::vector args; + args.reserve(dims.size()); + std::transform(dims.begin(), dims.end(), std::back_inserter(args), + [dim, wide_elt_idx](const Var &input_dim) { + return input_dim.same_as(dim) ? wide_elt_idx : input_dim; + }); + + Func narrowed(name); + narrowed(dims) = + get_nth_chunk(wide_input(args), dst_type, chunk_idx, chunk_order); + + return narrowed; +} + +static Func widen(const Func &narrow_input, const Type &dst_type, + int num_chunks, const Var &dim, const std::string &name, + ChunkOrder chunk_order) { + const std::vector dims = narrow_input.args(); + user_assert(count(begin(dims), end(dims), dim) == 1) + << "Expected dimension " << dim << " to represent " + << "exactly one function argument!\n"; + + auto dim_index = + std::distance(begin(dims), std::find(begin(dims), end(dims), dim)); + + std::vector baseline_args; + baseline_args.reserve(dims.size()); + std::transform(dims.begin(), dims.end(), std::back_inserter(baseline_args), + [](const Var &input_dim) { return input_dim; }); + + std::vector chunks; + chunks.reserve(num_chunks); + std::generate_n( + std::back_inserter(chunks), num_chunks, + [&chunks, baseline_args, dim_index, num_chunks, dim, narrow_input]() { + int chunk_idx = chunks.size(); + std::vector args = baseline_args; + args[dim_index] = (num_chunks * dim) + chunk_idx; + return narrow_input(args); + }); + + Func widened(name); + widened(dims) = concatenate_chunks(chunks, chunk_order); + + return widened; +} + +Func change_type(const Func &input, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) { + const Type &src_type = input.output_type(); + int src_width = src_type.bits(); + int dst_width = dst_type.bits(); + bool is_widening = dst_width > src_width; + auto [min_width, max_width] = std::minmax(src_width, dst_width); + int num_chunks = max_width / min_width; + user_assert(src_type.with_bits(dst_width) == dst_type && + src_type.is_uint() && src_width != dst_width && + max_width % min_width == 0 && num_chunks > 1) + << "The source type " << src_type << " and destination type " + << dst_type << " must be similar uint types with different widths, " + << "larger width must be an integral multiple of the smaller width.\n"; + + return is_widening ? + widen(input, dst_type, num_chunks, dim, name, chunk_order) : + narrow(input, dst_type, num_chunks, dim, name, chunk_order); +} + +} // namespace FuncTypeChanging + +} // namespace Halide diff --git a/src/FuncTypeChanging.h b/src/FuncTypeChanging.h new file mode 100644 index 000000000000..7285e75469d7 --- /dev/null +++ b/src/FuncTypeChanging.h @@ -0,0 +1,53 @@ +#ifndef HALIDE_FUNC_TYPE_CHANGING_H +#define HALIDE_FUNC_TYPE_CHANGING_H + +/** \file + * Support for changing the function's return type by fusing a number of + * consequtive elements, or splitting a single element into parts, + * along a certain dimension. + */ + +#include "Func.h" +#include "FuncExtras.h" + +namespace Halide { + +namespace FuncTypeChanging { + +enum class ChunkOrder { + // Example: + // i32 0x0D0C0B0A -> 4xi8 -> { 0x0A, 0x0B, 0x0C, 0x0D } + // i32 0x0D0C0B0A -> 2xi16 -> { 0x0B0A, 0x0D0C } + // 4xi8 { 0x0A, 0x0B, 0x0C, 0x0D } -> i32 -> 0x0D0C0B0A + // 2xi16 { 0x0B0A, 0x0D0C } -> i32 -> 0x0D0C0B0A + // 2xi16 { 0x0D0C, 0x0B0A } -> i32 -> 0x0B0A0D0C + LowestFirst, + + // Example: + // i32 0x0D0C0B0A -> 4xi8 -> { 0x0D, 0x0C, 0x0B, 0x0A } + // i32 0x0D0C0B0A -> 2xi16t -> { 0x0D0C, 0x0B0A } + // 4xi8 { 0x0A, 0x0B, 0x0C, 0x0D } -> i32 -> 0x0A0B0C0D + // 2xi16 { 0x0B0A, 0x0D0C } -> i32 -> 0x0B0A0D0C + // 2xi16 { 0x0D0C, 0x0B0A } -> i32 -> 0x0D0C0B0A + HighestFirst, + + Default = LowestFirst // DO NOT CHANGE. +}; + +Func change_type(const Func &input, const Type &dst_type, const Var &dim, + const std::string &name, + ChunkOrder chunk_order = ChunkOrder::Default); + +template +HALIDE_NO_USER_CODE_INLINE Func change_type( + const T &func_like, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order = ChunkOrder::Default) { + return change_type(Internal::func_like_to_func(func_like), dst_type, dim, + name, chunk_order); +} + +} // namespace FuncTypeChanging + +} // namespace Halide + +#endif diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 4693060a8d45..eedf62f4954b 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -1941,6 +1941,21 @@ Expr cast(Type t, Expr a) { return Internal::Cast::make(t, std::move(a)); } +Expr make_unsigned(Expr a) { + user_assert(a.defined()) << "cast of undefined Expr\n"; + + Type ty = a.type(); + + user_assert(ty.is_int_or_uint()) + << "The input must be of an integer type, got " << ty << "\n"; + + if (ty.is_uint()) { + return a; + } + + return cast(ty.with_code(halide_type_uint), a); +} + Expr clamp(Expr a, const Expr &min_val, const Expr &max_val) { user_assert(a.defined() && min_val.defined() && max_val.defined()) << "clamp of undefined Expr\n"; @@ -2531,6 +2546,99 @@ Expr count_trailing_zeros(Expr x) { {std::move(x)}, Internal::Call::PureIntrinsic); } +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr extract_high_bits(const Expr &val, const Expr &num_high_bits) { + user_assert(val.defined()) << "extract_high_bits with undefined val"; + user_assert(num_high_bits.defined()) + << "extract_high_bits with undefined num_high_bits"; + + Type ty = val.type(); + user_assert(ty.is_int_or_uint()) + << "extract_high_bits: val must be of an integer type, got " << ty + << "\n"; + + Type shamt_ty = num_high_bits.type(); + user_assert(shamt_ty.is_uint()) << "extract_high_bits: num_high_bits must " + "be of an unsigned integer type, got " + << shamt_ty << "\n"; + + Expr num_low_padding_bits = ty.bits() - num_high_bits; + // The sign bit is already positioned, just perform the right-shift. + // We'll either pad with zeros (if uint) or replicate sign bit (if int). + return val >> cast(UInt(ty.bits()), num_low_padding_bits); +} + +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr variable_length_extend(Expr val, const Expr &num_low_bits) { + user_assert(val.defined()) << "variable_length_extend with undefined val"; + user_assert(num_low_bits.defined()) + << "variable_length_extend with undefined num_low_bits"; + + Type ty = val.type(); + user_assert(ty.is_int_or_uint()) + << "variable_length_extend: val must be of an integer type, got " << ty + << "\n"; + + Type shamt_ty = num_low_bits.type(); + user_assert(shamt_ty.is_uint()) + << "variable_length_extend: num_low_bits must be of an unsigned " + "integer type, got " + << shamt_ty << "\n"; + + Expr num_high_padding_bits = ty.bits() - num_low_bits; + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type. + val = val << cast(UInt(ty.bits()), num_high_padding_bits); + // And then let the `extract_high_bits()` deal with it. + return extract_high_bits(val, /*num_high_bits=*/num_low_bits); +} + +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr extract_bits(Expr val, const Expr &num_low_padding_bits, + const Expr &num_bits) { + user_assert(val.defined()) << "extract_bits with undefined val"; + user_assert(num_low_padding_bits.defined()) + << "extract_bits with undefined num_low_padding_bits"; + user_assert(num_bits.defined()) << "extract_bits with undefined num_bits"; + + Type ty = val.type(); + user_assert(ty.is_int_or_uint()) + << "extract_bits: val must be of an integer type, got " << ty << "\n"; + + Type shamt_ty = num_low_padding_bits.type(); + user_assert(shamt_ty.is_uint()) + << "extract_bits: num_low_padding_bits must be of an unsigned integer " + "type, got " + << shamt_ty << "\n"; + + shamt_ty = num_bits.type(); + user_assert(shamt_ty.is_uint()) + << "extract_bits: num_bits must be of an unsigned integer type, got " + << shamt_ty << "\n"; + + Expr num_high_padding_bits = (ty.bits() - num_low_padding_bits) - num_bits; + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type. + val = val << cast(UInt(ty.bits()), num_high_padding_bits); + // And then let the `extract_high_bits()` deal with it. + return extract_high_bits(val, /*num_high_bits=*/num_bits); +} + +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr extract_low_bits(Expr val, const Expr &num_low_bits) { + // Let `extract_bits()` deal with everything. + return extract_bits(std::move(val), /*num_low_padding_bits=*/Expr(0U), + /*num_bits=*/num_low_bits); +} + Expr div_round_to_zero(Expr x, Expr y) { user_assert(x.defined()) << "div_round_to_zero of undefined dividend\n"; user_assert(y.defined()) << "div_round_to_zero of undefined divisor\n"; diff --git a/src/IROperator.h b/src/IROperator.h index ed0b11bb4fef..581c152058fd 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -394,6 +394,9 @@ inline Expr cast(Expr a) { /** Cast an expression to a new type. */ Expr cast(Type t, Expr a); +/** Cast an integer-typed expression to an unsigned integer type. */ +Expr make_unsigned(Expr a); + /** Return the sum of two expressions, doing any necessary type * coercion using \ref Internal::match_types */ Expr operator+(Expr a, Expr b); @@ -1208,6 +1211,43 @@ Expr count_leading_zeros(Expr x); * zero, the result is the number of bits in the type. */ Expr count_trailing_zeros(Expr x); +/** + * Extract \p num_high_bits high bits of \p val. + * (with either zero-extension, if \p val is uint, + * or sign-extension, if \p val is int) + * + * NOTE: Precondition: `num_high_bits != 0 && num_high_bits u<= bitwidth(val)`! + */ +Expr extract_high_bits(const Expr &val, const Expr &num_high_bits); + +/** + * Extend (either zero-extend if \p val is uint or sign-extend, if \p is int) + * low \p num_low_bits bits of \p val to the whole \p val underlying type. + * + * NOTE: Precondition: `num_low_bits != 0 && num_low_bits u<= bitwidth(val)`! + */ +Expr variable_length_extend(Expr val, const Expr &num_low_bits); + +/** + * Extract \p num_bits bits starting with the \p low_bit_offset bit. + * (with either zero-extension, if \p val is uint, + * or sign-extension, if \p val is int) + * + * NOTE: Precondition: + * `num_bits != 0 && (num_bits + low_bit_offset) u<= bitwidth(val)`! + */ +Expr extract_bits(Expr val, const Expr &num_low_padding_bits, + const Expr &num_bits); + +/** + * Extract \p num_high_bits high bits of \p val. + * (with either zero-extension, if \p val is uint, + * or sign-extension, if \p val is int) + * + * NOTE: Precondition: `num_low_bits != 0 && num_low_bits u<= bitwidth(val)`! + */ +Expr extract_low_bits(Expr val, const Expr &num_low_bits); + /** Divide two integers, rounding towards zero. This is the typical * behavior of most hardware architectures, which differs from * Halide's division operator, which is Euclidean (rounds towards diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 62a93e25a982..79e08a8dc2d7 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -108,6 +108,7 @@ tests(GROUPS correctness force_onto_stack.cpp func_lifetime.cpp func_lifetime_2.cpp + func_type_changing.cpp fuse.cpp fuse_gpu_threads.cpp fused_where_inner_extent_is_zero.cpp @@ -310,6 +311,7 @@ tests(GROUPS correctness unsafe_promises.cpp unused_func.cpp update_chunk.cpp + variable_bit_length_ops.cpp vector_bounds_inference.cpp vector_cast.cpp vector_extern.cpp diff --git a/test/correctness/func_type_changing.cpp b/test/correctness/func_type_changing.cpp new file mode 100644 index 000000000000..1c752f659a42 --- /dev/null +++ b/test/correctness/func_type_changing.cpp @@ -0,0 +1,307 @@ +#include "Halide.h" +#include +#include + +#include + +using namespace Halide; +using namespace Halide::FuncTypeChanging; + +template +bool expect_eq(Buffer actual, Buffer expected) { + bool eq = true; + expected.for_each_value( + [&](const T &expected_val, const T &actual_val) { + if (actual_val != expected_val) { + eq = false; + fprintf(stderr, "Failed: expected %d, actual %d\n", + (int)expected_val, (int)actual_val); + } + }, + actual); + return eq; +} + +template +auto gen_random_chunks(std::initializer_list dims) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist( + std::numeric_limits::min(), + std::numeric_limits::max()); + + Buffer buf(dims); + buf.for_each_value([&](CHUNK_TYPE &v) { v = dist(gen); }); + + return buf; +} + +template +bool test_1d_rowwise_with_n_times_chunk_type(int num_chunks, Target t) { + const int width = 256; + Buffer input_buf = gen_random_chunks({width}); + + using WIDE_STORAGE_TYPE = uint64_t; + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + const int WIDE_TYPE_WIDTH = CHUNK_WIDTH * num_chunks; + + Var x("x"); + + int wide_width = width / num_chunks; + + auto forward = [wide_width, WIDE_TYPE_WIDTH, x, t](const Func &input, + ChunkOrder chunk_order) { + Buffer wide(wide_width); + Func widen = change_type(input, UInt(WIDE_TYPE_WIDTH), x, "widener", + chunk_order); + Func store("store"); + store(x) = cast(widen(x)); + store.realize(wide, t); + return wide; + }; + + auto forward_naive = [wide_width, num_chunks](Buffer input_buf, + ChunkOrder chunk_order) { + Buffer wide(wide_width); + for (int32_t x = 0; x < wide_width; x++) { + WIDE_STORAGE_TYPE &v = wide(x); + v = 0; + for (int chunk = 0; chunk != num_chunks; ++chunk) { + int chunk_idx = chunk_order == ChunkOrder::HighestFirst ? + chunk : + (num_chunks - 1) - chunk; + v <<= CHUNK_WIDTH; + v |= (WIDE_STORAGE_TYPE)input_buf(num_chunks * x + chunk_idx); + } + } + return wide; + }; + + auto backward = [t, width, x, WIDE_TYPE_WIDTH]( + const Buffer &actual_widened_result, + ChunkOrder chunk_order) { + Buffer narrow(width); + Func load("load"); + load(x) = cast(UInt(WIDE_TYPE_WIDTH), actual_widened_result(x)); + Func narrown = + change_type(load, UInt(CHUNK_WIDTH), x, "narrower", chunk_order); + narrown.realize(narrow, t); + return narrow; + }; + + Func input("input"); + input(x) = input_buf(x); + + bool success = true; + for (ChunkOrder chunk_order : + {ChunkOrder::LowestFirst, ChunkOrder::HighestFirst}) { + const auto wide_actual = forward(input, chunk_order); + const auto wide_expected = forward_naive(input_buf, chunk_order); + success &= expect_eq(wide_actual, wide_expected); + + const auto narrow_actual = backward(wide_actual, chunk_order); + success &= expect_eq(narrow_actual, input_buf); + } + + return success; +} + +template +bool test_2d_rowwise_with_n_times_chunk_type(int num_chunks, Target t) { + const int width = 256; + const int height = 16; + Buffer input_buf = + gen_random_chunks({width, height}); + + using WIDE_STORAGE_TYPE = uint64_t; + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + const int WIDE_TYPE_WIDTH = CHUNK_WIDTH * num_chunks; + + Var x("x"), y("y"); + + int wide_width = width / num_chunks; + + auto forward = [wide_width, WIDE_TYPE_WIDTH, x, y, + t](const Func &input, ChunkOrder chunk_order) { + Buffer wide({wide_width, height}); + Func widen = change_type(input, UInt(WIDE_TYPE_WIDTH), x, "widener", + chunk_order); + Func store("store"); + store(x, y) = cast(widen(x, y)); + store.realize(wide, t); + return wide; + }; + + auto forward_naive = [wide_width, num_chunks](Buffer input_buf, + ChunkOrder chunk_order) { + Buffer wide({wide_width, height}); + for (int32_t y = 0; y < height; y++) { + for (int32_t x = 0; x < wide_width; x++) { + WIDE_STORAGE_TYPE &v = wide(x, y); + v = 0; + for (int chunk = 0; chunk != num_chunks; ++chunk) { + int chunk_idx = chunk_order == ChunkOrder::HighestFirst ? + chunk : + (num_chunks - 1) - chunk; + v <<= CHUNK_WIDTH; + v |= (WIDE_STORAGE_TYPE)input_buf( + num_chunks * x + chunk_idx, y); + } + } + } + return wide; + }; + + auto backward = [t, width, height, x, y, WIDE_TYPE_WIDTH]( + const Buffer &actual_widened_result, + ChunkOrder chunk_order) { + Buffer narrow({width, height}); + Func load("load"); + load(x, y) = cast(UInt(WIDE_TYPE_WIDTH), actual_widened_result(x, y)); + Func narrown = + change_type(load, UInt(CHUNK_WIDTH), x, "narrower", chunk_order); + narrown.realize(narrow, t); + return narrow; + }; + + Func input("input"); + input(x, y) = input_buf(x, y); + + bool success = true; + for (ChunkOrder chunk_order : + {ChunkOrder::LowestFirst, ChunkOrder::HighestFirst}) { + const auto wide_actual = forward(input, chunk_order); + const auto wide_expected = forward_naive(input_buf, chunk_order); + success &= expect_eq(wide_actual, wide_expected); + + const auto narrow_actual = backward(wide_actual, chunk_order); + success &= expect_eq(narrow_actual, input_buf); + } + + return success; +} + +template +bool test_2d_colwise_with_n_times_chunk_type(int num_chunks, Target t) { + const int width = 16; + const int height = 256; + Buffer input_buf = + gen_random_chunks({width, height}); + + using WIDE_STORAGE_TYPE = uint64_t; + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + const int WIDE_TYPE_WIDTH = CHUNK_WIDTH * num_chunks; + + Var x("x"), y("y"); + + int wide_height = height / num_chunks; + + auto forward = [wide_height, WIDE_TYPE_WIDTH, x, y, + t](const Func &input, ChunkOrder chunk_order) { + Buffer wide({width, wide_height}); + Func widen = change_type(input, UInt(WIDE_TYPE_WIDTH), y, "widener", + chunk_order); + Func store("store"); + store(x, y) = cast(widen(x, y)); + store.realize(wide, t); + return wide; + }; + + auto forward_naive = [wide_height, num_chunks](Buffer input_buf, + ChunkOrder chunk_order) { + Buffer wide({width, wide_height}); + for (int32_t y = 0; y < wide_height; y++) { + for (int32_t x = 0; x < width; x++) { + WIDE_STORAGE_TYPE &v = wide(x, y); + v = 0; + for (int chunk = 0; chunk != num_chunks; ++chunk) { + int chunk_idx = chunk_order == ChunkOrder::HighestFirst ? + chunk : + (num_chunks - 1) - chunk; + v <<= CHUNK_WIDTH; + v |= (WIDE_STORAGE_TYPE)input_buf(x, num_chunks * y + + chunk_idx); + } + } + } + return wide; + }; + + auto backward = [t, width, height, x, y, WIDE_TYPE_WIDTH]( + const Buffer &actual_widened_result, + ChunkOrder chunk_order) { + Buffer narrow({width, height}); + Func load("load"); + load(x, y) = cast(UInt(WIDE_TYPE_WIDTH), actual_widened_result(x, y)); + Func narrown = + change_type(load, UInt(CHUNK_WIDTH), y, "narrower", chunk_order); + narrown.realize(narrow, t); + return narrow; + }; + + Func input("input"); + input(x, y) = input_buf(x, y); + + bool success = true; + for (ChunkOrder chunk_order : + {ChunkOrder::LowestFirst, ChunkOrder::HighestFirst}) { + const auto wide_actual = forward(input, chunk_order); + const auto wide_expected = forward_naive(input_buf, chunk_order); + success &= expect_eq(wide_actual, wide_expected); + + const auto narrow_actual = backward(wide_actual, chunk_order); + success &= expect_eq(narrow_actual, input_buf); + } + + return success; +} + +template +bool test_with_n_times_chunk_type(int num_chunks, Target t) { + bool success = true; + + success &= + test_1d_rowwise_with_n_times_chunk_type(num_chunks, t); + success &= + test_2d_rowwise_with_n_times_chunk_type(num_chunks, t); + success &= + test_2d_colwise_with_n_times_chunk_type(num_chunks, t); + + return success; +} + +template +bool test_with_chunk_type(Target t) { + bool success = true; + + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + for (int num_chunks = 2; CHUNK_WIDTH * num_chunks <= 64; num_chunks *= 2) { + success &= test_with_n_times_chunk_type(num_chunks, t); + } + + return success; +} + +bool test_all(Target t) { + bool success = true; + + success &= test_with_chunk_type(t); + success &= test_with_chunk_type(t); + success &= test_with_chunk_type(t); + + return success; +} + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + + bool success = test_all(target); + + if (!success) { + fprintf(stderr, "Failed!\n"); + return -1; + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/variable_bit_length_ops.cpp b/test/correctness/variable_bit_length_ops.cpp new file mode 100644 index 000000000000..ed6bce2dcb76 --- /dev/null +++ b/test/correctness/variable_bit_length_ops.cpp @@ -0,0 +1,317 @@ +#include "Halide.h" +#include +#include + +#include + +using namespace Halide; +using namespace Halide::FuncTypeChanging; + +template +static constexpr int local_bitwidth() { + return 8 * sizeof(T); +} + +template +static T local_extract_high_bits(T val, int num_high_bits) { + int num_low_padding_bits = local_bitwidth() - num_high_bits; + if ((unsigned)num_low_padding_bits >= (unsigned)local_bitwidth()) { + return 42; + } + // The sign bit is already positioned, just perform the right-shift. + // We'll either pad with zeros (if uint) or replicate sign bit (if int). + assert((unsigned)num_low_padding_bits < (unsigned)local_bitwidth()); + return val >> num_low_padding_bits; +} + +template +static T local_variable_length_extend(T val, int num_low_bits) { + int num_high_padding_bits = local_bitwidth() - num_low_bits; + if ((unsigned)num_high_padding_bits >= (unsigned)local_bitwidth()) { + return 42; + } + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type, + assert((unsigned)num_high_padding_bits < (unsigned)local_bitwidth()); + val <<= num_high_padding_bits; + return local_extract_high_bits(val, /*num_high_bits=*/num_low_bits); +} + +template +static T local_extract_bits(T val, int num_low_padding_bits, int num_bits) { + if (num_bits == 0) { + return 42; + } + int num_high_padding_bits = + (local_bitwidth() - num_low_padding_bits) - num_bits; + if ((unsigned)num_high_padding_bits >= (unsigned)local_bitwidth()) { + return 42; + } + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type, + assert((unsigned)num_high_padding_bits < (unsigned)local_bitwidth()); + val <<= num_high_padding_bits; + return local_extract_high_bits(val, /*num_high_bits=*/num_bits); +} + +template +static T local_extract_low_bits(T val, int num_low_bits) { + return local_extract_bits(val, /*num_low_padding_bits=*/0, num_low_bits); +} + +template +static bool expect_eq(Buffer actual, Buffer expected) { + bool eq = true; + expected.for_each_value( + [&](const T &expected_val, const T &actual_val) { + if (actual_val != expected_val) { + eq = false; + fprintf(stderr, "Failed: expected %d, actual %d\n", + (int)expected_val, (int)actual_val); + } + }, + actual); + return eq; +} + +template +static auto gen_random_input(std::initializer_list dims) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist(std::numeric_limits::min(), + std::numeric_limits::max()); + + Buffer buf(dims); + buf.for_each_value([&](T &v) { v = dist(gen); }); + + return buf; +} + +template +static bool test_extract_high_bits(Target t) { + const int width = 8192; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res(MAX_BITS * width); + Func fun("f"); + Expr input_idx = x / MAX_BITS; + Expr num_high_bits = x % MAX_BITS; + Expr num_low_padding_bits = local_bitwidth() - num_high_bits; + // `extract_high_bits()` is not defined for OOB or 0 num_high_bits. + fun(x) = select(make_unsigned(num_low_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + extract_high_bits(input_buf(input_idx), + make_unsigned(num_high_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res(MAX_BITS * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / MAX_BITS; + int num_high_bits = x % MAX_BITS; + res(x) = + local_extract_high_bits(input_buf(input_idx), num_high_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_variable_length_extend(Target t) { + const int width = 8192; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res(MAX_BITS * width); + Func fun("f"); + Expr input_idx = x / MAX_BITS; + Expr num_low_bits = x % MAX_BITS; + Expr num_high_padding_bits = local_bitwidth() - num_low_bits; + // `variable_length_extend()` is not defined for OOB or 0 num_low_bits. + fun(x) = select(make_unsigned(num_high_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + variable_length_extend(input_buf(input_idx), + make_unsigned(num_low_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res(MAX_BITS * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / MAX_BITS; + int num_low_bits = x % MAX_BITS; + res(x) = local_variable_length_extend(input_buf(input_idx), + num_low_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_extract_bits(Target t) { + const int width = 256; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res((MAX_BITS * MAX_BITS) * width); + Func fun("f"); + Expr input_idx = x / (MAX_BITS * MAX_BITS); + Expr num_low_padding_bits = (x / MAX_BITS) % MAX_BITS; + Expr num_bits = x % MAX_BITS; + Expr num_high_padding_bits = + (local_bitwidth() - num_low_padding_bits) - num_bits; + // `extract_bits()` is not defined for 0 or OOB num_bits. + fun(x) = select(num_bits == 0 || make_unsigned(num_high_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + extract_bits(input_buf(input_idx), + make_unsigned(num_low_padding_bits), + make_unsigned(num_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res((MAX_BITS * MAX_BITS) * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / (MAX_BITS * MAX_BITS); + int num_low_padding_bits = (x / MAX_BITS) % MAX_BITS; + int num_bits = x % MAX_BITS; + res(x) = local_extract_bits(input_buf(input_idx), + num_low_padding_bits, num_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_extract_low_bits(Target t) { + const int width = 8192; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res(MAX_BITS * width); + Func fun("f"); + Expr input_idx = x / MAX_BITS; + Expr num_low_bits = x % MAX_BITS; + Expr num_high_padding_bits = local_bitwidth() - num_low_bits; + // `extract_low_bits()` is not defined for OOB or 0 num_low_bits. + fun(x) = select(make_unsigned(num_high_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + extract_low_bits(input_buf(input_idx), + make_unsigned(num_low_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res(MAX_BITS * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / MAX_BITS; + int num_low_bits = x % MAX_BITS; + res(x) = local_extract_low_bits(input_buf(input_idx), num_low_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_with_type(Target t) { + bool success = true; + + success &= test_extract_high_bits(t); + success &= test_variable_length_extend(t); + success &= test_extract_bits(t); + success &= test_extract_low_bits(t); + + return success; +} + +static bool test_all(Target t) { + bool success = true; + + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + + return success; +} + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + + bool success = test_all(target); + + if (!success) { + fprintf(stderr, "Failed!\n"); + return -1; + } + + printf("Success!\n"); + return 0; +}