Skip to content

Commit

Permalink
MueLu CoalesceDrop_kokkos: Move serial sort to DroppingCommon
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Glusa <caglusa@sandia.gov>
  • Loading branch information
cgcgcg committed Dec 17, 2024
1 parent 785e97d commit c87813d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
33 changes: 1 addition & 32 deletions packages/muelu/src/Graph/MatrixTransformation/MueLu_CutDrop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,37 +423,6 @@ class ScaledDistanceLaplacianComparison {
}
};

template <class view_type, class comparator_type>
KOKKOS_INLINE_FUNCTION void serialHeapSort(view_type& v, comparator_type comparator) {
auto N = v.extent(0);
size_t start = N / 2;
size_t end = N;
while (end > 1) {
if (start > 0)
start = start - 1;
else {
end = end - 1;
auto temp = v(0);
v(0) = v(end);
v(end) = temp;
}
size_t root = start;
while (2 * root + 1 < end) {
size_t child = 2 * root + 1;
if ((child + 1 < end) and (comparator(v(child), v(child + 1))))
++child;

if (comparator(v(root), v(child))) {
auto temp = v(root);
v(root) = v(child);
v(child) = temp;
root = child;
} else
break;
}
}
}

/*!
@class CutDropFunctor
@brief Order each row by a criterion, compare the ratio of values and drop all entries once the ratio is below the threshold.
Expand Down Expand Up @@ -499,7 +468,7 @@ class CutDropFunctor {
for (size_t i = 0; i < nnz; ++i) {
row_permutation(i) = i;
}
serialHeapSort(row_permutation, comparator);
Misc::serialHeapSort(row_permutation, comparator);

size_t keepStart = 0;
size_t dropStart = nnz;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,37 @@ class SymmetrizeFunctor {
}
};

template <class view_type, class comparator_type>
KOKKOS_INLINE_FUNCTION void serialHeapSort(view_type& v, comparator_type comparator) {
auto N = v.extent(0);
size_t start = N / 2;
size_t end = N;
while (end > 1) {
if (start > 0)
start = start - 1;
else {
end = end - 1;
auto temp = v(0);
v(0) = v(end);
v(end) = temp;
}
size_t root = start;
while (2 * root + 1 < end) {
size_t child = 2 * root + 1;
if ((child + 1 < end) and (comparator(v(child), v(child + 1))))
++child;

if (comparator(v(root), v(child))) {
auto temp = v(root);
v(root) = v(child);
v(child) = temp;
root = child;
} else
break;
}
}
}

} // namespace Misc

} // namespace MueLu
Expand Down

0 comments on commit c87813d

Please sign in to comment.