1+ // ==------------ group_sort_impl.hpp ---------------------------------------==//
2+ //
3+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+ // This file includes some functions for group sorting algorithm implementations
9+ //
10+
11+ #pragma once
12+
13+ #if __cplusplus >= 201703L
14+ #include < CL/sycl/detail/helpers.hpp>
15+
16+ #ifdef __SYCL_DEVICE_ONLY__
17+
18+ __SYCL_INLINE_NAMESPACE (cl) {
19+ namespace sycl {
20+ namespace detail {
21+
22+ // ---- merge sort implementation
23+
24+ // following two functions could be useless if std::[lower|upper]_bound worked
25+ // well
26+ template <typename Acc, typename Value, typename Compare>
27+ std::size_t lower_bound (Acc acc, std::size_t first, std::size_t last,
28+ const Value &value, Compare comp) {
29+ std::size_t n = last - first;
30+ std::size_t cur = n;
31+ std::size_t it;
32+ while (n > 0 ) {
33+ it = first;
34+ cur = n / 2 ;
35+ it += cur;
36+ if (comp (acc[it], value)) {
37+ n -= cur + 1 , first = ++it;
38+ } else
39+ n = cur;
40+ }
41+ return first;
42+ }
43+
44+ template <typename Acc, typename Value, typename Compare>
45+ std::size_t upper_bound (Acc acc, const std::size_t first,
46+ const std::size_t last, const Value &value,
47+ Compare comp) {
48+ return detail::lower_bound (acc, first, last, value,
49+ [comp](auto x, auto y) { return !comp (y, x); });
50+ }
51+
52+ // swap for all data types including tuple-like types
53+ template <typename T> void swap_tuples (T &a, T &b) { std::swap (a, b); }
54+
55+ template <template <typename ...> class TupleLike , typename T1, typename T2>
56+ void swap_tuples (TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
57+ std::swap (std::get<0 >(a), std::get<0 >(b));
58+ std::swap (std::get<1 >(a), std::get<1 >(b));
59+ }
60+
61+ template <typename Iter> struct GetValueType {
62+ using type = typename std::iterator_traits<Iter>::value_type;
63+ };
64+
65+ template <typename ElementType, access::address_space Space>
66+ struct GetValueType <sycl::multi_ptr<ElementType, Space>> {
67+ using type = ElementType;
68+ };
69+
70+ // since we couldn't assign data to raw memory, it's better to use placement
71+ // for first assignment
72+ template <typename Acc, typename T>
73+ void set_value (Acc ptr, const std::size_t idx, const T &val, bool is_first) {
74+ if (is_first) {
75+ ::new (ptr + idx) T (val);
76+ } else {
77+ ptr[idx] = val;
78+ }
79+ }
80+
81+ template <typename InAcc, typename OutAcc, typename Compare>
82+ void merge (const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
83+ const std::size_t start_1, const std::size_t end_1,
84+ const std::size_t end_2, const std::size_t start_out, Compare comp,
85+ const std::size_t chunk, bool is_first) {
86+ const std::size_t start_2 = end_1;
87+ // Borders of the sequences to merge within this call
88+ const std::size_t local_start_1 =
89+ sycl::min (static_cast <std::size_t >(offset + start_1), end_1);
90+ const std::size_t local_end_1 =
91+ sycl::min (static_cast <std::size_t >(local_start_1 + chunk), end_1);
92+ const std::size_t local_start_2 =
93+ sycl::min (static_cast <std::size_t >(offset + start_2), end_2);
94+ const std::size_t local_end_2 =
95+ sycl::min (static_cast <std::size_t >(local_start_2 + chunk), end_2);
96+
97+ const std::size_t local_size_1 = local_end_1 - local_start_1;
98+ const std::size_t local_size_2 = local_end_2 - local_start_2;
99+
100+ // TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st
101+ // to improve performance
102+
103+ // Process 1st sequence
104+ if (local_start_1 < local_end_1) {
105+ // Reduce the range for searching within the 2nd sequence and handle bound
106+ // items find left border in 2nd sequence
107+ const auto local_l_item_1 = in_acc1[local_start_1];
108+ std::size_t l_search_bound_2 =
109+ detail::lower_bound (in_acc1, start_2, end_2, local_l_item_1, comp);
110+ const std::size_t l_shift_1 = local_start_1 - start_1;
111+ const std::size_t l_shift_2 = l_search_bound_2 - start_2;
112+
113+ set_value (out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
114+ is_first);
115+
116+ std::size_t r_search_bound_2{};
117+ // find right border in 2nd sequence
118+ if (local_size_1 > 1 ) {
119+ const auto local_r_item_1 = in_acc1[local_end_1 - 1 ];
120+ r_search_bound_2 = detail::lower_bound (in_acc1, l_search_bound_2, end_2,
121+ local_r_item_1, comp);
122+ const auto r_shift_1 = local_end_1 - 1 - start_1;
123+ const auto r_shift_2 = r_search_bound_2 - start_2;
124+
125+ set_value (out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
126+ is_first);
127+ }
128+
129+ // Handle intermediate items
130+ for (std::size_t idx = local_start_1 + 1 ; idx < local_end_1 - 1 ; ++idx) {
131+ const auto intermediate_item_1 = in_acc1[idx];
132+ // we shouldn't seek in whole 2nd sequence. Just for the part where the
133+ // 1st sequence should be
134+ l_search_bound_2 =
135+ detail::lower_bound (in_acc1, l_search_bound_2, r_search_bound_2,
136+ intermediate_item_1, comp);
137+ const std::size_t shift_1 = idx - start_1;
138+ const std::size_t shift_2 = l_search_bound_2 - start_2;
139+
140+ set_value (out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
141+ is_first);
142+ }
143+ }
144+ // Process 2nd sequence
145+ if (local_start_2 < local_end_2) {
146+ // Reduce the range for searching within the 1st sequence and handle bound
147+ // items find left border in 1st sequence
148+ const auto local_l_item_2 = in_acc1[local_start_2];
149+ std::size_t l_search_bound_1 =
150+ detail::upper_bound (in_acc1, start_1, end_1, local_l_item_2, comp);
151+ const std::size_t l_shift_1 = l_search_bound_1 - start_1;
152+ const std::size_t l_shift_2 = local_start_2 - start_2;
153+
154+ set_value (out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
155+ is_first);
156+
157+ std::size_t r_search_bound_1{};
158+ // find right border in 1st sequence
159+ if (local_size_2 > 1 ) {
160+ const auto local_r_item_2 = in_acc1[local_end_2 - 1 ];
161+ r_search_bound_1 = detail::upper_bound (in_acc1, l_search_bound_1, end_1,
162+ local_r_item_2, comp);
163+ const std::size_t r_shift_1 = r_search_bound_1 - start_1;
164+ const std::size_t r_shift_2 = local_end_2 - 1 - start_2;
165+
166+ set_value (out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
167+ is_first);
168+ }
169+
170+ // Handle intermediate items
171+ for (auto idx = local_start_2 + 1 ; idx < local_end_2 - 1 ; ++idx) {
172+ const auto intermediate_item_2 = in_acc1[idx];
173+ // we shouldn't seek in whole 1st sequence. Just for the part where the
174+ // 2nd sequence should be
175+ l_search_bound_1 =
176+ detail::upper_bound (in_acc1, l_search_bound_1, r_search_bound_1,
177+ intermediate_item_2, comp);
178+ const std::size_t shift_1 = l_search_bound_1 - start_1;
179+ const std::size_t shift_2 = idx - start_2;
180+
181+ set_value (out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
182+ is_first);
183+ }
184+ }
185+ }
186+
187+ template <typename Iter, typename Compare>
188+ void bubble_sort (Iter first, const std::size_t begin, const std::size_t end,
189+ Compare comp) {
190+ if (begin < end) {
191+ for (std::size_t i = begin; i < end; ++i) {
192+ // Handle intermediate items
193+ for (std::size_t idx = i + 1 ; idx < end; ++idx) {
194+ if (comp (first[idx], first[i])) {
195+ detail::swap_tuples (first[i], first[idx]);
196+ }
197+ }
198+ }
199+ }
200+ }
201+
202+ template <typename Group, typename Iter, typename Compare>
203+ void merge_sort (Group group, Iter first, const std::size_t n, Compare comp,
204+ std::byte *scratch) {
205+ using T = typename GetValueType<Iter>::type;
206+ auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
207+ const std::size_t idx = id.get_local_linear_id ();
208+ const std::size_t local = group.get_local_range ().size ();
209+ const std::size_t chunk = (n - 1 ) / local + 1 ;
210+
211+ // we need to sort within work item first
212+ bubble_sort (first, idx * chunk, sycl::min ((idx + 1 ) * chunk, n), comp);
213+ id.barrier ();
214+
215+ T *temp = reinterpret_cast <T *>(scratch);
216+ bool data_in_temp = false ;
217+ bool is_first = true ;
218+ std::size_t sorted_size = 1 ;
219+ while (sorted_size * chunk < n) {
220+ const std::size_t start_1 =
221+ sycl::min (2 * sorted_size * chunk * (idx / sorted_size), n);
222+ const std::size_t end_1 = sycl::min (start_1 + sorted_size * chunk, n);
223+ const std::size_t end_2 = sycl::min (end_1 + sorted_size * chunk, n);
224+ const std::size_t offset = chunk * (idx % sorted_size);
225+
226+ if (!data_in_temp) {
227+ merge (offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
228+ is_first);
229+ } else {
230+ merge (offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
231+ /* is_first*/ false );
232+ }
233+ id.barrier ();
234+
235+ data_in_temp = !data_in_temp;
236+ sorted_size *= 2 ;
237+ if (is_first)
238+ is_first = false ;
239+ }
240+
241+ // copy back if data is in a temporary storage
242+ if (data_in_temp) {
243+ for (std::size_t i = 0 ; i < chunk; ++i) {
244+ if (idx * chunk + i < n) {
245+ first[idx * chunk + i] = temp[idx * chunk + i];
246+ }
247+ }
248+ id.barrier ();
249+ }
250+ }
251+
252+ } // namespace detail
253+ } // namespace sycl
254+ } // __SYCL_INLINE_NAMESPACE(cl)
255+ #endif
256+ #endif // __cplusplus >=201703L
0 commit comments