@@ -86,8 +86,10 @@ template <typename Group> bool GroupAny(bool pred) {
8686}
8787
8888// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
89+ // FIXME: Do not special-case for half once all backends support all data types.
8990template <typename T>
90- using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value>;
91+ using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
92+ !std::is_same<T, half>::value>;
9193
9294template <typename T, typename IdT = size_t >
9395using EnableIfNativeBroadcast = detail::enable_if_t <
@@ -121,6 +123,13 @@ template <typename T, typename IdT = size_t>
121123using EnableIfGenericBroadcast = detail::enable_if_t <
122124 is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
123125
126+ // FIXME: Disable widening once all backends support all data types.
127+ template <typename T>
128+ using WidenOpenCLTypeTo32_t = conditional_t <
129+ std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
130+ conditional_t <std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
131+ cl_uint, T>>;
132+
124133// Broadcast with scalar local index
125134// Work-group supports any integral type
126135// Sub-group currently supports only uint32_t
@@ -133,21 +142,17 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
133142 using GroupIdT = typename GroupId<Group>::type;
134143 GroupIdT GroupLocalId = static_cast <GroupIdT>(local_id);
135144 using OCLT = detail::ConvertToOpenCLType_t<T>;
145+ using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
136146 using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
137- OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
147+ WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
138148 OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
139149 return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId);
140150}
141151template <typename Group, typename T, typename IdT>
142152EnableIfBitcastBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
143- using GroupIdT = typename GroupId<Group>::type;
144- GroupIdT GroupLocalId = static_cast <GroupIdT>(local_id);
145153 using BroadcastT = ConvertToNativeBroadcastType_t<T>;
146- using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
147154 auto BroadcastX = bit_cast<BroadcastT>(x);
148- OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
149- BroadcastT Result =
150- __spirv_GroupBroadcast (group_scope<Group>::value, BroadcastX, OCLId);
155+ BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
151156 return bit_cast<T>(Result);
152157}
153158template <typename Group, typename T, typename IdT>
@@ -173,31 +178,21 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
173178 }
174179 using IdT = vec<size_t , Dimensions>;
175180 using OCLT = detail::ConvertToOpenCLType_t<T>;
181+ using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
176182 using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
177183 IdT VecId;
178184 for (int i = 0 ; i < Dimensions; ++i) {
179185 VecId[i] = local_id[Dimensions - i - 1 ];
180186 }
181- OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
187+ WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
182188 OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
183189 return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId);
184190}
185191template <typename Group, typename T, int Dimensions>
186192EnableIfBitcastBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
187- if (Dimensions == 1 ) {
188- return GroupBroadcast<Group>(x, local_id[0 ]);
189- }
190- using IdT = vec<size_t , Dimensions>;
191193 using BroadcastT = ConvertToNativeBroadcastType_t<T>;
192- using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
193- IdT VecId;
194- for (int i = 0 ; i < Dimensions; ++i) {
195- VecId[i] = local_id[Dimensions - i - 1 ];
196- }
197194 auto BroadcastX = bit_cast<BroadcastT>(x);
198- OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
199- BroadcastT Result =
200- __spirv_GroupBroadcast (group_scope<Group>::value, BroadcastX, OCLId);
195+ BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
201196 return bit_cast<T>(Result);
202197}
203198template <typename Group, typename T, int Dimensions>
0 commit comments