diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 13f39317dbe4c..89074d83e1d64 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -59,8 +59,11 @@ class TVM_DLL OperationNode : public Object { std::string name; /*! \brief optional tag of the operation */ std::string tag; - /*! \brief additional attributes of the operation*/ + /*! \brief additional attributes of the operation */ Map attrs; + /*! \brief output tensors */ + Array outputs; + // virtual destructor. virtual ~OperationNode() {} /*! \return number of outputs */ @@ -473,7 +476,7 @@ class HybridOpNode : public OperationNode { /*! \brief The input tensors */ Array inputs; /*! \brief Symbolic placeholder representation of outputs */ - Array outputs; + Array symbolic_outputs; /*! \brief The axis of iterations */ Array axis; /*! \brief the statement that generates the computation. This is @@ -509,6 +512,7 @@ class HybridOpNode : public OperationNode { v->Visit("attrs", &attrs); v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); + v->Visit("symbolic_outputs", &symbolic_outputs); v->Visit("axis", &axis); v->Visit("body", &body); } diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 442aeb6f1027c..655748783dd10 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -312,7 +312,7 @@ def visit_Assign(self, node): "You should bind a pure name to the tensors", ) self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i)) - rmap[rhs.outputs[i].op] = rhs.output(i) + rmap[rhs.symbolic_outputs[i].op] = rhs.output(i) return utils.replace_io(rhs.body, rmap) _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index fc85d830c91a9..b5609375b8aa3 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -86,11 +86,12 @@ def __eq__(self, other): if isinstance(other, _expr.ExprOp): return _expr.EqualOp(self, other) return False + if self.same_as(other): + return True if self.ndim == 0 and other.ndim == 0: raise ValueError( "Equal == comparison among rank-0 tensor is ambiguous, " - "use Tensor.equal for content expression equvalence, " - "use Tensor.same_as for exact reference comparison" + "use Tensor.equal for content expression equvalence." ) return _ffi_api.TensorEqual(self, other) diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index b602efcfc28b6..84869ccd77755 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -90,6 +90,7 @@ Operation ExternOpNode::ReplaceInputs(const Operation& self, ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = ReplaceTensor(this->body, rmap); + n->outputs = Array(); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; if (rmap.count(t)) { diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 5d2412abb3d25..5bb6458229964 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -49,13 +49,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(HybridOpNode); -int HybridOpNode::num_outputs() const { return static_cast(outputs.size()); } +int HybridOpNode::num_outputs() const { return static_cast(symbolic_outputs.size()); } Array HybridOpNode::root_iter_vars() const { return this->axis; } -DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } +DataType HybridOpNode::output_dtype(size_t i) const { return symbolic_outputs[i]->dtype; } -Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } +Array HybridOpNode::output_shape(size_t i) const { return symbolic_outputs[i]->shape; } HybridOp::HybridOp(std::string name, std::string tag, Map attrs, Array inputs, Array outputs, Stmt body) { @@ -67,7 +67,7 @@ HybridOp::HybridOp(std::string name, std::string tag, Map att n->tag = std::move(tag); n->attrs = std::move(attrs); n->inputs = std::move(inputs); - n->outputs = std::move(outputs); + n->symbolic_outputs = std::move(outputs); n->axis = te::GatherLoopVars(body); n->body = std::move(body); data_ = std::move(n); @@ -104,6 +104,7 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self, ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = te::ReplaceTensor(this->body, rmap); + n->outputs = Array(); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; if (rmap.count(t)) { @@ -166,7 +167,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage, Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { - rmap[outputs[i]] = stage->op.output(i); + rmap[symbolic_outputs[i]] = stage->op.output(i); } auto n = make_object(*this); /* This is a story little bit complicated. diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 39689bd9654ad..a29cc66017957 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -40,6 +40,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(ScanOpNode); int ScanOpNode::num_outputs() const { return static_cast(update.size()); } + Array ScanOpNode::root_iter_vars() const { Array ret{scan_axis}; for (IterVar iv : spatial_axis_) { @@ -143,6 +144,7 @@ Operation ScanOpNode::ReplaceInputs(const Operation& self, const std::unordered_map& rmap) const { ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); + n->outputs = Array(); for (size_t i = 0; i < n->init.size(); ++i) { if (rmap.count(n->init[i])) { n->init.Set(i, rmap.at(n->init[i])); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 262e5a2b97f44..432df75f03b92 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -85,6 +85,7 @@ Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, const std::unordered_map& rmap) const { ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); + n->outputs = Array(); auto intrin = make_object(*(this->intrin.operator->())); intrin->body = ReplaceTensor(this->intrin->body, rmap); if (intrin->reduce_init.defined()) { diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index fae826b926e3e..e09cdfe146e16 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -616,7 +616,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); ICHECK(hybrid); Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->outputs, new_hybrid_body[i]); + hybrid->symbolic_outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 1568df4670af8..99e02ccaf943a 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -298,9 +298,15 @@ class SchedulePostProc : public StmtExprMutator { private: void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { - replace_buffer_[src] = dst; - replace_realize_[src] = repl_realize; - replace_op_[src->op.get()] = repl_op; + if (!src.same_as(dst)) { + replace_buffer_[src] = dst; + } + if (!src.same_as(repl_realize)) { + replace_realize_[src] = repl_realize; + } + if (!src->op.same_as(repl_op)) { + replace_op_[src->op.get()] = repl_op; + } } // The thread extent scope. std::unordered_map thread_extent_scope_; diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 1d75761216f1e..1f43714ea1072 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -78,13 +78,22 @@ String TensorNode::GetNameHint() const { return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } -Tensor Operation::output(size_t i) const { - auto node = make_object(); - node->op = *this; - node->value_index = i; - node->dtype = (*this)->output_dtype(i); - node->shape = (*this)->output_shape(i); - return Tensor(node); +Tensor Operation::output(size_t n) const { + // cache the output tensors if empty + if ((*this)->outputs.empty()) { + auto* ptr = static_cast(get_mutable()); + size_t num = static_cast((*this)->num_outputs()); + for (size_t i = 0; i < num; ++i) { + auto node = make_object(); + node->op = *this; + node->value_index = i; + node->dtype = (*this)->output_dtype(i); + node->shape = (*this)->output_shape(i); + ptr->outputs.push_back(Tensor(node)); + } + } + ICHECK_LT(n, (*this)->outputs.size()); + return (*this)->outputs[n]; } Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 6958888e9bb69..6b09410af6c4c 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -37,6 +37,8 @@ def test_tensor(): assert T.op.output(0).__hash__() == T.__hash__() d = {T.op.output(0): 1} assert d[T] == 1 + assert T == T.op.output(0) + assert T.same_as(T.op.output(0)) assert T[0][0][0].astype("float16").dtype == "float16" @@ -49,6 +51,8 @@ def test_rank_zero(): print(T) print(T.op.body) assert tuple(T.shape) == () + assert T == T.op.output(0) + assert T.same_as(T.op.output(0)) def test_conv1d():