diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index ba340166339b1..8bb3345a5eb39 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -220,6 +220,9 @@ class BuildConfigNode : public Node { /*! \brief Whether to dump the IR of each pass (only when building from python) */ bool dump_pass_ir = false; + /*! \brief Whether to instrument loads and stores with check for out of the bounds. */ + bool instrument_bound_checkers = false; + void VisitAttrs(AttrVisitor* v) final { v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); @@ -232,6 +235,7 @@ class BuildConfigNode : public Node { v->Visit("detect_global_barrier", &detect_global_barrier); v->Visit("partition_const_loop", &partition_const_loop); v->Visit("dump_pass_ir", &dump_pass_ir); + v->Visit("instrument_bound_checkers", &instrument_bound_checkers); } static constexpr const char* _type_key = "BuildConfig"; diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 212234303c616..adaffa77dae67 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -206,6 +206,8 @@ constexpr const char* scan_init_scope = "scan_init_scope"; * This gives hint to require stride of dim to be k * align + offset. */ constexpr const char* buffer_dim_align = "buffer_dim_align"; +/*! \brief Mark stores/loads with theirs bounds. */ +constexpr const char* buffer_bound = "buffer_bound"; /*! * \brief Bind the buffer specification to the region of the op * When this scope occurs, the stmt.node is a Array = [buffer, tensor] diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 332becb7aa389..c19bd208ae98a 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -181,11 +181,12 @@ Stmt Inline(Stmt stmt, * \param extern_buffer Map specifies external * buffer assignment of input and outputs. * \param cache_line_size The size of CPU cache line. + * \param create_bound_attribute Whether to create bound attributes. * \return Transformed stmt. */ -Stmt StorageFlatten(Stmt stmt, - Map extern_buffer, - int cache_line_size); +Stmt StorageFlatten(Stmt stmt, Map extern_buffer, + int cache_line_size, + bool create_bound_attribute = false); /*! * \brief Remove No Op from the Stmt. @@ -234,6 +235,13 @@ Stmt UnrollLoop(Stmt stmt, */ Stmt VectorizeLoop(Stmt stmt); +/*! +* \brief instruments bound checkers. +* \param stmt The statment to be instrumented. +* \return Instrumented Stmt. +*/ +Stmt InstrumentBoundCheckers(Stmt stmt); + /*! * \brief Inject virtual thread loops into stmt. * \param stmt The statment to be transformed. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 2bb7442bab765..a2a33be946f6c 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -125,7 +125,8 @@ class BuildConfig(NodeBase): "data_alignment": -1, "restricted_func": True, "double_buffer_split_loop": 1, - "dump_pass_ir": False + "dump_pass_ir": False, + "instrument_bound_checkers": False } _dump_ir = DumpIR() @@ -349,7 +350,7 @@ def lower(sch, for f in lower_phase0: stmt = f(stmt) # Phase 1 - stmt = ir_pass.StorageFlatten(stmt, binds, 64) + stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) @@ -375,6 +376,9 @@ def lower(sch, stmt = ir_pass.RewriteUnsafeSelect(stmt) for f in lower_phase3: stmt = f(stmt) + # Instrument BoundCheckers + if cfg.instrument_bound_checkers: + stmt = ir_pass.InstrumentBoundCheckers(stmt) if simple_mode: return stmt return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 575535f26e81f..cd672ec8fca32 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -66,6 +66,11 @@ TVM_REGISTER_API("ir_pass.Equal") } }); +TVM_REGISTER_API("ir_pass.StorageFlatten") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = StorageFlatten(args[0], args[1], args[2], + args.size() == 3 ? false : args[3]); + }); TVM_REGISTER_API("ir_pass.AttrsEqual") .set_body_typed([](const NodeRef& lhs, const NodeRef& rhs) { @@ -126,7 +131,6 @@ REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(VerifySSA); REGISTER_PASS1(RewriteUnsafeSelect); REGISTER_PASS4(Inline); -REGISTER_PASS3(StorageFlatten); REGISTER_PASS4(IRTransform); REGISTER_PASS1(VectorizeLoop); REGISTER_PASS5(UnrollLoop); @@ -155,5 +159,6 @@ REGISTER_PASS1(CombineContextCall); REGISTER_PASS2(VerifyMemory); REGISTER_PASS2(VerifyGPUCode); REGISTER_PASS1(DecorateDeviceScope); +REGISTER_PASS1(InstrumentBoundCheckers); } // namespace ir } // namespace tvm diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index c5c14d711df76..0659a07f2520a 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -364,7 +364,8 @@ Stmt BuildStmt(Schedule sch, stmt = ir::InjectPrefetch(stmt); // Phase 1 - stmt = ir::StorageFlatten(stmt, out_binds, 64); + stmt = ir::StorageFlatten(stmt, out_binds, 64, + config->instrument_bound_checkers); stmt = ir::CanonicalSimplify(stmt); if (loop_partition) { stmt = ir::LoopPartition(stmt, config->partition_const_loop); @@ -382,6 +383,9 @@ Stmt BuildStmt(Schedule sch, stmt = ir::RemoveNoOp(stmt); stmt = ir::RewriteUnsafeSelect(stmt); + if (config->instrument_bound_checkers) + stmt = ir::InstrumentBoundCheckers(stmt); + return stmt; } diff --git a/src/pass/bound_checker.cc b/src/pass/bound_checker.cc new file mode 100644 index 0000000000000..fd3a2c7a80d5f --- /dev/null +++ b/src/pass/bound_checker.cc @@ -0,0 +1,193 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file bounds_checker.cc + */ +// Instrument checkers for out of the bounds access. + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +class BoundCollector : public IRVisitor { + public: + BoundCollector() {} + + void Visit_(const AttrStmt *op) { + if (op->attr_key == ir::attr::buffer_bound) { + if (const Variable *key = op->node.as()) { + mem_to_shape[key] = op->value; + } + } + IRVisitor::Visit_(op); + } + // Hashtable which maps buffer_var to shape. + std::unordered_map mem_to_shape; +}; + +class BoundChecker : public IRMutator { + public: + explicit BoundChecker( + const std::unordered_map &mem_to_shape) + : mem_to_shape_(mem_to_shape) {} + + Stmt Mutate_(const Allocate *op, const Stmt &s) final { + // If the shape was updated we should update the hashtable. + if (UpdateIsNeeded(op->buffer_var)) { + Update(op->buffer_var, op->extents, op->type); + } + return IRMutator::Mutate_(op, s); + } + + Expr Mutate_(const Call *op, const Expr &ex) final { + if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { + unsafe_rewritten_ = true; + } + return IRMutator::Mutate_(op, ex); + } + + Stmt Mutate_(const Store *op, const Stmt &s) final { + store_scope_bound_collector_.clear(); + process_store_ = true; + unsafe_rewritten_ = false; + IRMutator::Mutate_(op, s); + process_store_ = false; + if (CanInstrument(op->index, op->buffer_var)) { + Collect(op->index, op->buffer_var); + } + // The collector should has at least one item. + if (store_scope_bound_collector_.size()) { + Expr condition = MakeCondition(); + if (!condition.as()) { + Stmt nop = Evaluate::make(1); + Stmt then_case = + Store::make(op->buffer_var, op->value, op->index, op->predicate); + Stmt else_case = + AssertStmt::make(condition, StringImm::make(error_message_), nop); + Stmt body = IfThenElse::make(condition, then_case, else_case); + return body; + } + } + return s; + } + + Expr Mutate_(const Load *op, const Expr &ex) final { + if (CanInstrument(op->index, op->buffer_var)) { + Collect(op->index, op->buffer_var); + } + return IRMutator::Mutate_(op, ex); + } + + private: + bool UpdateIsNeeded(const VarExpr &buffer_var) const { + return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); + } + + void Update(const VarExpr &buffer_var, const Array &new_shape, + const Type &type) { + // Sanity check at first. + if (!new_shape.size()) + return; + + for (size_t i = 0; i < new_shape.size(); ++i) { + if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() || + is_negative_const(new_shape[i])) { + return; + } + } + + // Scalarize the shape. + Expr shape = Mul::make(make_const(UInt(64), type.lanes()), + Cast::make(UInt(64), new_shape[0])); + for (size_t i = 1; i < new_shape.size(); ++i) { + // Cast to unsigned to avoid integer overlow at frist. + shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()), + Cast::make(UInt(64), new_shape[i]))); + } + mem_to_shape_[buffer_var.get()] = shape; + } + + bool IndexIsValid(const Expr &index) const { + if (!index.defined()) + return false; + + if (const Ramp *ramp_index = index.as()) { + return ramp_index->base.defined() && + ramp_index->base.type().is_scalar() && + ramp_index->stride.defined() && + ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0); + } + return true; + } + + bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const { + return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && + IndexIsValid(index) && !unsafe_rewritten_; + } + + void Collect(Expr index, VarExpr buffer_var) { + store_scope_bound_collector_.push_back( + std::make_pair(index, mem_to_shape_[buffer_var.get()])); + } + + Expr MakeCondition() { + Expr condition; + for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) { + std::pair buffer_to_mem = store_scope_bound_collector_[i]; + Expr index = buffer_to_mem.first; + Expr upper_bound = buffer_to_mem.second; + + if (const Ramp *ramp_index = index.as()) { + // In case index is base + stride * i. + // Non inclusive range. + index = Add::make( + ramp_index->base, + Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(), + ramp_index->lanes - 1))); + } + + // Try to simplify index and bound. + index = ir::Simplify(index); + upper_bound = ir::Simplify(upper_bound); + + // Cast to the same type - signed, to be able to check lower bound. + index = Cast::make(Int(64), index); + upper_bound = Cast::make(Int(64), upper_bound); + + // Looks like a lower bound should always be zero after normalization. + Expr lower_bound = make_zero(Int(64)); + + Expr current_condition = + And::make(GE::make(index, lower_bound), LT::make(index, upper_bound)); + condition = + !i ? current_condition : And::make(condition, current_condition); + } + return condition; + } + + // Whether we process store value recursively. + bool process_store_{false}; + // Whether we face tvm_if_then_else intrinsic. + bool unsafe_rewritten_{false}; + // Pool which collects the pair of index and shape for specific store/load. + std::vector> store_scope_bound_collector_; + // Error message. + const char *const error_message_ = "OUT OF THE BOUNDS"; + // Hashtable which maps buffer_var to shape. + std::unordered_map mem_to_shape_; +}; + +Stmt InstrumentBoundCheckers(Stmt stmt) { + BoundCollector bound_collector; + // At first walk recursively and collect bound attributes. + bound_collector.Visit(stmt); + return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt); +} +} // namespace ir +} // namespace tvm diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 8c2105829839b..3753a81aafdfb 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -31,7 +31,8 @@ using intrinsic::tvm_address_of; class StorageFlattener : public IRMutator { public: explicit StorageFlattener(Map extern_buffer, - int cache_line_size) { + int cache_line_size, bool create_bound_attributes) + : create_bound_attributes_(create_bound_attributes) { for (auto kv : extern_buffer) { BufferEntry e; e.buffer = kv.second; @@ -101,6 +102,8 @@ class StorageFlattener : public IRMutator { } Stmt Mutate_(const Provide* op, const Stmt& s) final { + if (create_bound_attributes_) + shape_collector_.clear(); Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as(); TensorKey key{op->func, op->value_index}; @@ -117,7 +120,20 @@ class StorageFlattener : public IRMutator { {e.buffer->data, op->value}, Call::Intrinsic)); } else { - return e.buffer.vstore(e.RelIndex(op->args), op->value); + Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value); + if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { + shape_collector_.push_back( + std::make_pair(e.buffer->data, e.buffer->shape)); + } + // To create bound attribute collector should has at least one item. + if (create_bound_attributes_ && shape_collector_.size()) { + for (size_t i = 0; i < shape_collector_.size(); ++i) { + body = AttrStmt::make( + shape_collector_[i].first, ir::attr::buffer_bound, + MakeBound(e.buffer->dtype, shape_collector_[i].second), body); + } + } + return body; } } @@ -216,6 +232,11 @@ class StorageFlattener : public IRMutator { ret = AttrStmt::make( e.buffer->data, attr::storage_scope, StringImm::make(e.buffer->scope), ret); + + if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { + ret = AttrStmt::make(e.buffer->data, ir::attr::buffer_bound, + MakeBound(e.buffer->dtype, e.buffer->shape), ret); + } return ret; } } @@ -254,6 +275,11 @@ class StorageFlattener : public IRMutator { const BufferEntry& e = it->second; CHECK(!e.released) << "Read a buffer that is already out of scope"; + + if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { + shape_collector_.push_back( + std::make_pair(e.buffer->data, e.buffer->shape)); + } return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype); } else { return expr; @@ -429,6 +455,30 @@ class StorageFlattener : public IRMutator { } } }; + + bool ShapeIsValid(const Array &shape) { + if (!shape.size()) + return false; + + for (size_t i = 0; i < shape.size(); ++i) { + if (!shape[i].defined() || !shape[i].type().is_scalar() || + is_negative_const(shape[i])) { + return false; + } + } + return true; + } + + Expr MakeBound(const Type &type, const Array &shape) { + // We have already checked the shape size to be greater then 0. + Expr bound = Mul::make(make_const(shape[0].type(), type.lanes()), shape[0]); + for (size_t i = 1; i < shape.size(); ++i) { + bound = Mul::make( + bound, Mul::make(make_const(bound.type(), type.lanes()), shape[i])); + } + return bound; + } + // The buffer assignment map // Variable remap std::unordered_map var_remap_; @@ -440,16 +490,21 @@ class StorageFlattener : public IRMutator { std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; + // Collects shapes. + std::vector>> shape_collector_; // The size of cacheline int cache_line_size_; // The current stage is an OpenGL shader. bool is_opengl_{false}; + // Whether to mark load/store with theirs bounds. + bool create_bound_attributes_{false}; }; -Stmt StorageFlatten(Stmt stmt, - Map extern_buffer, - int cache_line_size) { - stmt = StorageFlattener(extern_buffer, cache_line_size).Mutate(stmt); +Stmt StorageFlatten(Stmt stmt, Map extern_buffer, + int cache_line_size, bool create_bound_attributes) { + stmt = + StorageFlattener(extern_buffer, cache_line_size, create_bound_attributes) + .Mutate(stmt); return stmt; } diff --git a/tests/python/unittest/test_pass_bound_checkers.py b/tests/python/unittest/test_pass_bound_checkers.py new file mode 100644 index 0000000000000..3eb6783c357cb --- /dev/null +++ b/tests/python/unittest/test_pass_bound_checkers.py @@ -0,0 +1,561 @@ +from nose.tools import raises +import tvm +import numpy as np +def collect_visit(stmt, f): + ret = [] + tvm.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + return ret + +def lower(sch, args): + binds = {} + arg_list = [] + for x in args: + if isinstance(x, tvm.tensor.Tensor): + buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.name) + assert x not in binds + binds[x] = buf + arg_list.append(buf) + else: + raise ValueError("args must be Tensor, Buffer or Var") + sch = sch.normalize() + bounds = tvm.schedule.InferBound(sch) + stmt = tvm.schedule.ScheduleOps(sch, bounds) + stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True) + stmt = tvm.ir_pass.CanonicalSimplify(stmt) + stmt = tvm.ir_pass.VectorizeLoop(stmt) + stmt = tvm.ir_pass.Simplify(stmt) + return stmt + +@raises(Exception) +def test_out_of_bounds_llvm(index_a, index_b): + n = tvm.var("n") + A = tvm.placeholder ((n,), name='A') + B = tvm.placeholder ((n,), name='B') + C = tvm.compute(A.shape, lambda i: A[i + index_a] + B[i + index_b], name='C') + s = tvm.create_schedule (C.op) + tgt = "llvm" + tgt_host = "llvm" + stmt = tvm.lower (s, [A, B, C], simple_mode=True) + print (stmt) + fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd") + ctx = tvm.context(tgt, 0) + a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx) + fadd (a, b, c) + +def test_in_bounds_llvm(): + n = tvm.var("n") + A = tvm.placeholder ((n,), name='A') + B = tvm.placeholder ((n,), name='B') + C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') + s = tvm.create_schedule (C.op) + tgt = "llvm" + tgt_host = "llvm" + stmt = tvm.lower (s, [A, B, C], simple_mode=True) + print (stmt) + fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd") + ctx = tvm.context(tgt, 0) + a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), ctx) + fadd (a, b, c) + +@raises(Exception) +def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): + n = tvm.convert(nn) + a = tvm.placeholder((n), name='a') + b = tvm.placeholder((n), name='b') + c = tvm.compute((n,), lambda i: a[i + index_a] + b[i + index_b], name='c') + s = tvm.create_schedule(c.op) + xo, xi = s[c].split(c.op.axis[0], factor=8) + s[c].parallel(xo) + s[c].vectorize(xi) + tgt = "llvm" + tgt_host = "llvm" + stmt = tvm.lower (s, [a, b, c], simple_mode=True) + print (stmt) + f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec") + ctx = tvm.cpu(0) + n = nn + a = tvm.nd.array(np.random.uniform(size=(n)).astype(a.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(n)).astype(a.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=c.dtype), ctx) + f(a, b, c) + +def test_in_bounds_vectorize_llvm(): + n = 512 + lanes = 2 + A = tvm.placeholder((n,), name='A', dtype="float32x%d" % lanes) + B = tvm.compute((n,), lambda i: A[i], name='B') + C = tvm.compute((n,), lambda i: B[i] + tvm.const(1, A.dtype), name='C') + s = tvm.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], nparts=2) + _, xi = s[C].split(xi, factor=2) + s[C].parallel(xo) + s[C].vectorize(xi) + s[B].compute_at(s[C], xo) + xo, xi = s[B].split(B.op.axis[0], factor=2) + s[B].vectorize(xi) + # build and invoke the kernel. + lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False) + print (lowered_func.body) + f = tvm.build(s, [A, C], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.empty((n,), A.dtype).copyfrom( + np.random.uniform(size=(n, lanes))) + c = tvm.nd.empty((n,), C.dtype, ctx) + f(a, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1) + +def test_in_bounds_loop_partition_basic_llvm(): + n = tvm.var('n') + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i]+B[i]) + s = tvm.create_schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=4) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(32,)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(32,)).astype(B.dtype), ctx) + t = tvm.nd.empty((32,), T.dtype, ctx) + f(a, b, t) + +@raises(Exception) +def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b): + n = tvm.var('n') + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i + index_a]+B[i + index_b]) + s = tvm.create_schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=4) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(32,)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(32,)).astype(B.dtype), ctx) + t = tvm.nd.empty((32,), T.dtype, ctx) + f(a, b, t) + +def test_in_bounds_const_loop_partition_ir(): + def check_attr_stmt (x): + if isinstance(x, tvm.stmt.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n): + return True + return False + + def check_branch_stmt (x): + if isinstance(x, tvm.stmt.IfThenElse): + return True + return False + + def assert_bound_instrumentation(stmt, f, nums): + count = 0 + for i in collect_visit(stmt, f): + if i is True: + count = count + 1 + assert (count == nums) + + def collect_branch_stmt (x): + if isinstance(x, tvm.stmt.IfThenElse): + branch_collector.append(x) + + n = 21 + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i]+B[i]) + s = tvm.create_schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=4) + + bounds = tvm.schedule.InferBound(s) + stmt = lower (s, [A, B, T]) + # num_attributes = num_buffers * num_splits = 2 * 3 + # before instrumentation + assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) + assert_bound_instrumentation(stmt, check_branch_stmt, 0) + stmt = tvm.ir_pass.InstrumentBoundCheckers(stmt) + # after instrumentation + assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) + assert_bound_instrumentation(stmt, check_branch_stmt, 2) + print (stmt) + branch_collector = list() + collect_visit(stmt, collect_branch_stmt) + assert(len(branch_collector) == 2) + print (branch_collector[0].condition) + print (branch_collector[1].condition) + +def test_in_bounds_const_loop_partition_llvm(): + with tvm.build_config(instrument_bound_checkers=True, partition_const_loop=True): + n = 21 + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i]+B[i]) + s = tvm.create_schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=4) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx) + t = tvm.nd.empty((n,), T.dtype, ctx) + f(a, b, t) + +@raises(Exception) +def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b): + with tvm.build_config(instrument_bound_checkers=True, partition_const_loop=True): + n = 21 + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i + index_a]+B[i + index_b]) + s = tvm.create_schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=4) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx) + t = tvm.nd.empty((n,), T.dtype, ctx) + f(a, b, t) + +def test_in_bounds_conv_llvm(loop_tiling=False): + HSTR = WSTR = 1 + in_channel = 128 + kernel_height = kernel_width = 3 + out_channel = 64 + batch_size = 1 + in_height = in_width = 64 + out_height = out_width = in_height - kernel_height + 1 + data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data') + kernel = tvm.placeholder((kernel_height, kernel_width, in_channel, + out_channel), name='kernel') + ic = tvm.reduce_axis((0, in_channel), name='ic') + kh = tvm.reduce_axis((0, kernel_height), name='kh') + kw = tvm.reduce_axis((0, kernel_width), name='kw') + conv = tvm.compute((batch_size, out_channel, out_height, out_width), + lambda n, oc, oh, ow: tvm.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] * + kernel[kh, kw, ic, oc], + axis=[ic, kh, kw]), + name="conv2d") + s = tvm.create_schedule(conv.op) + + n, oc, oh, ow = conv.op.axis + if loop_tiling: + oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) + lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True) + print (lowered_func.body) + ctx = tvm.cpu (0) + + f = tvm.build(s, [data, kernel, conv], "llvm") + data_input = tvm.nd.array(np.random.uniform( + size=(batch_size, in_channel, in_height, in_width)).astype(tvm.float32), ctx) + kernel_input = tvm.nd.array(np.random.uniform( + size=(kernel_height, kernel_width, in_channel, out_channel)).astype(tvm.float32), ctx) + conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), tvm.float32, ctx) + f(data_input, kernel_input, conv_out) + +@raises(Exception) +def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False): + HSTR = WSTR = 1 + in_channel = 128 + kernel_height = kernel_width = 3 + out_channel = 64 + batch_size = 1 + in_height = in_width = 64 + out_height = out_width = in_height - kernel_height + 1 + data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data') + kernel = tvm.placeholder((kernel_height, kernel_width, in_channel, + out_channel), name='kernel') + ic = tvm.reduce_axis((0, in_channel), name='ic') + kh = tvm.reduce_axis((0, kernel_height), name='kh') + kw = tvm.reduce_axis((0, kernel_width), name='kw') + conv = tvm.compute((batch_size, out_channel, out_height, out_width), + lambda n, oc, oh, ow: tvm.sum(data[n + data_offsets[0], + ic + data_offsets[1], + oh*HSTR + kh + data_offsets[2], + ow*WSTR + kw + data_offsets[3]] + * + kernel[kh + kernel_offsets[0], + kw + kernel_offsets[1], + ic + kernel_offsets[2], + oc + kernel_offsets[3]], + axis=[ic, kh, kw]), + name="conv2d") + s = tvm.create_schedule(conv.op) + + n, oc, oh, ow = conv.op.axis + if loop_tiling: + oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) + lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True) + print (lowered_func.body) + ctx = tvm.cpu (0) + + f = tvm.build(s, [data, kernel, conv], "llvm") + data_input = tvm.nd.array(np.random.uniform( + size=(batch_size, in_channel, in_height, in_width)).astype(tvm.float32), ctx) + kernel_input = tvm.nd.array(np.random.uniform( + size=(kernel_height, kernel_width, in_channel, out_channel)).astype(tvm.float32), ctx) + conv_out = tvm.nd.empty ((batch_size, out_channel, out_height, out_width), tvm.float32, ctx) + f(data_input, kernel_input, conv_out) + +def test_in_bounds_tensors_with_same_shapes1D_llvm(): + n = tvm.var('n') + k = tvm.var('k') + m = tvm.var('m') + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((k, ), name='B') + + T = tvm.compute((m, ), lambda i: A[i]*B[i]) + s = tvm.create_schedule(T.op) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(32, )).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(32,)).astype(B.dtype), ctx) + t = tvm.nd.empty((32,), T.dtype, ctx) + f(a, b, t) + +@raises(Exception) +def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape): + n = tvm.var('n') + k = tvm.var('k') + m = tvm.var('m') + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((k, ), name='B') + + T = tvm.compute((m, ), lambda i: A[i]*B[i]) + s = tvm.create_schedule(T.op) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(a_shape,)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(b_shape,)).astype(B.dtype), ctx) + t = tvm.nd.empty((c_shape,), T.dtype, ctx) + f(a, b, t) + +def test_in_bounds_tensors_with_same_shapes2D_llvm(): + n = tvm.var('n') + k = tvm.var('k') + m = tvm.var('m') + A = tvm.placeholder((n, n), name='A') + B = tvm.placeholder((k, k), name='B') + + T = tvm.compute((m, m), lambda i, j: A[i][j]*B[i][j]) + s = tvm.create_schedule(T.op) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(32, 32)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(32, 32)).astype(B.dtype), ctx) + t = tvm.nd.empty((32, 32), T.dtype, ctx) + f(a, b, t) + +@raises(Exception) +def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape): + n = tvm.var('n') + k = tvm.var('k') + m = tvm.var('m') + A = tvm.placeholder((n, n), name='A') + B = tvm.placeholder((k, k), name='B') + + T = tvm.compute((m, m), lambda i, j: A[i][j]*B[i][j]) + s = tvm.create_schedule(T.op) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(a_shape[0],a_shape[1])).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(b_shape[0],b_shape[1])).astype(B.dtype), ctx) + t = tvm.nd.empty((c_shape[0],c_shape[1]), T.dtype, ctx) + f(a, b, t) + +def test_in_bounds_tensors_with_same_shapes3D_llvm(): + n = tvm.var('n') + k = tvm.var('k') + m = tvm.var('m') + A = tvm.placeholder((n, n, n), name='A') + B = tvm.placeholder((k, k, k), name='B') + + T = tvm.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p]) + s = tvm.create_schedule(T.op) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(32,32,32)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(32,32,32)).astype(B.dtype), ctx) + t = tvm.nd.empty((32, 32, 32), T.dtype, ctx) + f(a, b, t) + +@raises(Exception) +def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape): + n = tvm.var('n') + k = tvm.var('k') + m = tvm.var('m') + A = tvm.placeholder((n, n, n), name='A') + B = tvm.placeholder((k, k, k), name='B') + + T = tvm.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p]) + s = tvm.create_schedule(T.op) + lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) + print (lowered_func.body) + ctx = tvm.cpu(0) + + f = tvm.build(s, [A, B, T], "llvm") + a = tvm.nd.array(np.random.uniform(size=(a_shape[0],a_shape[1], c_shape[2])).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(b_shape[0],b_shape[1], b_shape[2])).astype(B.dtype), ctx) + t = tvm.nd.empty((c_shape[0],c_shape[1],c_shape[2]), T.dtype, ctx) + f(a, b, t) + +def test_in_bounds_tensors_with_zero_shape_llvm(): + A = tvm.placeholder((), name='A') + B = tvm.placeholder((), name='B') + C = tvm.compute((), lambda : A + B + 1) + s = tvm.create_schedule(C.op) + lowered_func = tvm.lower(s, [A, B, C], simple_mode=False) + print (lowered_func.body) + f = tvm.build(s, [A, B, C], "llvm") + ctx = tvm.cpu(0) + a = tvm.nd.array( + np.random.randint(0, 2, size=()).astype(A.dtype), ctx) + b = tvm.nd.array( + np.random.randint(0, 2, size=()).astype(B.dtype), ctx) + c = tvm.nd.empty((), C.dtype, ctx) + f(a, b, c) + c_np = np.sum(a.asnumpy()) + b.asnumpy() + 1 + tvm.testing.assert_allclose(c.asnumpy(), c_np) + +@raises(Exception) +def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm(): + n = 64 + A = tvm.placeholder((n, ), name='A') + scale = tvm.placeholder((), name='scale') + k = tvm.reduce_axis((0, n), name="k") + C = tvm.compute((), lambda : tvm.sum(A[k + k + k] * scale, axis=k), name="C") + D = tvm.compute((), lambda : C + 1) + s = tvm.create_schedule(D.op) + stmt = tvm.lower (s, [A, scale, D], simple_mode=True) + print (stmt) + # build and invoke the kernel. + f = tvm.build(s, [A, scale, D], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx) + sc = tvm.nd.array( + np.random.randint(0, 2, size=()).astype(scale.dtype), ctx) + d = tvm.nd.empty((), D.dtype, ctx) + f(a, sc, d) + d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1 + tvm.testing.assert_allclose(d.asnumpy(), d_np) + +if __name__ == "__main__": + with tvm.build_config(instrument_bound_checkers=True): + # zero scale + test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm() + test_in_bounds_tensors_with_zero_shape_llvm() + # in bound + test_in_bounds_llvm() + # upper bound + test_out_of_bounds_llvm(1, 0) + test_out_of_bounds_llvm(0, 1) + test_out_of_bounds_llvm(1, 1) + test_out_of_bounds_llvm(10000, 0) + test_out_of_bounds_llvm(0, 10000) + test_out_of_bounds_llvm(10000, 10000) + # lower bound + test_out_of_bounds_llvm(-1, 0) + test_out_of_bounds_llvm(0, -1) + test_out_of_bounds_llvm(-1, -1) + test_out_of_bounds_llvm(-10000, 0) + test_out_of_bounds_llvm(0, -10000) + test_out_of_bounds_llvm(-10000, -10000) + # vectorize in bound + test_in_bounds_vectorize_llvm() + # vectorization upper bound + test_out_of_bounds_vectorize_llvm(1024, 1000, 0) + test_out_of_bounds_vectorize_llvm(1024, 0, 10000) + # vectorization lower bound + test_out_of_bounds_vectorize_llvm(1024, -1000, 0) + test_out_of_bounds_vectorize_llvm(1024, 0, -10000) + test_in_bounds_const_loop_partition_llvm() + test_out_of_bounds_const_loop_partition_llvm(1, 0) + test_out_of_bounds_const_loop_partition_llvm(0, 1) + test_out_of_bounds_const_loop_partition_llvm(-1, 0) + test_out_of_bounds_const_loop_partition_llvm(0, -1) + test_in_bounds_loop_partition_basic_llvm() + test_out_of_bounds_loop_partition_basic_llvm(32, 0) + test_out_of_bounds_loop_partition_basic_llvm(0, 32) + test_out_of_bounds_loop_partition_basic_llvm(-32, 0) + test_out_of_bounds_loop_partition_basic_llvm(0, -32) + # conv + test_in_bounds_conv_llvm() + test_out_of_bounds_conv_llvm([1, 0, 0, 0], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 1, 0, 0], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 1, 0], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 1], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([-1, 0, 0, 0], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, -1, 0, 0], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, -1, 0], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, -1], [0, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [1, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 1, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 1, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, 1]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [-1, 0, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, -1, 0, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, -1, 0]) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, -1]) + # loop tiling + test_in_bounds_conv_llvm(True) + test_out_of_bounds_conv_llvm([1, 0, 0, 0], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 1, 0, 0], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 1, 0], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 1], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([-1, 0, 0, 0], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, -1, 0, 0], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, -1, 0], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, -1], [0, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [1, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 1, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 1, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, 1], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [-1, 0, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, -1, 0, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, -1, 0], True) + test_out_of_bounds_conv_llvm([0, 0, 0, 0], [0, 0, 0, -1], True) + # tensors with diff shapes basic operation such as mul + test_out_of_bounds_tensors_with_diff_shapes1D_llvm (32, 64, 64) + test_out_of_bounds_tensors_with_diff_shapes1D_llvm (64, 32, 64) + test_out_of_bounds_tensors_with_diff_shapes2D_llvm([64, 64], [32, 32], [64, 64]) + test_out_of_bounds_tensors_with_diff_shapes2D_llvm([32, 32], [64, 64], [64, 64]) + test_out_of_bounds_tensors_with_diff_shapes3D_llvm([64, 64, 64], [32, 32, 32], [64, 64, 64]) + test_out_of_bounds_tensors_with_diff_shapes3D_llvm([32, 32, 32], [64, 64, 64], [64, 64, 64]) + # check tensors with the same shapes + test_in_bounds_tensors_with_same_shapes1D_llvm() + test_in_bounds_tensors_with_same_shapes2D_llvm() + test_in_bounds_tensors_with_same_shapes3D_llvm() + # ir tests + test_in_bounds_const_loop_partition_ir()