diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index d03bf5064a5f..55469cd32817 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -47,72 +47,6 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { new_vectors.push_back(new_vector); } - // If any of the args are narrowing casts, convert them to shuffles of - // reinterpret casts so we can fold them into this shuffle. This all assumes - // little-endianness, so if we ever support a big-endian backend we'll have - // to switch on the target here. - for (Expr &v : new_vectors) { - if (!(v.type().is_int() || v.type().is_uint()) || !v.as()) { - continue; - } - - auto x_is_16_bit = is_int(x, 16) || is_uint(x, 16); - auto x_is_32_bit = is_int(x, 32) || is_uint(x, 32); - auto x_is_64_bit = is_int(x, 64) || is_uint(x, 64); - - auto rewrite = IRMatcher::rewriter(v, op->type); - - auto t8 = v.type().with_bits(8); - auto t16 = v.type().with_bits(16); - auto t32 = v.type().with_bits(32); - - // Shifts have been canonicalized to divisions provided they are less - // than 32 bit, so the patterns below switch from division to shifting - // at 32-bits. - int stride = 0, start = 0; - - if (rewrite(cast(t8, x / (1 << 8)), x, x_is_16_bit) || - rewrite(cast(t16, x / (1 << 16)), x, x_is_32_bit) || - rewrite(cast(t32, shift_right(x, 32)), x, x_is_64_bit)) { - // Extract high half - stride = 2; - start = 1; - } else if (rewrite(cast(t8, x), x, x_is_16_bit) || - rewrite(cast(t16, x), x, x_is_32_bit) || - rewrite(cast(t32, x), x, x_is_64_bit)) { - // Extract low half - stride = 2; - start = 0; - } else if (rewrite(cast(t8, x / (1 << 24)), x, x_is_32_bit) || - rewrite(cast(t16, shift_right(x, 48)), x, x_is_64_bit)) { - // Extract 4th quarter - stride = 4; - start = 3; - } else if (rewrite(cast(t8, x / (1 << 16)), x, x_is_32_bit) || - rewrite(cast(t16, shift_right(x, 32)), x, x_is_64_bit)) { - // Extract 3rd quarter - stride = 4; - start = 2; - } else if (rewrite(cast(t8, x / (1 << 8)), x, x_is_32_bit) || - rewrite(cast(t16, x / (1 << 16)), x, x_is_64_bit)) { - // Extract 2nd quarter - stride = 4; - start = 1; - } else if (rewrite(cast(t8, x), x, x_is_32_bit) || - rewrite(cast(t16, x), x, x_is_64_bit)) { - // Extract low quarter - stride = 4; - start = 0; - } else { - continue; - } - - int lanes = v.type().lanes(); - v = reinterpret(v.type().with_lanes(lanes * stride), rewrite.result); - v = Shuffle::make_slice(v, start, stride, lanes); - changed = true; - } - // Try to convert a load with shuffled indices into a // shuffle of a dense load. if (const Load *first_load = new_vectors[0].as()) { @@ -257,7 +191,6 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { } } } - } else if (op->is_concat()) { // Try to collapse a concat of ramps into a single ramp. const Ramp *r = new_vectors[0].as(); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index d9461850500a..60016b67d4c3 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -272,7 +272,6 @@ tests(GROUPS correctness reduction_non_rectangular.cpp reduction_schedule.cpp register_shuffle.cpp - reinterpret_vector.cpp reorder_rvars.cpp reorder_storage.cpp require.cpp diff --git a/test/correctness/reinterpret_vector.cpp b/test/correctness/reinterpret_vector.cpp deleted file mode 100644 index fc6b5602af3f..000000000000 --- a/test/correctness/reinterpret_vector.cpp +++ /dev/null @@ -1,90 +0,0 @@ -#include "Halide.h" - -using namespace Halide; -using namespace Halide::Internal; - -class CheckNoVectorMath : public IRMutator { -public: - using IRMutator::mutate; - Expr mutate(const Expr &e) override { - IRMutator::mutate(e); - // An allow-list of IR nodes we are OK with - if (e.type().is_vector() && - !(Call::as_intrinsic(e, {Call::reinterpret}) || - e.as() || - e.as() || - e.as() || - e.as())) { - std::cout << "Unexpected vector expression: " << e << "\n"; - exit(-1); - } - - return e; - } -}; - -int main(int argc, char **argv) { - // Check we can treat a vector of a wide type as a wider vector of a - // narrower type for free. - Var x, y, c; - - // Treat a 32-bit image as a twice-as-wide 16-bit image - { - Func narrow, wide; - wide(x, y) = cast(x + y); - narrow(x, y) = select(x % 2 == 0, - cast(wide(x / 2, y)), - cast(wide(x / 2, y) >> 16)); - wide.compute_root(); - narrow.align_bounds(x, 16).vectorize(x, 16); - CheckNoVectorMath checker; - narrow.add_custom_lowering_pass(&checker, nullptr); - - Buffer out = narrow.realize({1024, 1024}); - - for (int y = 0; y < out.height(); y++) { - for (int x = 0; x < out.width(); x++) { - int correct = ((x % 2 == 0) ? x / 2 + y : (x / 2 + y) >> 16); - if (out(x, y) != correct) { - printf("out(%d, %d) = %d instead of %d\n", x, y, out(x, y), correct); - return -1; - } - } - } - } - - // Treat 2-dimensional 32-bit values representing rgba as 3-dimensional rgba - { - Func rgba_packed, rgba; - rgba_packed(x, y) = cast(x + y); - rgba(c, x, y) = mux(c, {cast(rgba_packed(x, y)), - cast(rgba_packed(x, y) >> 8), - cast(rgba_packed(x, y) >> 16), - cast(rgba_packed(x, y) >> 24)}); - rgba_packed.compute_root(); - rgba.align_bounds(x, 16).vectorize(x, 16).bound(c, 0, 4).unroll(c); - rgba.output_buffer().dim(1).set_stride(4); - CheckNoVectorMath checker; - rgba.add_custom_lowering_pass(&checker, nullptr); - - Buffer out = rgba.realize({3, 1024, 1024}); - - for (int y = 0; y < out.dim(2).extent(); y++) { - for (int x = 0; x < out.dim(1).extent(); x++) { - for (int c = 0; c < out.dim(0).extent(); c++) { - uint8_t correct = (c == 0) ? (x + y) : - (c == 1) ? (x + y) >> 8 : - (c == 2) ? (x + y) >> 16 : - (x + y) >> 24; - if (out(c, x, y) != correct) { - printf("out(%d, %d, %d) = %d instead of %d\n", c, x, y, out(x, y), correct); - return -1; - } - } - } - } - } - - printf("Success!\n"); - return 0; -}