Skip to content

Commit

Permalink
fix: select narrow dtype (apache#10519)
Browse files Browse the repository at this point in the history
  • Loading branch information
ganler authored and pfk-beta committed Apr 11, 2022
1 parent 17040a0 commit 2e4edcb
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,24 @@ class DataTypeRewriter : public StmtExprMutator {
return StmtExprMutator::VisitExpr_(op);
}

PrimExpr VisitExpr_(const SelectNode* op) final {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr true_value = this->VisitExpr(op->true_value);
PrimExpr false_value = this->VisitExpr(op->false_value);
if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
if (op->true_value.dtype().is_int() && op->false_value.dtype().is_int()) {
int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits());
DataType dtype = true_value.dtype().with_bits(bits);
if (true_value.dtype() != dtype) true_value = cast(dtype, true_value);
if (false_value.dtype() != dtype) false_value = cast(dtype, false_value);
}
return Select(condition, true_value, false_value);
}
}

PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr base = VisitExpr(op->base);
PrimExpr stride = VisitExpr(op->stride);
Expand Down

0 comments on commit 2e4edcb

Please sign in to comment.