Skip to content

Commit 8226c81

Browse files
committed
[CUDA] Remove htanh from unsupported math ops for CUDA 12.8
This PR removes htanh from the list of unsupported CUDA half operators, as it is started to be supported since CUDA 12.8. Specifically, we added a CUDA version check in the generated CUDA code, so that when the CUDA version is older than 12.8, htanh will still be treated as an unsupported operator and fall back to the packed operation. While for newer CUDA versions, we directly use the function that is defined in `cuda_fp16.h`.
1 parent 41b6da1 commit 8226c81

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/target/source/literal/cuda_half_t.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ static inline __device__ __host__ half HALF_MATH_NAME(half x) { \
317317
#if defined(__CUDA_ARCH__)
318318
#if (__CUDA_ARCH__ >= 530)
319319
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
320+
#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8)))
320321
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
322+
#endif
321323
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
322324
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
323325
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
@@ -358,7 +360,9 @@ static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x) {
358360
}
359361
360362
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
363+
#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8)))
361364
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
365+
#endif
362366
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
363367
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
364368
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)

tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,9 @@ class TVM_ALIGNED(2) half {
717717
#if defined(__CUDA_ARCH__)
718718
#if (__CUDA_ARCH__ >= 530)
719719
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
720+
#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8)))
720721
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
722+
#endif
721723
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
722724
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
723725
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)

0 commit comments

Comments
 (0)