diff --git a/python_bindings/src/PyFuncTypeChanging.cpp b/python_bindings/src/PyFuncTypeChanging.cpp new file mode 100644 index 000000000000..93b1ac2063e2 --- /dev/null +++ b/python_bindings/src/PyFuncTypeChanging.cpp @@ -0,0 +1,66 @@ +#include "PyFuncTypeChanging.h" + +namespace Halide { +namespace PythonBindings { + +namespace { + +inline Func to_func(const Buffer<> &b) { + return lambda(_, b(_)); +} + +} // namespace + +void define_func_type_changing(py::module &m) { + using namespace FuncTypeChanging; + + py::module bc = m.def_submodule("FuncTypeChanging"); + + py::enum_(bc, "ArgumentKind") + .value("LowestFirst", ChunkOrder::LowestFirst) + .value("HighestFirst", ChunkOrder::HighestFirst) + .value("Default", ChunkOrder::Default); + + bc.def( + "change_type", + [](const ImageParam &im, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) -> Func { + return change_type(im, dst_type, dim, name, chunk_order); + }, + py::arg("f"), py::arg("dst_type"), py::arg("dim"), py::arg("name"), + py::arg("chunk_order")); + + bc.def( + "change_type", + [](const Buffer<> &b, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) -> Func { + return change_type(b, dst_type, dim, name, chunk_order); + }, + py::arg("f"), py::arg("dst_type"), py::arg("dim"), py::arg("name"), + py::arg("chunk_order")); + + bc.def( + "change_type", + [](const py::object &target, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) -> Func { + try { + return change_type(target.cast(), dst_type, dim, name, + chunk_order); + } catch (...) { + // fall thru + } + try { + return change_type(to_func(target.cast>()), dst_type, + dim, name, chunk_order); + } catch (...) { + // fall thru + } + throw py::value_error("Invalid arguments to change_type"); + return Func(); + }, + py::arg("f"), py::arg("dst_type"), py::arg("dim"), py::arg("name"), + py::arg("chunk_order")); +} + +} // namespace PythonBindings +} // namespace Halide diff --git a/python_bindings/src/PyFuncTypeChanging.h b/python_bindings/src/PyFuncTypeChanging.h new file mode 100644 index 000000000000..6ce53de1819c --- /dev/null +++ b/python_bindings/src/PyFuncTypeChanging.h @@ -0,0 +1,14 @@ +#ifndef HALIDE_PYTHON_BINDINGS_PYFUNCTYPECHANGING_H +#define HALIDE_PYTHON_BINDINGS_PYFUNCTYPECHANGING_H + +#include "PyHalide.h" + +namespace Halide { +namespace PythonBindings { + +void define_func_type_changing(py::module &m); + +} // namespace PythonBindings +} // namespace Halide + +#endif // HALIDE_PYTHON_BINDINGS_PYFUNCTYPECHANGING_H diff --git a/python_bindings/src/halide/CMakeLists.txt b/python_bindings/src/halide/CMakeLists.txt index bea2d7b7b885..c3579cc96b3f 100644 --- a/python_bindings/src/halide/CMakeLists.txt +++ b/python_bindings/src/halide/CMakeLists.txt @@ -11,6 +11,7 @@ set(SOURCES PyExternFuncArgument.cpp PyFunc.cpp PyFuncRef.cpp + PyFuncTypeChanging.cpp PyHalide.cpp PyImageParam.cpp PyInlineReductions.cpp diff --git a/python_bindings/src/halide/halide_/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp index d64ca3f9876b..24b46d32d20b 100644 --- a/python_bindings/src/halide/halide_/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -11,6 +11,7 @@ #include "PyExpr.h" #include "PyExternFuncArgument.h" #include "PyFunc.h" +#include "PyFuncTypeChanging.h" #include "PyIROperator.h" #include "PyImageParam.h" #include "PyInlineReductions.h" @@ -49,6 +50,7 @@ PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) { define_tuple(m); define_argument(m); define_boundary_conditions(m); + define_func_type_changing(m); define_buffer(m); define_concise_casts(m); define_error(m); diff --git a/python_bindings/src/halide/halide_/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp index 74eedd9e723e..df5b3d0feb91 100644 --- a/python_bindings/src/halide/halide_/PyIROperator.cpp +++ b/python_bindings/src/halide/halide_/PyIROperator.cpp @@ -167,6 +167,10 @@ void define_operators(py::module &m) { m.def("popcount", &popcount); m.def("count_leading_zeros", &count_leading_zeros); m.def("count_trailing_zeros", &count_trailing_zeros); + m.def("extract_high_bits", &extract_high_bits); + m.def("variable_length_extend", &variable_length_extend); + m.def("extract_bits", &extract_bits); + m.def("extract_low_bits", &extract_low_bits); m.def("div_round_to_zero", &div_round_to_zero); m.def("mod_round_to_zero", &mod_round_to_zero); m.def("random_float", (Expr(*)()) & random_float); diff --git a/src/BoundaryConditions.h b/src/BoundaryConditions.h index a01cc14d42bb..0ce50a557c49 100644 --- a/src/BoundaryConditions.h +++ b/src/BoundaryConditions.h @@ -9,6 +9,7 @@ #include "Expr.h" #include "Func.h" +#include "FuncExtras.h" #include "Lambda.h" namespace Halide { @@ -62,15 +63,6 @@ inline HALIDE_NO_USER_CODE_INLINE void collect_region(Region &collected_args, collect_region(collected_args, std::forward(args)...); } -inline const Func &func_like_to_func(const Func &func) { - return func; -} - -template -inline HALIDE_NO_USER_CODE_INLINE Func func_like_to_func(const T &func_like) { - return lambda(_, func_like(_)); -} - } // namespace Internal /** Impose a boundary condition such that a given expression is returned @@ -99,12 +91,12 @@ Func constant_exterior(const Func &source, const Expr &value, template HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tuple &value, const Region &bounds) { - return constant_exterior(Internal::func_like_to_func(func_like), value, bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, bounds); } template HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Expr &value, const Region &bounds) { - return constant_exterior(Internal::func_like_to_func(func_like), value, bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, bounds); } template @@ -114,7 +106,7 @@ HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tupl object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return constant_exterior(Internal::func_like_to_func(func_like), value, object_bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, object_bounds); } template HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Expr &value) { @@ -127,7 +119,7 @@ HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tupl Bounds &&...bounds) { Region collected_bounds; Internal::collect_region(collected_bounds, std::forward(bounds)...); - return constant_exterior(Internal::func_like_to_func(func_like), value, collected_bounds); + return constant_exterior(::Halide::Internal::func_like_to_func(func_like), value, collected_bounds); } template::value>::type * = nullptr> @@ -154,7 +146,7 @@ Func repeat_edge(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func repeat_edge(const T &func_like, const Region &bounds) { - return repeat_edge(Internal::func_like_to_func(func_like), bounds); + return repeat_edge(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -164,7 +156,7 @@ HALIDE_NO_USER_CODE_INLINE Func repeat_edge(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return repeat_edge(Internal::func_like_to_func(func_like), object_bounds); + return repeat_edge(::Halide::Internal::func_like_to_func(func_like), object_bounds); } // @} @@ -185,7 +177,7 @@ Func repeat_image(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func repeat_image(const T &func_like, const Region &bounds) { - return repeat_image(Internal::func_like_to_func(func_like), bounds); + return repeat_image(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -195,7 +187,7 @@ HALIDE_NO_USER_CODE_INLINE Func repeat_image(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return repeat_image(Internal::func_like_to_func(func_like), object_bounds); + return repeat_image(::Halide::Internal::func_like_to_func(func_like), object_bounds); } /** Impose a boundary condition such that the entire coordinate space is @@ -216,7 +208,7 @@ Func mirror_image(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func mirror_image(const T &func_like, const Region &bounds) { - return mirror_image(Internal::func_like_to_func(func_like), bounds); + return mirror_image(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -226,7 +218,7 @@ HALIDE_NO_USER_CODE_INLINE Func mirror_image(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return mirror_image(Internal::func_like_to_func(func_like), object_bounds); + return mirror_image(::Halide::Internal::func_like_to_func(func_like), object_bounds); } // @} @@ -251,7 +243,7 @@ Func mirror_interior(const Func &source, const Region &bounds); template HALIDE_NO_USER_CODE_INLINE Func mirror_interior(const T &func_like, const Region &bounds) { - return mirror_interior(Internal::func_like_to_func(func_like), bounds); + return mirror_interior(::Halide::Internal::func_like_to_func(func_like), bounds); } template @@ -261,7 +253,7 @@ HALIDE_NO_USER_CODE_INLINE Func mirror_interior(const T &func_like) { object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())}); } - return mirror_interior(Internal::func_like_to_func(func_like), object_bounds); + return mirror_interior(::Halide::Internal::func_like_to_func(func_like), object_bounds); } // @} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c31e37c32a20..7ed887df893d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -70,6 +70,8 @@ set(HEADER_FILES FlattenNestedRamps.h Float16.h Func.h + FuncExtras.h + FuncTypeChanging.h Function.h FunctionPtr.h FuseGPUThreadLoops.h @@ -234,6 +236,7 @@ set(SOURCE_FILES FlattenNestedRamps.cpp Float16.cpp Func.cpp + FuncTypeChanging.cpp Function.cpp FuseGPUThreadLoops.cpp FuzzFloatStores.cpp diff --git a/src/FuncExtras.h b/src/FuncExtras.h new file mode 100644 index 000000000000..92f92cdd4b62 --- /dev/null +++ b/src/FuncExtras.h @@ -0,0 +1,24 @@ +#ifndef HALIDE_FUNC_EXTRAS_H +#define HALIDE_FUNC_EXTRAS_H + +#include "Func.h" +#include "Lambda.h" + +namespace Halide { + +namespace Internal { + +inline const Func &func_like_to_func(const Func &func) { + return func; +} + +template +inline HALIDE_NO_USER_CODE_INLINE Func func_like_to_func(const T &func_like) { + return lambda(_, func_like(_)); +} + +} // namespace Internal + +} // namespace Halide + +#endif diff --git a/src/FuncTypeChanging.cpp b/src/FuncTypeChanging.cpp new file mode 100644 index 000000000000..be843203dae0 --- /dev/null +++ b/src/FuncTypeChanging.cpp @@ -0,0 +1,130 @@ +#include "FuncTypeChanging.h" + +namespace Halide { + +static bool operator==(const Var &a, const Var &b) { + return a.same_as(b); +} + +namespace FuncTypeChanging { + +// NOTE: Precondition: `chunk_idx u< num_chunks`! +static Expr get_nth_chunk(const Expr &value, const Type &chunk_type, + const Expr &chunk_idx, ChunkOrder chunk_order) { + int num_bits_total = value.type().bits(); + int num_bits_per_chunk = chunk_type.bits(); + int num_chunks = num_bits_total / num_bits_per_chunk; + user_assert(num_bits_total > num_bits_per_chunk && + num_bits_total % num_bits_per_chunk == 0 && num_chunks > 1) + << "Input value must evenly partition into several chunks.\n"; + + Expr low_chunk_idx = chunk_order == ChunkOrder::LowestFirst ? + chunk_idx : + (num_chunks - 1) - chunk_idx; + Expr num_low_padding_bits = num_bits_per_chunk * low_chunk_idx; + Expr chunk_bits = extract_bits(value, num_low_padding_bits, + make_unsigned(num_bits_per_chunk)); + return cast(chunk_type, chunk_bits); +} + +static Expr concatenate_chunks(std::vector chunks, + ChunkOrder chunk_order) { + const Type chunk_type = chunks.front().type(); + const int chunk_width = chunk_type.bits(); + Type final_type = chunk_type.with_bits(chunk_width * chunks.size()); + + if (chunk_order != ChunkOrder::LowestFirst) { + std::reverse(std::begin(chunks), std::end(chunks)); + } + + Expr res = Internal::make_zero(final_type); + for (size_t chunk_idx = 0; chunk_idx != chunks.size(); ++chunk_idx) { + Expr wide_chunk = cast(final_type, chunks[chunk_idx]); // zero ext + Expr positioned_chunk = wide_chunk << (chunk_width * chunk_idx); + res = res | positioned_chunk; + } + + return res; +} + +static Func narrow(const Func &wide_input, const Type &dst_type, int num_chunks, + const Var &dim, const std::string &name, + ChunkOrder chunk_order) { + const std::vector dims = wide_input.args(); + user_assert(count(begin(dims), end(dims), dim) == 1) + << "Expected dimension " << dim << " to represent " + << "exactly one function argument!\n"; + + Expr wide_elt_idx = dim / num_chunks; + Expr chunk_idx = make_unsigned(dim % num_chunks); + + std::vector args; + args.reserve(dims.size()); + std::transform(dims.begin(), dims.end(), std::back_inserter(args), + [dim, wide_elt_idx](const Var &input_dim) { + return input_dim.same_as(dim) ? wide_elt_idx : input_dim; + }); + + Func narrowed(name); + narrowed(dims) = + get_nth_chunk(wide_input(args), dst_type, chunk_idx, chunk_order); + + return narrowed; +} + +static Func widen(const Func &narrow_input, const Type &dst_type, + int num_chunks, const Var &dim, const std::string &name, + ChunkOrder chunk_order) { + const std::vector dims = narrow_input.args(); + user_assert(count(begin(dims), end(dims), dim) == 1) + << "Expected dimension " << dim << " to represent " + << "exactly one function argument!\n"; + + auto dim_index = + std::distance(begin(dims), std::find(begin(dims), end(dims), dim)); + + std::vector baseline_args; + baseline_args.reserve(dims.size()); + std::transform(dims.begin(), dims.end(), std::back_inserter(baseline_args), + [](const Var &input_dim) { return input_dim; }); + + std::vector chunks; + chunks.reserve(num_chunks); + std::generate_n( + std::back_inserter(chunks), num_chunks, + [&chunks, baseline_args, dim_index, num_chunks, dim, narrow_input]() { + int chunk_idx = chunks.size(); + std::vector args = baseline_args; + args[dim_index] = (num_chunks * dim) + chunk_idx; + return narrow_input(args); + }); + + Func widened(name); + widened(dims) = concatenate_chunks(chunks, chunk_order); + + return widened; +} + +Func change_type(const Func &input, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order) { + const Type &src_type = input.output_type(); + int src_width = src_type.bits(); + int dst_width = dst_type.bits(); + bool is_widening = dst_width > src_width; + auto [min_width, max_width] = std::minmax(src_width, dst_width); + int num_chunks = max_width / min_width; + user_assert(src_type.with_bits(dst_width) == dst_type && + src_type.is_uint() && src_width != dst_width && + max_width % min_width == 0 && num_chunks > 1) + << "The source type " << src_type << " and destination type " + << dst_type << " must be similar uint types with different widths, " + << "larger width must be an integral multiple of the smaller width.\n"; + + return is_widening ? + widen(input, dst_type, num_chunks, dim, name, chunk_order) : + narrow(input, dst_type, num_chunks, dim, name, chunk_order); +} + +} // namespace FuncTypeChanging + +} // namespace Halide diff --git a/src/FuncTypeChanging.h b/src/FuncTypeChanging.h new file mode 100644 index 000000000000..7285e75469d7 --- /dev/null +++ b/src/FuncTypeChanging.h @@ -0,0 +1,53 @@ +#ifndef HALIDE_FUNC_TYPE_CHANGING_H +#define HALIDE_FUNC_TYPE_CHANGING_H + +/** \file + * Support for changing the function's return type by fusing a number of + * consequtive elements, or splitting a single element into parts, + * along a certain dimension. + */ + +#include "Func.h" +#include "FuncExtras.h" + +namespace Halide { + +namespace FuncTypeChanging { + +enum class ChunkOrder { + // Example: + // i32 0x0D0C0B0A -> 4xi8 -> { 0x0A, 0x0B, 0x0C, 0x0D } + // i32 0x0D0C0B0A -> 2xi16 -> { 0x0B0A, 0x0D0C } + // 4xi8 { 0x0A, 0x0B, 0x0C, 0x0D } -> i32 -> 0x0D0C0B0A + // 2xi16 { 0x0B0A, 0x0D0C } -> i32 -> 0x0D0C0B0A + // 2xi16 { 0x0D0C, 0x0B0A } -> i32 -> 0x0B0A0D0C + LowestFirst, + + // Example: + // i32 0x0D0C0B0A -> 4xi8 -> { 0x0D, 0x0C, 0x0B, 0x0A } + // i32 0x0D0C0B0A -> 2xi16t -> { 0x0D0C, 0x0B0A } + // 4xi8 { 0x0A, 0x0B, 0x0C, 0x0D } -> i32 -> 0x0A0B0C0D + // 2xi16 { 0x0B0A, 0x0D0C } -> i32 -> 0x0B0A0D0C + // 2xi16 { 0x0D0C, 0x0B0A } -> i32 -> 0x0D0C0B0A + HighestFirst, + + Default = LowestFirst // DO NOT CHANGE. +}; + +Func change_type(const Func &input, const Type &dst_type, const Var &dim, + const std::string &name, + ChunkOrder chunk_order = ChunkOrder::Default); + +template +HALIDE_NO_USER_CODE_INLINE Func change_type( + const T &func_like, const Type &dst_type, const Var &dim, + const std::string &name, ChunkOrder chunk_order = ChunkOrder::Default) { + return change_type(Internal::func_like_to_func(func_like), dst_type, dim, + name, chunk_order); +} + +} // namespace FuncTypeChanging + +} // namespace Halide + +#endif diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 4693060a8d45..eedf62f4954b 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -1941,6 +1941,21 @@ Expr cast(Type t, Expr a) { return Internal::Cast::make(t, std::move(a)); } +Expr make_unsigned(Expr a) { + user_assert(a.defined()) << "cast of undefined Expr\n"; + + Type ty = a.type(); + + user_assert(ty.is_int_or_uint()) + << "The input must be of an integer type, got " << ty << "\n"; + + if (ty.is_uint()) { + return a; + } + + return cast(ty.with_code(halide_type_uint), a); +} + Expr clamp(Expr a, const Expr &min_val, const Expr &max_val) { user_assert(a.defined() && min_val.defined() && max_val.defined()) << "clamp of undefined Expr\n"; @@ -2531,6 +2546,99 @@ Expr count_trailing_zeros(Expr x) { {std::move(x)}, Internal::Call::PureIntrinsic); } +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr extract_high_bits(const Expr &val, const Expr &num_high_bits) { + user_assert(val.defined()) << "extract_high_bits with undefined val"; + user_assert(num_high_bits.defined()) + << "extract_high_bits with undefined num_high_bits"; + + Type ty = val.type(); + user_assert(ty.is_int_or_uint()) + << "extract_high_bits: val must be of an integer type, got " << ty + << "\n"; + + Type shamt_ty = num_high_bits.type(); + user_assert(shamt_ty.is_uint()) << "extract_high_bits: num_high_bits must " + "be of an unsigned integer type, got " + << shamt_ty << "\n"; + + Expr num_low_padding_bits = ty.bits() - num_high_bits; + // The sign bit is already positioned, just perform the right-shift. + // We'll either pad with zeros (if uint) or replicate sign bit (if int). + return val >> cast(UInt(ty.bits()), num_low_padding_bits); +} + +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr variable_length_extend(Expr val, const Expr &num_low_bits) { + user_assert(val.defined()) << "variable_length_extend with undefined val"; + user_assert(num_low_bits.defined()) + << "variable_length_extend with undefined num_low_bits"; + + Type ty = val.type(); + user_assert(ty.is_int_or_uint()) + << "variable_length_extend: val must be of an integer type, got " << ty + << "\n"; + + Type shamt_ty = num_low_bits.type(); + user_assert(shamt_ty.is_uint()) + << "variable_length_extend: num_low_bits must be of an unsigned " + "integer type, got " + << shamt_ty << "\n"; + + Expr num_high_padding_bits = ty.bits() - num_low_bits; + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type. + val = val << cast(UInt(ty.bits()), num_high_padding_bits); + // And then let the `extract_high_bits()` deal with it. + return extract_high_bits(val, /*num_high_bits=*/num_low_bits); +} + +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr extract_bits(Expr val, const Expr &num_low_padding_bits, + const Expr &num_bits) { + user_assert(val.defined()) << "extract_bits with undefined val"; + user_assert(num_low_padding_bits.defined()) + << "extract_bits with undefined num_low_padding_bits"; + user_assert(num_bits.defined()) << "extract_bits with undefined num_bits"; + + Type ty = val.type(); + user_assert(ty.is_int_or_uint()) + << "extract_bits: val must be of an integer type, got " << ty << "\n"; + + Type shamt_ty = num_low_padding_bits.type(); + user_assert(shamt_ty.is_uint()) + << "extract_bits: num_low_padding_bits must be of an unsigned integer " + "type, got " + << shamt_ty << "\n"; + + shamt_ty = num_bits.type(); + user_assert(shamt_ty.is_uint()) + << "extract_bits: num_bits must be of an unsigned integer type, got " + << shamt_ty << "\n"; + + Expr num_high_padding_bits = (ty.bits() - num_low_padding_bits) - num_bits; + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type. + val = val << cast(UInt(ty.bits()), num_high_padding_bits); + // And then let the `extract_high_bits()` deal with it. + return extract_high_bits(val, /*num_high_bits=*/num_bits); +} + +// FIXME: this only handles the sane case when at least one, +// but not more than all, bits are extracted. +// Do we need to handle the general case instead? +Expr extract_low_bits(Expr val, const Expr &num_low_bits) { + // Let `extract_bits()` deal with everything. + return extract_bits(std::move(val), /*num_low_padding_bits=*/Expr(0U), + /*num_bits=*/num_low_bits); +} + Expr div_round_to_zero(Expr x, Expr y) { user_assert(x.defined()) << "div_round_to_zero of undefined dividend\n"; user_assert(y.defined()) << "div_round_to_zero of undefined divisor\n"; diff --git a/src/IROperator.h b/src/IROperator.h index ed0b11bb4fef..581c152058fd 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -394,6 +394,9 @@ inline Expr cast(Expr a) { /** Cast an expression to a new type. */ Expr cast(Type t, Expr a); +/** Cast an integer-typed expression to an unsigned integer type. */ +Expr make_unsigned(Expr a); + /** Return the sum of two expressions, doing any necessary type * coercion using \ref Internal::match_types */ Expr operator+(Expr a, Expr b); @@ -1208,6 +1211,43 @@ Expr count_leading_zeros(Expr x); * zero, the result is the number of bits in the type. */ Expr count_trailing_zeros(Expr x); +/** + * Extract \p num_high_bits high bits of \p val. + * (with either zero-extension, if \p val is uint, + * or sign-extension, if \p val is int) + * + * NOTE: Precondition: `num_high_bits != 0 && num_high_bits u<= bitwidth(val)`! + */ +Expr extract_high_bits(const Expr &val, const Expr &num_high_bits); + +/** + * Extend (either zero-extend if \p val is uint or sign-extend, if \p is int) + * low \p num_low_bits bits of \p val to the whole \p val underlying type. + * + * NOTE: Precondition: `num_low_bits != 0 && num_low_bits u<= bitwidth(val)`! + */ +Expr variable_length_extend(Expr val, const Expr &num_low_bits); + +/** + * Extract \p num_bits bits starting with the \p low_bit_offset bit. + * (with either zero-extension, if \p val is uint, + * or sign-extension, if \p val is int) + * + * NOTE: Precondition: + * `num_bits != 0 && (num_bits + low_bit_offset) u<= bitwidth(val)`! + */ +Expr extract_bits(Expr val, const Expr &num_low_padding_bits, + const Expr &num_bits); + +/** + * Extract \p num_high_bits high bits of \p val. + * (with either zero-extension, if \p val is uint, + * or sign-extension, if \p val is int) + * + * NOTE: Precondition: `num_low_bits != 0 && num_low_bits u<= bitwidth(val)`! + */ +Expr extract_low_bits(Expr val, const Expr &num_low_bits); + /** Divide two integers, rounding towards zero. This is the typical * behavior of most hardware architectures, which differs from * Halide's division operator, which is Euclidean (rounds towards diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 62a93e25a982..79e08a8dc2d7 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -108,6 +108,7 @@ tests(GROUPS correctness force_onto_stack.cpp func_lifetime.cpp func_lifetime_2.cpp + func_type_changing.cpp fuse.cpp fuse_gpu_threads.cpp fused_where_inner_extent_is_zero.cpp @@ -310,6 +311,7 @@ tests(GROUPS correctness unsafe_promises.cpp unused_func.cpp update_chunk.cpp + variable_bit_length_ops.cpp vector_bounds_inference.cpp vector_cast.cpp vector_extern.cpp diff --git a/test/correctness/func_type_changing.cpp b/test/correctness/func_type_changing.cpp new file mode 100644 index 000000000000..1c752f659a42 --- /dev/null +++ b/test/correctness/func_type_changing.cpp @@ -0,0 +1,307 @@ +#include "Halide.h" +#include +#include + +#include + +using namespace Halide; +using namespace Halide::FuncTypeChanging; + +template +bool expect_eq(Buffer actual, Buffer expected) { + bool eq = true; + expected.for_each_value( + [&](const T &expected_val, const T &actual_val) { + if (actual_val != expected_val) { + eq = false; + fprintf(stderr, "Failed: expected %d, actual %d\n", + (int)expected_val, (int)actual_val); + } + }, + actual); + return eq; +} + +template +auto gen_random_chunks(std::initializer_list dims) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist( + std::numeric_limits::min(), + std::numeric_limits::max()); + + Buffer buf(dims); + buf.for_each_value([&](CHUNK_TYPE &v) { v = dist(gen); }); + + return buf; +} + +template +bool test_1d_rowwise_with_n_times_chunk_type(int num_chunks, Target t) { + const int width = 256; + Buffer input_buf = gen_random_chunks({width}); + + using WIDE_STORAGE_TYPE = uint64_t; + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + const int WIDE_TYPE_WIDTH = CHUNK_WIDTH * num_chunks; + + Var x("x"); + + int wide_width = width / num_chunks; + + auto forward = [wide_width, WIDE_TYPE_WIDTH, x, t](const Func &input, + ChunkOrder chunk_order) { + Buffer wide(wide_width); + Func widen = change_type(input, UInt(WIDE_TYPE_WIDTH), x, "widener", + chunk_order); + Func store("store"); + store(x) = cast(widen(x)); + store.realize(wide, t); + return wide; + }; + + auto forward_naive = [wide_width, num_chunks](Buffer input_buf, + ChunkOrder chunk_order) { + Buffer wide(wide_width); + for (int32_t x = 0; x < wide_width; x++) { + WIDE_STORAGE_TYPE &v = wide(x); + v = 0; + for (int chunk = 0; chunk != num_chunks; ++chunk) { + int chunk_idx = chunk_order == ChunkOrder::HighestFirst ? + chunk : + (num_chunks - 1) - chunk; + v <<= CHUNK_WIDTH; + v |= (WIDE_STORAGE_TYPE)input_buf(num_chunks * x + chunk_idx); + } + } + return wide; + }; + + auto backward = [t, width, x, WIDE_TYPE_WIDTH]( + const Buffer &actual_widened_result, + ChunkOrder chunk_order) { + Buffer narrow(width); + Func load("load"); + load(x) = cast(UInt(WIDE_TYPE_WIDTH), actual_widened_result(x)); + Func narrown = + change_type(load, UInt(CHUNK_WIDTH), x, "narrower", chunk_order); + narrown.realize(narrow, t); + return narrow; + }; + + Func input("input"); + input(x) = input_buf(x); + + bool success = true; + for (ChunkOrder chunk_order : + {ChunkOrder::LowestFirst, ChunkOrder::HighestFirst}) { + const auto wide_actual = forward(input, chunk_order); + const auto wide_expected = forward_naive(input_buf, chunk_order); + success &= expect_eq(wide_actual, wide_expected); + + const auto narrow_actual = backward(wide_actual, chunk_order); + success &= expect_eq(narrow_actual, input_buf); + } + + return success; +} + +template +bool test_2d_rowwise_with_n_times_chunk_type(int num_chunks, Target t) { + const int width = 256; + const int height = 16; + Buffer input_buf = + gen_random_chunks({width, height}); + + using WIDE_STORAGE_TYPE = uint64_t; + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + const int WIDE_TYPE_WIDTH = CHUNK_WIDTH * num_chunks; + + Var x("x"), y("y"); + + int wide_width = width / num_chunks; + + auto forward = [wide_width, WIDE_TYPE_WIDTH, x, y, + t](const Func &input, ChunkOrder chunk_order) { + Buffer wide({wide_width, height}); + Func widen = change_type(input, UInt(WIDE_TYPE_WIDTH), x, "widener", + chunk_order); + Func store("store"); + store(x, y) = cast(widen(x, y)); + store.realize(wide, t); + return wide; + }; + + auto forward_naive = [wide_width, num_chunks](Buffer input_buf, + ChunkOrder chunk_order) { + Buffer wide({wide_width, height}); + for (int32_t y = 0; y < height; y++) { + for (int32_t x = 0; x < wide_width; x++) { + WIDE_STORAGE_TYPE &v = wide(x, y); + v = 0; + for (int chunk = 0; chunk != num_chunks; ++chunk) { + int chunk_idx = chunk_order == ChunkOrder::HighestFirst ? + chunk : + (num_chunks - 1) - chunk; + v <<= CHUNK_WIDTH; + v |= (WIDE_STORAGE_TYPE)input_buf( + num_chunks * x + chunk_idx, y); + } + } + } + return wide; + }; + + auto backward = [t, width, height, x, y, WIDE_TYPE_WIDTH]( + const Buffer &actual_widened_result, + ChunkOrder chunk_order) { + Buffer narrow({width, height}); + Func load("load"); + load(x, y) = cast(UInt(WIDE_TYPE_WIDTH), actual_widened_result(x, y)); + Func narrown = + change_type(load, UInt(CHUNK_WIDTH), x, "narrower", chunk_order); + narrown.realize(narrow, t); + return narrow; + }; + + Func input("input"); + input(x, y) = input_buf(x, y); + + bool success = true; + for (ChunkOrder chunk_order : + {ChunkOrder::LowestFirst, ChunkOrder::HighestFirst}) { + const auto wide_actual = forward(input, chunk_order); + const auto wide_expected = forward_naive(input_buf, chunk_order); + success &= expect_eq(wide_actual, wide_expected); + + const auto narrow_actual = backward(wide_actual, chunk_order); + success &= expect_eq(narrow_actual, input_buf); + } + + return success; +} + +template +bool test_2d_colwise_with_n_times_chunk_type(int num_chunks, Target t) { + const int width = 16; + const int height = 256; + Buffer input_buf = + gen_random_chunks({width, height}); + + using WIDE_STORAGE_TYPE = uint64_t; + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + const int WIDE_TYPE_WIDTH = CHUNK_WIDTH * num_chunks; + + Var x("x"), y("y"); + + int wide_height = height / num_chunks; + + auto forward = [wide_height, WIDE_TYPE_WIDTH, x, y, + t](const Func &input, ChunkOrder chunk_order) { + Buffer wide({width, wide_height}); + Func widen = change_type(input, UInt(WIDE_TYPE_WIDTH), y, "widener", + chunk_order); + Func store("store"); + store(x, y) = cast(widen(x, y)); + store.realize(wide, t); + return wide; + }; + + auto forward_naive = [wide_height, num_chunks](Buffer input_buf, + ChunkOrder chunk_order) { + Buffer wide({width, wide_height}); + for (int32_t y = 0; y < wide_height; y++) { + for (int32_t x = 0; x < width; x++) { + WIDE_STORAGE_TYPE &v = wide(x, y); + v = 0; + for (int chunk = 0; chunk != num_chunks; ++chunk) { + int chunk_idx = chunk_order == ChunkOrder::HighestFirst ? + chunk : + (num_chunks - 1) - chunk; + v <<= CHUNK_WIDTH; + v |= (WIDE_STORAGE_TYPE)input_buf(x, num_chunks * y + + chunk_idx); + } + } + } + return wide; + }; + + auto backward = [t, width, height, x, y, WIDE_TYPE_WIDTH]( + const Buffer &actual_widened_result, + ChunkOrder chunk_order) { + Buffer narrow({width, height}); + Func load("load"); + load(x, y) = cast(UInt(WIDE_TYPE_WIDTH), actual_widened_result(x, y)); + Func narrown = + change_type(load, UInt(CHUNK_WIDTH), y, "narrower", chunk_order); + narrown.realize(narrow, t); + return narrow; + }; + + Func input("input"); + input(x, y) = input_buf(x, y); + + bool success = true; + for (ChunkOrder chunk_order : + {ChunkOrder::LowestFirst, ChunkOrder::HighestFirst}) { + const auto wide_actual = forward(input, chunk_order); + const auto wide_expected = forward_naive(input_buf, chunk_order); + success &= expect_eq(wide_actual, wide_expected); + + const auto narrow_actual = backward(wide_actual, chunk_order); + success &= expect_eq(narrow_actual, input_buf); + } + + return success; +} + +template +bool test_with_n_times_chunk_type(int num_chunks, Target t) { + bool success = true; + + success &= + test_1d_rowwise_with_n_times_chunk_type(num_chunks, t); + success &= + test_2d_rowwise_with_n_times_chunk_type(num_chunks, t); + success &= + test_2d_colwise_with_n_times_chunk_type(num_chunks, t); + + return success; +} + +template +bool test_with_chunk_type(Target t) { + bool success = true; + + const int CHUNK_WIDTH = 8 * sizeof(CHUNK_TYPE); + for (int num_chunks = 2; CHUNK_WIDTH * num_chunks <= 64; num_chunks *= 2) { + success &= test_with_n_times_chunk_type(num_chunks, t); + } + + return success; +} + +bool test_all(Target t) { + bool success = true; + + success &= test_with_chunk_type(t); + success &= test_with_chunk_type(t); + success &= test_with_chunk_type(t); + + return success; +} + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + + bool success = test_all(target); + + if (!success) { + fprintf(stderr, "Failed!\n"); + return -1; + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/variable_bit_length_ops.cpp b/test/correctness/variable_bit_length_ops.cpp new file mode 100644 index 000000000000..ed6bce2dcb76 --- /dev/null +++ b/test/correctness/variable_bit_length_ops.cpp @@ -0,0 +1,317 @@ +#include "Halide.h" +#include +#include + +#include + +using namespace Halide; +using namespace Halide::FuncTypeChanging; + +template +static constexpr int local_bitwidth() { + return 8 * sizeof(T); +} + +template +static T local_extract_high_bits(T val, int num_high_bits) { + int num_low_padding_bits = local_bitwidth() - num_high_bits; + if ((unsigned)num_low_padding_bits >= (unsigned)local_bitwidth()) { + return 42; + } + // The sign bit is already positioned, just perform the right-shift. + // We'll either pad with zeros (if uint) or replicate sign bit (if int). + assert((unsigned)num_low_padding_bits < (unsigned)local_bitwidth()); + return val >> num_low_padding_bits; +} + +template +static T local_variable_length_extend(T val, int num_low_bits) { + int num_high_padding_bits = local_bitwidth() - num_low_bits; + if ((unsigned)num_high_padding_bits >= (unsigned)local_bitwidth()) { + return 42; + } + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type, + assert((unsigned)num_high_padding_bits < (unsigned)local_bitwidth()); + val <<= num_high_padding_bits; + return local_extract_high_bits(val, /*num_high_bits=*/num_low_bits); +} + +template +static T local_extract_bits(T val, int num_low_padding_bits, int num_bits) { + if (num_bits == 0) { + return 42; + } + int num_high_padding_bits = + (local_bitwidth() - num_low_padding_bits) - num_bits; + if ((unsigned)num_high_padding_bits >= (unsigned)local_bitwidth()) { + return 42; + } + // First, left-shift the variable-sized input so that it's highest (sign) + // bit is positioned in the highest (sign) bit of the containment type, + assert((unsigned)num_high_padding_bits < (unsigned)local_bitwidth()); + val <<= num_high_padding_bits; + return local_extract_high_bits(val, /*num_high_bits=*/num_bits); +} + +template +static T local_extract_low_bits(T val, int num_low_bits) { + return local_extract_bits(val, /*num_low_padding_bits=*/0, num_low_bits); +} + +template +static bool expect_eq(Buffer actual, Buffer expected) { + bool eq = true; + expected.for_each_value( + [&](const T &expected_val, const T &actual_val) { + if (actual_val != expected_val) { + eq = false; + fprintf(stderr, "Failed: expected %d, actual %d\n", + (int)expected_val, (int)actual_val); + } + }, + actual); + return eq; +} + +template +static auto gen_random_input(std::initializer_list dims) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist(std::numeric_limits::min(), + std::numeric_limits::max()); + + Buffer buf(dims); + buf.for_each_value([&](T &v) { v = dist(gen); }); + + return buf; +} + +template +static bool test_extract_high_bits(Target t) { + const int width = 8192; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res(MAX_BITS * width); + Func fun("f"); + Expr input_idx = x / MAX_BITS; + Expr num_high_bits = x % MAX_BITS; + Expr num_low_padding_bits = local_bitwidth() - num_high_bits; + // `extract_high_bits()` is not defined for OOB or 0 num_high_bits. + fun(x) = select(make_unsigned(num_low_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + extract_high_bits(input_buf(input_idx), + make_unsigned(num_high_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res(MAX_BITS * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / MAX_BITS; + int num_high_bits = x % MAX_BITS; + res(x) = + local_extract_high_bits(input_buf(input_idx), num_high_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_variable_length_extend(Target t) { + const int width = 8192; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res(MAX_BITS * width); + Func fun("f"); + Expr input_idx = x / MAX_BITS; + Expr num_low_bits = x % MAX_BITS; + Expr num_high_padding_bits = local_bitwidth() - num_low_bits; + // `variable_length_extend()` is not defined for OOB or 0 num_low_bits. + fun(x) = select(make_unsigned(num_high_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + variable_length_extend(input_buf(input_idx), + make_unsigned(num_low_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res(MAX_BITS * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / MAX_BITS; + int num_low_bits = x % MAX_BITS; + res(x) = local_variable_length_extend(input_buf(input_idx), + num_low_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_extract_bits(Target t) { + const int width = 256; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res((MAX_BITS * MAX_BITS) * width); + Func fun("f"); + Expr input_idx = x / (MAX_BITS * MAX_BITS); + Expr num_low_padding_bits = (x / MAX_BITS) % MAX_BITS; + Expr num_bits = x % MAX_BITS; + Expr num_high_padding_bits = + (local_bitwidth() - num_low_padding_bits) - num_bits; + // `extract_bits()` is not defined for 0 or OOB num_bits. + fun(x) = select(num_bits == 0 || make_unsigned(num_high_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + extract_bits(input_buf(input_idx), + make_unsigned(num_low_padding_bits), + make_unsigned(num_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res((MAX_BITS * MAX_BITS) * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / (MAX_BITS * MAX_BITS); + int num_low_padding_bits = (x / MAX_BITS) % MAX_BITS; + int num_bits = x % MAX_BITS; + res(x) = local_extract_bits(input_buf(input_idx), + num_low_padding_bits, num_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_extract_low_bits(Target t) { + const int width = 8192; + Buffer input_buf = gen_random_input({width}); + + constexpr int T_BITS = 8 * sizeof(T); + constexpr int MAX_BITS = 2 + T_BITS; + + Var x("x"); + + auto actual = [&]() { + Buffer res(MAX_BITS * width); + Func fun("f"); + Expr input_idx = x / MAX_BITS; + Expr num_low_bits = x % MAX_BITS; + Expr num_high_padding_bits = local_bitwidth() - num_low_bits; + // `extract_low_bits()` is not defined for OOB or 0 num_low_bits. + fun(x) = select(make_unsigned(num_high_padding_bits) >= + make_unsigned(local_bitwidth()), + 42, + extract_low_bits(input_buf(input_idx), + make_unsigned(num_low_bits))); + fun.realize(res, t); + return res; + }; + + auto expected = [&]() { + Buffer res(MAX_BITS * width); + for (int x = 0; x != res.width(); ++x) { + int input_idx = x / MAX_BITS; + int num_low_bits = x % MAX_BITS; + res(x) = local_extract_low_bits(input_buf(input_idx), num_low_bits); + } + return res; + }; + + bool success = true; + + const auto res_actual = actual(); + const auto res_expected = expected(); + success &= expect_eq(res_actual, res_expected); + + return success; +} + +template +static bool test_with_type(Target t) { + bool success = true; + + success &= test_extract_high_bits(t); + success &= test_variable_length_extend(t); + success &= test_extract_bits(t); + success &= test_extract_low_bits(t); + + return success; +} + +static bool test_all(Target t) { + bool success = true; + + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + success &= test_with_type(t); + + return success; +} + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + + bool success = test_all(target); + + if (!success) { + fprintf(stderr, "Failed!\n"); + return -1; + } + + printf("Success!\n"); + return 0; +}