Skip to content

Commit

Permalink
Implemented bounded analyzer which traverses tree and for reduce/for
Browse files Browse the repository at this point in the history
statements binds the bound of the analyzer. Later this is used to
simplify expressions. Inspired from ir_mutator_with_analyzer

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
kimishpatel committed Sep 23, 2019
1 parent 8e19998 commit cb8614a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 73 deletions.
43 changes: 43 additions & 0 deletions include/tvm/bounded_analyzer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>

namespace tvm {
namespace ir {

class BoundedAnalyzer final : public IRVisitor {
public:
void Visit_(const For* op) {
analyzer.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRVisitor::Visit_(op);
}

void Visit_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer.Bind(iv->var,
Range::make_by_min_extent(0, op->value));
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}

void Visit_(const Reduce* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer.Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
IRVisitor::Visit_(op);
}

/*! \brief internal analyzer field. */
arith::Analyzer analyzer;
};

} // namespace ir
} // namespace tvm
38 changes: 0 additions & 38 deletions include/tvm/shape_expr_mutator.h

This file was deleted.

40 changes: 15 additions & 25 deletions src/op/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/shape_expr_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include "op_util.h"
Expand Down Expand Up @@ -213,6 +212,15 @@ class TensorIntrinMatcher final : public IRMutator {
const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) {
CHECK(self == stage->op.get());

for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range vrange = vit->second;
compute_intrin_iter_space->Set(iv->var, vrange);
}

// input remap.
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size());
Expand All @@ -224,22 +232,12 @@ class TensorIntrinMatcher final : public IRMutator {
// Enable fuzzy matching, to match [1, n, m] to [n, m]
e.start = e.region.size() - e.tensor.ndim();
for (size_t j = 0; j < e.start; ++j) {
if(!is_one(e.region[j]->extent)) {
// Is this safe to do?
// This essentially finds variables in min expr and replaces
// them with some const so as to make it easy to simplify.
IndexVarFinder lhs_var_finder;
lhs_var_finder.Visit(e.region[j]->min);
IndexVarReplacer rhs_var_replacer;
rhs_var_replacer.Init(lhs_var_finder.var_map());
auto new_extent = rhs_var_replacer.Mutate(e.region[j]->extent);
auto canonical_extent = Simplify(new_extent);
CHECK(is_one(canonical_extent))
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape
<< ", given region=" << e.region;
}
auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
CHECK(is_one(canonical_extent))
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape
<< ", given region=" << e.region;
}
in_remap_[inputs[i]] = e;
}
Expand Down Expand Up @@ -289,14 +287,6 @@ class TensorIntrinMatcher final : public IRMutator {
axis_remap_[iv] = target_iv;
compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
}

for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range vrange = vit->second;
compute_intrin_iter_space->Set(iv->var, vrange);
}
}

private:
Expand Down
34 changes: 24 additions & 10 deletions src/pass/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
*/
// Flattens storage from multi-dimensional array to 1D
// buffer access as in Halide pipeline.
#include <tvm/arithmetic.h>
#include <tvm/bounded_analyzer.h>
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/shape_expr_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
Expand All @@ -50,8 +52,10 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes)
: create_bound_attributes_(create_bound_attributes) {
int cache_line_size, bool create_bound_attributes,
const std::shared_ptr<BoundedAnalyzer>& bounded_analyzer)
: create_bound_attributes_(create_bound_attributes),
bounded_analyzer_(bounded_analyzer) {
for (auto kv : extern_buffer) {
BufferEntry e;
e.buffer = kv.second;
Expand Down Expand Up @@ -419,12 +423,8 @@ class StorageFlattener : public IRMutator {
}
} else {
for (size_t i = 0; i < tuple->args.size(); i += 2) {
IndexVarFinder begins_var_finder;
begins_var_finder.Visit(tuple->args[i]);
IndexVarReplacer extent_var_replacer;
extent_var_replacer.Init(begins_var_finder.var_map());
auto new_extent = Simplify(extent_var_replacer.Mutate(tuple->args[i+1]));
begins.push_back(tuple->args[i]);
auto new_extent = bounded_analyzer_->analyzer.Simplify(tuple->args[i+1]);
extents.push_back(new_extent);
}
}
Expand Down Expand Up @@ -516,6 +516,9 @@ class StorageFlattener : public IRMutator {
std::vector<ThreadScope> curr_thread_scope_;
// Collects shapes.
std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
// bounds populator. We really need the analyzer from it.
// However
std::shared_ptr<BoundedAnalyzer> bounded_analyzer_;
// The size of cacheline
int cache_line_size_;
// The current stage is an OpenGL shader.
Expand All @@ -526,9 +529,20 @@ class StorageFlattener : public IRMutator {

Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) {
/*
* Unforunately we have to resort to shared_ptr because Analyzer used by
* bounded_analyzer has other analyzers in it. e.g. canonical.
* These ones allocate impl and destroy them on destructor calls.
* However there is no copy/move operator on analyzer that safely copies
* or moves data. Perhaps we should disable copy operator and implement
* move operator.
*/
std::shared_ptr<BoundedAnalyzer> bounded_analyzer=
std::make_shared<BoundedAnalyzer>();
bounded_analyzer->Visit(stmt);
stmt =
StorageFlattener(extern_buffer, cache_line_size, create_bound_attributes)
.Mutate(stmt);
StorageFlattener(extern_buffer, cache_line_size,
create_bound_attributes, bounded_analyzer).Mutate(stmt);
return stmt;
}

Expand Down

0 comments on commit cb8614a

Please sign in to comment.