Skip to content

Commit 4f2efaf

Browse files
authored
Implement cuda::std::numeric_limits for __half and __nv_bfloat16 (#3361)
* implement `cuda::std::numeric_limits` for `__half` and `__nv_bfloat16`
1 parent c339a52 commit 4f2efaf

36 files changed

+563
-201
lines changed

libcudacxx/include/cuda/std/limits

Lines changed: 199 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
#endif // no system header
2323

2424
#include <cuda/std/__bit/bit_cast.h>
25-
#include <cuda/std/__type_traits/is_arithmetic.h>
25+
#include <cuda/std/__type_traits/integral_constant.h>
26+
#include <cuda/std/__type_traits/is_extended_floating_point.h>
27+
#include <cuda/std/__type_traits/is_floating_point.h>
28+
#include <cuda/std/__type_traits/is_integral.h>
2629
#include <cuda/std/climits>
2730
#include <cuda/std/version>
2831

@@ -46,7 +49,46 @@ enum float_denorm_style
4649
denorm_present = 1
4750
};
4851

49-
template <class _Tp, bool = is_arithmetic<_Tp>::value>
52+
enum class __numeric_limits_type
53+
{
54+
__integral,
55+
__bool,
56+
__floating_point,
57+
__other,
58+
};
59+
60+
template <class _Tp>
61+
_LIBCUDACXX_HIDE_FROM_ABI constexpr __numeric_limits_type __make_numeric_limits_type()
62+
{
63+
#if !defined(_CCCL_NO_IF_CONSTEXPR)
64+
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_same, _Tp, bool))
65+
{
66+
return __numeric_limits_type::__bool;
67+
}
68+
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_integral, _Tp))
69+
{
70+
return __numeric_limits_type::__integral;
71+
}
72+
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp))
73+
{
74+
return __numeric_limits_type::__floating_point;
75+
}
76+
else
77+
{
78+
return __numeric_limits_type::__other;
79+
}
80+
#else // ^^^ !_CCCL_NO_IF_CONSTEXPR ^^^ // vvv _CCCL_NO_IF_CONSTEXPR vvv
81+
return _CCCL_TRAIT(is_same, _Tp, bool)
82+
? __numeric_limits_type::__bool
83+
: (_CCCL_TRAIT(is_integral, _Tp)
84+
? __numeric_limits_type::__integral
85+
: (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp)
86+
? __numeric_limits_type::__floating_point
87+
: __numeric_limits_type::__other));
88+
#endif // _CCCL_NO_IF_CONSTEXPR
89+
}
90+
91+
template <class _Tp, __numeric_limits_type = __make_numeric_limits_type<_Tp>()>
5092
class __numeric_limits_impl
5193
{
5294
public:
@@ -135,7 +177,7 @@ struct __int_min<_Tp, __digits, false>
135177
};
136178

137179
template <class _Tp>
138-
class __numeric_limits_impl<_Tp, true>
180+
class __numeric_limits_impl<_Tp, __numeric_limits_type::__integral>
139181
{
140182
public:
141183
using type = _Tp;
@@ -212,7 +254,7 @@ public:
212254
};
213255

214256
template <>
215-
class __numeric_limits_impl<bool, true>
257+
class __numeric_limits_impl<bool, __numeric_limits_type::__bool>
216258
{
217259
public:
218260
using type = bool;
@@ -286,7 +328,7 @@ public:
286328
};
287329

288330
template <>
289-
class __numeric_limits_impl<float, true>
331+
class __numeric_limits_impl<float, __numeric_limits_type::__floating_point>
290332
{
291333
public:
292334
using type = float;
@@ -381,7 +423,7 @@ public:
381423
};
382424

383425
template <>
384-
class __numeric_limits_impl<double, true>
426+
class __numeric_limits_impl<double, __numeric_limits_type::__floating_point>
385427
{
386428
public:
387429
using type = double;
@@ -476,7 +518,7 @@ public:
476518
};
477519

478520
template <>
479-
class __numeric_limits_impl<long double, true>
521+
class __numeric_limits_impl<long double, __numeric_limits_type::__floating_point>
480522
{
481523
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
482524

@@ -551,6 +593,156 @@ public:
551593
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE
552594
};
553595

