diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 0f5b49210791..437ff6def4ba 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -226,26 +226,6 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // os << "))"; } -void CodeGenOpenCL::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*) - /* Return type of ternary expression is not always same as its sub-expressions, - * add a cast */ - if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { - os << "("; - PrintType(op->args[2].dtype(), os); - os << ")"; - } - CodeGenC::VisitExpr_(op, os); -} - -void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) - /* Return type of ternary expression is not always same as its sub-expressions, - * add a cast */ - os << "("; - PrintType(op->true_value.dtype(), os); - os << ")"; - CodeGenC::VisitExpr_(op, os); -} - void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*) if (std::isinf(op->value)) { if (op->value < 0) { @@ -259,6 +239,34 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO } } +template +inline void PrintBinaryExpr(const T* op, + const char* opstr, + std::ostream& os, + CodeGenOpenCL* p) { + if (op->dtype.lanes() == 1) { + os << opstr << "(("; + p->PrintType(op->a->dtype, os); + os << ")"; + p->PrintExpr(op->a, os); + os << ", ("; + p->PrintType(op->b->dtype, os); + os << ")"; + p->PrintExpr(op->b, os); + os << ')'; + } else { + p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); + } +} + +void CodeGenOpenCL::VisitExpr_(const MinNode *op, std::ostream& os) { + PrintBinaryExpr(op, "min", os, this); +} + +void CodeGenOpenCL::VisitExpr_(const MaxNode *op, std::ostream& os) { + PrintBinaryExpr(op, "max", os, this); +} + runtime::Module BuildOpenCL(Array funcs) { using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 52c4c58aa8dc..9f1c7f4c3044 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -55,9 +55,10 @@ class CodeGenOpenCL final : public CodeGenC { // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*) + // overload min and max to avoid ambiguous call errors + void VisitExpr_(const MinNode *op, std::ostream& os) final; + void VisitExpr_(const MaxNode *op, std::ostream& os) final; private: // whether enable fp16 and fp64 extension diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 140e1f6fbdea..e403589dff1d 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -94,6 +94,35 @@ def check_inf_nan(ctx, n, value, dtype): check_inf_nan(ctx, 1, float('nan'), 'float64') +def test_opencl_max(): + def check_max(ctx, n, dtype): + A = te.placeholder((n,), name='A', dtype=dtype) + max_lhs = A[0] + tvm.tir.const(1, dtype=dtype) + max_rhs = tvm.tir.const(0, dtype=dtype) + C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name='C') + s = te.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + + a = tvm.nd.empty((n,), A.dtype, ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + # Only need to test compiling here + fun(a, c) + + if not tvm.runtime.enabled(target): + print("skip because opencl is not enabled..") + return + + ctx = tvm.context(target, 0) + + check_max(ctx, 1, 'int8') + check_max(ctx, 1, 'uint8') + check_max(ctx, 1, 'int16') + check_max(ctx, 1, 'uint16') + check_max(ctx, 1, 'float32') + check_max(ctx, 1, 'float64') + + if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan()