Skip to content

Commit

Permalink
Made bitonic support arbitrary comparators
Browse files Browse the repository at this point in the history
It is the last template param for bitonicSort, bitonicSortTeam and
bitonicSortTeam2. The default just does "operator<".
  • Loading branch information
brian-kelley committed Sep 4, 2019
1 parent 82ff5f3 commit 7bd18ed
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 31 deletions.
70 changes: 42 additions & 28 deletions src/common/KokkosKernels_Sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,22 @@ radixSort2(ValueType* values, ValueType* valuesAux, PermType* perm, PermType* pe
}
}

template<typename Value>
struct DefaultComparator
{
KOKKOS_INLINE_FUNCTION bool operator()(const Value lhs, const Value rhs) const
{
return lhs < rhs;
}
};

//Bitonic merge sort (requires only comparison operators and trivially-copyable)
//Pros: In-place, plenty of parallelism for GPUs, and memory references are coalesced
//Con: O(n log^2(n)) serial time is bad on CPUs
//Good diagram of the algorithm at https://en.wikipedia.org/wiki/Bitonic_sorter
template<typename Ordinal, typename ValueType, typename TeamMember>
template<typename Ordinal, typename ValueType, typename TeamMember, typename Comparator = DefaultComparator<ValueType>>
KOKKOS_INLINE_FUNCTION void
bitonicSort(ValueType* values, Ordinal n, const TeamMember mem)
bitonicSortTeam(ValueType* values, Ordinal n, const TeamMember mem)
{
//Algorithm only works on power-of-two input size only.
//If n is not a power-of-two, will implicitly pretend
Expand All @@ -326,14 +335,15 @@ bitonicSort(ValueType* values, Ordinal n, const TeamMember mem)
Ordinal boxStart = boxID << (1 + i - j); //boxID * boxSize
Ordinal boxOffset = t - (boxStart >> 1); //t - boxID * boxSize / 2;
Ordinal elem1 = boxStart + boxOffset;
Comparator comp;
if(j == 0)
{
//first phase (brown box): within a block, compare with the opposite value in the box
Ordinal elem2 = boxStart + boxSize - 1 - boxOffset;
if(elem2 < n)
{
//both elements in bounds, so compare them and swap if out of order
if(values[elem1] > values[elem2])
if(comp(values[elem2], values[elem1]))
{
ValueType temp = values[elem1];
values[elem1] = values[elem2];
Expand All @@ -347,7 +357,7 @@ bitonicSort(ValueType* values, Ordinal n, const TeamMember mem)
Ordinal elem2 = elem1 + boxSize / 2;
if(elem2 < n)
{
if(values[elem1] > values[elem2])
if(comp(values[elem2], values[elem1]))
{
ValueType temp = values[elem1];
values[elem1] = values[elem2];
Expand All @@ -362,9 +372,9 @@ bitonicSort(ValueType* values, Ordinal n, const TeamMember mem)
}

//Sort "values", while applying the same swaps to "perm"
template<typename Ordinal, typename ValueType, typename PermType, typename TeamMember>
template<typename Ordinal, typename ValueType, typename PermType, typename TeamMember, typename Comparator = DefaultComparator<ValueType>>
KOKKOS_INLINE_FUNCTION void
bitonicSort2(ValueType* values, PermType* perm, Ordinal n, const TeamMember mem)
bitonicSortTeam2(ValueType* values, PermType* perm, Ordinal n, const TeamMember mem)
{
//Algorithm only works on power-of-two input size only.
//If n is not a power-of-two, will implicitly pretend
Expand All @@ -391,18 +401,19 @@ bitonicSort2(ValueType* values, PermType* perm, Ordinal n, const TeamMember mem)
Ordinal boxStart = boxID << (1 + i - j); //boxID * boxSize
Ordinal boxOffset = t - (boxStart >> 1); //t - boxID * boxSize / 2;
Ordinal elem1 = boxStart + boxOffset;
Comparator comp;
if(j == 0)
{
//first phase (brown box): within a block, compare with the opposite value in the box
Ordinal elem2 = boxStart + boxSize - 1 - boxOffset;
if(elem2 < n)
{
//both elements in bounds, so compare them and swap if out of order
if(values[elem1] > values[elem2])
if(comp(values[elem2], values[elem1]))
{
ValueType temp = values[elem1];
ValueType temp1 = values[elem1];
values[elem1] = values[elem2];
values[elem2] = temp;
values[elem2] = temp1;
PermType temp2 = perm[elem1];
perm[elem1] = perm[elem2];
perm[elem2] = temp2;
Expand All @@ -415,11 +426,11 @@ bitonicSort2(ValueType* values, PermType* perm, Ordinal n, const TeamMember mem)
Ordinal elem2 = elem1 + boxSize / 2;
if(elem2 < n)
{
if(values[elem1] > values[elem2])
if(comp(values[elem2], values[elem1]))
{
ValueType temp = values[elem1];
ValueType temp1 = values[elem1];
values[elem1] = values[elem2];
values[elem2] = temp;
values[elem2] = temp1;
PermType temp2 = perm[elem1];
perm[elem1] = perm[elem2];
perm[elem2] = temp2;
Expand All @@ -433,19 +444,19 @@ bitonicSort2(ValueType* values, PermType* perm, Ordinal n, const TeamMember mem)
}

//Functor that sorts a view on one team
template<typename View, typename Ordinal, typename TeamMember>
template<typename View, typename Ordinal, typename TeamMember, typename Comparator>
struct BitonicSingleTeamFunctor
{
BitonicSingleTeamFunctor(View& v_) : v(v_) {}
KOKKOS_INLINE_FUNCTION void operator()(const TeamMember t) const
{
bitonicSort(v.data(), v.extent(0), t);
bitonicSortTeam<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data(), v.extent(0), t);
};
View v;
};

//Functor that sorts equally sized chunks on each team
template<typename View, typename Ordinal, typename TeamMember>
template<typename View, typename Ordinal, typename TeamMember, typename Comparator>
struct BitonicChunkFunctor
{
BitonicChunkFunctor(View& v_, Ordinal chunkSize_) : v(v_), chunkSize(chunkSize_) {}
Expand All @@ -456,14 +467,14 @@ struct BitonicChunkFunctor
Ordinal n = chunkSize;
if(chunkStart + n > Ordinal(v.extent(0)))
n = v.extent(0) - chunkStart;
bitonicSort(v.data() + chunkStart, n, t);
bitonicSortTeam<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data() + chunkStart, n, t);
};
View v;
Ordinal chunkSize;
};

//Functor that does just the first phase (brown) of bitonic sort on equally-sized chunks
template<typename View, typename Ordinal, typename TeamMember>
template<typename View, typename Ordinal, typename TeamMember, typename Comparator>
struct BitonicPhase1Functor
{
typedef typename View::value_type Value;
Expand All @@ -480,11 +491,12 @@ struct BitonicPhase1Functor
Kokkos::parallel_for(Kokkos::TeamThreadRange(t, work),
[=](const Ordinal i)
{
Comparator comp;
Ordinal elem1 = boxStart + workStart + i;
Ordinal elem2 = boxStart + workReflect - i;
if(elem2 < Ordinal(v.extent(0)))
{
if(v(elem1) > v(elem2))
if(comp(v(elem2), v(elem1)))
{
Value temp = v(elem1);
v(elem1) = v(elem2);
Expand All @@ -499,7 +511,7 @@ struct BitonicPhase1Functor
};

//Functor that does the second phase (red) of bitonic sort
template<typename View, typename Ordinal, typename TeamMember>
template<typename View, typename Ordinal, typename TeamMember, typename Comparator>
struct BitonicPhase2Functor
{
typedef typename View::value_type Value;
Expand All @@ -516,14 +528,15 @@ struct BitonicPhase2Functor
Ordinal work = boxSize / teamsPerBox / 2;
Ordinal workStart = boxStart + work * (t.league_rank() % teamsPerBox);
Ordinal jump = boxSize / 2;
Comparator comp;
Kokkos::parallel_for(Kokkos::TeamThreadRange(t, work),
[=](const Ordinal i)
{
Ordinal elem1 = workStart + i;
Ordinal elem2 = workStart + jump + i;
if(elem2 < Ordinal(v.extent(0)))
{
if(v(elem1) > v(elem2))
if(comp(v(elem2), v(elem1)))
{
Value temp = v(elem1);
v(elem1) = v(elem2);
Expand Down Expand Up @@ -552,7 +565,7 @@ struct BitonicPhase2Functor
Ordinal elem2 = elem1 + subBoxSize / 2;
if(elem2 < Ordinal(v.extent(0)))
{
if(v(elem1) > v(elem2))
if(comp(v(elem2), v(elem1)))
{
Value temp = v(elem1);
v(elem1) = v(elem2);
Expand All @@ -572,9 +585,10 @@ struct BitonicPhase2Functor
//Generally ~2x slower than Kokkos::sort() for large arrays (> 50 M elements),
//but faster for smaller arrays.
//
//This is also more general: supports 8- and 16-bit integers,
//and any other trivially copyable value type that has device-compatible comparison operators.
template<typename View, typename ExecSpace, typename Ordinal = typename View::size_type>
//This is more general than BinSort: bitonic supports any trivially copyable type
//and an arbitrary device-compatible comparison operator (provided through operator() of Comparator)
//If comparator is void, use operator< (which should only be used for primitives)
template<typename View, typename ExecSpace, typename Ordinal, typename Comparator = DefaultComparator<typename View::value_type>>
void bitonicSort(View v)
{
typedef Kokkos::TeamPolicy<ExecSpace> team_policy;
Expand All @@ -584,7 +598,7 @@ void bitonicSort(View v)
if(n <= Ordinal(1) << 16)
{
Kokkos::parallel_for(team_policy(1, Kokkos::AUTO()),
BitonicSingleTeamFunctor<View, Ordinal, team_member>(v));
BitonicSingleTeamFunctor<View, Ordinal, team_member, Comparator>(v));
}
else
{
Expand All @@ -596,16 +610,16 @@ void bitonicSort(View v)
Ordinal numPerTeam = npot / numTeams;
//First, sort within teams
Kokkos::parallel_for(team_policy(numTeams, Kokkos::AUTO()),
BitonicChunkFunctor<View, Ordinal, team_member>(v, numPerTeam));
BitonicChunkFunctor<View, Ordinal, team_member, Comparator>(v, numPerTeam));
for(int teamsPerBox = 2; teamsPerBox <= npot / numPerTeam; teamsPerBox *= 2)
{
Ordinal boxSize = teamsPerBox * numPerTeam;
Kokkos::parallel_for(team_policy(numTeams, Kokkos::AUTO()),
BitonicPhase1Functor<View, Ordinal, team_member>(v, boxSize, teamsPerBox));
BitonicPhase1Functor<View, Ordinal, team_member, Comparator>(v, boxSize, teamsPerBox));
for(int boxDiv = 1; teamsPerBox >> boxDiv; boxDiv++)
{
Kokkos::parallel_for(team_policy(numTeams, Kokkos::AUTO()),
BitonicPhase2Functor<View, Ordinal, team_member>(v, boxSize >> boxDiv, teamsPerBox >> boxDiv));
BitonicPhase2Functor<View, Ordinal, team_member, Comparator>(v, boxSize >> boxDiv, teamsPerBox >> boxDiv));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/sparse/KokkosSparse_spadd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ namespace Experimental {
size_type rowStart = Crowptrs(i);
size_type rowEnd = Crowptrs(i + 1);
size_type rowNum = rowEnd - rowStart;
KokkosKernels::Impl::bitonicSort2<size_type, typename CcolindsT::non_const_value_type, typename CcolindsT::non_const_value_type, TeamMember>
KokkosKernels::Impl::bitonicSortTeam2<size_type, typename CcolindsT::non_const_value_type, typename CcolindsT::non_const_value_type, TeamMember>
(Ccolinds.data() + rowStart, ABperm.data() + rowStart, rowNum, t);
}
CrowptrsT Crowptrs;
Expand Down
4 changes: 2 additions & 2 deletions test_common/Test_Common_Sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ struct BitonicSortFunctor
KOKKOS_INLINE_FUNCTION void operator()(const TeamMem t) const
{
int i = t.league_rank();
KokkosKernels::Impl::bitonicSort<int, Value, TeamMem>(&values(offsets(i)), counts(i), t);
KokkosKernels::Impl::bitonicSortTeam<int, Value, TeamMem>(&values(offsets(i)), counts(i), t);
}
ValView values;
OrdView offsets;
Expand Down Expand Up @@ -247,7 +247,7 @@ void testBitonicSort()
//Create a view of randomized data
typedef typename ExecSpace::memory_space mem_space;
typedef Kokkos::View<Scalar*, mem_space> ValView;
size_t n = 100000;
size_t n = 1599898;
ValView data("Bitonic sort testing data", n);
fillRandom(data);
KokkosKernels::Impl::bitonicSort<ValView, ExecSpace, int>(data);
Expand Down

0 comments on commit 7bd18ed

Please sign in to comment.