Skip to content

Commit

Permalink
[Opt] Constant folding for BitExtractStmt (taichi-dev#1307)
Browse files Browse the repository at this point in the history
* [Opt] Constant folding for BitExtractStmt

* fix tests

* Let BitExtractStmt support all integral types
  • Loading branch information
xumingkuan authored and Rullec committed Jun 26, 2020
1 parent 4afa294 commit 47de864
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
27 changes: 25 additions & 2 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class ConstantFold : public BasicStmtVisitor {
auto evaluated =
Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(new_constant));
stmt->replace_with(evaluated.get());
modifier.insert_before(stmt, VecStatement(std::move(evaluated)));
modifier.insert_before(stmt, std::move(evaluated));
modifier.erase(stmt);
}
}
Expand All @@ -177,11 +177,34 @@ class ConstantFold : public BasicStmtVisitor {
auto evaluated =
Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(new_constant));
stmt->replace_with(evaluated.get());
modifier.insert_before(stmt, VecStatement(std::move(evaluated)));
modifier.insert_before(stmt, std::move(evaluated));
modifier.erase(stmt);
}
}

void visit(BitExtractStmt *stmt) override {
auto input = stmt->input->cast<ConstStmt>();
if (!input)
return;
if (stmt->width() != 1)
return;
std::unique_ptr<Stmt> result_stmt;
if (is_signed(input->val[0].dt)) {
auto result = (input->val[0].val_int() >> stmt->bit_begin) &
((1LL << (stmt->bit_end - stmt->bit_begin)) - 1);
result_stmt = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(
TypedConstant(input->val[0].dt, result)));
} else {
auto result = (input->val[0].val_uint() >> stmt->bit_begin) &
((1LL << (stmt->bit_end - stmt->bit_begin)) - 1);
result_stmt = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(
TypedConstant(input->val[0].dt, result)));
}
stmt->replace_with(result_stmt.get());
modifier.insert_before(stmt, std::move(result_stmt));
modifier.erase(stmt);
}

static bool run(IRNode *node) {
ConstantFold folder;
bool modified = false;
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class TypeCheck : public IRVisitor {
}

void visit(BitExtractStmt *stmt) {
stmt->ret_type.data_type = DataType::i32;
stmt->ret_type = stmt->input->ret_type;
}

void visit(LinearizeStmt *stmt) {
Expand Down

0 comments on commit 47de864

Please sign in to comment.