Skip to content

Commit

Permalink
[BugFix][Opencl] Explicitly cast min/max operands (#9374)
Browse files Browse the repository at this point in the history
* [BugFix][Opencl] Explicitly cast min/max operands

* enable test_opencl_max
  • Loading branch information
hope51607 authored Oct 27, 2021
1 parent 9315113 commit 37a8d7b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,31 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N
}
}

template <typename T>
inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, CodeGenOpenCL* p) {
if (op->dtype.lanes() == 1) {
os << opstr << "((";
p->PrintType(op->a->dtype, os);
os << ")";
p->PrintExpr(op->a, os);
os << ", (";
p->PrintType(op->b->dtype, os);
os << ")";
p->PrintExpr(op->b, os);
os << ')';
} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}

void CodeGenOpenCL::VisitExpr_(const MinNode* op, std::ostream& os) {
PrintBinaryExpr(op, "min", os, this);
}

void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) {
PrintBinaryExpr(op, "max", os, this);
}

void CodeGenOpenCL::SetTextureScope(
const std::unordered_map<const VarNode*, std::string>& scope) { // NOLINT(*)
for (auto& texture : scope) {
Expand Down
4 changes: 4 additions & 0 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class CodeGenOpenCL final : public CodeGenC {
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const StoreNode* op) final; // NOLINT(*)

// overload min and max to avoid ambiguous call errors
void VisitExpr_(const MinNode* op, std::ostream& os) final;
void VisitExpr_(const MaxNode* op, std::ostream& os) final;

private:
// whether enable fp16 and fp64 extension
bool enable_fp16_{false};
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,5 @@ def check_erf(dev, n, dtype):
if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()
test_opencl_max()
test_opencl_erf()

0 comments on commit 37a8d7b

Please sign in to comment.