Skip to content

Commit cc4b86e

Browse files
Backport to 2.8: Implement cuda::std::numeric_limits for __half and __nv_bfloat16 (#3361) (#3490)
With a dedicated C++11 fix Co-authored-by: Michael Schellenberger Costa <miscco@nvidia.com>
1 parent 6d735b6 commit cc4b86e

36 files changed

+564
-201
lines changed

libcudacxx/include/cuda/std/limits

Lines changed: 200 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
#endif // no system header
2323

2424
#include <cuda/std/__bit/bit_cast.h>
25-
#include <cuda/std/__type_traits/is_arithmetic.h>
25+
#include <cuda/std/__type_traits/integral_constant.h>
26+
#include <cuda/std/__type_traits/is_extended_floating_point.h>
27+
#include <cuda/std/__type_traits/is_floating_point.h>
28+
#include <cuda/std/__type_traits/is_integral.h>
2629
#include <cuda/std/climits>
2730
#include <cuda/std/version>
2831

@@ -46,7 +49,47 @@ enum float_denorm_style
4649
denorm_present = 1
4750
};
4851

49-
template <class _Tp, bool = is_arithmetic<_Tp>::value>
52+
enum class __numeric_limits_type
53+
{
54+
__integral,
55+
__bool,
56+
__floating_point,
57+
__other,
58+
};
59+
60+
template <class _Tp>
61+
_LIBCUDACXX_HIDE_FROM_ABI constexpr __numeric_limits_type __make_numeric_limits_type()
62+
{
63+
#if !defined(_CCCL_NO_IF_CONSTEXPR)
64+
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_same, _Tp, bool))
65+
{
66+
return __numeric_limits_type::__bool;
67+
}
68+
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_integral, _Tp))
69+
{
70+
return __numeric_limits_type::__integral;
71+
}
72+
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp))
73+
{
74+
return __numeric_limits_type::__floating_point;
75+
}
76+
else
77+
{
78+
return __numeric_limits_type::__other;
79+
}
80+
_CCCL_UNREACHABLE();
81+
#else // ^^^ !_CCCL_NO_IF_CONSTEXPR ^^^ // vvv _CCCL_NO_IF_CONSTEXPR vvv
82+
return _CCCL_TRAIT(is_same, _Tp, bool)
83+
? __numeric_limits_type::__bool
84+
: (_CCCL_TRAIT(is_integral, _Tp)
85+
? __numeric_limits_type::__integral
86+
: (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp)
87+
? __numeric_limits_type::__floating_point
88+
: __numeric_limits_type::__other));
89+
#endif // _CCCL_NO_IF_CONSTEXPR
90+
}
91+
92+
template <class _Tp, __numeric_limits_type = __make_numeric_limits_type<_Tp>()>
5093
class __numeric_limits_impl
5194
{
5295
public:
@@ -135,7 +178,7 @@ struct __int_min<_Tp, __digits, false>
135178
};
136179

137180
template <class _Tp>
138-
class __numeric_limits_impl<_Tp, true>
181+
class __numeric_limits_impl<_Tp, __numeric_limits_type::__integral>
139182
{
140183
public:
141184
using type = _Tp;
@@ -212,7 +255,7 @@ public:
212255
};
213256

214257
template <>
215-
class __numeric_limits_impl<bool, true>
258+
class __numeric_limits_impl<bool, __numeric_limits_type::__bool>
216259
{
217260
public:
218261
using type = bool;
@@ -286,7 +329,7 @@ public:
286329
};
287330

288331
template <>
289-
class __numeric_limits_impl<float, true>
332+
class __numeric_limits_impl<float, __numeric_limits_type::__floating_point>
290333
{
291334
public:
292335
using type = float;
@@ -381,7 +424,7 @@ public:
381424
};
382425

383426
template <>
384-
class __numeric_limits_impl<double, true>
427+
class __numeric_limits_impl<double, __numeric_limits_type::__floating_point>
385428
{
386429
public:
387430
using type = double;
@@ -476,7 +519,7 @@ public:
476519
};
477520

478521
template <>
479-
class __numeric_limits_impl<long double, true>
522+
class __numeric_limits_impl<long double, __numeric_limits_type::__floating_point>
480523
{
481524
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
482525

@@ -551,6 +594,156 @@ public:
551594
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE
552595
};
553596

