Skip to content

Commit

Permalink
[PASS] InstrumentBoundCheckers pass
Browse files Browse the repository at this point in the history
The pass which instruments checkers before
memory accesses (load/store).
This allows to handle invalid memory accesses.

The patch is related to issue:
https://discuss.tvm.ai/t/array-bounds-checking/944
  • Loading branch information
denis0x0D committed Nov 28, 2018
1 parent 644a15c commit aa0745c
Show file tree
Hide file tree
Showing 9 changed files with 849 additions and 13 deletions.
4 changes: 4 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;

/*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
Expand All @@ -232,6 +235,7 @@ class BuildConfigNode : public Node {
v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
}

static constexpr const char* _type_key = "BuildConfig";
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ constexpr const char* scan_init_scope = "scan_init_scope";
* This gives hint to require stride of dim to be k * align + offset.
*/
constexpr const char* buffer_dim_align = "buffer_dim_align";
/*! \brief Mark stores/loads with theirs bounds. */
constexpr const char* buffer_bound = "buffer_bound";
/*!
* \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
Expand Down
14 changes: 11 additions & 3 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,12 @@ Stmt Inline(Stmt stmt,
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \param cache_line_size The size of CPU cache line.
* \param create_bound_attribute Whether to create bound attributes.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size);
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);

/*!
* \brief Remove No Op from the Stmt.
Expand Down Expand Up @@ -234,6 +235,13 @@ Stmt UnrollLoop(Stmt stmt,
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief instruments bound checkers.
* \param stmt The statment to be instrumented.
* \return Instrumented Stmt.
*/
Stmt InstrumentBoundCheckers(Stmt stmt);

/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ class BuildConfig(NodeBase):
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": 1,
"dump_pass_ir": False
"dump_pass_ir": False,
"instrument_bound_checkers": False
}
_dump_ir = DumpIR()

Expand Down Expand Up @@ -349,7 +350,7 @@ def lower(sch,
for f in lower_phase0:
stmt = f(stmt)
# Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
Expand All @@ -375,6 +376,9 @@ def lower(sch,
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
Expand Down
7 changes: 6 additions & 1 deletion src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ TVM_REGISTER_API("ir_pass.Equal")
}
});

TVM_REGISTER_API("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = StorageFlatten(args[0], args[1], args[2],
args.size() == 3 ? false : args[3]);
});

TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
Expand Down Expand Up @@ -126,7 +131,6 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(RewriteUnsafeSelect);
REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten);
REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS5(UnrollLoop);
Expand Down Expand Up @@ -155,5 +159,6 @@ REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope);
REGISTER_PASS1(InstrumentBoundCheckers);
} // namespace ir
} // namespace tvm
6 changes: 5 additions & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::InjectPrefetch(stmt);

// Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64);
stmt = ir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop);
Expand All @@ -382,6 +383,9 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::RemoveNoOp(stmt);
stmt = ir::RewriteUnsafeSelect(stmt);

if (config->instrument_bound_checkers)
stmt = ir::InstrumentBoundCheckers(stmt);

return stmt;
}

Expand Down
193 changes: 193 additions & 0 deletions src/pass/bound_checker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*!
* Copyright (c) 2018 by Contributors
* \file bounds_checker.cc
*/
// Instrument checkers for out of the bounds access.

#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include <utility>

namespace tvm {
namespace ir {

class BoundCollector : public IRVisitor {
public:
BoundCollector() {}

void Visit_(const AttrStmt *op) {
if (op->attr_key == ir::attr::buffer_bound) {
if (const Variable *key = op->node.as<Variable>()) {
mem_to_shape[key] = op->value;
}
}
IRVisitor::Visit_(op);
}
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape;
};

class BoundChecker : public IRMutator {
public:
explicit BoundChecker(
const std::unordered_map<const Variable *, Expr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {}

Stmt Mutate_(const Allocate *op, const Stmt &s) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->type);
}
return IRMutator::Mutate_(op, s);
}

Expr Mutate_(const Call *op, const Expr &ex) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true;
}
return IRMutator::Mutate_(op, ex);
}

