diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index de96f923e2fa..14fbc4d8401c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -567,11 +567,7 @@ void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { os << ", "; PrintExpr(op->condition, oss); if (op->dtype.is_float()) { - if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) { - os << oss.str(); - } else { - os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); - } + os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); } else { os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype); } diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index bc2d0a84fd9d..67dc37363ea9 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -168,20 +168,34 @@ def check_type_casting(ctx, n, dtype): c = tvm.nd.empty((n,), dtype, ctx) assembly = fun.imported_modules[0].get_source() - false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))" - true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))" - lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" - rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))" - cond = "({} && {})".format(lcond, rcond) - select = "select({}, {}, {})".format(false_branch, true_branch, cond) - count = assembly.count(select) - assert count == 1 - fun(c) + if dtype == "float32": + false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))" + true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))" + lcond = "convert_int4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" + rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))" + cond = "({} && {})".format(lcond, rcond) + select = "select({}, {}, {})".format(false_branch, true_branch, cond) + count = assembly.count(select) + assert count == 1 + fun(c) + + elif dtype == "float16": + false_branch = "((half4)((half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f))" + true_branch = "((half4)((half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f))" + lcond = "convert_short4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" + rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))))" + cond = "({} && {})".format(lcond, rcond) + select = "select({}, {}, {})".format(false_branch, true_branch, cond) + count = assembly.count(select) + assert count == 1 + fun(c) dev = tvm.device(target, 0) check_type_casting(dev, 16, "float32") + # fp16 is not yet supported in ci + # check_type_casting(dev, 16, "float16") if __name__ == "__main__":