From 33d5240da3fdcd727341385ebd59e1f6cd4f46df Mon Sep 17 00:00:00 2001 From: ganler Date: Mon, 7 Mar 2022 13:59:08 -0600 Subject: [PATCH] fix: select narrow dtype --- src/tir/transforms/narrow_datatype.cc | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index dd5f54e52455..3bd47db1990f 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -253,6 +253,24 @@ class DataTypeRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const SelectNode* op) final { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr true_value = this->VisitExpr(op->true_value); + PrimExpr false_value = this->VisitExpr(op->false_value); + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return GetRef(op); + } else { + if (op->true_value.dtype().is_int() && op->false_value.dtype().is_int()) { + int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); + DataType dtype = true_value.dtype().with_bits(bits); + if (true_value.dtype() != dtype) true_value = cast(dtype, true_value); + if (false_value.dtype() != dtype) false_value = cast(dtype, false_value); + } + return Select(condition, true_value, false_value); + } + } + PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr base = VisitExpr(op->base); PrimExpr stride = VisitExpr(op->stride);