597+
#if defined(_LIBCUDACXX_HAS_NVFP16)
598+
template <>
599+
class __numeric_limits_impl<__half, __numeric_limits_type::__floating_point>
600+
{
601+
public:
602+
using type = __half;
603+
604+
static constexpr bool is_specialized = true;
605+
606+
static constexpr bool is_signed = true;
607+
static constexpr int digits = 11;
608+
static constexpr int digits10 = 3;
609+
static constexpr int max_digits10 = 5;
610+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
611+
{
612+
return type(__half_raw{0x0400u});
613+
}
614+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
615+
{
616+
return type(__half_raw{0x7bffu});
617+
}
618+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
619+
{
620+
return type(__half_raw{0xfbffu});
621+
}
622+
623+
static constexpr bool is_integer = false;
624+
static constexpr bool is_exact = false;
625+
static constexpr int radix = __FLT_RADIX__;
626+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
627+
{
628+
return type(__half_raw{0x1400u});
629+
}
630+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
631+
{
632+
return type(__half_raw{0x3800u});
633+
}
634+
635+
static constexpr int min_exponent = -13;
636+
static constexpr int min_exponent10 = -4;
637+
static constexpr int max_exponent = 16;
638+
static constexpr int max_exponent10 = 4;
639+
640+
static constexpr bool has_infinity = true;
641+
static constexpr bool has_quiet_NaN = true;
642+
static constexpr bool has_signaling_NaN = true;
643+
static constexpr float_denorm_style has_denorm = denorm_present;
644+
static constexpr bool has_denorm_loss = false;
645+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
646+
{
647+
return type(__half_raw{0x7c00u});
648+
}
649+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
650+
{
651+
return type(__half_raw{0x7e00u});
652+
}
653+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
654+
{
655+
return type(__half_raw{0x7d00u});
656+
}
657+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
658+
{
659+
return type(__half_raw{0x0001u});
660+
}
661+
662+
static constexpr bool is_iec559 = true;
663+
static constexpr bool is_bounded = true;
664+
static constexpr bool is_modulo = false;
665+
666+
static constexpr bool traps = false;
667+
static constexpr bool tinyness_before = false;
668+
static constexpr float_round_style round_style = round_to_nearest;
669+
};
670+
#endif // _LIBCUDACXX_HAS_NVFP16
671+
672+
#if defined(_LIBCUDACXX_HAS_NVBF16)
673+
template <>
674+
class __numeric_limits_impl<__nv_bfloat16, __numeric_limits_type::__floating_point>
675+
{
676+
public:
677+
using type = __nv_bfloat16;
678+
679+
static constexpr bool is_specialized = true;
680+
681+
static constexpr bool is_signed = true;
682+
static constexpr int digits = 8;
683+
static constexpr int digits10 = 2;
684+
static constexpr int max_digits10 = 4;
685+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
686+
{
687+
return type(__nv_bfloat16_raw{0x0080u});
688+
}
689+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
690+
{
691+
return type(__nv_bfloat16_raw{0x7f7fu});
692+
}
693+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
694+
{
695+
return type(__nv_bfloat16_raw{0xff7fu});
696+
}
697+
698+
static constexpr bool is_integer = false;
699+
static constexpr bool is_exact = false;
700+
static constexpr int radix = __FLT_RADIX__;
701+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
702+
{
703+
return type(__nv_bfloat16_raw{0x3c00u});
704+
}
705+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
706+
{
707+
return type(__nv_bfloat16_raw{0x3f00u});
708+
}
709+
710+
static constexpr int min_exponent = -125;
711+
static constexpr int min_exponent10 = -37;
712+
static constexpr int max_exponent = 128;
713+
static constexpr int max_exponent10 = 38;
714+
715+
static constexpr bool has_infinity = true;
716+
static constexpr bool has_quiet_NaN = true;
717+
static constexpr bool has_signaling_NaN = true;
718+
static constexpr float_denorm_style has_denorm = denorm_present;
719+
static constexpr bool has_denorm_loss = false;
720+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
721+
{
722+
return type(__nv_bfloat16_raw{0x7f80u});
723+
}
724+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
725+
{
726+
return type(__nv_bfloat16_raw{0x7fc0u});
727+
}
728+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
729+
{
730+
return type(__nv_bfloat16_raw{0x7fa0u});
731+
}
732+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
733+
{
734+
return type(__nv_bfloat16_raw{0x0001u});
735+
}
736+
737+
static constexpr bool is_iec559 = true;
738+
static constexpr bool is_bounded = true;
739+
static constexpr bool is_modulo = false;
740+
741+
static constexpr bool traps = false;
742+
static constexpr bool tinyness_before = false;
743+
static constexpr float_round_style round_style = round_to_nearest;
744+
};
745+
#endif // _LIBCUDACXX_HAS_NVBF16
746+
554747
template <class _Tp>
555748
class numeric_limits : public __numeric_limits_impl<_Tp>
556749
{};