Stmt Mutate_(const Store *op, const Stmt &s) final {
store_scope_bound_collector_.clear();
process_store_ = true;
unsafe_rewritten_ = false;
IRMutator::Mutate_(op, s);
process_store_ = false;
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
// The collector should has at least one item.
if (store_scope_bound_collector_.size()) {
Expr condition = MakeCondition();
if (!condition.as<StringImm>()) {
Stmt nop = Evaluate::make(1);
Stmt then_case =
Store::make(op->buffer_var, op->value, op->index, op->predicate);
Stmt else_case =
AssertStmt::make(condition, StringImm::make(error_message_), nop);
Stmt body = IfThenElse::make(condition, then_case, else_case);
return body;
}
}
return s;
}

Expr Mutate_(const Load *op, const Expr &ex) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
return IRMutator::Mutate_(op, ex);
}

private:
bool UpdateIsNeeded(const VarExpr &buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}

void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape,
const Type &type) {
// Sanity check at first.
if (!new_shape.size())
return;

for (size_t i = 0; i < new_shape.size(); ++i) {
if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() ||
is_negative_const(new_shape[i])) {
return;
}
}

// Scalarize the shape.
Expr shape = Mul::make(make_const(UInt(64), type.lanes()),
Cast::make(UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()),
Cast::make(UInt(64), new_shape[i])));
}
mem_to_shape_[buffer_var.get()] = shape;
}

bool IndexIsValid(const Expr &index) const {
if (!index.defined())
return false;

if (const Ramp *ramp_index = index.as<Ramp>()) {
return ramp_index->base.defined() &&
ramp_index->base.type().is_scalar() &&
ramp_index->stride.defined() &&
ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0);
}
return true;
}

bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
IndexIsValid(index) && !unsafe_rewritten_;
}

void Collect(Expr index, VarExpr buffer_var) {
store_scope_bound_collector_.push_back(
std::make_pair(index, mem_to_shape_[buffer_var.get()]));
}

Expr MakeCondition() {
Expr condition;
for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
std::pair<Expr, Expr> buffer_to_mem = store_scope_bound_collector_[i];
Expr index = buffer_to_mem.first;
Expr upper_bound = buffer_to_mem.second;

if (const Ramp *ramp_index = index.as<Ramp>()) {
// In case index is base + stride * i.
// Non inclusive range.
index = Add::make(
ramp_index->base,
Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(),
ramp_index->lanes - 1)));
}

// Try to simplify index and bound.
index = ir::Simplify(index);
upper_bound = ir::Simplify(upper_bound);

// Cast to the same type - signed, to be able to check lower bound.
index = Cast::make(Int(64), index);
upper_bound = Cast::make(Int(64), upper_bound);

// Looks like a lower bound should always be zero after normalization.
Expr lower_bound = make_zero(Int(64));

Expr current_condition =
And::make(GE::make(index, lower_bound), LT::make(index, upper_bound));
condition =
!i ? current_condition : And::make(condition, current_condition);
}
return condition;
}

// Whether we process store value recursively.
bool process_store_{false};
// Whether we face tvm_if_then_else intrinsic.
bool unsafe_rewritten_{false};
// Pool which collects the pair of index and shape for specific store/load.
std::vector<std::pair<Expr, Expr>> store_scope_bound_collector_;
// Error message.
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
std::unordered_map<const Variable *, Expr> mem_to_shape_;
};

Stmt InstrumentBoundCheckers(Stmt stmt) {
BoundCollector bound_collector;
// At first walk recursively and collect bound attributes.
bound_collector.Visit(stmt);
return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt);
}
} // namespace ir
} // namespace tvm
Loading

0 comments on commit aa0745c

Please sign in to comment.