@@ -109,6 +109,7 @@ template <typename Group> bool GroupAll(Group g, bool pred) {
109109template <typename ParentGroup>
110110bool GroupAll (ext::oneapi::experimental::ballot_group<ParentGroup> g,
111111 bool pred) {
112+ #if defined (__SPIR__)
112113 // ballot_group partitions its parent into two groups (0 and 1)
113114 // We have to force each group down different control flow
114115 // Work-items in the "false" group (0) may still be active
@@ -117,6 +118,10 @@ bool GroupAll(ext::oneapi::experimental::ballot_group<ParentGroup> g,
117118 } else {
118119 return __spirv_GroupNonUniformAll (group_scope<ParentGroup>::value, pred);
119120 }
121+ #elif defined (__NVPTX__)
122+ sycl::vec<unsigned , 4 > MemberMask = detail::ExtractMask (detail::GetMask (g));
123+ return __nvvm_vote_all_sync (MemberMask[0 ], pred);
124+ #endif
120125}
121126
122127template <typename Group> bool GroupAny (Group g, bool pred) {
@@ -125,6 +130,7 @@ template <typename Group> bool GroupAny(Group g, bool pred) {
125130template <typename ParentGroup>
126131bool GroupAny (ext::oneapi::experimental::ballot_group<ParentGroup> g,
127132 bool pred) {
133+ #if defined (__SPIR__)
128134 // ballot_group partitions its parent into two groups (0 and 1)
129135 // We have to force each group down different control flow
130136 // Work-items in the "false" group (0) may still be active
@@ -133,6 +139,10 @@ bool GroupAny(ext::oneapi::experimental::ballot_group<ParentGroup> g,
133139 } else {
134140 return __spirv_GroupNonUniformAny (group_scope<ParentGroup>::value, pred);
135141 }
142+ #elif defined (__NVPTX__)
143+ sycl::vec<unsigned , 4 > MemberMask = detail::ExtractMask (detail::GetMask (g));
144+ return __nvvm_vote_any_sync (MemberMask[0 ], pred);
145+ #endif
136146}
137147
138148// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
@@ -219,13 +229,18 @@ GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
219229 // ballot_group partitions its parent into two groups (0 and 1)
220230 // We have to force each group down different control flow
221231 // Work-items in the "false" group (0) may still be active
232+ #if defined(__SPIR__)
222233 if (g.get_group_id () == 1 ) {
223234 return __spirv_GroupNonUniformBroadcast (group_scope<ParentGroup>::value,
224235 OCLX, OCLId);
225236 } else {
226237 return __spirv_GroupNonUniformBroadcast (group_scope<ParentGroup>::value,
227238 OCLX, OCLId);
228239 }
240+ #elif defined(__NVPTX__)
241+ sycl::vec<unsigned , 4 > MemberMask = detail::ExtractMask (detail::GetMask (g));
242+ return __nvvm_shfl_sync_idx_i32 (MemberMask[0 ], x, LocalId, 31 ); // 31 not 32 as docs suggest.
243+ #endif
229244}
230245
231246template <typename Group, typename T, typename IdT>
@@ -886,7 +901,7 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
886901template <typename Group>
887902typename std::enable_if_t <
888903 ext::oneapi::experimental::is_user_constructed_group_v<Group>>
889- ControlBarrier (Group, memory_scope FenceScope, memory_order Order) {
904+ ControlBarrier (Group g , memory_scope FenceScope, memory_order Order) {
890905#if defined(__SPIR__)
891906 // SPIR-V does not define an instruction to synchronize partial groups.
892907 // However, most (possibly all?) of the current SPIR-V targets execute
@@ -899,6 +914,7 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
899914 __spv::MemorySemanticsMask::CrossWorkgroupMemory);
900915#elif defined(__NVPTX__)
901916 // TODO: Call syncwarp with appropriate mask extracted from the group
917+ __nvvm_bar_warp_sync (detail::ExtractMask (detail::GetMask (g))[0 ]);
902918#endif
903919}
904920
0 commit comments