Skip to content

Commit

Permalink
[PASS] Basic storage flatten (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jan 16, 2017
1 parent 0992873 commit 7f82912
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tvm/_ctypes/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def func(*args):
"""TVM function"""
cargs = []
for x in args:
if isinstance(x, (list, tuple, SliceBase)):
if isinstance(x, (list, tuple, dict, SliceBase)):
cargs.append(convert(x))
else:
cargs.append(x)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def compute(shape, fcompute, name="compute"):


def Buffer(shape, dtype=None,
name="buffer", ptr=None,
name="buffer",
ptr=None,
strides=None):
"""Create a new buffer
Expand Down
1 change: 1 addition & 0 deletions src/c_api/c_api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
REGISTER_PASS2(StorageFlatten);

} // namespace ir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Expr Buffer::MakeLoad(Array<Expr> index) const {
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->ptr, BufferOffset(n, index), value);
return ir::Store::make(n->ptr, value, BufferOffset(n, index));
}

Buffer BufferNode::make(std::string name,
Expand Down
2 changes: 1 addition & 1 deletion src/pass/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, op->value, op->body);
return AttrStmt::make(op->node, op->type_key, value, body);
}
});

Expand Down
168 changes: 168 additions & 0 deletions src/pass/storage_flatten.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*!
* Copyright (c) 2016 by Contributors
* \file storage_flatten.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>

namespace tvm {
namespace ir {

// key of function buffer
struct TensorKey {
FunctionRef f;
int value_index;

inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};

} // namespace ir
} // namespace tvm

namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std

namespace tvm {
namespace ir {

using Halide::Internal::Region;

// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class StorageFlattener : public IRMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer) {
for (auto kv : extern_buffer) {
BufferEntry e;
e.buffer = kv.second;
e.external = true;
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
}
}
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
const Call* op = expr.as<Call>();
if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeLoad(e.RelIndex(op->args));
} else {
return expr;
}
}

Stmt Mutate(Stmt stmt) final {
const Realize* realize = stmt.as<Realize>();
if (realize != nullptr) {
return HandleRealize(realize);
} else if (stmt.as<Provide>()) {
return HandleProvide(stmt);
} else {
return IRMutator::Mutate(stmt);
}
}

private:
// The buffer entry in the flatten map
struct BufferEntry {
// the buffer of storage
Buffer buffer;
// the bounds of realization, can be null
Region bounds;
// Whether the buffer is external
bool external{false};
// Whether we are out of allocation bounds and buffer get released.
bool released{false};
// TODO(tqchen) allow permutation and inference of index dimension.
// relative index
inline Array<Expr> RelIndex(Array<Expr> args) const {
if (bounds.size() != 0) {
Array<Expr> index;
CHECK_EQ(bounds.size(), args.size());
for (size_t i = 0; i < bounds.size(); ++i) {
index.push_back(args[i] - bounds[i]->min);
}
return index;
} else {
return args;
}
}
};

// The buffer assignment map
std::unordered_map<TensorKey, BufferEntry> buf_map_;

Stmt HandleRealize(const Realize* op) {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
return this->Mutate(op->body);
} else {
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry e;
e.bounds = op->bounds;
Array<Expr> shape;
for (auto r : e.bounds) {
shape.push_back(r->extent);
}
e.buffer = Buffer(shape, op->type, key.GetName());

buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;

return Allocate::make(
e.buffer->ptr, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
}

Stmt HandleProvide(Stmt stmt) {
stmt = IRMutator::Mutate(stmt);
const Provide* op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeStore(e.RelIndex(op->args), op->value);
}
};


Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer) {
stmt = StorageFlattener(extern_buffer).Mutate(stmt);
return stmt;
}

} // namespace ir
} // namespace tvm
24 changes: 24 additions & 0 deletions tests/python/test_pass_storage_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tvm

def test_flatten2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

s = tvm.Schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)

print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
print(stmt)

if __name__ == "__main__":
test_flatten2()

0 comments on commit 7f82912

Please sign in to comment.