From 1d879463163330ddf7fdc4930429a5792972d87f Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 14 Jan 2020 13:45:28 -0800 Subject: [PATCH] [REFACTOR][IR] Remove UIntImm to use IntImm --- include/tvm/attrs.h | 4 -- include/tvm/expr_operator.h | 25 +---------- include/tvm/ir.h | 17 -------- include/tvm/ir_functor_ext.h | 4 -- python/tvm/autotvm/task/task.py | 4 +- python/tvm/autotvm/util.py | 8 ++-- python/tvm/expr.py | 17 -------- python/tvm/hybrid/calls.py | 2 +- python/tvm/hybrid/parser.py | 4 +- python/tvm/hybrid/util.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 2 +- src/api/api_ir.cc | 1 - src/arithmetic/analyzer.cc | 6 +-- src/arithmetic/const_fold.h | 43 ++++++++----------- src/arithmetic/const_int_bound.cc | 8 ---- src/arithmetic/int_set.cc | 4 -- src/arithmetic/modular_set.cc | 8 ---- src/codegen/codegen_c.cc | 8 +--- src/codegen/codegen_c.h | 1 - src/codegen/codegen_opengl.cc | 5 --- src/codegen/codegen_opengl.h | 1 - src/codegen/llvm/codegen_arm.cc | 22 +++++----- src/codegen/llvm/codegen_llvm.cc | 12 ++---- src/codegen/llvm/codegen_llvm.h | 1 - src/codegen/llvm/intrin_rule_llvm.h | 8 ++-- src/codegen/spirv/codegen_spirv.cc | 7 +-- src/codegen/spirv/codegen_spirv.h | 1 - src/codegen/spirv/intrin_rule_spirv.cc | 2 +- src/codegen/stackvm/codegen_stackvm.cc | 6 --- src/codegen/stackvm/codegen_stackvm.h | 1 - src/contrib/hybrid/codegen_hybrid.cc | 5 +-- src/contrib/hybrid/codegen_hybrid.h | 1 - src/lang/attr_functor.h | 4 -- src/lang/attrs.cc | 11 ----- src/lang/expr_operator.cc | 25 ++--------- src/lang/ir.cc | 14 ------ src/pass/arg_binder.cc | 6 +-- src/pass/ir_deep_compare.cc | 4 -- src/pass/ir_functor.cc | 2 - src/pass/lift_attr_scope.cc | 3 -- src/pass/lower_thread_allreduce.cc | 2 +- src/pass/rewrite_unsafe_select.cc | 1 - src/pass/unroll_loop.cc | 4 -- src/relay/ir/pretty_printer.cc | 4 -- src/relay/pass/type_solver.cc | 2 +- src/relay/qnn/util.h | 12 +----- tests/cpp/pattern_match_test.cc | 4 +- tests/python/unittest/test_hybrid_script.py | 2 +- .../python/unittest/test_lang_constructor.py | 7 +-- tests/python/unittest/test_lang_operator.py | 2 +- topi/include/topi/detail/constant_utils.h | 10 ++--- topi/python/topi/util.py | 12 +++--- 52 files changed, 82 insertions(+), 289 deletions(-) diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index ab9a711d28d89..9d9f98e79695b 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { CHECK(expr.defined()); if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } @@ -523,8 +521,6 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { *ptr = static_cast(op->value); } else if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index cbcb72a151e45..6ca3cea3a2df3 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -82,21 +82,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) { } } -/*! - * \brief Get x as constant uint expression. - * \param x The expression - * \return the address to the int expression, - * return nullptr, if x is not UIntImm. - */ -inline const uint64_t* as_const_uint(const PrimExpr& x) { - if (!x.defined()) return nullptr; - if (const ir::UIntImmNode* op = x.as()) { - return &(op->value); - } else { - return nullptr; - } -} - /*! * \brief Check whether x is a constant integer expression. * \param x The input argument @@ -626,11 +611,11 @@ TVM_DECLARE_INTRIN_UNARY(atan); // Implementation details after this inline bool is_const(const PrimExpr& x) { - if (x.as() || x.as()) { + if (x.as()) { return true; } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; - if (val.as() || val.as()) { + if (val.as()) { return true; } } @@ -640,8 +625,6 @@ inline bool is_const(const PrimExpr& x) { inline bool is_positive_const(const PrimExpr& a) { if (const ir::IntImmNode* op = a.as()) { return op->value > 0; - } else if (const ir::UIntImmNode* op = a.as()) { - return op->value > 0; } else { return false; } @@ -658,14 +641,10 @@ inline bool is_negative_const(const PrimExpr& a) { inline bool is_const_int(const PrimExpr& x, int64_t value) { if (const auto* op = x.as()) { return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; if (const auto* opv = val.as()) { return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); } } return false; diff --git a/include/tvm/ir.h b/include/tvm/ir.h index c637d055928cb..20ebd92fc4239 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -39,23 +39,6 @@ namespace ir { using IntImmNode = tvm::IntImmNode; using VarNode = tvm::VarNode; -/*! \brief constant unsigned integer. */ -class UIntImmNode : public PrimExprNode { - public: - /*! \brief The constant value content. */ - uint64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static PrimExpr make(DataType t, uint64_t value); - - static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode); -}; - /*! \brief Floating point constants. */ class FloatImmNode : public PrimExprNode { public: diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 7d57564fd3df3..37a1fe4bffb2c 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -161,7 +161,6 @@ class ExprFunctor { virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args ...) { @@ -203,7 +202,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); - IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode); IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); return vtable; @@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const BroadcastNode* op) override; void VisitExpr_(const ShuffleNode* op) override; void VisitExpr_(const IntImmNode* op) override; - void VisitExpr_(const UIntImmNode* op) override; void VisitExpr_(const FloatImmNode* op) override; void VisitExpr_(const StringImmNode* op) override; }; @@ -372,7 +369,6 @@ class TVM_DLL ExprMutator : PrimExpr VisitExpr_(const BroadcastNode* op) override; PrimExpr VisitExpr_(const ShuffleNode* op) override; PrimExpr VisitExpr_(const IntImmNode* op) override; - PrimExpr VisitExpr_(const UIntImmNode* op) override; PrimExpr VisitExpr_(const FloatImmNode* op) override; PrimExpr VisitExpr_(const StringImmNode* op) override; }; diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 7f36914eb0a67..5067277d32a87 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None): workload = tuple([args_to_workload(a) for a in x]) elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)): workload = x - elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): + elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): workload = x.value elif x is None: workload = 0 @@ -344,7 +344,7 @@ def _count_flop(exp): if len(source) != 1: raise FlopCalculationError("Found multiple output in the source of reduce op") return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) - if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)): + if isinstance(exp, (expr.FloatImm, expr.IntImm)): return 0 if isinstance(exp, expr.Cast): return _count_flop(exp.value) diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 3026914aed209..54001d3338ad7 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -155,9 +155,9 @@ def get_const_int(exp): """ if isinstance(exp, int): return exp - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): exp = ir_pass.Simplify(exp) - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -179,9 +179,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, expr.Var): ret.append(elem) - elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)): + elif not isinstance(elem, (expr.IntImm, int)): elem = ir_pass.Simplify(elem) - if not isinstance(elem, (expr.IntImm, expr.UIntImm)): + if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: ret.append(get_const_int(elem)) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 71c0aecd1f6a6..2fd7b78d9d66e 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -341,23 +341,6 @@ def __int__(self): return self.value -@register_object -class UIntImm(ConstExpr): - """UInt constant. - - Parameters - ---------- - dtype : str - The data type - - value : int - The constant value. - """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__( - _make.UIntImm, dtype, value) - - @register_object class StringImm(ConstExpr): """String constant. diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 1d5612e67e80d..7038f6144db34 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -156,6 +156,6 @@ def max_num_threads(func_id, args): if args.__len__() == 0: res = _tgt.current_target().max_num_threads else: - _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") + _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") res = _tgt.current_target(args[0].value).max_num_threads return _api.convert(res) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 06bcbcabe0c3e..57d6363288160 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -386,7 +386,7 @@ def visit_Subscript(self, node): if isinstance(i, numbers.Integral): arr = arr[i] else: - _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ + _internal_assert(isinstance(i, (_expr.IntImm,)), \ "All indices are supposed to be constants") arr = arr[i.value] return arr @@ -413,7 +413,7 @@ def visit_If(self, node): cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) # Return no IfThenElse if proven - if isinstance(cond, _expr.UIntImm): + if isinstance(cond, _expr.IntImm): if cond.value: return visit_list_to_block(self.visit, node.body) if node.orelse: diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0dd1fa1413299..a08a380dd7678 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -33,7 +33,7 @@ #pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) +halide_imm_types = (_expr.IntImm, _expr.FloatImm) def _internal_assert(cond, err): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7e22d72131ac4..e7f4682e7eb2d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -931,7 +931,7 @@ def _shape(): def _impl(inputs, attr, params): is_symbolic_shape = False for axis in attr['_input_shapes'][inputs[0]]: - if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(axis, (int, tvm.expr.IntImm)): is_symbolic_shape = True break diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 049b6ee38d48c..30ca51592c8fc 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); -REGISTER_MAKE(UIntImm); REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 7a3baa678352b..e03e5e2387bfd 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { } bool Analyzer::CanProve(const PrimExpr& expr) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value != 0; } auto res = this->rewrite_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } res = this->canonical_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } return false; diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 2bee70ed557a3..3b803ecd84a20 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -76,8 +76,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ using ir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ @@ -87,8 +85,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_INDEX_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ const DataType& ta = a.dtype(); \ @@ -268,8 +264,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); }); return PrimExpr(); } @@ -277,8 +273,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); }); return PrimExpr(); } @@ -286,8 +282,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); }); return PrimExpr(); } @@ -295,8 +291,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); }); return PrimExpr(); } @@ -304,8 +300,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); }); return PrimExpr(); } @@ -313,17 +309,16 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); }); return PrimExpr(); } template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; @@ -333,9 +328,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; @@ -345,10 +339,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); + const IntImmNode* pa = a.as(); if (pa) { - return UIntImmNode::make(DataType::UInt(1), !(pa->value)); + return IntImm(DataType::UInt(1), !(pa->value)); } return PrimExpr(); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 3a85c39aa3f02..25d88d3429b6f 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -150,14 +150,6 @@ class ConstIntBoundAnalyzer::Impl : return MakeBound(op->value, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value <= static_cast(kPosInf)) { - return MakeBound(op->value, op->value); - } else { - return Everything(op->dtype); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 9b1ab3d639071..37d5e9eb5e57e 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -384,10 +384,6 @@ class IntervalSetEvaluator : return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const UIntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); - } - IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = dom_map_.find(var); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 972c5148134fe..c81842035c9f3 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -132,14 +132,6 @@ class ModularSetAnalyzer::Impl : return Entry(0, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value < std::numeric_limits::max()) { - return Entry(0, static_cast(op->value)); - } else { - return Everything(); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index eae15248751b8..906631368f746 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -386,10 +386,6 @@ inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeG } } -inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintUIntConst(op->dtype, op->value, os, p); -} - inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) switch (op->dtype.bits()) { case 64: case 32: { @@ -413,9 +409,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintConst(op, os, this); -} + void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index cb092c566322a..7e5dd4269c942 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -128,7 +128,6 @@ class CodeGenC : void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 7967c1847ac28..cea276d5cb1a7 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -247,11 +247,6 @@ void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) { CodeGenC::VisitExpr_(op, os); } -void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) { - CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; - CodeGenC::VisitExpr_(op, os); -} - void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CodeGenC::VisitExpr_(op, os); diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index cd1ec83360c68..19ca2ee12c6c3 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -50,7 +50,6 @@ class CodeGenOpenGL final : public CodeGenC { // Codegen for immediate values void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 6879fd5f8542b..44862cf7a97ca 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -48,7 +48,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); + Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -68,8 +68,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { if (!call->dtype.is_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { Array vcnt_args; - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } @@ -93,16 +93,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { const CallNode* c0 = input8.as(); CHECK(c0 != nullptr); Array vcnt8_args; - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); PrimExpr vcnt8 = ir::CallNode::make( uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); PrimExpr vcnt16 = ir::CallNode::make( uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); @@ -112,8 +112,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 16->32bit Array vcnt32_args; - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); PrimExpr vcnt32 = ir::CallNode::make( uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); @@ -123,8 +123,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 32->64bit Array vcnt64_args; - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); return ir::CallNode::make( call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 20edd0a901a79..75982cc218485 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -662,15 +662,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); - const uint64_t *num_signature = as_const_uint(op->args[1]); - CHECK(num_signature) << "The second argument should be a uint represents number of arguments, " - << "but " << op->args[1] << " got!\n"; + Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector sig_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); - if (i - 2 < *num_signature) { + if (i - 2 < static_cast(num_signature)) { sig_type.push_back(arg_value.back()->getType()); } } @@ -810,10 +808,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) { - return llvm::ConstantInt::get(LLVMType(op->dtype), op->value); -} - llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(LLVMType(op->dtype), op->value); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 34c3ee723e18c..b269f2423fc8c 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -106,7 +106,6 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const VarNode* op) override; llvm::Value* VisitExpr_(const CastNode* op) override; llvm::Value* VisitExpr_(const IntImmNode* op) override; - llvm::Value* VisitExpr_(const UIntImmNode* op) override; llvm::Value* VisitExpr_(const FloatImmNode* op) override; llvm::Value* VisitExpr_(const StringImmNode* op) override; llvm::Value* VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index b3ab557ee215a..1f839f362f40a 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -43,8 +43,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); @@ -60,8 +60,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index ac7423e8ad875..8016444dad504 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -136,10 +136,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { return builder_->IntImm(builder_->GetSType(op->dtype), op->value); } -spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) { - return builder_->UIntImm(builder_->GetSType(op->dtype), op->value); -} - spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); } @@ -242,7 +238,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = op->args[0].as()->value; + uint32_t inst_id = static_cast( + op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 3804bda0f2e01..5aa7f9c499101 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -65,7 +65,6 @@ class CodeGenSPIRV: spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; spirv::Value VisitExpr_(const IntImmNode* op) override; - spirv::Value VisitExpr_(const UIntImmNode* op) override; spirv::Value VisitExpr_(const FloatImmNode* op) override; spirv::Value VisitExpr_(const StringImmNode* op) override; spirv::Value VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index d41d96db51653..d96883ed02fdd 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -39,7 +39,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), id)); for (PrimExpr arg : call->args) { cargs.push_back(arg); diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index eccff6c74c2e9..01096ae1dd469 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -280,12 +280,6 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) { - CHECK(op->value <= std::numeric_limits::max()) - << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); -} - void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { LOG(FATAL) << "Float Imm is not supported"; } diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 07989b2062e19..1360cc2d70f18 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -136,7 +136,6 @@ class CodeGenStackVM void VisitExpr_(const RampNode* op) final; void VisitExpr_(const BroadcastNode* op) final; void VisitExpr_(const IntImmNode* op) final; - void VisitExpr_(const UIntImmNode* op) final; void VisitExpr_(const FloatImmNode* op) final; void VisitExpr_(const StringImmNode* op) final; // statment diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7e3d44f26aeff..346ec38089196 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -79,10 +79,7 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) os << op->value; } -void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "(" << op->value << ")"; -} + void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 89a1ece577f91..33bd0efae8a45 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -117,7 +117,6 @@ class CodeGenHybrid : void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 34ee4b3159a59..4fffc475a7734 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -77,7 +77,6 @@ class AttrFunctor { virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. @@ -113,7 +112,6 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); - ATTR_FUNCTOR_DISPATCH(UIntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); @@ -157,7 +155,6 @@ class AttrsEqualHandler : bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final; @@ -198,7 +195,6 @@ class AttrsHashHandler : protected: size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImmNode* lhs) final; - size_t VisitAttr_(const ir::UIntImmNode* lhs) final; size_t VisitAttr_(const ir::FloatImmNode* lhs) final; size_t VisitAttr_(const ir::StringImmNode* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 1d3e767a5b714..a590f10e78e51 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -97,13 +97,6 @@ bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } - return false; -} - bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; @@ -224,10 +217,6 @@ size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) { - return std::hash()(op->value); -} - size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) { return std::hash()(op->value); } diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 6c7c54726eb90..5f9816f6898b9 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -86,7 +86,6 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) } } - // maximum and min limits PrimExpr max_value(const DataType& dtype) { using namespace ir; @@ -101,11 +100,11 @@ PrimExpr max_value(const DataType& dtype) { } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { - return UIntImmNode::make(dtype, std::numeric_limits::max()); + return make_const(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { uint64_t val = 1; val = (val << static_cast(dtype.bits())) - 1; - return UIntImmNode::make(dtype, val); + return IntImm(dtype, static_cast(val)); } } else if (dtype.is_float()) { if (dtype.bits() == 64) { @@ -132,7 +131,7 @@ PrimExpr min_value(const DataType& dtype) { return IntImm(dtype, val); } } else if (dtype.is_uint()) { - return UIntImmNode::make(dtype, 0); + return IntImm(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImmNode::make(dtype, std::numeric_limits::lowest()); @@ -163,24 +162,18 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); - } else if (const auto* op = x.as()) { - return ConstPowerHelper(op->value, shift); } else { return false; } } PrimExpr cast(const DataType& t, PrimExpr value) { - using ir::IntImmNode; - using ir::UIntImmNode; using ir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } @@ -192,8 +185,6 @@ PrimExpr cast(const DataType& t, PrimExpr value) { if (value.dtype() != vtype) { if (const IntImmNode* op = value.as()) { value = make_const(vtype, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { @@ -330,18 +321,10 @@ PrimExpr max(PrimExpr a, PrimExpr b) { } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - using ir::IntImmNode; - using ir::UIntImmNode; CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); - if (const UIntImmNode* op = cond.as()) { - if (op->value != 0) { - return true_value; - } else { - return false_value; - } - } else if (const IntImmNode* op = cond.as()) { + if (const IntImmNode* op = cond.as()) { if (op->value != 0) { return true_value; } else { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 5a24e965e780e..f06a6be5e75a3 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -31,14 +31,6 @@ namespace tvm { namespace ir { // constructors -PrimExpr UIntImmNode::make(DataType t, uint64_t value) { - CHECK(t.is_uint() && t.lanes() == 1) - << "ValueError: UIntImm can only take scalar"; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return PrimExpr(node); -} PrimExpr FloatImmNode::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) @@ -531,11 +523,6 @@ Stmt EvaluateNode::make(PrimExpr value) { } // Printers -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(" << op->dtype << ")" << op->value; - }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -1153,7 +1140,6 @@ TVM_REGISTER_NODE_TYPE(AnyNode); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_REGISTER_NODE_TYPE(FloatImmNode); TVM_REGISTER_NODE_TYPE(IntImmNode); -TVM_REGISTER_NODE_TYPE(UIntImmNode); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(VarNode); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 612a56664c8cc..0f350d2d732ef 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -179,11 +179,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - UIntImmNode::make(DataType::UInt(8), dtype.code()) && + IntImm(DataType::UInt(8), dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - UIntImmNode::make(DataType::UInt(8), dtype.bits()) && + IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - UIntImmNode::make(DataType::UInt(16), dtype.lanes())); + IntImm(DataType::UInt(16), dtype.lanes())); asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 6eacb145b29bd..8c441510c51d7 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -252,10 +252,6 @@ class IRDeepCompare : CompareValue(op->value, other.as()->value); } - void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final { - CompareValue(op->value, other.as()->value); - } - void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 67acec674630d..857206f8dd9f7 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -260,7 +260,6 @@ DEFINE_BINOP_VISIT_(AndNode); DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImmNode* op) {} -void ExprVisitor::VisitExpr_(const UIntImmNode* op) {} void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} void ExprVisitor::VisitExpr_(const StringImmNode* op) {} @@ -640,7 +639,6 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 7b760fa4a672e..5aba355b70032 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -180,9 +180,6 @@ class AttrScopeLifter : public StmtMutator { if (const IntImmNode* op = a.as()) { return op->value == b.as()->value; } - if (const UIntImmNode* op = a.as()) { - return op->value == b.as()->value; - } return false; } diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index a0b07c293b057..d509169df0b15 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -120,7 +120,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const CommReducerNode *combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const UIntImmNode *size_of_args = call->args[0].as(); + const IntImmNode *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 224a81c123965..9fb19cc4b308c 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -96,7 +96,6 @@ class UnsafeExprDetector : public ExprFunctor { return false; } bool VisitExpr_(const VarNode* op) final { return false; } - bool VisitExpr_(const UIntImmNode* op) final { return false; } bool VisitExpr_(const IntImmNode* op) final { return false; } bool VisitExpr_(const FloatImmNode* op) final { return false; } bool VisitExpr_(const StringImmNode* op) final { return false; } diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index b2c50f7a8bd26..26ad591896712 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -159,14 +159,10 @@ class LoopUnroller : public StmtExprMutator { // constant folding. PrimExpr extent = ir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); - const UIntImmNode *v2 = extent.as(); int value = -1; if (v1 != nullptr) { value = static_cast(v1->value); } - if (v2 != nullptr) { - value = static_cast(v2->value); - } return value; } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 25650c7766cba..400a6bea22ed9 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -857,10 +857,6 @@ class PrettyPrinter : return PrintConstScalar(op->dtype, &(op->value)); } - Doc VisitAttr_(const ir::UIntImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); - } - Doc VisitAttr_(const ir::FloatImmNode* op) final { return PrintConstScalar(op->dtype, &(op->value)); } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index d0d8b43f4c613..01280d209c0c2 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -41,7 +41,7 @@ class TypeSolver::Reporter : public TypeReporterNode { } bool Assert(const IndexExpr& cond) final { - if (const uint64_t* pdiff = as_const_uint(cond)) { + if (const int64_t* pdiff = as_const_int(cond)) { return pdiff[0]; } return true; diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 378a5e3728f44..2e332413c1f6b 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -47,14 +47,10 @@ static inline Array get_shape(const Type& type) { static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* min_value = as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); return static_cast(min_value[0]); - } else if (dtype.is_uint()) { - auto* min_value = as_const_uint(tvm::min_value(dtype)); - CHECK(min_value != nullptr); - return static_cast(min_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning @@ -64,14 +60,10 @@ static inline const int32_t GetQmin(const DataType& dtype) { static inline const int32_t GetQmax(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* max_value = as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); return static_cast(max_value[0]); - } else if (dtype.is_uint()) { - auto* max_value = as_const_uint(tvm::max_value(dtype)); - CHECK(max_value != nullptr); - return static_cast(max_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5392eaeac1e8f..193f2f206c064 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -127,10 +127,10 @@ TEST(Pattern, Basic) { } } -TEST(Pattern, Integer) { +TEST(Pattern, IntImm) { using namespace tvm; tvm::Var tx, ty; - arith::PVar c; + arith::PVar c; arith::PVar v; { // We can match integer and Var, both of which are diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index c3c40cf740ad2..5f1facb2b45fc 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) + assert isinstance(val, (tvm.expr.IntImm,)) return val.value ctx = tvm.context(target, 0) diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index fe329494e24e5..c4187858a8a88 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -38,16 +38,11 @@ def test_expr_constructor(): assert x.value == 2 assert x.dtype == "int64" - x = tvm.expr.UIntImm("uint16", 2) - assert isinstance(x, tvm.expr.UIntImm) - assert x.value == 2 - assert x.dtype == "uint16" - x = tvm.expr.StringImm("xyza") assert isinstance(x, tvm.expr.StringImm) assert x.value == "xyza" - x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1)) + x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1)) assert isinstance(x, tvm.expr.Cast) assert x.dtype == "float32" assert x.value.value == 1 diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index c57f4a1109ec1..ac2ee6d88cc55 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -29,7 +29,7 @@ def test_const_fold(): def check(f, *args): x = f(*[tvm.const(x, "int32") for x in args]) y = f(*args) - if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): + if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y): raise ValueError("check error: %s vs %s " % (x, y)) tmod = tvm.truncmod diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 43ac3a29cd7c4..e6de76f208813 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -43,8 +43,7 @@ using namespace tvm; */ inline bool IsConstInt(PrimExpr expr) { return - expr->IsInstance() || - expr->IsInstance(); + expr->IsInstance(); } /*! @@ -56,11 +55,8 @@ inline bool IsConstInt(PrimExpr expr) { * \return The integer value. */ inline int64_t GetConstInt(PrimExpr expr) { - if (expr->IsInstance()) { - return expr.as()->value; - } - if (expr->IsInstance()) { - return expr.as()->value; + if (expr->IsInstance()) { + return expr.as()->value; } LOG(ERROR) << "expr must be a constant integer"; return -1; diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 8f32a297d7195..379d1c3b457ba 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -92,9 +92,9 @@ def get_const_int(expr): """ if isinstance(expr, Integral): return expr - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, (tvm.expr.IntImm,)): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, (tvm.expr.IntImm,)): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -136,9 +136,9 @@ def equal_const_int(expr, value): """ if isinstance(expr, Integral): return expr == value - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, (tvm.expr.IntImm,): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, (tvm.expr.IntImm,)): return False return expr.value == value @@ -160,9 +160,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, tvm.expr.Var): ret.append(elem) - elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): + elif not isinstance(elem, (tvm.expr.IntImm, int)): elem = tvm.ir_pass.Simplify(elem) - if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(elem, (tvm.expr.IntImm,)): ret.append(elem) else: ret.append(get_const_int(elem))