Skip to content

Commit 6a2cd90

Browse files
garimagubader
authored andcommitted
[SYCL] Float2Half precision requirements, convert method correction.
- Added the support for output operator for half datatypes. - Edit test case vec_convert.cpp for half datatype. Corrected source code for float2Half conversion. - Correct the automatic/rte mode implementation. Signed-off-by: Garima Gupta <garima.gupta@intel.com>
1 parent b55e0d1 commit 6a2cd90

File tree

4 files changed

+61
-27
lines changed

4 files changed

+61
-27
lines changed

sycl/include/CL/sycl/half_type.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cstdint>
1212
#include <functional>
13+
#include <iostream>
1314

1415
namespace cl {
1516
namespace sycl {
@@ -95,3 +96,21 @@ template <> struct hash<cl::sycl::detail::half_impl::half> {
9596
};
9697

9798
} // namespace std
99+
100+
#ifdef __SYCL_DEVICE_ONLY__
101+
using half = _Float16;
102+
#else
103+
using half = cl::sycl::detail::half_impl::half;
104+
#endif
105+
106+
inline std::ostream &operator<<(std::ostream &O, half const &rhs) {
107+
O << static_cast<float>(rhs);
108+
return O;
109+
}
110+
111+
inline std::istream &operator>>(std::istream &I, half &rhs) {
112+
float ValFloat = 0.0f;
113+
I >> ValFloat;
114+
rhs = ValFloat;
115+
return I;
116+
}

sycl/include/CL/sycl/types.hpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,14 @@
5151

5252
#include <array>
5353
#include <cmath>
54+
#ifndef __SYCL_DEVICE_ONLY__
55+
#include <cfenv>
56+
#pragma STDC FENV_ACCESS ON
57+
#endif
5458

5559
// 4.10.1: Scalar data types
5660
// 4.10.2: SYCL vector types
5761

58-
#ifdef __SYCL_DEVICE_ONLY__
59-
using half = _Float16;
60-
#else
61-
using half = cl::sycl::detail::half_impl::half;
62-
#endif
63-
6462
namespace cl {
6563
namespace sycl {
6664

@@ -258,6 +256,8 @@ detail::enable_if_t<std::is_same<T, R>::value, R> convertImpl(T Value) {
258256
return Value;
259257
}
260258

259+
// Note for float to half conversions, static_cast calls the conversion operator
260+
// implemented for host that takes care of the precision requirements.
261261
template <typename T, typename R, rounding_mode roundingMode>
262262
detail::enable_if_t<!std::is_same<T, R>::value &&
263263
(is_int_to_int<T, R>::value ||
@@ -270,16 +270,23 @@ convertImpl(T Value) {
270270

271271
// float to int
272272
template <typename T, typename R, rounding_mode roundingMode>
273-
detail::enable_if_t<!std::is_same<T, R>::value && is_float_to_int<T, R>::value,
274-
R>
275-
convertImpl(T Value) {
273+
detail::enable_if_t<is_float_to_int<T, R>::value, R> convertImpl(T Value) {
276274
#ifndef __SYCL_DEVICE_ONLY__
277275
switch (roundingMode) {
278276
// Round to nearest even is default rounding mode for floating-point types
279277
case rounding_mode::automatic:
280278
// Round to nearest even.
281-
case rounding_mode::rte:
282-
return std::round(Value);
279+
case rounding_mode::rte: {
280+
int OldRoundingDirection = std::fegetround();
281+
int Err = std::fesetround(FE_TONEAREST);
282+
if (Err)
283+
throw runtime_error("Unable to set rounding mode to FE_TONEAREST");
284+
R Result = std::rint(Value);
285+
Err = std::fesetround(OldRoundingDirection);
286+
if (Err)
287+
throw runtime_error("Unable to restore rounding mode.");
288+
return Result;
289+
}
283290
// Round toward zero.
284291
case rounding_mode::rtz:
285292
return std::trunc(Value);
@@ -294,7 +301,7 @@ convertImpl(T Value) {
294301
return static_cast<R>(Value);
295302
};
296303
#else
297-
// TODO implement device side convertion.
304+
// TODO implement device side conversion.
298305
return static_cast<R>(Value);
299306
#endif
300307
}

sycl/source/half_type.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <CL/sycl/half_type.hpp>
1010
// This is included to enable __builtin_expect()
1111
#include <CL/sycl/detail/platform_util.hpp>
12-
#include <iostream>
1312
#include <cstring>
1413

1514
namespace cl {
@@ -28,7 +27,13 @@ static uint16_t float2Half(const float &Val) {
2827
const int8_t Exp32Diff = Exp32 - 127;
2928

3029
uint16_t Exp16 = 0;
30+
31+
// convert 23-bit mantissa to 10-bit mantissa.
3132
uint16_t Frac16 = Frac32 >> 13;
33+
// Round the mantissa as given in OpenCL spec section : 6.1.1.1 The half data
34+
// type.
35+
if (Frac32 >> 12 & 0x01)
36+
Frac16 += 1;
3237

3338
if (__builtin_expect(Exp32 == 0xff || Exp32Diff > 15, 0)) {
3439
Exp16 = 0x1f;
@@ -111,18 +116,6 @@ static float half2Float(const uint16_t &Val) {
111116
return Result;
112117
}
113118

114-
std::ostream &operator<<(std::ostream &O, const half_impl::half &Val) {
115-
O << static_cast<float>(Val);
116-
return O;
117-
}
118-
119-
std::istream &operator>>(std::istream &I, half_impl::half &ValHalf) {
120-
float ValFloat = 0.0f;
121-
I >> ValFloat;
122-
ValHalf = ValFloat;
123-
return I;
124-
}
125-
126119
namespace half_impl {
127120

128121
half::half(const float &RHS) : Buf(float2Half(RHS)) {}

sycl/test/basic_tests/vec_convert.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,33 @@ int main() {
7171
int8{2, 3, 3, -2, -3, -3, 0, 0});
7272
test<float, int, 8, rounding_mode::automatic>(
7373
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
74-
int8{2, 3, 3, -2, -3, -3, 0, 0});
74+
int8{2, 2, 3, -2, -2, -3, 0, 0});
7575
test<int, float, 8, rounding_mode::automatic>(
7676
int8{2, 3, 3, -2, -3, -3, 0, 0},
7777
float8{2.f, 3.f, 3.f, -2.f, -3.f, -3.f, 0.f, 0.f});
7878
test<float, float, 8, rounding_mode::automatic>(
7979
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
8080
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
81+
test<float, half, 8, rounding_mode::automatic>(
82+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
83+
half8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
8184

8285
// rte
8386
test<int, int, 8, rounding_mode::rte>(
8487
int8{2, 3, 3, -2, -3, -3, 0, 0},
8588
int8{2, 3, 3, -2, -3, -3, 0, 0});
8689
test<float, int, 8, rounding_mode::rte>(
8790
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
88-
int8{2, 3, 3, -2, -3, -3, 0, 0});
91+
int8{2, 2, 3, -2, -2, -3, 0, 0});
8992
test<int, float, 8, rounding_mode::rte>(
9093
int8{2, 3, 3, -2, -3, -3, 0, 0},
9194
float8{2.f, 3.f, 3.f, -2.f, -3.f, -3.f, 0.f, 0.f});
9295
test<float, float, 8, rounding_mode::rte>(
9396
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
9497
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
98+
test<float, half, 8, rounding_mode::rte>(
99+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
100+
half8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
95101

96102
// rtz
97103
test<int, int, 8, rounding_mode::rtz>(
@@ -106,6 +112,9 @@ int main() {
106112
test<float, float, 8, rounding_mode::rtz>(
107113
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
108114
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
115+
test<float, half, 8, rounding_mode::rtz>(
116+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
117+
half8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
109118

110119
// rtp
111120
test<int, int, 8, rounding_mode::rtp>(
@@ -120,6 +129,9 @@ int main() {
120129
test<float, float, 8, rounding_mode::rtp>(
121130
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
122131
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
132+
test<float, half, 8, rounding_mode::rtp>(
133+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
134+
half8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
123135

124136
// rtn
125137
test<int, int, 8, rounding_mode::rtn>(
@@ -134,6 +146,9 @@ int main() {
134146
test<float, float, 8, rounding_mode::rtn>(
135147
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
136148
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
149+
test<float, half, 8, rounding_mode::rtn>(
150+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
151+
half8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
137152

138153
return 0;
139154
}

0 commit comments

Comments
 (0)