Skip to content

Commit

Permalink
[REFACTOR][IR] Remove UIntImm to use IntImm
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 14, 2020
1 parent 4f7e1db commit 1d87946
Show file tree
Hide file tree
Showing 52 changed files with 82 additions and 289 deletions.
4 changes: 0 additions & 4 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
Expand Down Expand Up @@ -523,8 +521,6 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
}
Expand Down
25 changes: 2 additions & 23 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) {
}
}

/*!
* \brief Get x as constant uint expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not UIntImm.
*/
inline const uint64_t* as_const_uint(const PrimExpr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
return &(op->value);
} else {
return nullptr;
}
}

/*!
* \brief Check whether x is a constant integer expression.
* \param x The input argument
Expand Down Expand Up @@ -626,11 +611,11 @@ TVM_DECLARE_INTRIN_UNARY(atan);

// Implementation details after this
inline bool is_const(const PrimExpr& x) {
if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
if (x.as<ir::IntImmNode>()) {
return true;
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const PrimExpr& val = op->value;
if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
if (val.as<ir::IntImmNode>()) {
return true;
}
}
Expand All @@ -640,8 +625,6 @@ inline bool is_const(const PrimExpr& x) {
inline bool is_positive_const(const PrimExpr& a) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0;
} else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
return op->value > 0;
} else {
return false;
}
Expand All @@ -658,14 +641,10 @@ inline bool is_negative_const(const PrimExpr& a) {
inline bool is_const_int(const PrimExpr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImmNode>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const PrimExpr& val = op->value;
if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImmNode>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
return false;
Expand Down
17 changes: 0 additions & 17 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,6 @@ namespace ir {
using IntImmNode = tvm::IntImmNode;
using VarNode = tvm::VarNode;

/*! \brief constant unsigned integer. */
class UIntImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
uint64_t value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

TVM_DLL static PrimExpr make(DataType t, uint64_t value);

static constexpr const char* _type_key = "UIntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode);
};

/*! \brief Floating point constants. */
class FloatImmNode : public PrimExprNode {
public:
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args ...) {
Expand Down Expand Up @@ -203,7 +202,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
return vtable;
Expand Down Expand Up @@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor :
void VisitExpr_(const BroadcastNode* op) override;
void VisitExpr_(const ShuffleNode* op) override;
void VisitExpr_(const IntImmNode* op) override;
void VisitExpr_(const UIntImmNode* op) override;
void VisitExpr_(const FloatImmNode* op) override;
void VisitExpr_(const StringImmNode* op) override;
};
Expand Down Expand Up @@ -372,7 +369,6 @@ class TVM_DLL ExprMutator :
PrimExpr VisitExpr_(const BroadcastNode* op) override;
PrimExpr VisitExpr_(const ShuffleNode* op) override;
PrimExpr VisitExpr_(const IntImmNode* op) override;
PrimExpr VisitExpr_(const UIntImmNode* op) override;
PrimExpr VisitExpr_(const FloatImmNode* op) override;
PrimExpr VisitExpr_(const StringImmNode* op) override;
};
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
workload = x
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
Expand Down Expand Up @@ -344,7 +344,7 @@ def _count_flop(exp):
if len(source) != 1:
raise FlopCalculationError("Found multiple output in the source of reduce op")
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
if isinstance(exp, (expr.FloatImm, expr.IntImm)):
return 0
if isinstance(exp, expr.Cast):
return _count_flop(exp.value)
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def get_const_int(exp):
"""
if isinstance(exp, int):
return exp
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
if not isinstance(exp, (expr.IntImm,)):
exp = ir_pass.Simplify(exp)
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
if not isinstance(exp, (expr.IntImm,)):
raise ValueError("Expect value to be constant int")
return exp.value

Expand All @@ -179,9 +179,9 @@ def get_const_tuple(in_tuple):
for elem in in_tuple:
if isinstance(elem, expr.Var):
ret.append(elem)
elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
elif not isinstance(elem, (expr.IntImm, int)):
elem = ir_pass.Simplify(elem)
if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
if not isinstance(elem, (expr.IntImm)):
ret.append(elem)
else:
ret.append(get_const_int(elem))
Expand Down
17 changes: 0 additions & 17 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,23 +341,6 @@ def __int__(self):
return self.value


@register_object
class UIntImm(ConstExpr):
"""UInt constant.
Parameters
----------
dtype : str
The data type
value : int
The constant value.
"""
def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
_make.UIntImm, dtype, value)


@register_object
class StringImm(ConstExpr):
"""String constant.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,6 @@ def max_num_threads(func_id, args):
if args.__len__() == 0:
res = _tgt.current_target().max_num_threads
else:
_internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint")
_internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads
return _api.convert(res)
4 changes: 2 additions & 2 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def visit_Subscript(self, node):
if isinstance(i, numbers.Integral):
arr = arr[i]
else:
_internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
_internal_assert(isinstance(i, (_expr.IntImm,)), \
"All indices are supposed to be constants")
arr = arr[i.value]
return arr
Expand All @@ -413,7 +413,7 @@ def visit_If(self, node):
cond = _ir_pass.CanonicalSimplify(self.visit(node.test))

