Skip to content

Commit

Permalink
[CODEGEN][CUDA][OPENCL] Handle INF and NAN (apache#3194)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and wweic committed Jun 27, 2019
1 parent 3ca8f37 commit 50caa7b
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 3 deletions.
19 changes: 17 additions & 2 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <sm_61_intrinsics.h>\n";
}

if (need_math_constants_h_) {
decl_stream << "#include <math_constants.h>\n";
}

return CodeGenC::Finish();
}

Expand Down Expand Up @@ -318,8 +322,19 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { /
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << ((op->type.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
p->need_math_constants_h_ = true;
} else if (std::isnan(op->value)) {
temp << ((op->type.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
p->need_math_constants_h_ = true;
} else {
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
}
p->MarkConst(temp.str());
os << temp.str();
break;
Expand Down
7 changes: 6 additions & 1 deletion src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class CodeGenCUDA final : public CodeGenC {
void Init(bool output_ssa);
void AddFunction(LoweredFunc f);
std::string Finish();
bool need_include_path() { return (enable_fp16_ || enable_int8_); }
bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_);
}
// override behavior
void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const Call* op) final;
Expand Down Expand Up @@ -70,6 +72,9 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_fp16_{false};
// whether enable int8
bool enable_int8_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
};

} // namespace codegen
Expand Down
13 changes: 13 additions & 0 deletions src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,19 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(
CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
os << "-";
}
os << "INFINITY";
} else if (std::isnan(op->value)) {
os << "NAN";
} else {
CodeGenC::VisitExpr_(op, os);
}
}

runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class CodeGenOpenCL final : public CodeGenC {
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*)

private:
// whether enable fp16 and fp64 extension
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,38 @@ def check_cuda(n, value):
check_cuda(64, 0)
check_cuda(64, -3)


def test_cuda_inf_nan():
target = 'cuda'
def check_inf_nan(ctx, n, value, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
inf_value = tvm.const(value, dtype=dtype)
C = tvm.compute((n,), lambda i: inf_value, name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return

ctx = tvm.context(target, 0)

check_inf_nan(ctx, 1, -float('inf'), 'float32')
check_inf_nan(ctx, 1, -float('inf'), 'float64')
check_inf_nan(ctx, 1, float('inf'), 'float32')
check_inf_nan(ctx, 1, float('inf'), 'float64')
check_inf_nan(ctx, 1, float('nan'), 'float32')
check_inf_nan(ctx, 1, float('nan'), 'float64')


if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8x4()
test_cuda_inf_nan()
27 changes: 27 additions & 0 deletions tests/python/unittest/test_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,33 @@ def check_select(ctx, n, dtype):
check_select(ctx, 1, 'int16')
check_select(ctx, 1, 'uint16')

def test_opencl_inf_nan():
def check_inf_nan(ctx, n, value, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
inf_value = tvm.const(value, dtype=dtype)
C = tvm.compute((n,), lambda i: inf_value, name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.module.enabled(target):
print("skip because opencl is not enabled..")
return

ctx = tvm.context(target, 0)

check_inf_nan(ctx, 1, -float('inf'), 'float32')
check_inf_nan(ctx, 1, -float('inf'), 'float64')
check_inf_nan(ctx, 1, float('inf'), 'float32')
check_inf_nan(ctx, 1, float('inf'), 'float64')
check_inf_nan(ctx, 1, float('nan'), 'float32')
check_inf_nan(ctx, 1, float('nan'), 'float64')


if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()

0 comments on commit 50caa7b

Please sign in to comment.