diff --git a/aten/src/ATen/cuda/NumericLimits.cuh b/aten/src/ATen/cuda/NumericLimits.cuh new file mode 100644 index 000000000..986964997 --- /dev/null +++ b/aten/src/ATen/cuda/NumericLimits.cuh @@ -0,0 +1,75 @@ +#pragma once + +#include +#include + +// NumericLimits.cuh is a holder for numeric limits definitions of commonly used +// types. This header is very specific to ROCm HIP and may be removed in the future. +// This header is derived from the legacy THCNumerics.cuh. + +namespace at{ + +template +struct numeric_limits { +}; + +// WARNING: the following at::numeric_limits definitions are there only to support +// HIP compilation for the moment. Use std::numeric_limits if you are not +// compiling for ROCm. +// from @colesbury: "The functions on numeric_limits aren't marked with +// __device__ which is why they don't work with ROCm. CUDA allows them +// because they're constexpr." +template <> +struct numeric_limits { + static inline __host__ __device__ uint8_t lowest() { return 0; } + static inline __host__ __device__ uint8_t max() { return UINT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int8_t lowest() { return INT8_MIN; } + static inline __host__ __device__ int8_t max() { return INT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } + static inline __host__ __device__ int16_t max() { return INT16_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } + static inline __host__ __device__ int32_t max() { return INT32_MAX; } +}; + +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ int64_t lowest() { return _I64_MIN; } + static inline __host__ __device__ int64_t max() { return _I64_MAX; } +#else + static inline __host__ __device__ int64_t lowest() { return INT64_MIN; } + static inline __host__ __device__ int64_t max() { return INT64_MAX; } +#endif +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits); } + static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ float lowest() { return -FLT_MAX; } + static inline __host__ __device__ float max() { return FLT_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ double lowest() { return -DBL_MAX; } + static inline __host__ __device__ double max() { return DBL_MAX; } +}; + +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 6351b8aa6..433e1be5e 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -6,9 +6,9 @@ #include #include #include -#include #include "ATen/AccumulateType.h" +#include "ATen/cuda/NumericLimits.cuh" namespace at { @@ -200,7 +200,7 @@ __global__ void cunn_SpatialSoftMaxForward( //////////////////////////////////////////////////////////// if (blockDim.x > 1) { - accscalar_t max_input = THCNumerics::min(); + accscalar_t max_input = at::numeric_limits::lowest(); for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { const accscalar_t value = static_cast(input[data_offset + d * dim_stride]); max_input = Max()(max_input, value); @@ -217,7 +217,7 @@ __global__ void cunn_SpatialSoftMaxForward( for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); } else { - accscalar_t max_input = THCNumerics::min(); + accscalar_t max_input = at::numeric_limits::lowest(); for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { const accscalar_t value = static_cast(input[data_offset + d * dim_stride]); max_input = Max()(max_input, value); @@ -403,9 +403,9 @@ cunn_SoftMaxForward(scalar_t *output, scalar_t *input, int classes) // find the max accscalar_t threadMax = ilpReduce( - input, classes, MaxFloat(), -THCNumerics::max()); + input, classes, MaxFloat(), -at::numeric_limits::max()); accscalar_t max_k = blockReduce( - sdata, threadMax, Max(), -THCNumerics::max()); + sdata, threadMax, Max(), -at::numeric_limits::max()); // reduce all values accscalar_t threadExp = ilpReduce( diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 25d84a36a..190f9de9e 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -25,7 +25,8 @@ list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rng_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/apply_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu) if (CUDNN_FOUND) list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_test.cpp) diff --git a/aten/src/ATen/test/cuda_half_test.cu b/aten/src/ATen/test/cuda_half_test.cu new file mode 100644 index 000000000..fa00e534e --- /dev/null +++ b/aten/src/ATen/test/cuda_half_test.cu @@ -0,0 +1,90 @@ +#define CATCH_CONFIG_MAIN +#include "catch.hpp" + +#include "ATen/ATen.h" +#include "ATen/cuda/NumericLimits.cuh" +#include "cuda.h" +#include "cuda_fp16.h" +#include "cuda_runtime.h" + +#include + +using namespace at; + +__device__ void test(){ + + // test half construction and implicit conversions in device + assert(Half(3) == Half(3.0f)); + assert(static_cast(3.0f) == Half(3.0f)); + // there is no float <=> __half implicit conversion + assert(static_cast(3.0f) == 3.0f); + + __half a = __float2half(3.0f); + __half b = __float2half(2.0f); + __half c = a - Half(b); + assert(static_cast(c) == Half(1.0)); + + // asserting if the functions used on + // half types give almost equivalent results when using + // functions on double. + // The purpose of these asserts are to test the device side + // half API for the common mathematical functions. + // Note: When calling std math functions from device, don't + // use the std namespace, but just "::" so that the function + // gets resolved from nvcc math_functions.hpp + + float threshold = 0.00001; + assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold); + assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold); + assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold); + assert(::abs(::log10(Half(1000.0)) - ::log10(1000.0f)) <= threshold); + assert(::abs(::log1p(Half(0.0)) - ::log1p(0.0f)) <= threshold); + assert(::abs(::log2(Half(1000.0)) - ::log2(1000.0f)) <= threshold); + assert(::abs(::expm1(Half(1.0)) - ::expm1(1.0f)) <= threshold); + assert(::abs(::cos(Half(0.0)) - ::cos(0.0f)) <= threshold); + assert(::abs(::sin(Half(0.0)) - ::sin(0.0f)) <= threshold); + assert(::abs(::sqrt(Half(100.0)) - ::sqrt(100.0f)) <= threshold); + assert(::abs(::ceil(Half(2.4)) - ::ceil(2.4f)) <= threshold); + assert(::abs(::floor(Half(2.7)) - ::floor(2.7f)) <= threshold); + assert(::abs(::trunc(Half(2.7)) - ::trunc(2.7f)) <= threshold); + assert(::abs(::acos(Half(-1.0)) - ::acos(-1.0f)) <= threshold); + assert(::abs(::cosh(Half(1.0)) - ::cosh(1.0f)) <= threshold); + assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold); + assert(::abs(::asin(Half(1.0)) - ::asin(1.0f)) <= threshold); + assert(::abs(::sinh(Half(1.0)) - ::sinh(1.0f)) <= threshold); + assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold); + assert(::abs(::tan(Half(0.0)) - ::tan(0.0f)) <= threshold); + assert(::abs(::atan(Half(1.0)) - ::atan(1.0f)) <= threshold); + assert(::abs(::tanh(Half(1.0)) - ::tanh(1.0f)) <= threshold); + assert(::abs(::erf(Half(10.0)) - ::erf(10.0f)) <= threshold); + assert(::abs(::erfc(Half(10.0)) - ::erfc(10.0f)) <= threshold); + assert(::abs(::abs(Half(-3.0)) - ::abs(-3.0f)) <= threshold); + assert(::abs(::round(Half(2.3)) - ::round(2.3f)) <= threshold); + assert(::abs(::pow(Half(2.0), Half(10.0)) - ::pow(2.0f, 10.0f)) <= threshold); + assert(::abs(::atan2(Half(7.0), Half(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold); + // note: can't use namespace on isnan and isinf in device code + #ifdef _MSC_VER + // Windows requires this explicit conversion. The reason is unclear + // related issue with clang: https://reviews.llvm.org/D37906 + assert(::abs(::isnan((float)Half(0.0)) - ::isnan(0.0f)) <= threshold); + assert(::abs(::isinf((float)Half(0.0)) - ::isinf(0.0f)) <= threshold); + #else + assert(::abs(::isnan(Half(0.0)) - ::isnan(0.0f)) <= threshold); + assert(::abs(::isinf(Half(0.0)) - ::isinf(0.0f)) <= threshold); + #endif +} + +__global__ void kernel(){ + test(); +} + +void launch_function(){ + kernel<<<1,1>>>(); +} + +TEST_CASE( "half common math functions tests in device", "[cuda]" ) { + launch_function(); + cudaError_t err = cudaDeviceSynchronize(); + REQUIRE(err == cudaSuccess); +} + diff --git a/aten/src/ATen/test/half_test.cpp b/aten/src/ATen/test/half_test.cpp index fc70522c8..3b2944803 100644 --- a/aten/src/ATen/test/half_test.cpp +++ b/aten/src/ATen/test/half_test.cpp @@ -5,7 +5,10 @@ #include #include #include +#include #include +#include "test_seed.h" +#include "test_assert.h" using namespace at; @@ -115,3 +118,43 @@ ASSERT_SAME_TYPE(max_exponent); ASSERT_SAME_TYPE(max_exponent10); ASSERT_SAME_TYPE(traps); ASSERT_SAME_TYPE(tinyness_before); + +TEST_CASE( "half common math functions test", "[]" ) { + float threshold = 0.00001; + assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold); + assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold); + assert(std::abs(std::log(Half(1.0)) - std::log(1.0f)) <= threshold); + assert(std::abs(std::log10(Half(1000.0)) - std::log10(1000.0f)) <= threshold); + assert(std::abs(std::log1p(Half(0.0)) - std::log1p(0.0f)) <= threshold); + assert(std::abs(std::log2(Half(1000.0)) - std::log2(1000.0f)) <= threshold); + assert(std::abs(std::expm1(Half(1.0)) - std::expm1(1.0f)) <= threshold); + assert(std::abs(std::cos(Half(0.0)) - std::cos(0.0f)) <= threshold); + assert(std::abs(std::sin(Half(0.0)) - std::sin(0.0f)) <= threshold); + assert(std::abs(std::sqrt(Half(100.0)) - std::sqrt(100.0f)) <= threshold); + assert(std::abs(std::ceil(Half(2.4)) - std::ceil(2.4f)) <= threshold); + assert(std::abs(std::floor(Half(2.7)) - std::floor(2.7f)) <= threshold); + assert(std::abs(std::trunc(Half(2.7)) - std::trunc(2.7f)) <= threshold); + assert(std::abs(std::acos(Half(-1.0)) - std::acos(-1.0f)) <= threshold); + assert(std::abs(std::cosh(Half(1.0)) - std::cosh(1.0f)) <= threshold); + assert(std::abs(std::acosh(Half(1.0)) - std::acosh(1.0f)) <= threshold); + assert(std::abs(std::asin(Half(1.0)) - std::asin(1.0f)) <= threshold); + assert(std::abs(std::sinh(Half(1.0)) - std::sinh(1.0f)) <= threshold); + assert(std::abs(std::asinh(Half(1.0)) - std::asinh(1.0f)) <= threshold); + assert(std::abs(std::tan(Half(0.0)) - std::tan(0.0f)) <= threshold); + assert(std::abs(std::atan(Half(1.0)) - std::atan(1.0f)) <= threshold); + assert(std::abs(std::tanh(Half(1.0)) - std::tanh(1.0f)) <= threshold); + assert(std::abs(std::erf(Half(10.0)) - std::erf(10.0f)) <= threshold); + assert(std::abs(std::erfc(Half(10.0)) - std::erfc(10.0f)) <= threshold); + assert(std::abs(std::abs(Half(-3.0)) - std::abs(-3.0f)) <= threshold); + assert(std::abs(std::round(Half(2.3)) - std::round(2.3f)) <= threshold); + assert(std::abs(std::pow(Half(2.0), Half(10.0)) - std::pow(2.0f, 10.0f)) <= threshold); + assert(std::abs(std::atan2(Half(7.0), Half(0.0)) - std::atan2(7.0f, 0.0f)) <= threshold); + #ifdef __APPLE__ + // @TODO: can macos do implicit conversion of Half? + assert(std::abs(std::isnan(static_cast(Half(0.0))) - std::isnan(0.0f)) <= threshold); + assert(std::abs(std::isinf(static_cast(Half(0.0))) - std::isinf(0.0f)) <= threshold); + #else + assert(std::abs(std::isnan(Half(0.0)) - std::isnan(0.0f)) <= threshold); + assert(std::abs(std::isinf(Half(0.0)) - std::isinf(0.0f)) <= threshold); + #endif +} \ No newline at end of file diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index 5f92ea34a..44f5d188d 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -18,10 +18,6 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double) endforeach() endforeach() -IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) - LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/THCHalf.cu) -ENDIF() - set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingAllocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp diff --git a/aten/src/THC/THCAtomics.cuh b/aten/src/THC/THCAtomics.cuh index f040699d2..485b5744f 100644 --- a/aten/src/THC/THCAtomics.cuh +++ b/aten/src/THC/THCAtomics.cuh @@ -4,8 +4,7 @@ #include "THC.h" #include "THCHalf.h" #include "THCNumerics.cuh" - -namespace at { struct Half; } +#include "ATen/ATen.h" template struct AtomicAddIntegerImpl; @@ -118,8 +117,8 @@ static inline __device__ void atomicAdd(half *address, half val) { old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); } -static inline __device__ void atomicAdd(at::Half *address, half val) { - return atomicAdd(reinterpret_cast(address), val); +static inline __device__ void atomicAdd(at::Half *address, at::Half val) { + atomicAdd(reinterpret_cast(address), val); } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) diff --git a/aten/src/THC/THCHalf.cu b/aten/src/THC/THCHalf.cu deleted file mode 100644 index 786326089..000000000 --- a/aten/src/THC/THCHalf.cu +++ /dev/null @@ -1,51 +0,0 @@ -#include "THCHalf.h" -#include "THCThrustAllocator.cuh" -#include -#include - -struct __half2floatOp { - __device__ float operator()(half v) { return __half2float(v); } -}; - -struct __float2halfOp { - __device__ half operator()(float v) { return __float2half(v); } -}; - -void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len) { - THCThrustAllocator thrustAlloc(state); - thrust::transform( -#if CUDA_VERSION >= 7000 - thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), -#else - thrust::device, -#endif - in, in + len, out, __float2halfOp()); -} - -void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len) { - THCThrustAllocator thrustAlloc(state); - thrust::transform( -#if CUDA_VERSION >= 7000 - thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), -#else - thrust::device, -#endif - in, in + len, out, __half2floatOp()); -} - -THC_EXTERNC int THC_nativeHalfInstructions(THCState *state) { - cudaDeviceProp* prop = - THCState_getCurrentDeviceProperties(state); - - // CC 5.3+ - return (prop->major > 5 || - (prop->major == 5 && prop->minor == 3)); -} - -THC_EXTERNC int THC_fastHalfInstructions(THCState *state) { - cudaDeviceProp* prop = - THCState_getCurrentDeviceProperties(state); - - // Check for CC 6.0 only (corresponds to P100) - return (prop->major == 6 && prop->minor == 0); -} diff --git a/aten/src/THC/THCHalf.h b/aten/src/THC/THCHalf.h index 6b9a4f755..aeae06fc4 100644 --- a/aten/src/THC/THCHalf.h +++ b/aten/src/THC/THCHalf.h @@ -12,15 +12,7 @@ typedef __half_raw half; #endif #endif -THC_EXTERNC void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len); -THC_EXTERNC void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len); THC_API half THC_float2half(float a); THC_API float THC_half2float(half a); -/* Check for native fp16 support on the current device (CC 5.3+) */ -THC_API int THC_nativeHalfInstructions(THCState *state); - -/* Check for performant native fp16 support on the current device */ -THC_API int THC_fastHalfInstructions(THCState *state); - #endif diff --git a/aten/src/THC/THCNumerics.cuh b/aten/src/THC/THCNumerics.cuh index 9e653b793..9af18f6bb 100644 --- a/aten/src/THC/THCNumerics.cuh +++ b/aten/src/THC/THCNumerics.cuh @@ -5,11 +5,18 @@ #include #include #include "THCHalf.h" +#include "ATen/ATen.h" +#include "ATen/cuda/NumericLimits.cuh" + +// WARNING: THCNumerics is being deprecated. Please follow the comments +// in this file to learn about new usages. +// Comments on usage: +// - lt,le,gt,ge,eq,neg,add,mul,sub,div and other binary ops can +// be implemented using CUDA_apply_utils or binary cuda kernel +// - Check NumericLimits.cuh for specialized math functions. +// - Note how __half and at::Half can be casted. for instance: +// static_cast(std::sin(static_cast(a))); -/// Class for numeric limits of the particular data type, which -/// includes support for `half`. -/// Unfortunately since `half` does not have a constructor, these have -/// to be expressed as functions (either that or non-const statics). template struct THCNumerics { }; @@ -28,10 +35,12 @@ static inline __host__ __device__ scalar_t powi(scalar_t a, scalar_t b) { return result; } +// DEPRECATED: For integral types, use math functions from std and NumericLimits.cuh. +// Use binary_kernel or CUDA_apply_utils for arithmetic template <> struct THCNumerics { - static inline __host__ __device__ uint8_t min() { return 0; } - static inline __host__ __device__ uint8_t max() { return UCHAR_MAX; } + static inline __host__ __device__ uint8_t min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ uint8_t max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(uint8_t a, uint8_t b) { return a < b; } static inline __host__ __device__ bool le(uint8_t a, uint8_t b) { return a <= b; } @@ -53,8 +62,8 @@ struct THCNumerics { template <> struct THCNumerics { - static inline __host__ __device__ int8_t min() { return SCHAR_MIN; } - static inline __host__ __device__ int8_t max() { return SCHAR_MAX; } + static inline __host__ __device__ int8_t min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ int8_t max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(int8_t a, int8_t b) { return a < b; } static inline __host__ __device__ bool le(int8_t a, int8_t b) { return a <= b; } @@ -76,8 +85,8 @@ struct THCNumerics { template <> struct THCNumerics { - static inline __host__ __device__ int16_t min() { return SHRT_MIN; } - static inline __host__ __device__ int16_t max() { return SHRT_MAX; } + static inline __host__ __device__ int16_t min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ int16_t max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(int16_t a, int16_t b) { return a < b; } static inline __host__ __device__ bool le(int16_t a, int16_t b) { return a <= b; } @@ -99,8 +108,8 @@ struct THCNumerics { template <> struct THCNumerics { - static inline __host__ __device__ int32_t min() { return INT_MIN; } - static inline __host__ __device__ int32_t max() { return INT_MAX; } + static inline __host__ __device__ int32_t min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ int32_t max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(int32_t a, int32_t b) { return a < b; } static inline __host__ __device__ bool le(int32_t a, int32_t b) { return a <= b; } @@ -122,13 +131,8 @@ struct THCNumerics { template <> struct THCNumerics { -#ifdef _MSC_VER - static inline __host__ __device__ int64_t min() { return _I64_MIN; } - static inline __host__ __device__ int64_t max() { return _I64_MAX; } -#else - static inline __host__ __device__ int64_t min() { return LONG_MIN; } - static inline __host__ __device__ int64_t max() { return LONG_MAX; } -#endif + static inline __host__ __device__ int64_t min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ int64_t max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(int64_t a, int64_t b) { return a < b; } static inline __host__ __device__ bool le(int64_t a, int64_t b) { return a <= b; } @@ -149,430 +153,222 @@ struct THCNumerics { static inline __host__ __device__ bool isinf(int64_t a) { return false; } }; +// DEPRECATED: use math functions from std and NumericLimits.cuh template <> struct THCNumerics { -#if CUDA_VERSION < 9000 - static inline __host__ __device__ half min() { half h; h.x = 0xfbff; return h; } - static inline __host__ __device__ half max() { half h; h.x = 0x7bff; return h; } -#else - static inline __host__ __device__ half min() { __half_raw h; h.x = 0xfbff; return h; } - static inline __host__ __device__ half max() { __half_raw h; h.x = 0x7bff; return h; } -#endif + static inline __host__ __device__ half min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ half max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return fa < fb; -#else // __CUDA_ARCH__ - return THC_half2float(a) < THC_half2float(b); -#endif + return static_cast(a) < static_cast(b); } static inline __host__ __device__ bool le(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return fa <= fb; -#else // __CUDA_ARCH__ - return THC_half2float(a) <= THC_half2float(b); -#endif + return static_cast(a) <= static_cast(b); } static inline __host__ __device__ bool gt(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return fa > fb; -#else // __CUDA_ARCH__ - return THC_half2float(a) > THC_half2float(b); -#endif + return static_cast(a) > static_cast(b); } static inline __host__ __device__ bool ge(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return fa >= fb; -#else // __CUDA_ARCH__ - return THC_half2float(a) >= THC_half2float(b); -#endif + return static_cast(a) >= static_cast(b); } static inline __host__ __device__ bool eq(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return fa == fb; -#else // __CUDA_ARCH__ - return THC_half2float(a) == THC_half2float(b); -#endif + // has to be explicitly casted to float for now, otherwise get error: more than one operator "==" matches these operands + // Note: find the overloading for == and != (probably THCTensorTypeUtils.cuh) and resolve + return static_cast(static_cast(a)) == static_cast(static_cast(b)); } static inline __host__ __device__ bool ne(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return fa != fb; -#else // __CUDA_ARCH__ - return THC_half2float(a) != THC_half2float(b); -#endif + // has to be explicitly casted to float for now, otherwise get error: more than one operator "==" matches these operands + // Note: find the overloading for == and != (probably THCTensorTypeUtils.cuh) and resolve + return static_cast(static_cast(a)) != static_cast(static_cast(b)); } static inline __host__ __device__ half exp(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(expf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(expf(THC_half2float(a))); -#endif + return static_cast(std::exp(static_cast(a))); } - + + // note that exp10 is not in the std namespace. static inline __host__ __device__ half exp10(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(exp10f(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(exp10f(THC_half2float(a))); -#endif + return static_cast(::exp10(static_cast(a))); } static inline __host__ __device__ half log(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(logf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(logf(THC_half2float(a))); -#endif + return static_cast(::log(static_cast(a))); } static inline __host__ __device__ half log10(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(log10f(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(log10f(THC_half2float(a))); -#endif + return static_cast(::log10(static_cast(a))); } static inline __host__ __device__ half log1p(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(log1pf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(log1pf(THC_half2float(a))); -#endif + return static_cast(::log1p(static_cast(a))); } static inline __host__ __device__ half log2(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(log2f(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(log2f(THC_half2float(a))); -#endif + return static_cast(::log2(static_cast(a))); } -static inline __host__ __device__ half lgamma(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(lgammaf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(lgammaf(THC_half2float(a))); -#endif + static inline __host__ __device__ half lgamma(half a) { + return static_cast(::lgamma(static_cast(a))); } static inline __host__ __device__ half expm1(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(expm1f(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(expm1f(THC_half2float(a))); -#endif + return static_cast(::expm1(static_cast(a))); } static inline __host__ __device__ half cos(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(cosf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(cosf(THC_half2float(a))); -#endif + return static_cast(::cos(static_cast(a))); } static inline __host__ __device__ half sin(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(sinf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(sinf(THC_half2float(a))); -#endif + return static_cast(::sin(static_cast(a))); } static inline __host__ __device__ half sqrt(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(sqrtf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(sqrtf(THC_half2float(a))); -#endif + return static_cast(::sqrt(static_cast(a))); } + // note that rsqrt is not in the std namespace. static inline __host__ __device__ half rsqrt(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(rsqrtf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(rsqrtf(THC_half2float(a))); -#endif + return static_cast(::rsqrt(static_cast(a))); } static inline __host__ __device__ half ceil(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(ceilf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(ceilf(THC_half2float(a))); -#endif + return static_cast(::ceil(static_cast(a))); } static inline __host__ __device__ half floor(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(floorf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(floorf(THC_half2float(a))); -#endif + return static_cast(::floor(static_cast(a))); } static inline __host__ __device__ half trunc(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(truncf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(truncf(THC_half2float(a))); -#endif + return static_cast(::trunc(static_cast(a))); } static inline __host__ __device__ half neg(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(-fa); -#else // __CUDA_ARCH__ - return THC_float2half(-(THC_half2float(a))); -#endif + return static_cast(-(static_cast(a))); } static inline __host__ __device__ half acos(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(acosf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(acosf(THC_half2float(a))); -#endif + return static_cast(::acos(static_cast(a))); } static inline __host__ __device__ half cosh(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(coshf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(coshf(THC_half2float(a))); -#endif + return static_cast(::cosh(static_cast(a))); } static inline __host__ __device__ half asin(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(asinf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(asinf(THC_half2float(a))); -#endif + return static_cast(::asin(static_cast(a))); } static inline __host__ __device__ half sinh(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(sinhf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(sinhf(THC_half2float(a))); -#endif + return static_cast(::sinh(static_cast(a))); } static inline __host__ __device__ half tan(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(tanf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(tanf(THC_half2float(a))); -#endif + return static_cast(::tan(static_cast(a))); } static inline __host__ __device__ half atan(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(atanf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(atanf(THC_half2float(a))); -#endif + return static_cast(::atan(static_cast(a))); } static inline __host__ __device__ half tanh(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(tanhf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(tanhf(THC_half2float(a))); -#endif + return static_cast(::tanh(static_cast(a))); } static inline __host__ __device__ half erf(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(erff(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(erff(THC_half2float(a))); -#endif + return static_cast(::erf(static_cast(a))); } static inline __host__ __device__ half erfc(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(erfcf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(erfcf(THC_half2float(a))); -#endif + return static_cast(::erfc(static_cast(a))); } - + // note that erfinv is not in the std namespace. static inline __host__ __device__ half erfinv(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(erfinvf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(erfinvf(THC_half2float(a))); -#endif + return static_cast(::erfinv(static_cast(a))); } static inline __host__ __device__ half abs(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(fabs(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(fabs(THC_half2float(a))); -#endif + return static_cast(::abs(static_cast(a))); } static inline __host__ __device__ half round(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(roundf(fa)); -#else // __CUDA_ARCH__ - return THC_float2half(roundf(THC_half2float(a))); -#endif + return static_cast(::round(static_cast(a))); } static inline __host__ __device__ half frac(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(fa - truncf(fa)); -#else // __CUDA_ARCH__ - float fa = THC_half2float(a); - return THC_float2half(fa - floorf(fa)); -#endif + #ifdef __CUDA_ARCH__ + return static_cast(a) - static_cast(::trunc(static_cast(a))); + #else // __CUDA_ARCH__ + return static_cast(a) - static_cast(::floor(static_cast(a))); + #endif } static inline __host__ __device__ half cinv(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return __float2half(1.0f / fa); -#else // __CUDA_ARCH__ - return THC_float2half(1.0f / THC_half2float(a)); -#endif + return static_cast(1.0f / static_cast(a)); } static inline __host__ __device__ half add(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return __float2half( fa + fb ); -#else // __CUDA_ARCH__ - return THC_float2half(THC_half2float(a) + THC_half2float(b)); -#endif + return static_cast(a) + static_cast(b); } static inline __host__ __device__ half div(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return __float2half( fa / fb ); -#else // __CUDA_ARCH__ - return THC_float2half(THC_half2float(a) / THC_half2float(b)); -#endif + return static_cast(a) / static_cast(b); } static inline __host__ __device__ half mul(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return __float2half( fa * fb ); -#else // __CUDA_ARCH__ - return THC_float2half(THC_half2float(a) * THC_half2float(b)); -#endif + return static_cast(a) * static_cast(b); } static inline __host__ __device__ half sub(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return __float2half( fa - fb ); -#else // __CUDA_ARCH__ - return THC_float2half(THC_half2float(a) - THC_half2float(b)); -#endif + return static_cast(a) - static_cast(b); } static inline __host__ __device__ half pow(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return __float2half(powf(fa, fb)); -#else // __CUDA_ARCH__ - return THC_float2half(powf(THC_half2float(a), THC_half2float(b))); -#endif + return static_cast(::pow(static_cast(a), static_cast(b))); } static inline __host__ __device__ half atan2(half a, half b) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - float fb = __half2float(b); - return __float2half(atan2f(fa, fb)); -#else // __CUDA_ARCH__ - return THC_float2half(atan2f(THC_half2float(a), THC_half2float(b))); -#endif + return static_cast(::atan2(static_cast(a), static_cast(b))); } static inline __host__ __device__ bool isnan(half a) { - // implemented using that a!=a if and only if a is nan - return ne(a, a); + #ifdef _MSC_VER + // Windows requires this explicit conversion. The reason is unclear + // related issue with clang: https://reviews.llvm.org/D37906 + return ::isnan((float)static_cast(a)); + #else + return ::isnan(static_cast(a)); + #endif } static inline __host__ __device__ bool isinf(half a) { -#ifdef __CUDA_ARCH__ - float fa = __half2float(a); - return ::isinf(fa); -#else // __CUDA_ARCH__ - return ::isinf(THC_half2float(a)); -#endif + #ifdef _MSC_VER + // Windows requires this explicit conversion. The reason is unclear + // related issue with clang: https://reviews.llvm.org/D37906 + return ::isinf((float)static_cast(a)); + #else + return ::isinf(static_cast(a)); + #endif } }; +// DEPRECATED: use math functions from std and cuda math API (if needed) +// note that the functions exp10,rsqrt,erfinv,frac and cinv +// are not in the std namespace template <> struct THCNumerics { - static inline __host__ __device__ float min() { return -FLT_MAX; } - static inline __host__ __device__ float max() { return FLT_MAX; } + static inline __host__ __device__ float min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ float max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(float a, float b) { return a < b; } static inline __host__ __device__ bool le(float a, float b) { return a <= b; } @@ -623,10 +419,13 @@ struct THCNumerics { static inline __host__ __device__ bool isinf(float a) { return ::isinf(a); } }; +// DEPRECATED: use math functions from std and cuda math API (if needed) +// note that the functions exp10,rsqrt,erfinv,frac and cinv +// are not in the std namespace template <> struct THCNumerics { - static inline __host__ __device__ double min() { return -DBL_MAX; } - static inline __host__ __device__ double max() { return DBL_MAX; } + static inline __host__ __device__ double min() { return at::numeric_limits::lowest(); } + static inline __host__ __device__ double max() { return at::numeric_limits::max(); } static inline __host__ __device__ bool lt(double a, double b) { return a < b; } static inline __host__ __device__ bool le(double a, double b) { return a <= b; } @@ -677,10 +476,15 @@ struct THCNumerics { static inline __host__ __device__ bool isinf(double a) { return ::isinf(a); } }; -/// `half` has some type conversion issues associated with it, since it -/// is a struct without a constructor/implicit conversion constructor. -/// We use this to convert scalar values to the given type that the -/// tensor expects. +// WARNING: The following note is deprecated +/// `half` has some type conversion issues associated with it, since it +/// is a struct without a constructor/implicit conversion constructor. +/// We use this to convert scalar values to the given type that the +/// tensor expects. +/// +/// at::Half has implicit conversions for float and __half types. Moreover +/// it has constructors for __half and float types. + template struct ScalarConvert { static __host__ __device__ Out to(const In v) { return (Out) v; } @@ -715,6 +519,7 @@ struct ScalarConvert { } }; +// DEPRECATED: use static_cast in kernels instead of scalar_cast template __host__ __device__ T scalar_cast(U u) { return ScalarConvert::to(u); diff --git a/aten/src/THCUNN/THCHalfAutoNumerics.cuh b/aten/src/THCUNN/THCHalfAutoNumerics.cuh index fff37f85d..5f8fda899 100644 --- a/aten/src/THCUNN/THCHalfAutoNumerics.cuh +++ b/aten/src/THCUNN/THCHalfAutoNumerics.cuh @@ -4,6 +4,9 @@ #include "THCHalf.h" #include "THCNumerics.cuh" +// WARNING: THCNumerics is being deprecated. Read the comments and function usage +// in THCNumerics to learn about the deprecation +// // Half numerics functions defined as free functions, so cunn code can be //written generically, i.e. without excessive calling of THCNumerics functions. diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index c341b885f..d2669029c 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -28,6 +28,9 @@ fi if [[ -x ./stream_test ]]; then ./stream_test fi +if [[ -x ./cuda_half_test ]]; then + ./cuda_half_test +fi if [ "$VALGRIND" == "ON" ] then valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"