Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][TIR] Provide->ProducerStore, Realize->ProducerRealize. #5750

Merged
merged 1 commit into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 68 additions & 102 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,44 +334,92 @@ class BufferRealize : public Stmt {
};

/*!
* \brief Store value into mult-dimensional array defined by func.
* \brief Store value into mult-dimensional array that will be read by the consumer
* of the producer.
*
* \note Deprecated, move to BufferStore in the future.
* \note This node only appears in high-level DSLs that are built on top of the TIR.
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
* this node before TIR transformations.
*
* \sa DataProducer
*/
class ProvideNode : public StmtNode {
class ProducerStoreNode : public StmtNode {
public:
/*! \brief The function to be updated. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index{0};
/*! \brief The producer to store the results into. */
DataProducer producer;
/*! \brief The value to be stored. */
PrimExpr value;
/*! \brief The index arguments of the function. */
Array<PrimExpr> args;
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("producer", &producer);
v->Visit("value", &value);
v->Visit("args", &args);
v->Visit("indices", &indices);
}

bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const {
return equal(func, other->func) && equal(value_index, other->value_index) &&
equal(value, other->value) && equal(args, other->args);
bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(value, other->value) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(value_index);
hash_reduce(producer);
hash_reduce(value);
hash_reduce(args);
hash_reduce(indices);
}

TVM_DLL static Stmt make(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);

static constexpr const char* _type_key = "ProducerStore";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
};

/*!
* \brief Annotate the bounds where the data produced by the producer
* need to be written and read in body.
* We will need to allocate space for the corresponding regions.
*
* \note This node only appears in high-level DSLs that are built on top of the TIR.
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
* this node before TIR transformations.
*
* \sa DataProducer
*/
class ProducerRealizeNode : public StmtNode {
public:
/*! \brief The producer that produces the data. */
DataProducer producer;
/*! \brief Bounds to be realized. */
Region bounds;
/*! \brief Only realize if condition holds. */
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("producer", &producer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
}

TVM_DLL static Stmt make(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
}

TVM_DLL static Stmt make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args);
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(producer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
}

static constexpr const char* _type_key = "Provide";
TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode);
static constexpr const char* _type_key = "ProducerRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
};

/*!
Expand Down Expand Up @@ -453,58 +501,6 @@ class FreeNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};

/*!
* \brief Annotate the bounds where func need to be written and read in body.
* We will need to allocate space for the corresponding regions.
*
* \note Deprecated, move to BufferRealize in the future.
*/
class RealizeNode : public StmtNode {
public:
/*! \brief The function to be realized. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
DataType dtype;
/*! \brief Bounds to be realized. */
Region bounds;
/*! \brief Only realize if condition holds. */
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("dtype", &dtype);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
}

TVM_DLL static Stmt make(FunctionRef func, int value_index, DataType dtype, Region bounds,
PrimExpr condition, Stmt body);

bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const {
return equal(func, other->func) && equal(value_index, other->value_index) &&
equal(dtype, other->dtype) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(value_index);
hash_reduce(dtype);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
}

static constexpr const char* _type_key = "Realize";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down Expand Up @@ -777,23 +773,6 @@ class Prefetch : public Stmt {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};

/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
struct TensorKey {
FunctionRef f;
int value_index;

inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};

/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
Expand Down Expand Up @@ -933,17 +912,4 @@ TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);

} // namespace tir
} // namespace tvm