596+
#if defined(_LIBCUDACXX_HAS_NVFP16)
597+
template <>
598+
class __numeric_limits_impl<__half, __numeric_limits_type::__floating_point>
599+
{
600+
public:
601+
using type = __half;
602+
603+
static constexpr bool is_specialized = true;
604+
605+
static constexpr bool is_signed = true;
606+
static constexpr int digits = 11;
607+
static constexpr int digits10 = 3;
608+
static constexpr int max_digits10 = 5;
609+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
610+
{
611+
return type(__half_raw{0x0400u});
612+
}
613+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
614+
{
615+
return type(__half_raw{0x7bffu});
616+
}
617+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
618+
{
619+
return type(__half_raw{0xfbffu});
620+
}
621+
622+
static constexpr bool is_integer = false;
623+
static constexpr bool is_exact = false;
624+
static constexpr int radix = __FLT_RADIX__;
625+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
626+
{
627+
return type(__half_raw{0x1400u});
628+
}
629+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
630+
{
631+
return type(__half_raw{0x3800u});
632+
}
633+
634+
static constexpr int min_exponent = -13;
635+
static constexpr int min_exponent10 = -4;
636+
static constexpr int max_exponent = 16;
637+
static constexpr int max_exponent10 = 4;
638+
639+
static constexpr bool has_infinity = true;
640+
static constexpr bool has_quiet_NaN = true;
641+
static constexpr bool has_signaling_NaN = true;
642+
static constexpr float_denorm_style has_denorm = denorm_present;
643+
static constexpr bool has_denorm_loss = false;
644+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
645+
{
646+
return type(__half_raw{0x7c00u});
647+
}
648+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
649+
{
650+
return type(__half_raw{0x7e00u});
651+
}
652+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
653+
{
654+
return type(__half_raw{0x7d00u});
655+
}
656+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
657+
{
658+
return type(__half_raw{0x0001u});
659+
}
660+
661+
static constexpr bool is_iec559 = true;
662+
static constexpr bool is_bounded = true;
663+
static constexpr bool is_modulo = false;
664+
665+
static constexpr bool traps = false;
666+
static constexpr bool tinyness_before = false;
667+
static constexpr float_round_style round_style = round_to_nearest;
668+
};
669+
#endif // _LIBCUDACXX_HAS_NVFP16
670+
671+
#if defined(_LIBCUDACXX_HAS_NVBF16)
672+
template <>
673+
class __numeric_limits_impl<__nv_bfloat16, __numeric_limits_type::__floating_point>
674+
{
675+
public:
676+
using type = __nv_bfloat16;
677+
678+
static constexpr bool is_specialized = true;
679+
680+
static constexpr bool is_signed = true;
681+
static constexpr int digits = 8;
682+
static constexpr int digits10 = 2;
683+
static constexpr int max_digits10 = 4;
684+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
685+
{
686+
return type(__nv_bfloat16_raw{0x0080u});
687+
}
688+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
689+
{
690+
return type(__nv_bfloat16_raw{0x7f7fu});
691+
}
692+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
693+
{
694+
return type(__nv_bfloat16_raw{0xff7fu});
695+
}
696+
697+
static constexpr bool is_integer = false;
698+
static constexpr bool is_exact = false;
699+
static constexpr int radix = __FLT_RADIX__;
700+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
701+
{
702+
return type(__nv_bfloat16_raw{0x3c00u});
703+
}
704+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
705+
{
706+
return type(__nv_bfloat16_raw{0x3f00u});
707+
}
708+
709+
static constexpr int min_exponent = -125;
710+
static constexpr int min_exponent10 = -37;
711+
static constexpr int max_exponent = 128;
712+
static constexpr int max_exponent10 = 38;
713+
714+
static constexpr bool has_infinity = true;
715+
static constexpr bool has_quiet_NaN = true;
716+
static constexpr bool has_signaling_NaN = true;
717+
static constexpr float_denorm_style has_denorm = denorm_present;
718+
static constexpr bool has_denorm_loss = false;
719+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
720+
{
721+
return type(__nv_bfloat16_raw{0x7f80u});
722+
}
723+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
724+
{
725+
return type(__nv_bfloat16_raw{0x7fc0u});
726+
}
727+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
728+
{
729+
return type(__nv_bfloat16_raw{0x7fa0u});
730+
}
731+
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
732+
{
733+
return type(__nv_bfloat16_raw{0x0001u});
734+
}
735+
736+
static constexpr bool is_iec559 = true;
737+
static constexpr bool is_bounded = true;
738+
static constexpr bool is_modulo = false;
739+
740+
static constexpr bool traps = false;
741+
static constexpr bool tinyness_before = false;
742+
static constexpr float_round_style round_style = round_to_nearest;
743+
};
744+
#endif // _LIBCUDACXX_HAS_NVBF16
745+
554746
template <class _Tp>
555747
class numeric_limits : public __numeric_limits_impl<_Tp>
556748
{};

