Skip to content

Commit

Permalink
[Codegen] remove fp16 function override for cuda (apache#4331)
Browse files Browse the repository at this point in the history
* add volatile override back

* [codegen] remove fp16 function override for cuda
  • Loading branch information
yzhliu authored and Xingyu Zhou committed Nov 15, 2019
1 parent d763a68 commit 464da21
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
22 changes: 13 additions & 9 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() {
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(half a, half b)\n"
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half operator<="
<< "(__half a, __half b)\n"
<< "{\n return __hlt(a, b);\n}\n";
decl_stream << "__device__ half operator+"
<< "(__half a, __half &b)\n"
<<"{\n return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator*"
<< "(__half a, __half b)\n"
<< "{\n return __hmul(a, b);\n}\n";
// FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
// which is needed by operations such as softmax.
// However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
// We need to figure out a solution which can satisfy both scenario.
// decl_stream << "__device__ half operator<="
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hlt(a, b);\n}\n";
// decl_stream << "__device__ half operator+"
// << "(const volatile __half &a, const volatile __half &b)\n"
// <<"{\n return __hadd(a, b);\n}\n";
// decl_stream << "__device__ half operator*"
// << "(const volatile __half &a, const volatile __half &b)\n"
// << "{\n return __hmul(a, b);\n}\n";
// otherwise simulate computation via float32
decl_stream << "#else\n";
decl_stream << _cuda_half_t_def;
Expand Down
3 changes: 2 additions & 1 deletion src/codegen/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;
Expand Down Expand Up @@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half {
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
TVM_XINLINE explicit half(const long long& value) { constructor(value); }
TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
TVM_XINLINE operator float() const { \
Expand Down

0 comments on commit 464da21

Please sign in to comment.