# Return no IfThenElse if proven
if isinstance(cond, _expr.UIntImm):
if isinstance(cond, _expr.IntImm):
if cond.value:
return visit_list_to_block(self.visit, node.body)
if node.orelse:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
halide_imm_types = (_expr.IntImm, _expr.FloatImm)


def _internal_assert(cond, err):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def _shape():
def _impl(inputs, attr, params):
is_symbolic_shape = False
for axis in attr['_input_shapes'][inputs[0]]:
if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)):
if not isinstance(axis, (int, tvm.expr.IntImm)):
is_symbolic_shape = True
break

Expand Down
1 change: 0 additions & 1 deletion src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);

REGISTER_MAKE(UIntImm);
REGISTER_MAKE(FloatImm);
REGISTER_MAKE(StringImm);

Expand Down
6 changes: 3 additions & 3 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
}

bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImmNode>()) {
if (const auto* ptr = res.as<IntImmNode>()) {
return ptr->value != 0;
}
res = this->canonical_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImmNode>()) {
if (const auto* ptr = res.as<IntImmNode>()) {
return ptr->value != 0;
}
return false;
Expand Down
43 changes: 18 additions & 25 deletions src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ inline bool IsIndexType(const DataType& type) {


#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::IntImmNode; \
using ir::UIntImmNode; \
using ir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
const IntImmNode* pb = b.as<IntImmNode>(); \
Expand All @@ -87,8 +85,6 @@ inline bool IsIndexType(const DataType& type) {


#define TVM_INDEX_CONST_PROPAGATION(BODY) \
using ir::IntImmNode; \
using ir::UIntImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
const IntImmNode* pb = b.as<IntImmNode>(); \
const DataType& ta = a.dtype(); \
Expand Down Expand Up @@ -268,62 +264,61 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
template<>
inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value);
if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
});
return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value);
if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
});
return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value);
if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
});
return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value);
if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
});
return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value);
if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
});
return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value);
if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
});
return PrimExpr();
}

template<>
inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {
using ir::UIntImmNode;
const UIntImmNode* pa = a.as<UIntImmNode>();
const UIntImmNode* pb = b.as<UIntImmNode>();
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return b;
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
Expand All @@ -333,9 +328,8 @@ inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {

template<>
inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {
using ir::UIntImmNode;
const UIntImmNode* pa = a.as<UIntImmNode>();
const UIntImmNode* pb = b.as<UIntImmNode>();
const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return a;
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
Expand All @@ -345,10 +339,9 @@ inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {

template<>
inline PrimExpr TryConstFold<ir::NotNode>(PrimExpr a) {
using ir::UIntImmNode;
const UIntImmNode* pa = a.as<UIntImmNode>();
const IntImmNode* pa = a.as<IntImmNode>();
if (pa) {
return UIntImmNode::make(DataType::UInt(1), !(pa->value));
return IntImm(DataType::UInt(1), !(pa->value));
}
return PrimExpr();
}
Expand Down
Loading

0 comments on commit 1d87946

Please sign in to comment.