-
Notifications
You must be signed in to change notification settings - Fork 3
/
SortingKernel.cpp
90 lines (80 loc) · 3.18 KB
/
SortingKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/Sorting.h>
#include <ATen/native/SortingUtils.h>
namespace at { namespace native {
namespace {
static void topk_kernel(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted) {
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "topk_cpu", [&] {
dim_apply(
{self, values, indices},
dim,
[&](int64_t i, TensorList tl) {
auto tmp_values = tl[0].accessor<scalar_t, 1>();
auto mode_values = tl[1].accessor<scalar_t, 1>();
auto mode_indices = tl[2].accessor<int64_t, 1>();
auto n = tmp_values.size(0);
auto use_partial_sort = k * 64 <= n;
using elem_t = std::pair<scalar_t, int64_t>;
std::vector<elem_t> queue(n);
for (int64_t j = 0; j < n; j++) {
queue[j].first = tmp_values[j];
queue[j].second = j;
}
// we want NaN to be sorted as top for numpy compatibility
if (use_partial_sort) {
if (largest) {
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<scalar_t>(x.first) && !_isnan<scalar_t>(y.first)) || (x.first > y.first));
});
} else {
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<scalar_t>(x.first) && _isnan<scalar_t>(y.first)) || (x.first < y.first));
});
}
} else {
if (largest) {
std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<scalar_t>(x.first) && !_isnan<scalar_t>(y.first)) || (x.first > y.first));
});
if (sorted) {
std::sort(queue.begin(), queue.begin() + k - 1,
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<scalar_t>(x.first) && !_isnan<scalar_t>(y.first)) || (x.first > y.first));
});
}
} else {
std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<scalar_t>(x.first) && _isnan<scalar_t>(y.first)) || (x.first < y.first));
});
if (sorted) {
std::sort(queue.begin(), queue.begin() + k -1,
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<scalar_t>(x.first) && _isnan<scalar_t>(y.first)) || (x.first < y.first));
});
}
}
}
for (int64_t j = 0; j < k; j++) {
mode_values[j] = queue[j].first;
mode_indices[j] = queue[j].second;
}
});
});
}
} // anonymous namespace
REGISTER_DISPATCH(topk_stub, &topk_kernel);
}} //at::native