Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1100,42 +1100,6 @@ class Reduce : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
};

/*! \brief Any shape. */
class AnyNode : public PrimExprNode {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}

bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {}

/*! \brief Convert to var. */
Var ToVar() const { return Var("any_dim", DataType::Int(32)); }

/*! \brief Convert to SizeVar. */
SizeVar ToSizeVar() const { return SizeVar("any_dim", DataType::Int(32)); }

static constexpr const char* _type_key = "tir.Any";
TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
};

/*!
* \brief Managed reference to AnyNode
* \sa AnyNode
*/
class Any : public PrimExpr {
public:
TVM_DLL Any(Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
};

/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/tir/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
}
Expand Down Expand Up @@ -192,7 +191,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
vtable.Finalize();
return vtable;
}
Expand Down Expand Up @@ -244,7 +242,6 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
void VisitExpr_(const IntImmNode* op) override;
void VisitExpr_(const FloatImmNode* op) override;
void VisitExpr_(const StringImmNode* op) override;
void VisitExpr_(const AnyNode* op) override;
};

/*!
Expand Down Expand Up @@ -290,7 +287,6 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
PrimExpr VisitExpr_(const IntImmNode* op) override;
PrimExpr VisitExpr_(const FloatImmNode* op) override;
PrimExpr VisitExpr_(const StringImmNode* op) override;
PrimExpr VisitExpr_(const AnyNode* op) override;
};

} // namespace tir
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/topi/detail/strided_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape,
const Array<Integer>& axes, std::string slice_mode,
const Array<PrimExpr>& begin_canonicalized,
bool use_any = false) {
ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any";
const size_t src_tensor_dim = ishape.size();
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
Expand All @@ -140,8 +141,6 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape,
ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size)));
} else if (use_any) {
out_shape.Set(axes[i].IntValue(), tvm::tir::Any());
} else {
out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim", out_shape[i]->dtype));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def _check(sinfo):
return False
unknown_dim = 0
for s in sinfo.shape.values:
if isinstance(s, (tvm.tir.Var, tvm.tir.Any)):
unknown_dim += 1
elif isinstance(s, tvm.tir.IntImm) and s < 0:
if isinstance(s, tvm.tir.IntImm) and s < 0:
unknown_dim += 1
return unknown_dim <= 1

Expand Down
6 changes: 1 addition & 5 deletions python/tvm/contrib/msc/plugin/codegen/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,11 +642,7 @@ class TVMUtils {
Array<tvm::PrimExpr> tvm_shape;
for (size_t i = 0; i < meta_shape.ndim(); i++) {
auto dim = meta_shape.DimAt(i);
if (dim == -1) {
tvm_shape.push_back(tir::Any());
} else {
tvm_shape.push_back(Integer(dim));
}
tvm_shape.push_back(Integer(dim));
}
return tvm_shape;
}
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
from .expr import Call, CallEffectKind, Let, IterVar, CommReducer

from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
from .stmt import (
Expand Down
12 changes: 0 additions & 12 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,15 +1308,3 @@ def __init__(
self, var: Var, value: PrimExpr, body: PrimExpr, span: Optional[Span] = None
) -> None:
self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) # type: ignore


@tvm._ffi.register_object("tir.Any")
class Any(PrimExprWithOp):
"""Any node.

span : Optional[Span]
The location of this expression in the source code.
"""

def __init__(self, span: Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Any, span) # type: ignore
5 changes: 1 addition & 4 deletions python/tvm/topi/nn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=
ana = tvm.arith.Analyzer()
dshape = []
for dim in data.shape:
if isinstance(dim, tvm.tir.Any):
dshape.append(tvm.te.size_var("dim"))
else:
dshape.append(dim)
dshape.append(dim)
out_shape = tuple(ana.simplify(dshape[i] + pad_before[i] + pad_after[i]) for i in range(n))
pad_value = (
pad_value
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import tvm
from tvm import te
from tvm.tir import Any, SizeVar, bijective_layout, layout
from tvm.tir import SizeVar, bijective_layout, layout

from . import cpp, tag

Expand Down Expand Up @@ -187,7 +187,7 @@ def get_const_tuple(in_tuple):
ret = []
ana = None
for elem in in_tuple:
if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
if isinstance(elem, tvm.tir.Var):
ret.append(elem)
elif not isinstance(elem, (tvm.tir.IntImm, int)):
ana = tvm.arith.Analyzer() if ana is None else ana
Expand Down Expand Up @@ -525,4 +525,4 @@ def is_target(names):

def is_dynamic_shape(shape):
"""Checks if any part of a shape is dynamic"""
return any([isinstance(x, (Any, SizeVar)) for x in shape])
return any([isinstance(x, SizeVar) for x in shape])
6 changes: 1 addition & 5 deletions src/contrib/msc/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,7 @@ class ArrayUtils {
TVM_DLL static const Array<T> Cast(const Array<PrimExpr>& src_array) {
Array<T> new_array;
for (const auto& s : src_array) {
if (s->IsInstance<tvm::tir::AnyNode>()) {
new_array.push_back(T(-1));
} else {
new_array.push_back(Downcast<T>(s));
}
new_array.push_back(Downcast<T>(s));
}
return new_array;
}
Expand Down
3 changes: 0 additions & 3 deletions src/script/printer/legacy_repr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
(*p) << ")";
});

TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
.set_dispatch<AnyNode>([](const ObjectRef& node, ReprLegacyPrinter* p) { (*p) << "?"; });

TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
.set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
auto* op = static_cast<const BufferLoadNode*>(node.get());
Expand Down
6 changes: 0 additions & 6 deletions src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return prefix->Call(args);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Any>("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc {
return TIR(d, "Any")->Call({});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Reduce>("", [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc combiner = d->AsDoc<ExprDoc>(r->combiner, p->Attr("combiner"));
Expand Down Expand Up @@ -415,7 +410,6 @@ TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::IndexMapNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR);

} // namespace printer
Expand Down
3 changes: 0 additions & 3 deletions src/tir/analysis/deep_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
}
if (lhs.as<AnyNode>()) {
return false;
}
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt);
}

Expand Down
12 changes: 0 additions & 12 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,18 +747,6 @@ TVM_REGISTER_GLOBAL("tir.Reduce")

TVM_REGISTER_NODE_TYPE(ReduceNode);

// Any
Any::Any(Span span) {
auto n = make_object<AnyNode>();
n->dtype = DataType::Int(32);
n->span = std::move(span);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([](Span span) { return Any(span); });

TVM_REGISTER_NODE_TYPE(AnyNode);

// BufferLoad
void BufferLoadNode::LegalizeDType() {
for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
Expand Down
4 changes: 0 additions & 4 deletions src/tir/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
this->VisitExpr_(static_cast<const VarNode*>(op));
}

void ExprVisitor::VisitExpr_(const AnyNode* op) {}

void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
Expand Down Expand Up @@ -119,8 +117,6 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
return this->VisitExpr_(static_cast<const VarNode*>(op));
}

PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef<PrimExpr>(op); }

PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = op->indices.Map(fmutate);
Expand Down
2 changes: 0 additions & 2 deletions src/tir/ir/tir_visitor_with_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, ObjectPath path) {
VisitExpr_(static_cast<const VarNode*>(op), path);
}

void TIRVisitorWithPath::VisitExpr_(const AnyNode* op, ObjectPath path) {}

void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, ObjectPath path) {
Visit(op->buffer, path->Attr("buffer"));
Visit(op->indices, path->Attr("indices"));
Expand Down
1 change: 0 additions & 1 deletion src/tir/ir/tir_visitor_with_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
void VisitExpr_(const IntImmNode* op, ObjectPath path) override;
void VisitExpr_(const FloatImmNode* op, ObjectPath path) override;
void VisitExpr_(const StringImmNode* op, ObjectPath path) override;
void VisitExpr_(const AnyNode* op, ObjectPath path) override;

// Utility to call EnterDef/ExitDef. Used in the implementation of
// WithDef.
Expand Down
3 changes: 0 additions & 3 deletions tests/python/arith/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ class TestAddIndex(BaseCompare):

class TestSubIndex(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")
a, b = tvm.tir.Any(), tvm.tir.Any()

test_case = tvm.testing.parameter(
TestCase(x + y - y, x),
Expand All @@ -437,8 +436,6 @@ class TestSubIndex(BaseCompare):
TestCase(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)),
# mul co-efficient foldng
TestCase(x - x, 0),
TestCase(a - a, 0),
TestCase(a - b, a - b),
TestCase(x * y - x, x * (y + (-1))),
TestCase(x * y - 10 * x, x * (y + (-10))),
TestCase(y * x - x * z, x * (y - z)),
Expand Down
10 changes: 0 additions & 10 deletions tests/python/tvmscript/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,6 @@ def test_comm_reducer():
)


def test_any():
obj = tir.Any()
_assert_print(
obj,
"""
T.Any()
""",
)


def test_int_imm():
obj = T.int16(1)
_assert_print(
Expand Down
Loading