Skip to content

Commit

Permalink
Unbreak handling of reinterpret
Browse files Browse the repository at this point in the history
  • Loading branch information
LebedevRI committed Jul 13, 2022
1 parent 02ee591 commit 656eb43
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
19 changes: 10 additions & 9 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1513,19 +1513,20 @@ void CodeGen_LLVM::visit(const Reinterpret *op) {

} else {
if (src.is_scalar() && dst.is_vector()) {
// If we're converting from a scalar to a vector, first produce
// a temporary vector of the required size by splatting
// the source scalar. Note that the scalar size may be different.
value = create_broadcast(value, (dst.bits() * dst.lanes()) / src.bits());
// If the source type is a scalar, we promote it to an
// equivalent vector of width one before doing the
// bitcast, because llvm's bitcast operator doesn't
// want to convert between scalars and vectors.
value = create_broadcast(value, 1);
}
if (src.is_vector() && dst.is_scalar()) {
// If we're converting from a vector to a scalar, first change
// element count/type, and then extract the first lane.
Type tmp = dst.with_lanes((src.bits() * src.lanes()) / dst.bits());
llvm_dst = llvm_type_of(tmp);
// Similarly, if we're converting from a vector to a
// scalar, convert to a vector of width 1 first, and
// then extract the first lane.
llvm_dst = get_vector_type(llvm_dst, 1);
}
value = builder->CreateBitCast(value, llvm_dst);
if (src.is_vector() && dst.is_scalar() && value->getType()->isVectorTy()) {
if (src.is_vector() && dst.is_scalar()) {
value = builder->CreateExtractElement(value, (uint64_t)0);
}
}
Expand Down
9 changes: 9 additions & 0 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ class Deinterleaver : public IRGraphMutator {
}
}

Expr visit(const Reinterpret *op) override {
if (op->type.is_scalar()) {
return op;
} else {
Type t = op->type.with_lanes(new_lanes);
return Reinterpret::make(t, mutate(op->value));
}
}

Expr visit(const Call *op) override {
Type t = op->type.with_lanes(new_lanes);

Expand Down
8 changes: 1 addition & 7 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@ Expr Cast::make(Type t, Expr v) {
Expr Reinterpret::make(Type t, Expr v) {
user_assert(v.defined()) << "reinterpret of undefined Expr\n";
int from_bits = v.type().bits() * v.type().lanes();
int from_vector = v.type().is_vector();
int to_bits = t.bits() * t.lanes();
int to_vector = t.is_vector();
user_assert((from_bits == to_bits) ||
(from_vector && !to_vector && from_bits > to_bits &&
from_bits % to_bits == 0) ||
(to_vector && !from_vector && to_bits > from_bits &&
to_bits % from_bits == 0))
user_assert(from_bits == to_bits)
<< "Reinterpret cast from type " << v.type()
<< " which has " << from_bits
<< " bits, to type " << t
Expand Down
10 changes: 10 additions & 0 deletions src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,16 @@ class VectorSubs : public IRMutator {
}
}

Expr visit(const Reinterpret *op) override {
Expr value = mutate(op->value);
if (value.same_as(op->value)) {
return op;
} else {
Type t = op->type.with_lanes(value.type().lanes());
return Reinterpret::make(t, value);
}
}

string get_widened_var_name(const string &name) {
return name + ".widened." + vectorized_vars.back().name;
}
Expand Down

0 comments on commit 656eb43

Please sign in to comment.