libcudacxx/test/libcudacxx/std/containers/views/mdspan/my_int.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#ifndef _MY_INT_HPP
22
#define _MY_INT_HPP
33

4+
#include <cuda/std/limits>
5+
#include <cuda/std/type_traits>
6+
47
#include "test_macros.h"
58

69
struct my_int_non_convertible;
@@ -22,6 +25,10 @@ template <>
2225
struct cuda::std::is_integral<my_int> : cuda::std::true_type
2326
{};
2427

28+
template <>
29+
class cuda::std::numeric_limits<my_int> : public cuda::std::numeric_limits<int>
30+
{};
31+
2532
// Wrapper type that's not implicitly convertible
2633

2734
struct my_int_non_convertible
@@ -43,6 +50,10 @@ template <>
4350
struct cuda::std::is_integral<my_int_non_convertible> : cuda::std::true_type
4451
{};
4552

53+
template <>
54+
class cuda::std::numeric_limits<my_int_non_convertible> : public cuda::std::numeric_limits<int>
55+
{};
56+
4657
// Wrapper type that's not nothrow-constructible
4758

4859
struct my_int_non_nothrow_constructible
@@ -62,4 +73,8 @@ template <>
6273
struct cuda::std::is_integral<my_int_non_nothrow_constructible> : cuda::std::true_type
6374
{};
6475

76+
template <>
77+
class cuda::std::numeric_limits<my_int_non_nothrow_constructible> : public cuda::std::numeric_limits<int>
78+
{};
79+
6580
#endif

libcudacxx/test/libcudacxx/std/language.support/support.limits/limits/is_specialized.pass.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ int main(int, char**)
6868
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
6969
test<long double>();
7070
#endif
71+
#if defined(_LIBCUDACXX_HAS_NVFP16)
72+
test<__half>();
73+
#endif // _LIBCUDACXX_HAS_NVFP16
74+
#if defined(_LIBCUDACXX_HAS_NVBF16)
75+
test<__nv_bfloat16>();
76+
#endif // _LIBCUDACXX_HAS_NVBF16
77+
7178
static_assert(!cuda::std::numeric_limits<cuda::std::complex<double>>::is_specialized,
7279
"!cuda::std::numeric_limits<cuda::std::complex<double> >::is_specialized");
7380

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef NUMERIC_LIMITS_MEMBERS_COMMON_H
11+
#define NUMERIC_LIMITS_MEMBERS_COMMON_H
12+
13+
// Disable all the extended floating point operations and conversions
14+
#define __CUDA_NO_HALF_CONVERSIONS__ 1
15+
#define __CUDA_NO_HALF_OPERATORS__ 1
16+
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
17+
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1
18+
19+
#include <cuda/std/limits>
20+
21+
template <class T>
22+
__host__ __device__ bool float_eq(T x, T y)
23+
{
24+
return x == y;
25+
}
26+
27+
#if defined(_LIBCUDACXX_HAS_NVFP16)
28+
__host__ __device__ inline bool float_eq(__half x, __half y)
29+
{
30+
return __heq(x, y);
31+
}
32+
#endif // _LIBCUDACXX_HAS_NVFP16
33+
34+
#if defined(_LIBCUDACXX_HAS_NVBF16)
35+
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
36+
{
37+
return __heq(x, y);
38+
}
39+
#endif // _LIBCUDACXX_HAS_NVBF16
40+
41+
#endif // NUMERIC_LIMITS_MEMBERS_COMMON_H

0 commit comments

Comments
 (0)