diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index e2d847f26da6..045d186f58f5 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -49,11 +49,11 @@ class OperationNode; * \brief Tensor structure representing a possible input, * or intermediate computation result. */ -class Tensor : public ObjectRef { +class Tensor : public DataProducer { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(ObjectPtr n) : ObjectRef(n) {} + explicit Tensor(ObjectPtr n) : DataProducer(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -157,7 +157,7 @@ class Operation : public tir::FunctionRef { }; /*! \brief Node to represent a tensor */ -class TensorNode : public Object { +class TensorNode : public DataProducerNode { public: /*! \brief The shape of the tensor */ Array shape; @@ -176,10 +176,17 @@ class TensorNode : public Object { v->Visit("op", &op); v->Visit("value_index", &value_index); } + + Array GetShape() const final { return shape; } + + DataType GetDataType() const final { return dtype; } + + TVM_DLL String GetNameHint() const final; + TVM_DLL static Tensor make(Array shape, DataType dtype, Operation op, int value_index); static constexpr const char* _type_key = "Tensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); }; // Implementations of inline functions diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 5d4e86026b39..6904f2a4ed40 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -203,6 +203,61 @@ inline const BufferNode* Buffer::operator->() const { */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), std::string name = "buffer"); + +/*! + * \brief Base node for data producers. + * + * A DataProducer stores necessary information(e.g. a tensor expression) to produce + * a multi-dimensional array. The stored information is opaque to the TIR. + * DataProducer can appear in high-level DSLs that are built on top of the TIR. + * + * A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower + * all DataProducers to Buffers before TIR transformations. + * + * \sa tvm::te::Tensor + */ +class DataProducerNode : public Object { + public: + /*! \brief destructor. */ + virtual ~DataProducerNode() {} + /*! + * \brief Get the shape of the result. + * \return The shape. + */ + virtual Array GetShape() const = 0; + /*! + * \brief Get the data type of the result. + * \return The data type. + */ + virtual DataType GetDataType() const = 0; + /*! + * \brief Get the name hint of the data producer. + * \return The data type. + */ + virtual String GetNameHint() const = 0; + + bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const { + // because buffer producer is opaque, we just do pointer equality. + return this == other; + } + + void SHashReduce(SHashReducer hash_reduce) const {} + + static constexpr const char* _type_key = "DataProducer"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object); +}; + +/*! + * \brief Managed reference to DataProducerNode. + * \sa DataProducerNode + */ +class DataProducer : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode); +}; + } // namespace tir } // namespace tvm #endif // TVM_TIR_BUFFER_H_ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 5909a24ad3e9..d34165ef955d 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -449,12 +449,64 @@ class BufferLoadNode : public PrimExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); }; +/*! + * \brief Managed reference to BufferLoadNode. + * \sa BufferLoadNode + */ class BufferLoad : public PrimExpr { public: TVM_DLL explicit BufferLoad(Buffer buffer, Array indices); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); }; +/*! + * \brief Load value from the result produced by the producer. + * + * \note This node only appears in high-level DSLs that are built on top of the TIR. + * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower + * this node before TIR transformations. + * + * \sa ProducerLoad, DataProducerNode + */ +class ProducerLoadNode : public PrimExprNode { + public: + /*! \brief The buffer producer. */ + DataProducer producer; + /*! \brief The location arguments. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &(this->dtype)); + v->Visit("producer", &producer); + v->Visit("indices", &indices); + } + + bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(producer, other->producer) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dtype); + hash_reduce(producer); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "ProducerLoad"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode); +}; + +/*! + * \brief Managed reference to ProducerLoadNode. + * \sa ProducerLoadNode + */ +class ProducerLoad : public PrimExpr { + public: + TVM_DLL explicit ProducerLoad(DataProducer producer, Array indices); + + TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); +}; + /*! * \brief Load the value from buffer_var. * @@ -661,11 +713,6 @@ class CallNode : public PrimExprNode { ExternCPlusPlus = 1, /*! \brief Extern "C" without side-effect. */ PureExtern = 2, - /*! - * \brief Halide-style call, evaluates func(args). - * \note Deprecated, move to BufferLoad in the future. - */ - Halide = 3, /*! \brief Intrinsic functions. */ Intrinsic = 4, /*! \brief Intrinsic functions that are pure. */ @@ -677,30 +724,17 @@ class CallNode : public PrimExprNode { Array args; /*! \brief Type of calls. */ CallType call_type; - /*! - * \brief The function to be called. - * \note Deprecated, move to BufferLoad in the future. - */ - FunctionRef func; - /*! - * \brief The output value index if func's value is a tuple. - * \note Deprecated, move to BufferLoad in the future. - */ - int value_index{0}; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("name", &name); v->Visit("args", &args); v->Visit("call_type", &call_type); - v->Visit("func", &func); - v->Visit("value_index", &value_index); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) && - equal(call_type, other->call_type) && equal(func, other->func) && - equal(value_index, other->value_index); + equal(call_type, other->call_type); } void SHashReduce(SHashReducer hash_reduce) const { @@ -708,18 +742,13 @@ class CallNode : public PrimExprNode { hash_reduce(name); hash_reduce(args); hash_reduce(call_type); - hash_reduce(func); - hash_reduce(value_index); } TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array args, - CallType call_type, FunctionRef func = FunctionRef(), - int value_index = 0); + CallType call_type); /*! \return Whether call node is pure. */ - bool is_pure() const { - return (call_type == PureExtern || call_type == PureIntrinsic || call_type == Halide); - } + bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); } /*! * \return Whether call node corresponds to a defined intrinsic. diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index 15ec3d2ae0bf..a6c90b36a49c 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -119,6 +119,7 @@ class ExprFunctor { return VisitExpr_(static_cast(op), std::forward(args)...); } virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -163,6 +164,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); IR_EXPR_FUNCTOR_DISPATCH(LoadNode); IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode); + IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode); IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(CallNode); IR_EXPR_FUNCTOR_DISPATCH(AddNode); @@ -213,6 +215,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor { void VisitExpr_(const SizeVarNode* op) override; void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; + void VisitExpr_(const ProducerLoadNode* op) override; void VisitExpr_(const LetNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const AddNode* op) override; @@ -258,6 +261,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const SizeVarNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; + PrimExpr VisitExpr_(const ProducerLoadNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; PrimExpr VisitExpr_(const CallNode* op) override; PrimExpr VisitExpr_(const AddNode* op) override; diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 00b667670c65..b7cd6f2b04ed 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -495,11 +495,11 @@ def _count_flop(exp): if isinstance(exp, expr.Select): return _count_flop(exp.condition) + max(_count_flop(exp.true_value), _count_flop(exp.false_value)) - if isinstance(exp, expr.Call): - if exp.call_type == expr.Call.Halide: - # Ignore flops from indexing expressions. - return 0 + if isinstance(exp, expr.ProducerLoad): + # Ignore flops from indexing expressions. + return 0 + if isinstance(exp, expr.Call): return sum([_count_flop(x) for x in exp.args]) raise FlopCalculationError("Found unsupported operator in the compute expr") diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index 328568a360bc..e42ac6b37806 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -88,7 +88,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None): op_name : str The name of the operation which the function computes, given by its - Halide::Internal class name (e.g. Add, LE, Cast). + class name (e.g. Add, LE, Cast). target : str The name of codegen target. @@ -136,8 +136,8 @@ def lower(op): dtype += "x" + str(t.lanes) if isinstance(op, (_Cast, _FloatImm)): return _Call(dtype, extern_func_name, convert([op.value]), - _Call.Extern, None, 0) + _Call.Extern) return _Call(dtype, extern_func_name, convert([op.a, op.b]), - _Call.Extern, None, 0) + _Call.Extern) return lower diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 765efa0b976c..75300ab405e9 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -272,8 +272,7 @@ def visit_Name(self, node): return entry if isinstance(node.ctx, ast.Load) else None if ty is Symbol.BufferVar: if isinstance(node.ctx, ast.Load): - return tvm.tir.Call(entry.dtype, entry.name, [tvm.runtime.const(0, 'int32')], \ - _expr.Call.Halide, entry.op, entry.value_index) + return tvm.tir.ProducerLoad(entry, [tvm.runtime.const(0, 'int32')]) return entry, [tvm.runtime.const(0, 'int32')] # Do I need any assertion here? return entry @@ -305,7 +304,7 @@ def visit_AugAssign(self, node): args = [tvm.runtime.const(0, 'int32')] _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") - read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) + read = tvm.tir.ProducerLoad(buf, args) value = HybridParser._binop_maker[type(node.op)](read, rhs) return tvm.tir.Provide(buf.op, 0, value, args) @@ -392,8 +391,7 @@ def visit_Subscript(self, node): arr = arr[i.value] return arr if isinstance(node.ctx, ast.Load): - return tvm.tir.Call(arr.dtype, arr.name, args, - _expr.Call.Halide, arr.op, arr.value_index) + return tvm.tir.ProducerLoad(arr, args) return arr, args def visit_With(self, node): diff --git a/python/tvm/te/hybrid/util.py b/python/tvm/te/hybrid/util.py index 01eeeec16142..35c59f11be70 100644 --- a/python/tvm/te/hybrid/util.py +++ b/python/tvm/te/hybrid/util.py @@ -78,10 +78,9 @@ def replace(op): if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): buf = rmap[op.func] return _stmt.Provide(buf.op, op.value_index, op.value, op.args) - if isinstance(op, _expr.Call) and op.func in rmap.keys(): - buf = rmap[op.func] - return _expr.Call(buf.dtype, buf.name, op.args, \ - _expr.Call.Halide, buf.op, buf.value_index) + if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys(): + buf = rmap[op.producer.op] + return _expr.ProducerLoad(buf, op.indices) return None return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call']) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 739268aba4a5..7d73bf42ab7d 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -19,7 +19,7 @@ import tvm._ffi from tvm.runtime import Object, ObjectGeneric, convert_to_object -from tvm.tir import expr as _expr +from tvm.tir import expr as _expr, DataProducer from . import _ffi_api @@ -52,7 +52,7 @@ class TensorIntrinCall(Object): @tvm._ffi.register_object -class Tensor(Object, _expr.ExprOp): +class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): @@ -69,9 +69,8 @@ def __call__(self, *indices): else: raise ValueError("The indices must be expression") - return _expr.Call(self.dtype, self.op.name, - args, _expr.Call.Halide, - self.op, self.value_index) + return _expr.ProducerLoad(self, args) + def __getitem__(self, indices): return TensorSlice(self, indices) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 07e0c9ca0f27..9aec24a77f6f 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -19,12 +19,12 @@ from tvm.ir import PrimExpr from tvm.runtime import const -from .buffer import Buffer, decl_buffer +from .buffer import Buffer, decl_buffer, DataProducer from .data_layout import Layout, BijectiveLayout, bijective_layout, layout from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not -from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let +from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let from .expr import IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, For diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 0c7753e4d8ec..e4dec5f30950 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -245,3 +245,8 @@ def decl_buffer(shape, return _ffi_api.Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor, buffer_type) + + +@tvm._ffi.register_object +class DataProducer(Object): + pass diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index aca5e5a377fb..d55370e8bdfa 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -144,7 +144,7 @@ def __rxor__(self, other): def __invert__(self): if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") - return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) + return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic) def __lt__(self, other): return _ffi_api._OpLT(self, other) @@ -888,6 +888,23 @@ def __init__(self, buffer, indices): _ffi_api.BufferLoad, buffer, indices) +@tvm._ffi.register_object +class ProducerLoad(PrimExprWithOp): + """Producer load node. + + Parameters + ---------- + producer : DataProducer + The buffer to be loaded. + + indices : List[PrimExpr] + The buffer indices. + """ + def __init__(self, producer, indices): + self.__init_handle_by_constructor__( + _ffi_api.ProducerLoad, producer, indices) + + @tvm._ffi.register_object class Ramp(PrimExprWithOp): """Ramp node. @@ -959,22 +976,15 @@ class Call(PrimExprWithOp): call_type : int The type of the call - - func : Operation, optional - Operation if call_type is Halide - - value_index : int - The output value index """ Extern = 0 ExternCPlusPlus = 1 PureExtern = 2 - Halide = 3 Intrinsic = 4 PureIntrinsic = 5 - def __init__(self, dtype, name, args, call_type, func, value_index): + def __init__(self, dtype, name, args, call_type): self.__init_handle_by_constructor__( - _ffi_api.Call, dtype, name, args, call_type, func, value_index) + _ffi_api.Call, dtype, name, args, call_type) @tvm._ffi.register_object diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 4dd541e9bdbb..47ba2e2c805c 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -380,7 +380,7 @@ def likely(self, expr): The expression will likely tag. """ return _expr.Call(expr.dtype, "likely", [expr], - _expr.Call.PureIntrinsic, None, 0) + _expr.Call.PureIntrinsic) def get(self): """Return the builded IR. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index b87db19738b9..929d422ccc43 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -30,9 +30,9 @@ def _pack_buffer(buf): """ assert buf.shape shape = Call("handle", "tvm_stack_make_shape", buf.shape, - Call.Intrinsic, None, 0) + Call.Intrinsic) strides = Call("handle", "tvm_stack_make_shape", buf.strides, - Call.Intrinsic, None, 0) if buf.strides else 0 + Call.Intrinsic) if buf.strides else 0 pack_args = [buf.data, shape, strides, @@ -40,7 +40,7 @@ def _pack_buffer(buf): const(0, dtype=buf.dtype), buf.elem_offset] return Call("handle", "tvm_stack_make_array", - pack_args, Call.Intrinsic, None, 0) + pack_args, Call.Intrinsic) def call_packed(*args): """Build expression by call an external packed function. @@ -68,7 +68,7 @@ def call_packed(*args): """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] return Call( - "int32", "tvm_call_packed", call_args, Call.Intrinsic, None, 0) + "int32", "tvm_call_packed", call_args, Call.Intrinsic) def call_pure_intrin(dtype, func_name, *args): @@ -95,7 +95,7 @@ def call_pure_intrin(dtype, func_name, *args): """ args = convert(args) return Call( - dtype, func_name, convert(args), Call.PureIntrinsic, None, 0) + dtype, func_name, convert(args), Call.PureIntrinsic) def call_intrin(dtype, func_name, *args): @@ -122,7 +122,7 @@ def call_intrin(dtype, func_name, *args): """ args = convert(args) return Call( - dtype, func_name, convert(args), Call.Intrinsic, None, 0) + dtype, func_name, convert(args), Call.Intrinsic) def call_pure_extern(dtype, func_name, *args): @@ -145,7 +145,7 @@ def call_pure_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.PureExtern, None, 0) + dtype, func_name, convert(args), Call.PureExtern) def call_extern(dtype, func_name, *args): @@ -168,7 +168,7 @@ def call_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.Extern, None, 0) + dtype, func_name, convert(args), Call.Extern) def call_llvm_intrin(dtype, name, *args): @@ -278,7 +278,7 @@ def trace(args, trace_action="tvm.default_trace_action"): call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) return tvm.tir.Call( - args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic, None, 0) + args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic) diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index f61ad33190c4..706252057c98 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -202,18 +202,21 @@ void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT PrintExpr(op->a, os); } +void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { // NOLINT(*) + auto tensor = Downcast(op->producer); + + os << GetTensorID(tensor->op, tensor->value_index); + os << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + if (i) os << ", "; + std::stringstream idx; + PrintExpr(op->indices[i], idx); + os << idx.str(); + } + os << "]"; +} void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->call_type == CallNode::Halide) { - os << GetTensorID(op->func, op->value_index); - os << "["; - for (size_t i = 0; i < op->args.size(); ++i) { - if (i) os << ", "; - std::stringstream idx; - PrintExpr(op->args[i], idx); - os << idx.str(); - } - os << "]"; - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsitc(op, "&", os, this); } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsitc(op, "^", os, this); diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 78a22b55dae7..8a31e0902aea 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -90,6 +90,7 @@ class CodeGenHybrid : public ExprFunctor, void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index fdf14d9a587a..c7b2b31019ae 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -286,6 +286,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitExpr_(const NotNode* op) override; Doc VisitExpr_(const SelectNode* op) override; Doc VisitExpr_(const BufferLoadNode* op) override; + Doc VisitExpr_(const ProducerLoadNode* op) override; Doc VisitExpr_(const LoadNode* op) override; Doc VisitExpr_(const RampNode* op) override; Doc VisitExpr_(const BroadcastNode* op) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 4d22cbb68a9b..29927379f17d 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -291,6 +291,13 @@ Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) { return doc; } +Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) { + // TODO(tvm-team): consider make a better text format for producer. + Doc doc; + doc << op->producer->GetNameHint() << Print(op->indices); + return doc; +} + Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { Doc doc; doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index) @@ -327,8 +334,6 @@ inline const char* CallType2String(CallNode::CallType t) { return "extern_cpp"; case CallNode::PureExtern: return "pure_extern"; - case CallNode::Halide: - return "halide"; case CallNode::Intrinsic: return "intrin"; case CallNode::PureIntrinsic: @@ -346,8 +351,7 @@ Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { args.push_back(Print(arg)); } doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) - << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) - << ", index=" << op->value_index << ")"; + << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) << ")"; return doc; } diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index f770169e06e7..ecddf5e7b8d0 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -78,21 +78,24 @@ class JacobianMutator : public ExprMutator { PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + auto tensor = Downcast(op->producer); + if (input_.get() && tensor == input_) { + // Tensor(indices) + CHECK_EQ(indices_.size(), op->indices.size()); + PrimExpr condition = const_true(); + for (size_t i = 0; i < input_.ndim(); ++i) { + condition = AndNode::make(condition, EQNode::make(indices_[i], op->indices[i])); + } + return CastNode::make(op->dtype, condition); + } else { + return make_zero(op->dtype); + } + } + PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); - if (op->call_type == CallNode::CallType::Halide) { - if (input_.get() && op->func.same_as(input_->op) && op->value_index == input_->value_index) { - // Tensor(indices) - CHECK_EQ(indices_.size(), op->args.size()); - PrimExpr condition = const_true(); - for (size_t i = 0; i < input_.ndim(); ++i) { - condition = AndNode::make(condition, EQNode::make(indices_[i], op->args[i])); - } - return CastNode::make(op->dtype, condition); - } else { - return make_zero(op->dtype); - } - } else if (op->call_type == CallNode::CallType::PureIntrinsic) { + if (op->call_type == CallNode::CallType::PureIntrinsic) { static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; if (op->name == "exp") { return MulNode::make(Mutate(op->args[0]), expr); @@ -116,8 +119,7 @@ class JacobianMutator : public ExprMutator { FloatImm(type, 1.0), FloatImm(type, -1.0))); } else if (op->name == intrinsic::tvm_if_then_else) { Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; - return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func, - op->value_index); + return CallNode::make(op->dtype, op->name, new_args, op->call_type); } else if (piecewise_const.count(op->name)) { return FloatImm(expr.dtype(), 0.0); } else { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 048285d83360..5ff9e12e53f6 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -154,9 +154,8 @@ Array ComputeOpNode::InputTensors() const { std::unordered_set visited; for (auto& e : body) { tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { - const tir::CallNode* call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); if (!visited.count(t)) { ret.push_back(t); visited.insert(t); @@ -203,9 +202,8 @@ void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* an std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { - auto* call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); if (t->op.defined() && out_dom_map->count(t)) { TensorDom& dom = out_dom_map->at(t); for (size_t i = 0; i < t.ndim(); ++i) { @@ -213,7 +211,7 @@ void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* an // undefined behaviour), so we can intersect the estimated set of the argument with the // range expected by the tensor. However, intersection may result in overly complex // expressions, so we perform a more relaxed form of intersection. - IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map)); + IntSet arg_intset = analyzer->int_set(pload->indices[i], ConvertDomMap(dom_map)); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 7ee583335f18..55996a5afe77 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -86,9 +86,8 @@ Array HybridOpNode::InputTensors() const { std::unordered_set visited; Array curr_inputs; tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { - const tir::CallNode* call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); if (orig_inputs.count(t) && !visited.count(t)) { curr_inputs.push_back(t); visited.insert(t); diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 5b200ac0ce94..eb28a64c3a45 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -207,18 +207,19 @@ class TensorReplacer : public tir::StmtExprMutator { public: explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - PrimExpr VisitExpr_(const tir::CallNode* op) final { - if (op->call_type == tir::CallNode::Halide) { - Tensor t = Downcast(op->func).output(op->value_index); - auto it = vmap_.find(t); - if (it != vmap_.end()) { - PrimExpr ret = tir::CallNode::make(op->dtype, it->second->op->name, op->args, op->call_type, - it->second->op, it->second->value_index); - found = true; - return this->VisitExpr(ret); - } + PrimExpr VisitExpr_(const tir::ProducerLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + CHECK(op != nullptr); + + Tensor t = Downcast(op->producer); + auto it = vmap_.find(t); + if (it != vmap_.end()) { + found = true; + return tir::ProducerLoad(it->second, op->indices); + } else { + return expr; } - return StmtExprMutator::VisitExpr_(op); } // whether it is found. diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index f322e12f8db1..ddc05951909f 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -156,22 +156,19 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage, // Remap the tensor placeholder, index and inline things. class TensorIntrinMatcher final : public StmtExprMutator { public: - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op->call_type == CallNode::Halide) { - Tensor t = Downcast(op->func).output(op->value_index); - auto it = in_remap_.find(t); - if (it != in_remap_.end()) { - const InputEntry& e = it->second; - CHECK_EQ(op->args.size(), e.region.size()); - Array args; - for (size_t i = e.start; i < e.region.size(); ++i) { - args.push_back(op->args[i] - e.region[i]->min); - } - return CallNode::make(op->dtype, e.tensor->op->name, args, op->call_type, e.tensor->op, - e.tensor->value_index); + op = expr.as(); + auto t = Downcast(op->producer); + auto it = in_remap_.find(t); + if (it != in_remap_.end()) { + const InputEntry& e = it->second; + CHECK_EQ(op->indices.size(), e.region.size()); + Array indices; + for (size_t i = e.start; i < e.region.size(); ++i) { + indices.push_back(op->indices[i] - e.region[i]->min); } + return ProducerLoad(e.tensor, indices); } return expr; } diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc index bcde6807ad18..62557ed8573f 100644 --- a/src/te/schedule/graph.cc +++ b/src/te/schedule/graph.cc @@ -40,8 +40,6 @@ struct TensorDimKey { int value_index; int dim; TensorDimKey() {} - TensorDimKey(const tir::CallNode* op, int dim) - : f(op->func), value_index(op->value_index), dim(dim) {} TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {} TensorDimKey(const Tensor& t, size_t dim) : f(t->op), value_index(t->value_index), dim(static_cast(dim)) {} @@ -240,11 +238,11 @@ ReachGraph GetReachGraph(const Array& ops) { reach[TensorDimKey(t, i)] = {}; } auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { - const tir::CallNode* call = n.as(); - if (call != nullptr && call->func.defined()) { - if (!bset.count(call->func.get())) return; - for (size_t i = 0; i < call->args.size(); ++i) { - TensorDimKey dkey(call, static_cast(i)); + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); + if (!bset.count(t->op.get())) return; + for (size_t i = 0; i < pload->indices.size(); ++i) { + TensorDimKey dkey(t, static_cast(i)); auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { const VarNode* v = node.as(); auto it = vmap.find(v); @@ -252,7 +250,7 @@ ReachGraph GetReachGraph(const Array& ops) { reach[it->second].push_back(dkey); } }; - tir::PostOrderVisit(call->args[i], fpush); + tir::PostOrderVisit(pload->indices[i], fpush); } } }; @@ -328,11 +326,11 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { vmap[axis[i]->var.get()] = std::move(keys); } auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) { - const tir::CallNode* call = n.as(); - if (call != nullptr && call->func.defined()) { - for (size_t i = 0; i < call->args.size(); ++i) { - auto it = vmap.find(call->args[i].get()); - TensorDimKey src(call, static_cast(i)); + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); + for (size_t i = 0; i < pload->indices.size(); ++i) { + auto it = vmap.find(pload->indices[i].get()); + TensorDimKey src(t, static_cast(i)); if (it != vmap.end()) { const std::vector& keys = it->second; for (const auto& key : keys) { diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index 8c8f092b7008..8a130e98a5ef 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -42,27 +42,28 @@ class OperationInliner final : public StmtExprMutator { OperationInliner(Operation op, Array args, PrimExpr body) : operation_(op), args_(args), body_(body) {} - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); + auto tensor = Downcast(op->producer); - if (op->func.same_as(operation_)) { - CHECK_EQ(op->value_index, 0); + if (tensor->op.same_as(operation_)) { + CHECK_EQ(tensor->value_index, 0); expr = body_; - CHECK_EQ(args_.size(), op->args.size()); + CHECK_EQ(args_.size(), op->indices.size()); bool has_side_effect = false; - for (size_t i = 0; i < op->args.size(); ++i) { - if (HasSideEffect(op->args[i])) has_side_effect = true; + for (size_t i = 0; i < op->indices.size(); ++i) { + if (HasSideEffect(op->indices[i])) has_side_effect = true; } if (has_side_effect) { for (size_t i = 0; i < args_.size(); ++i) { - expr = LetNode::make(args_[i], op->args[i], expr); + expr = LetNode::make(args_[i], op->indices[i], expr); } } else { Map vmap; for (size_t i = 0; i < args_.size(); ++i) { - vmap.Set(args_[i], op->args[i]); + vmap.Set(args_[i], op->indices[i]); } expr = Substitute(EvaluateNode::make(expr), vmap).as()->value; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 6cc04d984484..10f1ed3326ab 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -245,18 +245,21 @@ class SchedulePostProc : public StmtExprMutator { } } - PrimExpr VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::Halide) { - TensorKey key{op->func, op->value_index}; - auto it = replace_buffer_.find(key); - if (it != replace_buffer_.end()) { - const Tensor& dst = it->second; - PrimExpr ret = CallNode::make(op->dtype, dst->op->name, op->args, op->call_type, dst->op, - dst->value_index); - return this->VisitExpr(ret); - } + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + CHECK(op != nullptr); + + auto tensor = Downcast(op->producer); + TensorKey key{tensor->op, tensor->value_index}; + + auto it = replace_buffer_.find(key); + if (it != replace_buffer_.end()) { + const Tensor& dst = it->second; + return ProducerLoad(dst, op->indices); + } else { + return expr; } - return StmtExprMutator::VisitExpr_(op); } PrimExpr VisitExpr_(const VarNode* op) final { diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 84166d11881b..3b15f89b90db 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -54,6 +54,11 @@ struct Tile { int k{-1}; }; +TensorKey TensorKeyFromProducer(DataProducer producer) { + auto tensor = Downcast(producer); + return TensorKey{tensor->op, tensor->value_index}; +} + std::string simplify_name(std::string input) { auto pos = input.find("."); if (pos != std::string::npos) { @@ -153,27 +158,25 @@ class MMAMatcher : public StmtVisitor { }; // Check whether the storage scope is local - bool check_local_buffer_(const CallNode* op, BufferInfo* bi) { - if (op->call_type == CallNode::Halide) { - auto it = storage_scope_.find(op->func.get()); - if (it == storage_scope_.end()) { - return false; - } - const std::string& strkey = it->second; - if (strkey != "local") { - return false; - } - auto it1 = buf_map_.find(TensorKey{op->func, op->value_index}); - if (it1 == buf_map_.end()) { - return false; - } - *bi = it1->second; - if (bi->released) { - return false; - } - return true; + bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) { + auto tensor = Downcast(op->producer); + auto it = storage_scope_.find(tensor.get()); + if (it == storage_scope_.end()) { + return false; } - return false; + const std::string& strkey = it->second; + if (strkey != "local") { + return false; + } + auto it1 = buf_map_.find(TensorKey{tensor->op, tensor->value_index}); + if (it1 == buf_map_.end()) { + return false; + } + *bi = it1->second; + if (bi->released) { + return false; + } + return true; } // Do the pattern matching @@ -183,7 +186,7 @@ class MMAMatcher : public StmtVisitor { return false; } - auto* load_c = add->a.as(); + auto* load_c = add->a.as(); BufferInfo buffer_c; if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { @@ -196,7 +199,7 @@ class MMAMatcher : public StmtVisitor { } auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); - auto load_a = load_a_expr.as(); + auto load_a = load_a_expr.as(); BufferInfo buffer_a; if (!check_local_buffer_(load_a, &buffer_a) || !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || @@ -206,7 +209,7 @@ class MMAMatcher : public StmtVisitor { } auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); - auto load_b = load_b_expr.as(); + auto load_b = load_b_expr.as(); BufferInfo buffer_b; if (!check_local_buffer_(load_b, &buffer_b) || !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || @@ -470,7 +473,7 @@ class BufferAnalyser : public StmtExprVisitor { strides_.insert(std::make_pair(key.GetName(), strides)); if (frag_reg_.count(bi.name)) { - PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0); + PrimExpr dst = ProducerLoad(Downcast(op->func).output(0), op->args); frag_load_.insert(std::make_pair(op, dst)); auto rel_index = bi.RelIndex(op->args); @@ -525,69 +528,70 @@ class BufferAnalyser : public StmtExprVisitor { } } - const CallNode* value = op->value.as(); - if (value != nullptr && frag_reg_.count(value->name)) { - PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0); + const ProducerLoadNode* value = op->value.as(); + // TODO(tvm-team): string matching is dangerous, consider other means. + if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) { + PrimExpr dst = ProducerLoad(Downcast(op->func).output(0), op->args); frag_store_.insert(std::make_pair(op, dst)); } } - void VisitExpr_(const CallNode* op) final { + void VisitExpr_(const ProducerLoadNode* op) final { StmtExprVisitor::VisitExpr_(op); - if (op->call_type == CallNode::Halide) { - TensorKey key{op->func, op->value_index}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; - const BufferInfo& bi = it->second; - CHECK(!bi.released) << "Read a buffer that is already out of scope"; - if (matrix_abc_.count(op->name)) { - if (bi.shape.size() < 2) { + auto tensor = Downcast(op->producer); + TensorKey key{tensor->op, tensor->value_index}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; + const BufferInfo& bi = it->second; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; + + if (matrix_abc_.count(tensor->op->name)) { + if (bi.shape.size() < 2) { + invalid_ = true; + return; + } + for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { + const IntImmNode* shape = bi.shape[i].as(); + if (shape == nullptr || shape->value % 16 != 0) { invalid_ = true; return; } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } } + } - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = MulNode::make(stride, bi.shape[j]); - } - strides.push_back(stride); + Array strides; + if (bi.strides.size() > 0) { + strides = bi.strides; + } else { + for (size_t i = 1; i < bi.shape.size(); ++i) { + PrimExpr stride = IntImm(DataType::Int(32), 1); + for (size_t j = bi.shape.size() - 1; j >= i; --j) { + stride = MulNode::make(stride, bi.shape[j]); } - strides.push_back(make_const(DataType::Int(32), 1)); + strides.push_back(stride); } - strides_.insert(std::make_pair(key.GetName(), strides)); + strides.push_back(make_const(DataType::Int(32), 1)); + } + strides_.insert(std::make_pair(key.GetName(), strides)); - if (!frag_reg_.count(bi.name)) { - return; - } + if (!frag_reg_.count(bi.name)) { + return; + } - auto rel_index = bi.RelIndex(op->args); - if (op->args.size() < 2) { - invalid_ = true; - return; - } - for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - index_visitor.scaling_factor_ = shape->value; - } - auto index = rel_index[i]; - auto simplified_index = analyzer_.Simplify(index); - index_visitor(simplified_index); + auto rel_index = bi.RelIndex(op->indices); + if (op->indices.size() < 2) { + invalid_ = true; + return; + } + for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { + index_visitor.scaling_factor_ = 16; + if (const IntImmNode* shape = bi.shape[i].as()) { + index_visitor.scaling_factor_ = shape->value; } + auto index = rel_index[i]; + auto simplified_index = analyzer_.Simplify(index); + index_visitor(simplified_index); } } @@ -837,11 +841,11 @@ class TensorCoreIRMutator : public StmtExprMutator { if (it != mma_sync_.end()) { const auto& operands = it->second; PrimExpr a = operands[0]; - auto ca = a.as(); + auto ca = a.as(); PrimExpr b = operands[1]; - auto cb = b.as(); + auto cb = b.as(); PrimExpr c = operands[2]; - auto cc = c.as(); + auto cc = c.as(); ObjectPtr buffer_node_a = make_object(); ObjectPtr buffer_node_b = make_object(); @@ -866,24 +870,24 @@ class TensorCoreIRMutator : public StmtExprMutator { }; auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, TensorKey{cc->func, cc->value_index}, - mma_sync_call, cc->dtype); + return add_buffer_bind_scope_(cc, buffer_node_c, TensorKeyFromProducer(cc->producer), + mma_sync_call); }; auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, TensorKey{cb->func, cb->value_index}, - call_add_c, cb->dtype); + return add_buffer_bind_scope_(cb, buffer_node_b, TensorKeyFromProducer(cb->producer), + call_add_c); }; - return add_buffer_bind_scope_(ca, buffer_node_a, TensorKey{ca->func, ca->value_index}, - call_add_b, ca->dtype); + return add_buffer_bind_scope_(ca, buffer_node_a, TensorKeyFromProducer(ca->producer), + call_add_b); } auto it2 = frag_load_.find(op); if (it2 != frag_load_.end()) { PrimExpr dst = it2->second; if (op->value.as() != nullptr || op->value.as() != nullptr) { - auto call = dst.as(); + auto pload = dst.as(); auto fill_fragment_call = [this, &op](const Buffer& buffer) { return EvaluateNode::make(CallNode::make(DataType::Handle(), intrinsic::tvm_fill_fragment, @@ -893,8 +897,8 @@ class TensorCoreIRMutator : public StmtExprMutator { }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, - fill_fragment_call, call->dtype); + return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer), + fill_fragment_call); } const CallNode* value = op->value.as(); @@ -912,16 +916,17 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr mutated_value = thread_idx_mutator(op->value); PrimExpr src = CallNode::make(value->dtype, "&", {mutated_value}, CallNode::Extern); - auto call = dst.as(); + auto pload = dst.as(); PrimExpr matrix_major; - auto iter2 = matrix_major_.find(simplify_name(call->name)); - CHECK(iter2 != matrix_major_.end()) << "Can not determine matrix major for " << call->name; + auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint())); + CHECK(iter2 != matrix_major_.end()) + << "Can not determine matrix major for " << pload->producer->GetNameHint(); if (iter2->second == "col_major") { matrix_major = StringImmNode::make("col_major"); } else if (iter2->second == "row_major") { matrix_major = StringImmNode::make("row_major"); } else { - LOG(FATAL) << "invalid matrix major for " << call->name; + LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint(); } auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { @@ -933,8 +938,8 @@ class TensorCoreIRMutator : public StmtExprMutator { }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, TensorKey{op->func, op->value_index}, - load_matrix_call, call->dtype); + return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer), + load_matrix_call); } auto it3 = frag_store_.find(op); @@ -953,7 +958,7 @@ class TensorCoreIRMutator : public StmtExprMutator { dst = thread_idx_mutator(dst); dst = CallNode::make(DataType::Handle(), "&", {dst}, CallNode::Extern); - auto call = op->value.as(); + auto pload = op->value.as(); auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { return EvaluateNode::make( @@ -964,8 +969,8 @@ class TensorCoreIRMutator : public StmtExprMutator { }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, - store_matrix_call, call->dtype); + return add_buffer_bind_scope_(pload, buffer_node, TensorKeyFromProducer(pload->producer), + store_matrix_call); } return stmt; @@ -1023,10 +1028,10 @@ class TensorCoreIRMutator : public StmtExprMutator { return tile_size; } - Stmt add_buffer_bind_scope_(const CallNode* call, const ObjectPtr& buffer_node, - const TensorKey& key, - const std::function& call_back, - DataType datatype) { + Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload, + const ObjectPtr& buffer_node, const TensorKey& key, + const std::function& call_back) { + auto tensor = Downcast(pload->producer); auto it = bounds_.find(key); CHECK(it != bounds_.end()); Array min_bound; @@ -1039,7 +1044,7 @@ class TensorCoreIRMutator : public StmtExprMutator { for (size_t i = 0; i < it->second.size() - 2; ++i) { shape.push_back(it->second[i]->extent); } - auto tile_size = get_tile_size_(simplify_name(call->name)); + auto tile_size = get_tile_size_(simplify_name(tensor->op->name)); shape.push_back(tile_size[0]); shape.push_back(tile_size[1]); @@ -1054,18 +1059,18 @@ class TensorCoreIRMutator : public StmtExprMutator { strides.push_back(make_const(DataType::Int(32), 1)); PrimExpr elem_offset = IntImm(DataType::Int(32), 0); - CHECK_EQ(call->args.size(), min_bound.size()); + CHECK_EQ(pload->indices.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( - elem_offset, MulNode::make(strides[i], SubNode::make(call->args[i], min_bound[i]))); + elem_offset, MulNode::make(strides[i], SubNode::make(pload->indices[i], min_bound[i]))); } - auto it2 = matrix_abc_.find(simplify_name(call->name)); - CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << call->name; - buffer_node->data = Var(call->name, DataType::Handle()); - buffer_node->name = call->name; + auto it2 = matrix_abc_.find(simplify_name(tensor->op->name)); + CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name; + buffer_node->data = Var(tensor->op->name, DataType::Handle()); + buffer_node->name = tensor->op->name; buffer_node->scope = "wmma." + it2->second; - buffer_node->dtype = datatype; + buffer_node->dtype = tensor->dtype; buffer_node->strides = strides; buffer_node->shape = shape; buffer_node->data_alignment = 1; @@ -1077,17 +1082,17 @@ class TensorCoreIRMutator : public StmtExprMutator { tensor_node->value_index = key.value_index; tensor_node->op = Downcast(key.f); tensor_node->shape = shape; - tensor_node->dtype = datatype; - Tensor tensor(tensor_node); + tensor_node->dtype = tensor->dtype; + Tensor tensor_bind(tensor_node); Array args; - for (size_t i = 0; i < call->args.size(); ++i) { - args.push_back(call->args[i]); + for (size_t i = 0; i < pload->indices.size(); ++i) { + args.push_back(pload->indices[i]); args.push_back(shape[i]); } auto tuple = CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); - Array node = {buffer, tensor}; + Array node = {buffer, tensor_bind}; return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer)); } diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 57e5528870ea..96df24dc6c7a 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -114,17 +114,12 @@ class TensorToBufferMapper : public StmtExprMutator { return BufferStore(buffer, op->value, op->args); } - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { auto ret = StmtExprMutator::VisitExpr_(op); - op = ret.as(); - - if (op->call_type == CallNode::Halide) { - Tensor tensor = Downcast(op->func).output(op->value_index); - Buffer buffer = GetBuffer(tensor); - return tir::BufferLoad(buffer, op->args); - } else { - return ret; - } + op = ret.as(); + Tensor tensor = Downcast(op->producer); + Buffer buffer = GetBuffer(tensor); + return tir::BufferLoad(buffer, op->indices); } private: diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 606797da5e87..1a31a85a14f8 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -47,14 +47,16 @@ PrimExpr Tensor::operator()(Array indices) const { } PrimExpr Tensor::operator()(Array indices) const { - using tir::CallNode; if (ndim() != 0) { CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" << "ndim = " << ndim() << ", indices.size=" << indices.size(); } - auto n = CallNode::make((*this)->dtype, (*this)->op->name, indices, CallNode::Halide, (*this)->op, - (*this)->value_index); - return n; + + return ProducerLoad((*this), indices); +} + +String TensorNode::GetNameHint() const { + return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } Tensor Operation::output(size_t i) const { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 8b9a8e2f7812..e1d8b3f00647 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -270,25 +270,17 @@ bool CallNode::is_vectorizable() const { return false; } -PrimExpr CallNode::make(DataType dtype, std::string name, Array args, CallType call_type, - FunctionRef func, int value_index) { +PrimExpr CallNode::make(DataType dtype, std::string name, Array args, + CallType call_type) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } - if (call_type == Halide) { - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].dtype().is_int()); - } - } - ObjectPtr node = make_object(); node->dtype = dtype; node->name = std::move(name); node->args = std::move(args); node->call_type = call_type; - node->func = std::move(func); - node->value_index = value_index; return PrimExpr(node); } @@ -403,6 +395,21 @@ TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array indices) { + ObjectPtr node = make_object(); + node->dtype = producer->GetDataType(); + node->producer = std::move(producer); + node->indices = std::move(indices); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.ProducerLoad") + .set_body_typed([](DataProducer producer, Array indices) { + return ProducerLoad(producer, indices); + }); + +TVM_REGISTER_NODE_TYPE(ProducerLoadNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -638,6 +645,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "]"; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } + } + p->stream << "]"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -758,8 +778,7 @@ TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) { }); TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, std::string name, Array args, int call_type, - FunctionRef func, int value_index) { + .set_body_typed([](DataType type, std::string name, Array args, int call_type) { Array prim_expr_args; for (const auto& it : args) { CHECK(it->IsInstance() || it->IsInstance()); @@ -769,8 +788,7 @@ TVM_REGISTER_GLOBAL("tir.Call") prim_expr_args.push_back(Downcast(it)); } } - return CallNode::make(type, name, prim_expr_args, static_cast(call_type), - func, value_index); + return CallNode::make(type, name, prim_expr_args, static_cast(call_type)); }); } // namespace tir diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 7f30abea3613..98d61a0a8987 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -41,6 +41,10 @@ void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void ExprVisitor::VisitExpr_(const LetNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->body); @@ -135,6 +139,16 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { } } +PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array indices = MutateArray(op->indices, fmutate); + if (indices.same_as(op->indices)) { + return GetRef(op); + } else { + return ProducerLoad(op->producer, indices); + } +} + PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); @@ -152,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, args, op->call_type, op->func, op->value_index); + return CallNode::make(op->dtype, op->name, args, op->call_type); } } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 646e00855c2b..44d0e8df951c 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -336,10 +336,9 @@ class StorageFlattener : public StmtExprMutator { return stmt; } - PrimExpr VisitExpr_(const CallNode* op) final { - CHECK(op->call_type != CallNode::Halide) << "Cannot handle Halide calls " - << " please run SchedulePostProcToPrimFunc first"; - return StmtExprMutator::VisitExpr_(op); + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc."; + return PrimExpr(); } Stmt VisitStmt_(const ProvideNode* op) final { diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 9e553cb12ceb..743bac49b354 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -209,8 +209,7 @@ class Vectorizer : public StmtExprMutator { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type, - op->func, op->value_index); + return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type); } } // Call @@ -232,8 +231,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func, - op->value_index); + return CallNode::make(op->dtype, op->name, new_args, op->call_type); } } else { int lane = 0; @@ -242,8 +240,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type, - op->func, op->value_index); + return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type); } } } diff --git a/tests/lint/git-clang-format.sh b/tests/lint/git-clang-format.sh index c76d4f47a02b..dc4450ab6c69 100755 --- a/tests/lint/git-clang-format.sh +++ b/tests/lint/git-clang-format.sh @@ -19,13 +19,21 @@ set -e set -u set -o pipefail -if [ "$#" -lt 1 ]; then - echo "Usage: tests/lint/git-clang-format.sh " +if [[ "$1" == "-i" ]]; then + INPLACE_FORMAT=1 + shift 1 +else + INPLACE_FORMAT=0 +fi + +if [[ "$#" -lt 1 ]]; then + echo "Usage: tests/lint/git-clang-format.sh [-i] " echo "" echo "Run clang-format on files that changed since " echo "Examples:" echo "- Compare last one commit: tests/lint/git-clang-format.sh HEAD~1" echo "- Compare against upstream/master: tests/lint/git-clang-format.sh upsstream/master" + echo "You can also add -i option to do inplace format" exit 1 fi @@ -50,6 +58,12 @@ fi # Print out specific version ${CLANG_FORMAT} --version +if [[ ${INPLACE_FORMAT} -eq 1 ]]; then + echo "Running inplace git-clang-format against" $1 + git-${CLANG_FORMAT} --extensions h,mm,c,cc --binary=${CLANG_FORMAT} $1 + exit 0 +fi + echo "Running git-clang-format against" $1 git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc --binary=${CLANG_FORMAT} $1 1> /tmp/$$.clang-format.txt echo "---------clang-format log----------" diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 179152273c00..525cd6c30736 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -202,7 +202,7 @@ def test_reduce_combiner_simplify(): assert tvm.ir.structural_equal(lhs, rhs) # Test that components with side effects are not removed - side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0) + side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic) ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 2e53bfd569dc..34db08f40c2f 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -34,7 +34,7 @@ def test_llvm_intrin(): ] ib.emit(tvm.tir.Evaluate( tvm.tir.Call( - "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0))) + "int32", "prefetch", args, tvm.tir.Call.Intrinsic))) body = ib.get() mod = tvm.IRModule.from_expr( diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index ea4179d7ca3f..c6f28ad9a4b5 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -138,9 +138,9 @@ def test_outer_product(): assert jbody.args[1].name == 'j' assert isinstance(jbody.value, tvm.tir.Mul) mul = jbody.value - assert isinstance(mul.a, tvm.tir.Call) - assert mul.a.name == 'a' - assert mul.b.name == 'b' + assert isinstance(mul.a, tvm.tir.ProducerLoad) + assert mul.a.producer.name == 'a' + assert mul.b.producer.name == 'b' func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101}) temp = util.tempdir() @@ -209,29 +209,29 @@ def fanout(n, a): assert jbody.func.name == 'sigma' assert isinstance(jbody.value, tvm.tir.Add) value = jbody.value - assert isinstance(value.a, tvm.tir.Call) - assert value.a.name == 'sigma' - assert len(value.a.args) == 1 - assert value.a.args[0].value == 0 - assert value.b.name == 'a' - assert len(value.b.args) == 1 - assert tvm.ir.structural_equal(value.b.args[0], ir.loop_var + jloop.loop_var) + assert isinstance(value.a, tvm.tir.ProducerLoad) + assert value.a.producer.name == 'sigma' + assert len(value.a.indices) == 1 + assert value.a.indices[0].value == 0 + assert value.b.producer.name == 'a' + assert len(value.b.indices) == 1 + assert tvm.ir.structural_equal(value.b.indices[0], ir.loop_var + jloop.loop_var) divide= rbody[2] assert isinstance(divide, tvm.tir.Provide) assert len(divide.args) == 1 assert divide.args[0].value == 0 value = divide.value assert isinstance(value, tvm.tir.Mul) - assert value.a.name == 'sigma' - assert len(value.a.args) == 1 - assert value.a.args[0].value == 0 + assert value.a.producer.name == 'sigma' + assert len(value.a.indices) == 1 + assert value.a.indices[0].value == 0 assert abs(value.b.value - (1 / 3.0)) < 1e-5 write = rbody[3] assert isinstance(write, tvm.tir.Provide) assert write.func.name == 'b' - assert write.value.name == 'sigma' - assert len(write.value.args) == 1 - assert write.value.args[0].value == 0 + assert write.value.producer.name == 'sigma' + assert len(write.value.indices) == 1 + assert write.value.indices[0].value == 0 func, ins, outs = run_and_check(fanout, [n, a], {n: 10}) run_and_check(func, ins, {n: 10}, outs=outs) diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 4af93fd58c0a..86f87348ec53 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -112,14 +112,12 @@ def test_expr_constructor(): assert x.vectors[0] == a assert x.indices[0].value == 0 - x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern, None, 0) + x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern) assert isinstance(x, tvm.tir.Call) assert x.dtype == "float32" assert x.name == "xyz" assert x.args[0] == a assert x.call_type == tvm.tir.Call.Extern - assert x.func == None - assert x.value_index == 0 v = te.var("aa") x = tvm.tir.Let(v, 1, v) diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 36c9c764f6ab..e6322592edaf 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -171,18 +171,18 @@ def test_all(): def test_bitwise(): x = te.var('x') y = te.var('y') - assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin", index=0)' - assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin")' assert str(10 % x) == 'floormod(10, x: int32)' - assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin")' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -239,10 +239,10 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin", index=0)' + assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin")' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin", index=0)' + assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' z = te.var('z', 'int32') assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 7fd2593bd365..29a330319622 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -22,7 +22,7 @@ def test_for(): def device_context(dev_id): ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) return tvm.tir.Call( - "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic, None, 0) + "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) ib = tvm.tir.ir_builder.create() n = te.var("n") diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index b52fc916ff30..f713bb216808 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -187,7 +187,7 @@ def argsort_ir(data_buf, out_index_buf): index_out[offset + 1] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() @@ -248,7 +248,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): p_out[base_idx + i] = True ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index a1c70c44958d..ddae2bd96135 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -117,7 +117,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): tvm.tir.generic.cast(tid, indices_out.dtype) ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod @@ -145,7 +145,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): indices_out[offset + axis_mul_after] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() @@ -237,7 +237,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): output[offset + axis_mul_after] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index bbaac2ce1797..e68f098ba53f 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -80,7 +80,7 @@ def __init__(self, env): ctx = tvm.tir.call_extern("handle", "VTATLSCommandHandle") self.command_handle = tvm.tir.Call( "handle", "tvm_thread_context", [ctx], - tvm.tir.Call.Intrinsic, None, 0) + tvm.tir.Call.Intrinsic) self.DEBUG_NO_SYNC = False env._dev_ctx = self self.gemm = intrin.gemm(env, env.mock_mode) diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 1d54bb01bb49..37b4e0e3e7c4 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -297,7 +297,7 @@ def _do_fold(stmt): if _match_pragma(stmt, "coproc_sync"): success[0] = True sync = tvm.tir.Call( - "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0) + "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic) return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) if _match_pragma(stmt, "trim_loop"): op = stmt.body