diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index a157516f5342..06ee75070ce7 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1100,42 +1100,6 @@ class Reduce : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode); }; -/*! \brief Any shape. */ -class AnyNode : public PrimExprNode { - public: - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("span", &span); - } - - bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const {} - - /*! \brief Convert to var. */ - Var ToVar() const { return Var("any_dim", DataType::Int(32)); } - - /*! \brief Convert to SizeVar. */ - SizeVar ToSizeVar() const { return SizeVar("any_dim", DataType::Int(32)); } - - static constexpr const char* _type_key = "tir.Any"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); -}; - -/*! - * \brief Managed reference to AnyNode - * \sa AnyNode - */ -class Any : public PrimExpr { - public: - TVM_DLL Any(Span span = Span()); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode); -}; - /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index 7a9cf91a65af..dfa9d7e1e346 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -149,7 +149,6 @@ class ExprFunctor { virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); } @@ -192,7 +191,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); - IR_EXPR_FUNCTOR_DISPATCH(AnyNode); vtable.Finalize(); return vtable; } @@ -244,7 +242,6 @@ class TVM_DLL ExprVisitor : public ExprFunctor { void VisitExpr_(const IntImmNode* op) override; void VisitExpr_(const FloatImmNode* op) override; void VisitExpr_(const StringImmNode* op) override; - void VisitExpr_(const AnyNode* op) override; }; /*! @@ -290,7 +287,6 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const IntImmNode* op) override; PrimExpr VisitExpr_(const FloatImmNode* op) override; PrimExpr VisitExpr_(const StringImmNode* op) override; - PrimExpr VisitExpr_(const AnyNode* op) override; }; } // namespace tir diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index a69f8f99ae38..f2e021ed98bc 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -122,6 +122,7 @@ inline Array StridedSliceOutputShape(const Array& ishape, const Array& axes, std::string slice_mode, const Array& begin_canonicalized, bool use_any = false) { + ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any"; const size_t src_tensor_dim = ishape.size(); Array out_shape; for (size_t i = 0; i < src_tensor_dim; ++i) { @@ -140,8 +141,6 @@ inline Array StridedSliceOutputShape(const Array& ishape, ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size))); - } else if (use_any) { - out_shape.Set(axes[i].IntValue(), tvm::tir::Any()); } else { out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim", out_shape[i]->dtype)); } diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py index 17aee690e370..cd12b336ab5a 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py @@ -143,9 +143,7 @@ def _check(sinfo): return False unknown_dim = 0 for s in sinfo.shape.values: - if isinstance(s, (tvm.tir.Var, tvm.tir.Any)): - unknown_dim += 1 - elif isinstance(s, tvm.tir.IntImm) and s < 0: + if isinstance(s, tvm.tir.IntImm) and s < 0: unknown_dim += 1 return unknown_dim <= 1 diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py b/python/tvm/contrib/msc/plugin/codegen/sources.py index 1ea95a958f7a..3806dabd0e1e 100644 --- a/python/tvm/contrib/msc/plugin/codegen/sources.py +++ b/python/tvm/contrib/msc/plugin/codegen/sources.py @@ -642,11 +642,7 @@ class TVMUtils { Array tvm_shape; for (size_t i = 0; i < meta_shape.ndim(); i++) { auto dim = meta_shape.DimAt(i); - if (dim == -1) { - tvm_shape.push_back(tir::Any()); - } else { - tvm_shape.push_back(Integer(dim)); - } + tvm_shape.push_back(Integer(dim)); } return tvm_shape; } diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 568e05351aad..4f56ec3c15bc 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -26,7 +26,7 @@ from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle -from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any +from .expr import Call, CallEffectKind, Let, IterVar, CommReducer from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While from .stmt import ( diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 37976394f831..6cd4302133c5 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1308,15 +1308,3 @@ def __init__( self, var: Var, value: PrimExpr, body: PrimExpr, span: Optional[Span] = None ) -> None: self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) # type: ignore - - -@tvm._ffi.register_object("tir.Any") -class Any(PrimExprWithOp): - """Any node. - - span : Optional[Span] - The location of this expression in the source code. - """ - - def __init__(self, span: Optional[Span] = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.Any, span) # type: ignore diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py index 7bd2b7632b9c..8833ef38d694 100644 --- a/python/tvm/topi/nn/pad.py +++ b/python/tvm/topi/nn/pad.py @@ -59,10 +59,7 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs= ana = tvm.arith.Analyzer() dshape = [] for dim in data.shape: - if isinstance(dim, tvm.tir.Any): - dshape.append(tvm.te.size_var("dim")) - else: - dshape.append(dim) + dshape.append(dim) out_shape = tuple(ana.simplify(dshape[i] + pad_before[i] + pad_after[i]) for i in range(n)) pad_value = ( pad_value diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 71599ad74a62..3a0441ef84af 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -23,7 +23,7 @@ import numpy as np import tvm from tvm import te -from tvm.tir import Any, SizeVar, bijective_layout, layout +from tvm.tir import SizeVar, bijective_layout, layout from . import cpp, tag @@ -187,7 +187,7 @@ def get_const_tuple(in_tuple): ret = [] ana = None for elem in in_tuple: - if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)): + if isinstance(elem, tvm.tir.Var): ret.append(elem) elif not isinstance(elem, (tvm.tir.IntImm, int)): ana = tvm.arith.Analyzer() if ana is None else ana @@ -525,4 +525,4 @@ def is_target(names): def is_dynamic_shape(shape): """Checks if any part of a shape is dynamic""" - return any([isinstance(x, (Any, SizeVar)) for x in shape]) + return any([isinstance(x, SizeVar) for x in shape]) diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 41566883036c..84e3c667410c 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -221,11 +221,7 @@ class ArrayUtils { TVM_DLL static const Array Cast(const Array& src_array) { Array new_array; for (const auto& s : src_array) { - if (s->IsInstance()) { - new_array.push_back(T(-1)); - } else { - new_array.push_back(Downcast(s)); - } + new_array.push_back(Downcast(s)); } return new_array; } diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index 084a86a6f5c7..b047657b3df2 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -521,9 +521,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) (*p) << ")"; }); -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { (*p) << "?"; }); - TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 8268e6b35ecb..8ac093149659 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -296,11 +296,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return prefix->Call(args); }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc { - return TIR(d, "Any")->Call({}); - }); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc combiner = d->AsDoc(r->combiner, p->Attr("combiner")); @@ -415,7 +410,6 @@ TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::IndexMapNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR); } // namespace printer diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 1ec9fc5522c8..d4e0284343ef 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -65,9 +65,6 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { auto* prhs = rhs.as(); return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } - if (lhs.as()) { - return false; - } return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index ca28520f8f77..b52c85df3575 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -747,18 +747,6 @@ TVM_REGISTER_GLOBAL("tir.Reduce") TVM_REGISTER_NODE_TYPE(ReduceNode); -// Any -Any::Any(Span span) { - auto n = make_object(); - n->dtype = DataType::Int(32); - n->span = std::move(span); - data_ = std::move(n); -} - -TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([](Span span) { return Any(span); }); - -TVM_REGISTER_NODE_TYPE(AnyNode); - // BufferLoad void BufferLoadNode::LegalizeDType() { for (int i = 0; i < static_cast(indices.size()) - 1; i++) { diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 3c117b58a7a3..05e333b78ac6 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -32,8 +32,6 @@ void ExprVisitor::VisitExpr_(const SizeVarNode* op) { this->VisitExpr_(static_cast(op)); } -void ExprVisitor::VisitExpr_(const AnyNode* op) {} - void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } @@ -119,8 +117,6 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { return this->VisitExpr_(static_cast(op)); } -PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef(op); } - PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; Array indices = op->indices.Map(fmutate); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index e0318b21bee3..4f5007aedb3f 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -343,8 +343,6 @@ void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, ObjectPath path) { VisitExpr_(static_cast(op), path); } -void TIRVisitorWithPath::VisitExpr_(const AnyNode* op, ObjectPath path) {} - void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, ObjectPath path) { Visit(op->buffer, path->Attr("buffer")); Visit(op->indices, path->Attr("indices")); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 1ae6df58f760..61441541da32 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -152,7 +152,6 @@ class TIRVisitorWithPath : protected ExprFunctor