Skip to content

Commit

Permalink
[TE][Fix] Comparison of the output tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
leeexyz committed Feb 14, 2022
1 parent bb60ee9 commit 6ba46de
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 12 deletions.
7 changes: 4 additions & 3 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ class TVM_DLL OperationNode : public Object {
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
/*! \brief additional attributes of the operation */
Map<String, ObjectRef> attrs;
/*! \brief output tensors */
Array<Tensor> outputs;

// virtual destructor.
virtual ~OperationNode() {}
/*! \return number of outputs */
Expand Down Expand Up @@ -472,8 +475,6 @@ class HybridOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief The axis of iterations */
Array<IterVar> axis;
/*! \brief the statement that generates the computation. This is
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/te/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def __eq__(self, other):
if isinstance(other, _expr.ExprOp):
return _expr.EqualOp(self, other)
return False
if self.same_as(other):
return True
if self.ndim == 0 and other.ndim == 0:
raise ValueError(
"Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison"
"use Tensor.equal for content expression equvalence."
)
return _ffi_api.TensorEqual(self, other)

Expand Down
1 change: 1 addition & 0 deletions src/te/operation/scan_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ScanOpNode);

int ScanOpNode::num_outputs() const { return static_cast<int>(update.size()); }

Array<IterVar> ScanOpNode::root_iter_vars() const {
Array<IterVar> ret{scan_axis};
for (IterVar iv : spatial_axis_) {
Expand Down
22 changes: 15 additions & 7 deletions src/te/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,21 @@ String TensorNode::GetNameHint() const {
return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index));
}

Tensor Operation::output(size_t i) const {
auto node = make_object<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
Tensor Operation::output(size_t n) const {
if ((*this)->outputs.empty()) {
auto* ptr = static_cast<OperationNode*>(get_mutable());
size_t num = static_cast<size_t>((*this)->num_outputs());
for (size_t i = 0; i < num; ++i) {
auto node = make_object<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
ptr->outputs.push_back(Tensor(node));
}
}
ICHECK_LT(n, (*this)->outputs.size());
return (*this)->outputs[n];
}

Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_te_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_tensor():
assert T.op.output(0).__hash__() == T.__hash__()
d = {T.op.output(0): 1}
assert d[T] == 1
assert T == T.op.output(0)
assert T.same_as(T.op.output(0))
assert T[0][0][0].astype("float16").dtype == "float16"


Expand All @@ -49,6 +51,8 @@ def test_rank_zero():
print(T)
print(T.op.body)
assert tuple(T.shape) == ()
assert T == T.op.output(0)
assert T.same_as(T.op.output(0))


def test_conv1d():
Expand Down

0 comments on commit 6ba46de

Please sign in to comment.