Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODEGEN][CUDA][OPENCL] Handle INF and NAN #3194

Merged
merged 1 commit into from
May 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <sm_61_intrinsics.h>\n";
}

if (need_math_constants_h_) {
decl_stream << "#include <math_constants.h>\n";
}

return CodeGenC::Finish();
}

Expand Down Expand Up @@ -318,8 +322,19 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { /
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << ((op->type.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
p->need_math_constants_h_ = true;
} else if (std::isnan(op->value)) {
temp << ((op->type.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
p->need_math_constants_h_ = true;
} else {
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
}
p->MarkConst(temp.str());
os << temp.str();
break;
Expand Down
7 changes: 6 additions & 1 deletion src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class CodeGenCUDA final : public CodeGenC {
void Init(bool output_ssa);
void AddFunction(LoweredFunc f);
std::string Finish();
bool need_include_path() { return (enable_fp16_ || enable_int8_); }
bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_);
}
// override behavior
void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const Call* op) final;
Expand Down Expand Up @@ -70,6 +72,9 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_fp16_{false};
// whether enable int8
bool enable_int8_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
};

} // namespace codegen
Expand Down
13 changes: 13 additions & 0 deletions src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,19 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(
CodeGenC::VisitExpr_(op, os);
}

void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
os << "-";
}
os << "INFINITY";
} else if (std::isnan(op->value)) {
os << "NAN";
} else {
CodeGenC::VisitExpr_(op, os);
}
}

runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class CodeGenOpenCL final : public CodeGenC {
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*)

private:
// whether enable fp16 and fp64 extension
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,38 @@ def check_cuda(n, value):
check_cuda(64, 0)
check_cuda(64, -3)


def test_cuda_inf_nan():
target = 'cuda'
def check_inf_nan(ctx, n, value, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
inf_value = tvm.const(value, dtype=dtype)
C = tvm.compute((n,), lambda i: inf_value, name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return

ctx = tvm.context(target, 0)

check_inf_nan(ctx, 1, -float('inf'), 'float32')
check_inf_nan(ctx, 1, -float('inf'), 'float64')
check_inf_nan(ctx, 1, float('inf'), 'float32')
check_inf_nan(ctx, 1, float('inf'), 'float64')
check_inf_nan(ctx, 1, float('nan'), 'float32')
check_inf_nan(ctx, 1, float('nan'), 'float64')


if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8x4()
test_cuda_inf_nan()
27 changes: 27 additions & 0 deletions tests/python/unittest/test_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,33 @@ def check_select(ctx, n, dtype):
check_select(ctx, 1, 'int16')
check_select(ctx, 1, 'uint16')

def test_opencl_inf_nan():
def check_inf_nan(ctx, n, value, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
inf_value = tvm.const(value, dtype=dtype)
C = tvm.compute((n,), lambda i: inf_value, name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)

if not tvm.module.enabled(target):
print("skip because opencl is not enabled..")
return

ctx = tvm.context(target, 0)

check_inf_nan(ctx, 1, -float('inf'), 'float32')
check_inf_nan(ctx, 1, -float('inf'), 'float64')
check_inf_nan(ctx, 1, float('inf'), 'float32')
check_inf_nan(ctx, 1, float('inf'), 'float64')
check_inf_nan(ctx, 1, float('nan'), 'float32')
check_inf_nan(ctx, 1, float('nan'), 'float64')


if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()