Skip to content

Commit 17bf4b6

Browse files
committed
[SYCL] Add support for half type
Signed-off-by: Mariya Podchishchaeva <mariya.podchishchaeva@intel.com>
1 parent e87838c commit 17bf4b6

File tree

10 files changed

+718
-83
lines changed

10 files changed

+718
-83
lines changed

sycl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ add_library("${SYCLLibrary}" SHARED
9999
"${sourceRootPath}/device_selector.cpp"
100100
"${sourceRootPath}/event.cpp"
101101
"${sourceRootPath}/exception.cpp"
102+
"${sourceRootPath}/half_type.cpp"
102103
"${sourceRootPath}/kernel.cpp"
103104
"${sourceRootPath}/platform.cpp"
104105
"${sourceRootPath}/queue.cpp"

sycl/include/CL/sycl/half_type.hpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//==-------------- half_type.hpp --- SYCL half type ------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
#include <functional>
13+
14+
namespace cl {
15+
namespace sycl {
16+
namespace detail {
17+
namespace half_impl {
18+
19+
class half {
20+
public:
21+
half() = default;
22+
half(const half &) = default;
23+
half(half &&) = default;
24+
25+
half(const float &rhs);
26+
27+
half &operator=(const half &rhs) = default;
28+
29+
// Operator +=, -=, *=, /=
30+
half &operator+=(const half &rhs);
31+
32+
half &operator-=(const half &rhs);
33+
34+
half &operator*=(const half &rhs);
35+
36+
half &operator/=(const half &rhs);
37+
38+
// Operator ++, --
39+
half &operator++() {
40+
*this += 1;
41+
return *this;
42+
}
43+
44+
half operator++(int) {
45+
half ret(*this);
46+
operator++();
47+
return ret;
48+
}
49+
50+
half &operator--() {
51+
*this -= 1;
52+
return *this;
53+
}
54+
55+
half operator--(int) {
56+
half ret(*this);
57+
operator--();
58+
return ret;
59+
}
60+
61+
// Operator float
62+
operator float() const;
63+
64+
template <typename Key> friend struct std::hash;
65+
66+
private:
67+
uint16_t Buf;
68+
};
69+
} // namespace half_impl
70+
} // namespace detail
71+
72+
} // namespace sycl
73+
} // namespace cl
74+
75+
namespace std {
76+
77+
template <> struct hash<cl::sycl::detail::half_impl::half> {
78+
size_t operator()(cl::sycl::detail::half_impl::half const &key) const
79+
noexcept {
80+
return hash<uint16_t>()(key.Buf);
81+
}
82+
};
83+
84+
} // namespace std

sycl/include/CL/sycl/intel/sub_group.hpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,18 @@ struct sub_group {
139139
return BinaryOperation::template calc<T, cl::__spirv::InclusiveScan>(x);
140140
}
141141

142+
template <typename T>
143+
using EnableIfIsArithmeticOrHalf = typename std::enable_if<
144+
(std::is_arithmetic<T>::value ||
145+
std::is_same<typename std::remove_const<T>::type, half>::value),
146+
T>::type;
147+
148+
142149
/* --- one - input shuffles --- */
143150
/* indices in [0 , sub - group size ) */
144151

145152
template <typename T>
146-
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
153+
EnableIfIsArithmeticOrHalf<T>
147154
shuffle(T x, id<1> local_id) {
148155
return cl::__spirv::OpSubgroupShuffleINTEL(x, local_id.get(0));
149156
}
@@ -156,7 +163,7 @@ struct sub_group {
156163
}
157164

158165
template <typename T>
159-
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
166+
EnableIfIsArithmeticOrHalf<T>
160167
shuffle_down(T x, uint32_t delta) {
161168
return shuffle_down(x, x, delta);
162169
}
@@ -168,7 +175,7 @@ struct sub_group {
168175
}
169176

170177
template <typename T>
171-
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
178+
EnableIfIsArithmeticOrHalf<T>
172179
shuffle_up(T x, uint32_t delta) {
173180
return shuffle_up(x, x, delta);
174181
}
@@ -180,7 +187,7 @@ struct sub_group {
180187
}
181188

182189
template <typename T>
183-
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
190+
EnableIfIsArithmeticOrHalf<T>
184191
shuffle_xor(T x, id<1> value) {
185192
return cl::__spirv::OpSubgroupShuffleXorINTEL(x, (uint32_t)value.get(0));
186193
}
@@ -195,7 +202,7 @@ struct sub_group {
195202
/* --- two - input shuffles --- */
196203
/* indices in [0 , 2* sub - group size ) */
197204
template <typename T>
198-
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
205+
EnableIfIsArithmeticOrHalf<T>
199206
shuffle(T x, T y, id<1> local_id) {
200207
return cl::__spirv::OpSubgroupShuffleDownINTEL(
201208
x, y, local_id.get(0) - get_local_id().get(0));
@@ -210,7 +217,7 @@ struct sub_group {
210217
}
211218

212219
template <typename T>
213-
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
220+
EnableIfIsArithmeticOrHalf<T>
214221
shuffle_down(T current, T next, uint32_t delta) {
215222
return cl::__spirv::OpSubgroupShuffleDownINTEL(current, next, delta);
216223
}
@@ -223,7 +230,7 @@ struct sub_group {
223230
}
224231

225232
template <typename T>
226-
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
233+
EnableIfIsArithmeticOrHalf<T>
227234
shuffle_up(T previous, T current, uint32_t delta) {
228235
return cl::__spirv::OpSubgroupShuffleUpINTEL(previous, current, delta);
229236
}

0 commit comments

Comments
 (0)