From 656eb43e93e0fecf3e9d1cbc328a1bba0ea2ed77 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Thu, 14 Jul 2022 01:39:11 +0300 Subject: [PATCH] Unbreak handling of reinterpret --- src/CodeGen_LLVM.cpp | 19 ++++++++++--------- src/Deinterleave.cpp | 9 +++++++++ src/IR.cpp | 8 +------- src/VectorizeLoops.cpp | 10 ++++++++++ 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index e0f7635e462c..74dda7e5fd34 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1513,19 +1513,20 @@ void CodeGen_LLVM::visit(const Reinterpret *op) { } else { if (src.is_scalar() && dst.is_vector()) { - // If we're converting from a scalar to a vector, first produce - // a temporary vector of the required size by splatting - // the source scalar. Note that the scalar size may be different. - value = create_broadcast(value, (dst.bits() * dst.lanes()) / src.bits()); + // If the source type is a scalar, we promote it to an + // equivalent vector of width one before doing the + // bitcast, because llvm's bitcast operator doesn't + // want to convert between scalars and vectors. + value = create_broadcast(value, 1); } if (src.is_vector() && dst.is_scalar()) { - // If we're converting from a vector to a scalar, first change - // element count/type, and then extract the first lane. - Type tmp = dst.with_lanes((src.bits() * src.lanes()) / dst.bits()); - llvm_dst = llvm_type_of(tmp); + // Similarly, if we're converting from a vector to a + // scalar, convert to a vector of width 1 first, and + // then extract the first lane. + llvm_dst = get_vector_type(llvm_dst, 1); } value = builder->CreateBitCast(value, llvm_dst); - if (src.is_vector() && dst.is_scalar() && value->getType()->isVectorTy()) { + if (src.is_vector() && dst.is_scalar()) { value = builder->CreateExtractElement(value, (uint64_t)0); } } diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 5d46d60bf09e..e368d851d615 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -322,6 +322,15 @@ class Deinterleaver : public IRGraphMutator { } } + Expr visit(const Reinterpret *op) override { + if (op->type.is_scalar()) { + return op; + } else { + Type t = op->type.with_lanes(new_lanes); + return Reinterpret::make(t, mutate(op->value)); + } + } + Expr visit(const Call *op) override { Type t = op->type.with_lanes(new_lanes); diff --git a/src/IR.cpp b/src/IR.cpp index ffded0cbade1..740234b8e31f 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -23,14 +23,8 @@ Expr Cast::make(Type t, Expr v) { Expr Reinterpret::make(Type t, Expr v) { user_assert(v.defined()) << "reinterpret of undefined Expr\n"; int from_bits = v.type().bits() * v.type().lanes(); - int from_vector = v.type().is_vector(); int to_bits = t.bits() * t.lanes(); - int to_vector = t.is_vector(); - user_assert((from_bits == to_bits) || - (from_vector && !to_vector && from_bits > to_bits && - from_bits % to_bits == 0) || - (to_vector && !from_vector && to_bits > from_bits && - to_bits % from_bits == 0)) + user_assert(from_bits == to_bits) << "Reinterpret cast from type " << v.type() << " which has " << from_bits << " bits, to type " << t diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index e2960739ed86..7dcd79d24664 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -527,6 +527,16 @@ class VectorSubs : public IRMutator { } } + Expr visit(const Reinterpret *op) override { + Expr value = mutate(op->value); + if (value.same_as(op->value)) { + return op; + } else { + Type t = op->type.with_lanes(value.type().lanes()); + return Reinterpret::make(t, value); + } + } + string get_widened_var_name(const string &name) { return name + ".widened." + vectorized_vars.back().name; }