namespace std {
template <>
struct hash<::tvm::tir::TensorKey> {
std::size_t operator()(const ::tvm::tir::TensorKey& k) const {
size_t lhs = ::tvm::ObjectPtrHash()(k.f);
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std

#endif // TVM_TIR_STMT_H_
16 changes: 8 additions & 8 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -114,8 +114,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(FreeNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProvideNode);
IR_STMT_FUNCTOR_DISPATCH(RealizeNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
Expand Down Expand Up @@ -156,8 +156,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProvideNode* op) override;
void VisitStmt_(const RealizeNode* op) override;
void VisitStmt_(const ProducerStoreNode* op) override;
void VisitStmt_(const ProducerRealizeNode* op) override;
void VisitStmt_(const PrefetchNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
Expand Down Expand Up @@ -248,8 +248,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override;
Stmt VisitStmt_(const RealizeNode* op) override;
Stmt VisitStmt_(const ProducerStoreNode* op) override;
Stmt VisitStmt_(const ProducerRealizeNode* op) override;
Stmt VisitStmt_(const PrefetchNode* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const EvaluateNode* op) override;
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def wrap_up_realize(self, node, body):
_domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_true = tvm.runtime.convert(True)
body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = tvm.tir.ProducerRealize(_buf, _domain, _true, body)
body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', tvm.runtime.convert(_scope), body)

for elem in to_pop:
Expand Down Expand Up @@ -307,7 +307,7 @@ def visit_AugAssign(self, node):
read = tvm.tir.ProducerLoad(buf, args)
value = HybridParser._binop_maker[type(node.op)](read, rhs)

return tvm.tir.Provide(buf.op, 0, value, args)
return tvm.tir.ProducerStore(buf, value, args)


def visit_Assign(self, node):
Expand Down Expand Up @@ -358,13 +358,13 @@ def visit_Assign(self, node):
lhs = self.visit(lhs_)
if lhs is not None:
buf, args = lhs
return tvm.tir.Provide(buf.op, 0, rhs, args)
return tvm.tir.ProducerStore(buf, rhs, args)
return util.make_nop()

lhs, args = self.visit(lhs)
_internal_assert(isinstance(lhs, Tensor), \
"An array access's LHS is expected to be a expr.Call!")
res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args)
res = tvm.tir.ProducerStore(lhs, rhs, args)
return res


Expand Down
8 changes: 4 additions & 4 deletions python/tvm/te/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def replace_io(body, rmap):
from tvm.tir import stmt_functor

def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _stmt.Provide(buf.op, op.value_index, op.value, op.args)
if isinstance(op, _stmt.ProducerStore) and op.producer.op in rmap.keys():
buf = rmap[op.producer.op]
return _stmt.ProducerStore(buf, op.value, op.indices)
if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys():
buf = rmap[op.producer.op]
return _expr.ProducerLoad(buf, op.indices)
return None

return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])
return stmt_functor.ir_transform(body, None, replace, ['ProducerStore', 'ProducerLoad'])


def _is_tvm_arg_types(args):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from .expr import IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, BufferRealize, Store, Provide, Allocate, AttrStmt
from .stmt import Free, Realize, SeqStmt
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import Free, ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list

from .function import PrimFunc
Expand Down
40 changes: 14 additions & 26 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,26 +184,23 @@ def __init__(self, buffer, bounds, condition, body):


@tvm._ffi.register_object
class Provide(Stmt):
"""Provide node.
class ProducerStore(Stmt):
"""ProducerStore node.

Parameters
----------
func : Operation
The operation to create the function.

value_index : int
The output value index
producer : DataProducer
The data producer.

value : PrimExpr
The value to be stored.

args : list of Expr
The index arguments of the Provide.
indices : list of Expr
The index arguments of the store.
"""
def __init__(self, func, value_index, value, args):
def __init__(self, producer, value, indices):
self.__init_handle_by_constructor__(
_ffi_api.Provide, func, value_index, value, args)
_ffi_api.ProducerStore, producer, value, indices)


@tvm._ffi.register_object
Expand Down Expand Up @@ -276,19 +273,13 @@ def __init__(self, buffer_var):


@tvm._ffi.register_object
class Realize(Stmt):
"""Realize node.
class ProducerRealize(Stmt):
"""ProducerRealize node.

Parameters
----------
func : Operation
The operation to create the function.

value_index : int
The output value index

dtype : str
The data type of the operation.
producer : DataProducer
The data producer.

bounds : list of range
The bound of realize
Expand All @@ -300,15 +291,12 @@ class Realize(Stmt):
The realize body
"""
def __init__(self,
func,
value_index,
dtype,
producer,
bounds,
condition,
body):
self.__init_handle_by_constructor__(
_ffi_api.Realize, func, value_index, dtype,
bounds, condition, body)
_ffi_api.ProducerRealize, producer, bounds, condition, body)


@tvm._ffi.register_object
Expand Down
Loading