diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index d5d145653fa3..c2bf27393173 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -293,6 +293,24 @@ class DataTypeRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const SelectNode* op) final { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr true_value = this->VisitExpr(op->true_value); + PrimExpr false_value = this->VisitExpr(op->false_value); + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return GetRef(op); + } else { + if (op->true_value.dtype().is_int() && op->false_value.dtype().is_int()) { + int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); + DataType dtype = true_value.dtype().with_bits(bits); + if (true_value.dtype() != dtype) true_value = cast(dtype, true_value); + if (false_value.dtype() != dtype) false_value = cast(dtype, false_value); + } + return Select(condition, true_value, false_value); + } + } + PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr base = VisitExpr(op->base); PrimExpr stride = VisitExpr(op->stride);