Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][TE][TIR] Call::Halide => ProducerLoad, DSL/TIR decouple. #5743

Merged
merged 1 commit into from
Jun 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions include/tvm/te/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ class OperationNode;
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public ObjectRef {
class Tensor : public DataProducer {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
explicit Tensor(ObjectPtr<Object> n) : DataProducer(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -157,7 +157,7 @@ class Operation : public tir::FunctionRef {
};

/*! \brief Node to represent a tensor */
class TensorNode : public Object {
class TensorNode : public DataProducerNode {
public:
/*! \brief The shape of the tensor */
Array<PrimExpr> shape;
Expand All @@ -176,10 +176,17 @@ class TensorNode : public Object {
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}

Array<PrimExpr> GetShape() const final { return shape; }

DataType GetDataType() const final { return dtype; }

TVM_DLL String GetNameHint() const final;

TVM_DLL static Tensor make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);

static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
};

// Implementations of inline functions
Expand Down
55 changes: 55 additions & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,61 @@ inline const BufferNode* Buffer::operator->() const {
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
std::string name = "buffer");

/*!
* \brief Base node for data producers.
*
* A DataProducer stores necessary information(e.g. a tensor expression) to produce
* a multi-dimensional array. The stored information is opaque to the TIR.
* DataProducer can appear in high-level DSLs that are built on top of the TIR.
*
* A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower
* all DataProducers to Buffers before TIR transformations.
*
* \sa tvm::te::Tensor
*/
class DataProducerNode : public Object {
public:
/*! \brief destructor. */
virtual ~DataProducerNode() {}
tqchen marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief Get the shape of the result.
* \return The shape.
*/
virtual Array<PrimExpr> GetShape() const = 0;
/*!
* \brief Get the data type of the result.
* \return The data type.
*/
virtual DataType GetDataType() const = 0;
/*!
* \brief Get the name hint of the data producer.
* \return The data type.
*/
virtual String GetNameHint() const = 0;

bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const {
// because buffer producer is opaque, we just do pointer equality.
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {}

static constexpr const char* _type_key = "DataProducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
};

/*!
* \brief Managed reference to DataProducerNode.
* \sa DataProducerNode
*/
class DataProducer : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode);
};

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_BUFFER_H_
81 changes: 55 additions & 26 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,64 @@ class BufferLoadNode : public PrimExprNode {
TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to BufferLoadNode.
* \sa BufferLoadNode
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices);
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
};

/*!
* \brief Load value from the result produced by the producer.
*
* \note This node only appears in high-level DSLs that are built on top of the TIR.
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
* this node before TIR transformations.
*
* \sa ProducerLoad, DataProducerNode
*/
class ProducerLoadNode : public PrimExprNode {
public:
/*! \brief The buffer producer. */
DataProducer producer;
/*! \brief The location arguments. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("producer", &producer);
v->Visit("indices", &indices);
}

bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(producer, other->producer) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(producer);
hash_reduce(indices);
}

tqchen marked this conversation as resolved.
Show resolved Hide resolved
static constexpr const char* _type_key = "ProducerLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to ProducerLoadNode.
* \sa ProducerLoadNode
*/
class ProducerLoad : public PrimExpr {
public:
TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices);

TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
};

/*!
* \brief Load the value from buffer_var.
*
Expand Down Expand Up @@ -661,11 +713,6 @@ class CallNode : public PrimExprNode {
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*!
* \brief Halide-style call, evaluates func(args).
* \note Deprecated, move to BufferLoad in the future.
*/
Halide = 3,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
/*! \brief Intrinsic functions that are pure. */
Expand All @@ -677,49 +724,31 @@ class CallNode : public PrimExprNode {
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;
/*!
* \brief The function to be called.
* \note Deprecated, move to BufferLoad in the future.
*/
FunctionRef func;
/*!
* \brief The output value index if func's value is a tuple.
* \note Deprecated, move to BufferLoad in the future.
*/
int value_index{0};

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
v->Visit("func", &func);
v->Visit("value_index", &value_index);
}

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) &&
equal(call_type, other->call_type) && equal(func, other->func) &&
equal(value_index, other->value_index);
equal(call_type, other->call_type);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(name);
hash_reduce(args);
hash_reduce(call_type);
hash_reduce(func);
hash_reduce(value_index);
}

TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array<PrimExpr> args,
CallType call_type, FunctionRef func = FunctionRef(),
int value_index = 0);
CallType call_type);

/*! \return Whether call node is pure. */
bool is_pure() const {
return (call_type == PureExtern || call_type == PureIntrinsic || call_type == Halide);
}
bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }

/*!
* \return Whether call node corresponds to a defined intrinsic.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -163,6 +164,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
IR_EXPR_FUNCTOR_DISPATCH(AddNode);
Expand Down Expand Up @@ -213,6 +215,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const ProducerLoadNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const AddNode* op) override;
Expand Down Expand Up @@ -258,6 +261,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
PrimExpr VisitExpr_(const AddNode* op) override;
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,11 @@ def _count_flop(exp):
if isinstance(exp, expr.Select):
return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
_count_flop(exp.false_value))
if isinstance(exp, expr.Call):
if exp.call_type == expr.Call.Halide:
# Ignore flops from indexing expressions.
return 0
if isinstance(exp, expr.ProducerLoad):
# Ignore flops from indexing expressions.
return 0

if isinstance(exp, expr.Call):
return sum([_count_flop(x) for x in exp.args])

raise FlopCalculationError("Found unsupported operator in the compute expr")
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/target/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None):

op_name : str
The name of the operation which the function computes, given by its
Halide::Internal class name (e.g. Add, LE, Cast).
class name (e.g. Add, LE, Cast).

target : str
The name of codegen target.
Expand Down Expand Up @@ -136,8 +136,8 @@ def lower(op):
dtype += "x" + str(t.lanes)
if isinstance(op, (_Cast, _FloatImm)):
return _Call(dtype, extern_func_name, convert([op.value]),
_Call.Extern, None, 0)
_Call.Extern)
return _Call(dtype, extern_func_name, convert([op.a, op.b]),
_Call.Extern, None, 0)
_Call.Extern)

return lower
8 changes: 3 additions & 5 deletions python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,7 @@ def visit_Name(self, node):
return entry if isinstance(node.ctx, ast.Load) else None
if ty is Symbol.BufferVar:
if isinstance(node.ctx, ast.Load):
return tvm.tir.Call(entry.dtype, entry.name, [tvm.runtime.const(0, 'int32')], \
_expr.Call.Halide, entry.op, entry.value_index)
return tvm.tir.ProducerLoad(entry, [tvm.runtime.const(0, 'int32')])
return entry, [tvm.runtime.const(0, 'int32')]
# Do I need any assertion here?
return entry
Expand Down Expand Up @@ -305,7 +304,7 @@ def visit_AugAssign(self, node):
args = [tvm.runtime.const(0, 'int32')]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")

read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
read = tvm.tir.ProducerLoad(buf, args)
value = HybridParser._binop_maker[type(node.op)](read, rhs)

return tvm.tir.Provide(buf.op, 0, value, args)
Expand Down Expand Up @@ -392,8 +391,7 @@ def visit_Subscript(self, node):
arr = arr[i.value]
return arr
if isinstance(node.ctx, ast.Load):
return tvm.tir.Call(arr.dtype, arr.name, args,
_expr.Call.Halide, arr.op, arr.value_index)
return tvm.tir.ProducerLoad(arr, args)
return arr, args

def visit_With(self, node):
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/te/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _stmt.Provide(buf.op, op.value_index, op.value, op.args)
if isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _expr.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys():
buf = rmap[op.producer.op]
return _expr.ProducerLoad(buf, op.indices)
return None

return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/te/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tvm._ffi

from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr
from tvm.tir import expr as _expr, DataProducer

from . import _ffi_api

Expand Down Expand Up @@ -52,7 +52,7 @@ class TensorIntrinCall(Object):


@tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp):
class Tensor(DataProducer, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""

def __call__(self, *indices):
Expand All @@ -69,9 +69,8 @@ def __call__(self, *indices):
else:
raise ValueError("The indices must be expression")

return _expr.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
return _expr.ProducerLoad(self, args)


def __getitem__(self, indices):
return TensorSlice(self, indices)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from tvm.ir import PrimExpr
from tvm.runtime import const

from .buffer import Buffer, decl_buffer
from .buffer import Buffer, decl_buffer, DataProducer
from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, For
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,8 @@ def decl_buffer(shape,
return _ffi_api.Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor, buffer_type)


@tvm._ffi.register_object
class DataProducer(Object):
pass
Loading