diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index 4de64f9db4b6..cf5df619001c 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -225,7 +225,7 @@ def func(*args): """TVM function""" cargs = [] for x in args: - if isinstance(x, (list, tuple, SliceBase)): + if isinstance(x, (list, tuple, dict, SliceBase)): cargs.append(convert(x)) else: cargs.append(x) diff --git a/python/tvm/function.py b/python/tvm/function.py index 78491404d7b1..72ec0d2680de 100644 --- a/python/tvm/function.py +++ b/python/tvm/function.py @@ -133,7 +133,8 @@ def compute(shape, fcompute, name="compute"): def Buffer(shape, dtype=None, - name="buffer", ptr=None, + name="buffer", + ptr=None, strides=None): """Create a new buffer diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index 2d4cb6e3fb55..e05f696bd35b 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -36,6 +36,7 @@ REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(VerifySSA); REGISTER_PASS4(Inline); REGISTER_PASS2(ScheduleOps); +REGISTER_PASS2(StorageFlatten); } // namespace ir } // namespace tvm diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 02cd05224e53..b44bca783834 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -51,7 +51,7 @@ Expr Buffer::MakeLoad(Array index) const { Stmt Buffer::MakeStore(Array index, Expr value) const { const BufferNode* n = operator->(); CHECK_EQ(value.type(), n->dtype); - return ir::Store::make(n->ptr, BufferOffset(n, index), value); + return ir::Store::make(n->ptr, value, BufferOffset(n, index)); } Buffer BufferNode::make(std::string name, diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 2c534a6c1b28..b2572b88e44d 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -83,7 +83,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) body.same_as(op->body)) { return s; } else { - return AttrStmt::make(op->node, op->type_key, op->value, op->body); + return AttrStmt::make(op->node, op->type_key, value, body); } }); diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc new file mode 100644 index 000000000000..6058b6907fe7 --- /dev/null +++ b/src/pass/storage_flatten.cc @@ -0,0 +1,168 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file storage_flatten.cc + */ +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +// key of function buffer +struct TensorKey { + FunctionRef f; + int value_index; + + inline bool operator==(const TensorKey& other) const { + return f == other.f && value_index == other.value_index; + } + inline std::string GetName() const { + if (f->num_outputs() == 1) return f->func_name(); + std::ostringstream os; + os << f->func_name() << ".v" << value_index; + return os.str(); + } +}; + +} // namespace ir +} // namespace tvm + +namespace std { +template <> +struct hash<::tvm::ir::TensorKey> { + std::size_t operator()(const ::tvm::ir::TensorKey& k) const { + size_t lhs = k.f.hash(); + size_t rhs = static_cast(k.value_index); + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; +} // namespace std + +namespace tvm { +namespace ir { + +using Halide::Internal::Region; + +// inliner to inline a function +// the result may not be SSA, +// ConvertSSA need to be applied after this pass +class StorageFlattener : public IRMutator { + public: + explicit StorageFlattener(Map extern_buffer) { + for (auto kv : extern_buffer) { + BufferEntry e; + e.buffer = kv.second; + e.external = true; + buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; + } + } + Expr Mutate(Expr expr) final { + expr = IRMutator::Mutate(expr); + const Call* op = expr.as(); + if (op != nullptr && op->call_type == Call::Halide) { + TensorKey key{op->func, op->value_index}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key.f; + const BufferEntry& e = it->second; + CHECK(!e.released) + << "Read a buffer that is already out of scope"; + return e.buffer.MakeLoad(e.RelIndex(op->args)); + } else { + return expr; + } + } + + Stmt Mutate(Stmt stmt) final { + const Realize* realize = stmt.as(); + if (realize != nullptr) { + return HandleRealize(realize); + } else if (stmt.as()) { + return HandleProvide(stmt); + } else { + return IRMutator::Mutate(stmt); + } + } + + private: + // The buffer entry in the flatten map + struct BufferEntry { + // the buffer of storage + Buffer buffer; + // the bounds of realization, can be null + Region bounds; + // Whether the buffer is external + bool external{false}; + // Whether we are out of allocation bounds and buffer get released. + bool released{false}; + // TODO(tqchen) allow permutation and inference of index dimension. + // relative index + inline Array RelIndex(Array args) const { + if (bounds.size() != 0) { + Array index; + CHECK_EQ(bounds.size(), args.size()); + for (size_t i = 0; i < bounds.size(); ++i) { + index.push_back(args[i] - bounds[i]->min); + } + return index; + } else { + return args; + } + } + }; + + // The buffer assignment map + std::unordered_map buf_map_; + + Stmt HandleRealize(const Realize* op) { + TensorKey key{op->func, op->value_index}; + if (buf_map_.count(key)) { + CHECK(buf_map_.at(key).external); + return this->Mutate(op->body); + } else { + // create a buffer entry + // TODO(tqchen) allow permutation and inference of index dimension. + BufferEntry e; + e.bounds = op->bounds; + Array shape; + for (auto r : e.bounds) { + shape.push_back(r->extent); + } + e.buffer = Buffer(shape, op->type, key.GetName()); + + buf_map_[key] = e; + Stmt body = this->Mutate(op->body); + buf_map_[key].released = true; + + return Allocate::make( + e.buffer->ptr, e.buffer->dtype, e.buffer->shape, + make_const(Bool(e.buffer->dtype.lanes()), true), body); + } + } + + Stmt HandleProvide(Stmt stmt) { + stmt = IRMutator::Mutate(stmt); + const Provide* op = stmt.as(); + TensorKey key{op->func, op->value_index}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key.f; + const BufferEntry& e = it->second; + CHECK(!e.released) + << "Read a buffer that is already out of scope"; + return e.buffer.MakeStore(e.RelIndex(op->args), op->value); + } +}; + + +Stmt StorageFlatten(Stmt stmt, + Map extern_buffer) { + stmt = StorageFlattener(extern_buffer).Mutate(stmt); + return stmt; +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/test_pass_storage_flatten.py b/tests/python/test_pass_storage_flatten.py new file mode 100644 index 000000000000..b7dff05d0f6f --- /dev/null +++ b/tests/python/test_pass_storage_flatten.py @@ -0,0 +1,24 @@ +import tvm + +def test_flatten2(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = tvm.Schedule(A2.op) + xo, xi = s[A2].split(A2.op.axis[0], 8) + s[A1].compute_at(s[A2], xo) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.collections.Map) + stmt = tvm.ir_pass.ScheduleOps(s, bounds) + + print(stmt) + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) + print(stmt) + +if __name__ == "__main__": + test_flatten2()