Skip to content

Commit

Permalink
add multi-dim sort
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Sep 19, 2024
1 parent e5acb71 commit 3a366a6
Show file tree
Hide file tree
Showing 3 changed files with 8,491 additions and 451 deletions.
3 changes: 2 additions & 1 deletion arkouda/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def sort(pda: pdarray, algorithm: SortingAlgorithm = SortingAlgorithm.RadixSortL
if pda.size == 0:
return zeros(0, dtype=pda.dtype)
repMsg = generic_msg(
cmd=f"sort<{pda.dtype.name},{pda.ndim}>", args={"alg": algorithm.name, "array": pda, "axis": axis}
cmd=f"sort<{pda.dtype.name},{pda.ndim}>",
args={"alg": algorithm.name, "array": pda, "axis": axis},
)
return create_pdarray(cast(str, repMsg))
49 changes: 47 additions & 2 deletions src/SortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module SortMsg
use Logging;
use Message;
private use ArgSortMsg;
use NumPyDType;
use NumPyDType only whichDtype;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
Expand Down Expand Up @@ -48,8 +48,53 @@ module SortMsg
}
}


proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ( (t == real) || (t==int) || (t==uint(64)) ) && ( d.rank > 1 ) {

var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg);
const itemsize = dtypeSize(whichDtype(t));
overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize));

const DD = domOffAxis(d, axis);
var sorted = makeDistArray((...d.shape), t);

if algorithm == SortingAlgorithm.TwoArrayRadixSort {
for idx in DD {
// make a copy of the array along the slice corresponding to idx
// TODO: create a twoArrayRadixSort that operates on a slice of the array
// in place instead of requiring the copy in/out
var slice = makeDistArray(d.dim(axis).size, t);
forall i in d.dim(axis) with (var perpIdx = idx) {
perpIdx[axis] = i;
slice[i] = array[perpIdx];
}

ArgSortMsg.dynamicTwoArrayRadixSort(slice, comparator=myDefaultComparator);

forall i in d.dim(axis) with (var perpIdx = idx) {
perpIdx[axis] = i;
sorted[perpIdx] = slice[i];
}
}
} else {
// TODO: make a version of radixSortLSD_keys that does the sort on
// slices of `e.a` directly instead of requiring a copy for each slice
for idx in DD {
const sliceDom = domOnAxis(d, idx, axis),
sliced1D = removeDegenRanks(array[sliceDom], 1),
sliceSorted = radixSortLSD_keys(sliced1D);

forall i in sliceDom do sorted[i] = sliceSorted[i[axis]];
}
}

return sorted;
}


proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ( (t != real) && (t!=int) && (t!=uint(64)) ) || ( d.rank != 1 ) {
where ( (t != real) && (t!=int) && (t!=uint(64)) ) {
throw new Error("Insightful error message.");
}

Expand Down
Loading

0 comments on commit 3a366a6

Please sign in to comment.