Skip to content

Commit

Permalink
Moved to use IterVar bounds map instead of expr mutator.
Browse files Browse the repository at this point in the history
Much more reliable and safe.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
kimishpatel committed Sep 22, 2019
1 parent 8e19998 commit 517bbe9
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 80 deletions.
3 changes: 2 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ Stmt Inline(Stmt stmt,
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);
bool create_bound_attribute = false,
const Map<IterVar, Range>& bounds = {});

/*!
* \brief Remove No Op from the Stmt.
Expand Down
38 changes: 0 additions & 38 deletions include/tvm/shape_expr_mutator.h

This file was deleted.

2 changes: 1 addition & 1 deletion python/tvm/autotvm/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def ana_lower(sch, args,
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True)
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, False, bounds)
stmt = ir_pass.CanonicalSimplify(stmt)
assert simple_mode
return stmt
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def form_body(sch):
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
return stmt, bounds


def lower(sch,
Expand Down Expand Up @@ -372,13 +372,14 @@ def lower(sch,
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

# Phase 0
bounds = None
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)
stmt, bounds = form_body(sch)

for f in lower_phase0:
stmt = f(stmt)
# Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers, bounds)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
Expand Down
4 changes: 3 additions & 1 deletion src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ TVM_REGISTER_API("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() <= 3) {
*ret = StorageFlatten(args[0], args[1], args[2]);
} else {
} else if (args.size() <=4) {
*ret = StorageFlatten(args[0], args[1], args[2], args[3]);
} else {
*ret = StorageFlatten(args[0], args[1], args[2], args[3], args[4]);
}
});

Expand Down
2 changes: 1 addition & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ Stmt BuildStmt(Schedule sch,

// Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
config->instrument_bound_checkers, bounds);
stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop);
Expand Down
40 changes: 15 additions & 25 deletions src/op/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/shape_expr_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include "op_util.h"
Expand Down Expand Up @@ -213,6 +212,15 @@ class TensorIntrinMatcher final : public IRMutator {
const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) {
CHECK(self == stage->op.get());

for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range vrange = vit->second;
compute_intrin_iter_space->Set(iv->var, vrange);
}

// input remap.
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size());
Expand All @@ -224,22 +232,12 @@ class TensorIntrinMatcher final : public IRMutator {
// Enable fuzzy matching, to match [1, n, m] to [n, m]
e.start = e.region.size() - e.tensor.ndim();
for (size_t j = 0; j < e.start; ++j) {
if(!is_one(e.region[j]->extent)) {
// Is this safe to do?
// This essentially finds variables in min expr and replaces
// them with some const so as to make it easy to simplify.
IndexVarFinder lhs_var_finder;
lhs_var_finder.Visit(e.region[j]->min);
IndexVarReplacer rhs_var_replacer;
rhs_var_replacer.Init(lhs_var_finder.var_map());
auto new_extent = rhs_var_replacer.Mutate(e.region[j]->extent);
auto canonical_extent = Simplify(new_extent);
CHECK(is_one(canonical_extent))
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape
<< ", given region=" << e.region;
}
auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
CHECK(is_one(canonical_extent))
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape
<< ", given region=" << e.region;
}
in_remap_[inputs[i]] = e;
}
Expand Down Expand Up @@ -289,14 +287,6 @@ class TensorIntrinMatcher final : public IRMutator {
axis_remap_[iv] = target_iv;
compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
}

for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto vit = dom_map.find(iv);
CHECK(vit != dom_map.end());
const Range vrange = vit->second;
compute_intrin_iter_space->Set(iv->var, vrange);
}
}

private:
Expand Down
22 changes: 12 additions & 10 deletions src/pass/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/shape_expr_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
Expand All @@ -50,14 +49,18 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes)
int cache_line_size, bool create_bound_attributes,
const Map<IterVar, Range>& bounds)
: create_bound_attributes_(create_bound_attributes) {
for (auto kv : extern_buffer) {
BufferEntry e;
e.buffer = kv.second;
e.external = true;
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
}
for (const auto& kv : bounds) {
iter_var_bounds_.Set(kv.first->var, kv.second);
}
cache_line_size_ = cache_line_size;
}

Expand Down Expand Up @@ -419,11 +422,7 @@ class StorageFlattener : public IRMutator {
}
} else {
for (size_t i = 0; i < tuple->args.size(); i += 2) {
IndexVarFinder begins_var_finder;
begins_var_finder.Visit(tuple->args[i]);
IndexVarReplacer extent_var_replacer;
extent_var_replacer.Init(begins_var_finder.var_map());
auto new_extent = Simplify(extent_var_replacer.Mutate(tuple->args[i+1]));
auto new_extent = Simplify(tuple->args[i+1], iter_var_bounds_);
begins.push_back(tuple->args[i]);
extents.push_back(new_extent);
}
Expand Down Expand Up @@ -516,6 +515,8 @@ class StorageFlattener : public IRMutator {
std::vector<ThreadScope> curr_thread_scope_;
// Collects shapes.
std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
// Var (iter var) bounds map
Map<Var, Range> iter_var_bounds_;
// The size of cacheline
int cache_line_size_;
// The current stage is an OpenGL shader.
Expand All @@ -525,10 +526,11 @@ class StorageFlattener : public IRMutator {
};

Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) {
int cache_line_size, bool create_bound_attributes,
const Map<IterVar, Range>& bounds ) {
stmt =
StorageFlattener(extern_buffer, cache_line_size, create_bound_attributes)
.Mutate(stmt);
StorageFlattener(extern_buffer, cache_line_size,
create_bound_attributes, bounds).Mutate(stmt);
return stmt;
}

Expand Down

0 comments on commit 517bbe9

Please sign in to comment.