Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] remove fp16 function override for cuda #4331

Merged
merged 2 commits into from
Nov 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() {
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(half a, half b)\n"
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half operator<="
<< "(__half a, __half b)\n"
<< "{\n return __hlt(a, b);\n}\n";
decl_stream << "__device__ half operator+"
<< "(__half a, __half &b)\n"
<<"{\n return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator*"
<< "(__half a, __half b)\n"
<< "{\n return __hmul(a, b);\n}\n";
// FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
// which is needed by operations such as softmax.
// However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
// We need to figure out a solution which can satisfy both scenario.
// decl_stream << "__device__ half operator<="
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hlt(a, b);\n}\n";
// decl_stream << "__device__ half operator+"
// << "(const volatile __half &a, const volatile __half &b)\n"
// <<"{\n return __hadd(a, b);\n}\n";
// decl_stream << "__device__ half operator*"
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hmul(a, b);\n}\n";
// otherwise simulate computation via float32
decl_stream << "#else\n";
decl_stream << _cuda_half_t_def;
Expand Down
3 changes: 2 additions & 1 deletion src/codegen/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;
Expand Down Expand Up @@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half {
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
TVM_XINLINE explicit half(const long long& value) { constructor(value); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do I need to #include <cstdint> for int64_t ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's okay to use long long as they are the same for cuda

TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }

TVM_XINLINE operator float() const { \
Expand Down