Skip to content

Commit

Permalink
[CODEGEN][OPENCL] Explicitly cast min/max operands (apache#5090)
Browse files Browse the repository at this point in the history
* [CODEGEN][OPENCL] Explicitly cast min/max operands

* retrigger CI
  • Loading branch information
kazum authored and zhiics committed Apr 17, 2020
1 parent 5b4c2ef commit 717e217
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 22 deletions.
48 changes: 28 additions & 20 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,26 +226,6 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { //
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintType(op->args[2].dtype(), os);
os << ")";
}
CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
PrintType(op->true_value.dtype(), os);
os << ")";
CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
Expand All @@ -259,6 +239,34 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
}
}

template<typename T>
inline void PrintBinaryExpr(const T* op,
const char* opstr,
std::ostream& os,
CodeGenOpenCL* p) {
if (op->dtype.lanes() == 1) {
os << opstr << "((";
p->PrintType(op->a->dtype, os);
os << ")";
p->PrintExpr(op->a, os);
os << ", (";
p->PrintType(op->b->dtype, os);
os << ")";
p->PrintExpr(op->b, os);
os << ')';
} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}

void CodeGenOpenCL::VisitExpr_(const MinNode *op, std::ostream& os) {
PrintBinaryExpr(op, "min", os, this);
}

void CodeGenOpenCL::VisitExpr_(const MaxNode *op, std::ostream& os) {
PrintBinaryExpr(op, "max", os, this);
}

runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
Expand Down
5 changes: 3 additions & 2 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ class CodeGenOpenCL final : public CodeGenC {

// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
// overload min and max to avoid ambiguous call errors
void VisitExpr_(const MinNode *op, std::ostream& os) final;
void VisitExpr_(const MaxNode *op, std::ostream& os) final;

private:
// whether enable fp16 and fp64 extension
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,35 @@ def check_inf_nan(ctx, n, value, dtype):
check_inf_nan(ctx, 1, float('nan'), 'float64')


def test_opencl_max():
def check_max(ctx, n, dtype):
A = te.placeholder((n,), name='A', dtype=dtype)
max_lhs = A[0] + tvm.tir.const(1, dtype=dtype)
max_rhs = tvm.tir.const(0, dtype=dtype)
C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name='C')
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)

a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.runtime.enabled(target):
print("skip because opencl is not enabled..")
return

ctx = tvm.context(target, 0)

check_max(ctx, 1, 'int8')
check_max(ctx, 1, 'uint8')
check_max(ctx, 1, 'int16')
check_max(ctx, 1, 'uint16')
check_max(ctx, 1, 'float32')
check_max(ctx, 1, 'float64')


if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()

0 comments on commit 717e217

Please sign in to comment.