libcudacxx/test/libcudacxx/std/containers/views/mdspan/my_int.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#ifndef _MY_INT_HPP
22
#define _MY_INT_HPP
33

4+
#include <cuda/std/limits>
5+
#include <cuda/std/type_traits>
6+
47
#include "test_macros.h"
58

69
struct my_int_non_convertible;
@@ -22,6 +25,10 @@ template <>
2225
struct cuda::std::is_integral<my_int> : cuda::std::true_type
2326
{};
2427

28+
template <>
29+
class cuda::std::numeric_limits<my_int> : public cuda::std::numeric_limits<int>
30+
{};
31+
2532
// Wrapper type that's not implicitly convertible
2633

2734
struct my_int_non_convertible
@@ -43,6 +50,10 @@ template <>
4350
struct cuda::std::is_integral<my_int_non_convertible> : cuda::std::true_type
4451
{};
4552

53+
template <>
54+
class cuda::std::numeric_limits<my_int_non_convertible> : public cuda::std::numeric_limits<int>
55+
{};
56+
4657
// Wrapper type that's not nothrow-constructible
4758

4859
struct my_int_non_nothrow_constructible
@@ -62,4 +73,8 @@ template <>
6273
struct cuda::std::is_integral<my_int_non_nothrow_constructible> : cuda::std::true_type
6374
{};
6475

76+
template <>
77+
class cuda::std::numeric_limits<my_int_non_nothrow_constructible> : public cuda::std::numeric_limits<int>
78+
{};
79+
6580
#endif

libcudacxx/test/libcudacxx/std/language.support/support.limits/limits/is_specialized.pass.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ int main(int, char**)
6868
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
6969
test<long double>();
7070
#endif
71+
#if defined(_LIBCUDACXX_HAS_NVFP16)
72+
test<__half>();
73+
#endif // _LIBCUDACXX_HAS_NVFP16
74+
#if defined(_LIBCUDACXX_HAS_NVBF16)
75+
test<__nv_bfloat16>();
76+
#endif // _LIBCUDACXX_HAS_NVBF16
77+
7178
static_assert(!cuda::std::numeric_limits<cuda::std::complex<double>>::is_specialized,
7279
"!cuda::std::numeric_limits<cuda::std::complex<double> >::is_specialized");
7380

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef NUMERIC_LIMITS_MEMBERS_COMMON_H
11+
#define NUMERIC_LIMITS_MEMBERS_COMMON_H
12+
13+
// Disable all the extended floating point operations and conversions
14+
#define __CUDA_NO_HALF_CONVERSIONS__ 1
15+
#define __CUDA_NO_HALF_OPERATORS__ 1
16+
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
17+
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1
18+
19+
#include <cuda/std/limits>
20+
21+
template <class T>
22+
__host__ __device__ bool float_eq(T x, T y)
23+
{
24+
return x == y;
25+
}
26+
27+
#if defined(_LIBCUDACXX_HAS_NVFP16)
28+
__host__ __device__ inline bool float_eq(__half x, __half y)
29+
{
30+
return __heq(x, y);
31+
}
32+
#endif // _LIBCUDACXX_HAS_NVFP16
33+
34+
#if defined(_LIBCUDACXX_HAS_NVBF16)
35+
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
36+
{
37+
return __heq(x, y);
38+
}
39+
#endif // _LIBCUDACXX_HAS_NVBF16
40+
41+
#endif // NUMERIC_LIMITS_MEMBERS_COMMON_H

0 commit comments

Comments
 (0)