Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Promote Reinterpret Intrinsic into an Reinterpret IR Node #6853

Merged
merged 7 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ SOURCE_FILES = \
Simplify_And.cpp \
Simplify_Call.cpp \
Simplify_Cast.cpp \
Simplify_Reinterpret.cpp \
Simplify_Div.cpp \
Simplify_EQ.cpp \
Simplify_Exprs.cpp \
Expand Down
14 changes: 14 additions & 0 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,20 @@ class Bounds : public IRVisitor {
interval = Interval::single_point(op);
}

void visit(const Reinterpret *op) override {
TRACK_BOUNDS_INTERVAL;

Type t = op->type.element_of();

if (t.is_handle()) {
interval = Interval::everything();
return;
}

// Just use the bounds of the type
bounds_of_type(t);
}

void visit(const Cast *op) override {
TRACK_BOUNDS_INTERVAL;
op->value.accept(this);
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ set(SOURCE_FILES
Simplify_Add.cpp
Simplify_And.cpp
Simplify_Call.cpp
Simplify_Reinterpret.cpp
Simplify_Cast.cpp
Simplify_Div.cpp
Simplify_EQ.cpp
Expand Down
7 changes: 4 additions & 3 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,10 @@ void CodeGen_C::visit(const Cast *op) {
id = print_cast_expr(op->type, op->value);
}

void CodeGen_C::visit(const Reinterpret *op) {
id = print_assignment(op->type, print_reinterpret(op->type, op->value));
}

void CodeGen_C::visit_binop(Type t, const Expr &a, const Expr &b, const char *op) {
string sa = print_expr(a);
string sb = print_expr(b);
Expand Down Expand Up @@ -2294,9 +2298,6 @@ void CodeGen_C::visit(const Call *op) {
} else if (op->is_intrinsic(Call::bitwise_not)) {
internal_assert(op->args.size() == 1);
rhs << "~" << print_expr(op->args[0]);
} else if (op->is_intrinsic(Call::reinterpret)) {
internal_assert(op->args.size() == 1);
rhs << print_reinterpret(op->type, op->args[0]);
} else if (op->is_intrinsic(Call::shift_left)) {
internal_assert(op->args.size() == 2);
if (op->args[1].type().is_uint()) {
Expand Down
1 change: 1 addition & 0 deletions src/CodeGen_C.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class CodeGen_C : public IRPrinter {
void visit(const StringImm *) override;
void visit(const FloatImm *) override;
void visit(const Cast *) override;
void visit(const Reinterpret *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
void visit(const Mul *) override;
Expand Down
111 changes: 56 additions & 55 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,62 @@ void CodeGen_LLVM::visit(const Cast *op) {
}
}

void CodeGen_LLVM::visit(const Reinterpret *op) {
Type dst = op->type;
Type src = op->value.type();
llvm::Type *llvm_dst = llvm_type_of(dst);
value = codegen(op->value);
if (src.is_handle() && !dst.is_handle()) {
internal_assert(dst.is_uint() && dst.bits() == 64);

// Handle -> UInt64
llvm::DataLayout d(module.get());
if (d.getPointerSize() == 4) {
llvm::Type *intermediate = llvm_type_of(UInt(32, dst.lanes()));
value = builder->CreatePtrToInt(value, intermediate);
value = builder->CreateZExt(value, llvm_dst);
} else if (d.getPointerSize() == 8) {
value = builder->CreatePtrToInt(value, llvm_dst);
} else {
internal_error << "Pointer size is neither 4 nor 8 bytes\n";
}

} else if (dst.is_handle() && !src.is_handle()) {
internal_assert(src.is_uint() && src.bits() == 64);

// UInt64 -> Handle
llvm::DataLayout d(module.get());
if (d.getPointerSize() == 4) {
llvm::Type *intermediate = llvm_type_of(UInt(32, src.lanes()));
value = builder->CreateTrunc(value, intermediate);
value = builder->CreateIntToPtr(value, llvm_dst);
} else if (d.getPointerSize() == 8) {
value = builder->CreateIntToPtr(value, llvm_dst);
} else {
internal_error << "Pointer size is neither 4 nor 8 bytes\n";
}

} else {
if (src.is_scalar() && dst.is_vector()) {
// If the source type is a scalar, we promote it to an
// equivalent vector of width one before doing the
// bitcast, because llvm's bitcast operator doesn't
// want to convert between scalars and vectors.
value = create_broadcast(value, 1);
}
if (src.is_vector() && dst.is_scalar()) {
// Similarly, if we're converting from a vector to a
// scalar, convert to a vector of width 1 first, and
// then extract the first lane.
llvm_dst = get_vector_type(llvm_dst, 1);
}
value = builder->CreateBitCast(value, llvm_dst);
if (src.is_vector() && dst.is_scalar()) {
value = builder->CreateExtractElement(value, (uint64_t)0);
}
}
}

void CodeGen_LLVM::visit(const Variable *op) {
value = sym_get(op->name);
}
Expand Down Expand Up @@ -2593,61 +2649,6 @@ void CodeGen_LLVM::visit(const Call *op) {
internal_assert(op->args.size() == 1);
Value *a = codegen(op->args[0]);
value = builder->CreateNot(a);
} else if (op->is_intrinsic(Call::reinterpret)) {
internal_assert(op->args.size() == 1);
Type dst = op->type;
Type src = op->args[0].type();
llvm::Type *llvm_dst = llvm_type_of(dst);
value = codegen(op->args[0]);
if (src.is_handle() && !dst.is_handle()) {
internal_assert(dst.is_uint() && dst.bits() == 64);

// Handle -> UInt64
llvm::DataLayout d(module.get());
if (d.getPointerSize() == 4) {
llvm::Type *intermediate = llvm_type_of(UInt(32, dst.lanes()));
value = builder->CreatePtrToInt(value, intermediate);
value = builder->CreateZExt(value, llvm_dst);
} else if (d.getPointerSize() == 8) {
value = builder->CreatePtrToInt(value, llvm_dst);
} else {
internal_error << "Pointer size is neither 4 nor 8 bytes\n";
}

} else if (dst.is_handle() && !src.is_handle()) {
internal_assert(src.is_uint() && src.bits() == 64);

// UInt64 -> Handle
llvm::DataLayout d(module.get());
if (d.getPointerSize() == 4) {
llvm::Type *intermediate = llvm_type_of(UInt(32, src.lanes()));
value = builder->CreateTrunc(value, intermediate);
value = builder->CreateIntToPtr(value, llvm_dst);
} else if (d.getPointerSize() == 8) {
value = builder->CreateIntToPtr(value, llvm_dst);
} else {
internal_error << "Pointer size is neither 4 nor 8 bytes\n";
}

} else {
if (src.is_scalar() && dst.is_vector()) {
// If the source type is a scalar, we promote it to an
// equivalent vector of width one before doing the
// bitcast, because llvm's bitcast operator doesn't
// want to convert between scalars and vectors.
value = create_broadcast(value, 1);
}
if (src.is_vector() && dst.is_scalar()) {
// Similarly, if we're converting from a vector to a
// scalar, convert to a vector of width 1 first, and
// then extract the first lane.
llvm_dst = get_vector_type(llvm_dst, 1);
}
value = builder->CreateBitCast(value, llvm_dst);
if (src.is_vector() && dst.is_scalar()) {
value = builder->CreateExtractElement(value, (uint64_t)0);
}
LebedevRI marked this conversation as resolved.
Show resolved Hide resolved
}
} else if (op->is_intrinsic(Call::shift_left)) {
internal_assert(op->args.size() == 2);
if (op->args[1].type().is_uint()) {
Expand Down
1 change: 1 addition & 0 deletions src/CodeGen_LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ class CodeGen_LLVM : public IRVisitor {
void visit(const FloatImm *) override;
void visit(const StringImm *) override;
void visit(const Cast *) override;
void visit(const Reinterpret *) override;
void visit(const Variable *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
Expand Down
9 changes: 9 additions & 0 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ class Deinterleaver : public IRGraphMutator {
}
}

Expr visit(const Reinterpret *op) override {
if (op->type.is_scalar()) {
return op;
} else {
Type t = op->type.with_lanes(new_lanes);
return Reinterpret::make(t, mutate(op->value));
}
}

Expr visit(const Call *op) override {
Type t = op->type.with_lanes(new_lanes);

Expand Down
12 changes: 10 additions & 2 deletions src/Derivative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ReverseAccumulationVisitor : public IRVisitor {
void visit(const FloatImm *) override;
void visit(const StringImm *) override;
void visit(const Cast *op) override;
void visit(const Reinterpret *op) override;
void visit(const Variable *op) override;
void visit(const Add *op) override;
void visit(const Sub *op) override;
Expand Down Expand Up @@ -836,6 +837,14 @@ void ReverseAccumulationVisitor::visit(const Cast *op) {
}
}

void ReverseAccumulationVisitor::visit(const Reinterpret *op) {
internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
Expr adjoint = expr_adjoints[op];

// bit manipulation -- has zero derivative.
accumulate(op->value, make_zero(op->type));
}

void ReverseAccumulationVisitor::visit(const Variable *op) {
internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
Expr adjoint = expr_adjoints[op];
Expand Down Expand Up @@ -1169,8 +1178,7 @@ void ReverseAccumulationVisitor::visit(const Call *op) {
accumulate(op->args[1], adjoint);
} else if (op->is_intrinsic(Call::undef)) {
// do nothing
} else if (op->is_intrinsic(Call::reinterpret) ||
op->is_intrinsic(Call::bitwise_and) ||
} else if (op->is_intrinsic(Call::bitwise_and) ||
op->is_intrinsic(Call::bitwise_not) ||
op->is_intrinsic(Call::bitwise_or) ||
op->is_intrinsic(Call::bitwise_xor) ||
Expand Down
2 changes: 2 additions & 0 deletions src/EliminateBoolVectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class EliminateBoolVectors : public IRMutator {
}
}

// FIXME: what about Reinterpret?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abadams what about this one?
I do believe this is a correct translation, but is that the correct handling in reality?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I thought there was another FIXME. The risk is the reinterpret node turning something into a bool vector I guess.

I'm having a hard time imagining what might do this and how we might handle it. Any uses of reinterpret that introduce bool vectors to backends that can't handle them will have to be carefully vetted. Maybe this is fine for now.


Stmt visit(const Store *op) override {
Expr predicate = op->predicate;
if (!is_const_one(predicate)) {
Expand Down
1 change: 1 addition & 0 deletions src/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ enum class IRNodeType {
StringImm,
Broadcast,
Cast,
Reinterpret,
Variable,
Add,
Sub,
Expand Down
25 changes: 24 additions & 1 deletion src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ Expr Cast::make(Type t, Expr v) {
return node;
}

Expr Reinterpret::make(Type t, Expr v) {
user_assert(v.defined()) << "reinterpret of undefined Expr\n";
int from_bits = v.type().bits() * v.type().lanes();
int to_bits = t.bits() * t.lanes();
user_assert(from_bits == to_bits)
<< "Reinterpret cast from type " << v.type()
<< " which has " << from_bits
<< " bits, to type " << t
<< " which has " << to_bits << " bits\n";

Reinterpret *node = new Reinterpret;
node->type = t;
node->value = std::move(v);
return node;
}

Expr Add::make(Expr a, Expr b) {
internal_assert(a.defined()) << "Add of undefined\n";
internal_assert(b.defined()) << "Add of undefined\n";
Expand Down Expand Up @@ -628,7 +644,6 @@ const char *const intrinsic_op_names[] = {
"promise_clamped",
"random",
"register_destructor",
"reinterpret",
"require",
"require_mask",
"return_second",
Expand Down Expand Up @@ -970,6 +985,10 @@ void ExprNode<Cast>::accept(IRVisitor *v) const {
v->visit((const Cast *)this);
}
template<>
void ExprNode<Reinterpret>::accept(IRVisitor *v) const {
v->visit((const Reinterpret *)this);
}
template<>
void ExprNode<Variable>::accept(IRVisitor *v) const {
v->visit((const Variable *)this);
}
Expand Down Expand Up @@ -1155,6 +1174,10 @@ Expr ExprNode<Cast>::mutate_expr(IRMutator *v) const {
return v->visit((const Cast *)this);
}
template<>
Expr ExprNode<Reinterpret>::mutate_expr(IRMutator *v) const {
return v->visit((const Reinterpret *)this);
}
template<>
Expr ExprNode<Variable>::mutate_expr(IRMutator *v) const {
return v->visit((const Variable *)this);
}
Expand Down
10 changes: 9 additions & 1 deletion src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ struct Cast : public ExprNode<Cast> {
static const IRNodeType _node_type = IRNodeType::Cast;
};

/** Reinterpret a node as another type, without affecting any of the bits. */
struct Reinterpret : public ExprNode<Reinterpret> {
Expr value;

static Expr make(Type t, Expr v);

static const IRNodeType _node_type = IRNodeType::Reinterpret;
};

/** The sum of two expressions */
struct Add : public ExprNode<Add> {
Expr a, b;
Expand Down Expand Up @@ -537,7 +546,6 @@ struct Call : public ExprNode<Call> {
promise_clamped,
random,
register_destructor,
reinterpret,
require,
require_mask,
return_second,
Expand Down
5 changes: 5 additions & 0 deletions src/IREquality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class IRComparer : public IRVisitor {
void visit(const FloatImm *) override;
void visit(const StringImm *) override;
void visit(const Cast *) override;
void visit(const Reinterpret *) override;
void visit(const Variable *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
Expand Down Expand Up @@ -354,6 +355,10 @@ void IRComparer::visit(const Cast *op) {
compare_expr(expr.as<Cast>()->value, op->value);
}

void IRComparer::visit(const Reinterpret *op) {
compare_expr(expr.as<Reinterpret>()->value, op->value);
}

void IRComparer::visit(const Variable *op) {
const Variable *e = expr.as<Variable>();
compare_names(e->name, op->name);
Expand Down
15 changes: 15 additions & 0 deletions src/IRMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ class IRMatch : public IRVisitor {
}
}

void visit(const Reinterpret *op) override {
const Reinterpret *e = expr.as<Reinterpret>();
if (result && e && types_match(op->type, e->type)) {
expr = e->value;
op->value.accept(this);
} else {
result = false;
}
}

void visit(const Variable *op) override {
if (!result) {
return;
Expand Down Expand Up @@ -432,6 +442,11 @@ bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept {
// that the types of the values match, so use equal rather
// than equal_helper.
return equal(((const Cast &)a).value, ((const Cast &)b).value);
case IRNodeType::Reinterpret:
// While we know a and b have matching type, we don't know
// that the types of the values match, so use equal rather
// than equal_helper.
return equal(((const Reinterpret &)a).value, ((const Reinterpret &)b).value);
case IRNodeType::Variable:
return ((const Variable &)a).name == ((const Variable &)b).name;
case IRNodeType::Add:
Expand Down
Loading