Skip to content

Commit

Permalink
[CODEGEN][OPENCL] Fix compile error about ternary expression. (apache…
Browse files Browse the repository at this point in the history
…#2821)

Code like this can't be built with NV OpenCL, and it needs an explicit type
  converison for ternary expression if return type is uchar.

       uchar i = 0, j = 0;
       uchar t = max((uchar)j, ((i > 0) ? (uchar)1 : (uchar)0));
  • Loading branch information
lixiaoquan authored and wweic committed Mar 20, 2019
1 parent 8969809 commit a3b703b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,25 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintType(op->args[2].type(), os);
os << ")";
}
CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
PrintType(op->true_value.type(), os);
os << ")";
CodeGenC::VisitExpr_(op, os);
}

runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class CodeGenOpenCL final : public CodeGenC {

// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)

private:
// whether enable fp16 and fp64 extension
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_codegen_opencl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import tvm

target = 'opencl'

def test_opencl_ternary_expression():
def check_if_then_else(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.if_then_else(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

def check_select(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.module.enabled(target):
print("skip because opencl is not enabled..")
return

ctx = tvm.context(target, 0)

check_if_then_else(ctx, 1, 'int8')
check_if_then_else(ctx, 1, 'uint8')
check_if_then_else(ctx, 1, 'int16')
check_if_then_else(ctx, 1, 'uint16')
check_select(ctx, 1, 'int8')
check_select(ctx, 1, 'uint8')
check_select(ctx, 1, 'int16')
check_select(ctx, 1, 'uint16')


if __name__ == "__main__":
test_opencl_ternary_expression()

0 comments on commit a3b703b

Please sign in to comment.