Skip to content

Commit

Permalink
Closes #3722: Remove @arkouda.registerND from SortMsg.chpl
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Sep 20, 2024
1 parent 8dae0c5 commit 07fe429
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 147 deletions.
2 changes: 1 addition & 1 deletion arkouda/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def searchsorted(
_x1 = x1

resp = generic_msg(
cmd=f"searchSorted{x2.ndim}D",
cmd=f"searchSorted<float64,1,float64,{x2.ndim}>",
args={
"x1": _x1._array,
"x2": x2._array,
Expand Down
3 changes: 2 additions & 1 deletion arkouda/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def sort(pda: pdarray, algorithm: SortingAlgorithm = SortingAlgorithm.RadixSortL
if pda.size == 0:
return zeros(0, dtype=pda.dtype)
repMsg = generic_msg(
cmd=f"sort{pda.ndim}D", args={"alg": algorithm.name, "array": pda, "axis": axis}
cmd=f"sort<{pda.dtype.name},{pda.ndim}>",
args={"alg": algorithm.name, "array": pda, "axis": axis},
)
return create_pdarray(cast(str, repMsg))
6 changes: 3 additions & 3 deletions src/ArgSortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ module ArgSortMsg
};
config const defaultSortAlgorithm: SortingAlgorithm = SortingAlgorithm.RadixSortLSD;

proc getSortingAlgoritm(algoName:string) throws{
proc getSortingAlgorithm(algoName:string) throws{
var algorithm = defaultSortAlgorithm;
if algoName != "" {
try {
Expand Down Expand Up @@ -430,7 +430,7 @@ module ArgSortMsg
{
const name = msgArgs["name"],
algoName = msgArgs["algoName"].toScalar(string),
algorithm = getSortingAlgoritm(algoName),
algorithm = getSortingAlgorithm(algoName),
axis = msgArgs["axis"].toScalar(int),
symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype);
Expand All @@ -449,7 +449,7 @@ module ArgSortMsg
const name = msgArgs["name"].toScalar(string),
strings = getSegString(name, st),
algoName = msgArgs["algoName"].toScalar(string),
algorithm = getSortingAlgoritm(algoName);
algorithm = getSortingAlgorithm(algoName);

// check and throw if over memory limit
overMemLimit((8 * strings.size * 8)
Expand Down
217 changes: 75 additions & 142 deletions src/SortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ module SortMsg
use Logging;
use Message;
private use ArgSortMsg;
use NumPyDType only whichDtype;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
Expand All @@ -29,157 +30,84 @@ module SortMsg
}

/* sort takes pdarray and returns a sorted copy of the array */
@arkouda.registerND
proc sortMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const algoName = msgArgs.getValueOf("alg"),
name = msgArgs.getValueOf("array"),
axis = msgArgs.get("axis").getIntValue(),
rname = st.nextName();

var algorithm: SortingAlgorithm = defaultSortAlgorithm;
if algoName != "" {
try {
algorithm = algoName: SortingAlgorithm;
} catch {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algoName),
lineNumber=getLineNumber(),
pn,
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
@arkouda.registerCommand
proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ((t == real) || (t == int) || (t == uint(64))) && (d.rank == 1) {

var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg);
const itemsize = dtypeSize(whichDtype(t));
overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize));

if algorithm == SortingAlgorithm.TwoArrayRadixSort {
var sorted = makeDistArray(array);
ArgSortMsg.dynamicTwoArrayRadixSort(sorted, comparator=myDefaultComparator);
return sorted;
} else {
var sorted = radixSortLSD_keys(array);
return sorted;
}
}

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

sortLogger.debug(
getModuleName(),pn,getLineNumber(),
"cmd: %s, name: %s, sortedName: %s, dtype: %?, nd: %i, axis: %i".format(
cmd, name, rname, gEnt.dtype, nd, axis
)
);

proc doSort(type t): MsgTuple throws
where nd == 1
{
overMemLimit(radixSortLSD_keys_memEst(gEnt.size, gEnt.itemsize));

const e = toSymEntry(gEnt, t);

if algorithm == SortingAlgorithm.TwoArrayRadixSort {
var sorted = makeDistArray(e.a);
ArgSortMsg.dynamicTwoArrayRadixSort(sorted, comparator=myDefaultComparator);
st.addEntry(rname, createSymEntry(sorted));
} else {
var sorted = radixSortLSD_keys(e.a);
st.addEntry(rname, createSymEntry(sorted));
}

const repMsg = "created " + st.attrib(rname);
sortLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ((t == real) || (t==int) || (t==uint(64))) && (d.rank > 1) {

proc doSort(type t): MsgTuple throws
where nd > 1
{
const e = toSymEntry(gEnt, t, nd),
DD = domOffAxis(e.a.domain, axis);
var sorted = makeDistArray((...e.a.domain.shape), t);

if algorithm == SortingAlgorithm.TwoArrayRadixSort {
for idx in DD {
// make a copy of the array along the slice corresponding to idx
// TODO: create a twoArrayRadixSort that operates on a slice of the array
// in place instead of requiring the copy in/out
var slice = makeDistArray(e.a.domain.dim(axis).size, t);
forall i in e.a.domain.dim(axis) with (var perpIdx = idx) {
perpIdx[axis] = i;
slice[i] = e.a[perpIdx];
}

ArgSortMsg.dynamicTwoArrayRadixSort(slice, comparator=myDefaultComparator);

forall i in e.a.domain.dim(axis) with (var perpIdx = idx) {
perpIdx[axis] = i;
sorted[perpIdx] = slice[i];
}
}
} else {
// TODO: make a version of radixSortLSD_keys that does the sort on
// slices of `e.a` directly instead of requiring a copy for each slice
for idx in DD {
const sliceDom = domOnAxis(e.a.domain, idx, axis),
sliced1D = removeDegenRanks(e.a[sliceDom], 1),
sliceSorted = radixSortLSD_keys(sliced1D);

forall i in sliceDom do sorted[i] = sliceSorted[i[axis]];
var algorithm: SortingAlgorithm = ArgSortMsg.getSortingAlgorithm(alg);
const itemsize = dtypeSize(whichDtype(t));
overMemLimit(radixSortLSD_keys_memEst(d.size, itemsize));

const DD = domOffAxis(d, axis);
var sorted = makeDistArray((...d.shape), t);

if algorithm == SortingAlgorithm.TwoArrayRadixSort {
for idx in DD {
// make a copy of the array along the slice corresponding to idx
// TODO: create a twoArrayRadixSort that operates on a slice of the array
// in place instead of requiring the copy in/out
var slice = makeDistArray(d.dim(axis).size, t);
forall i in d.dim(axis) with (var perpIdx = idx) {
perpIdx[axis] = i;
slice[i] = array[perpIdx];
}
}

st.addEntry(rname, createSymEntry(sorted));
const repMsg = "created " + st.attrib(rname);
sortLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
ArgSortMsg.dynamicTwoArrayRadixSort(slice, comparator=myDefaultComparator);

select gEnt.dtype {
when DType.Int64 do return doSort(int);
when DType.UInt64 do return doSort(uint);
when DType.Float64 do return doSort(real);
otherwise {
var errorMsg = notImplementedError(pn,gEnt.dtype);
sortLogger.error(getModuleName(),pn,getLineNumber(), errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
forall i in d.dim(axis) with (var perpIdx = idx) {
perpIdx[axis] = i;
sorted[perpIdx] = slice[i];
}
}
} else {
// TODO: make a version of radixSortLSD_keys that does the sort on
// slices of `e.a` directly instead of requiring a copy for each slice
for idx in DD {
const sliceDom = domOnAxis(d, idx, axis),
sliced1D = removeDegenRanks(array[sliceDom], 1),
sliceSorted = radixSortLSD_keys(sliced1D);

forall i in sliceDom do sorted[i] = sliceSorted[i[axis]];
}
}

return sorted;
}

// https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#array_api.searchsorted
@arkouda.registerND
proc searchSortedMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const x1 = msgArgs.getValueOf("x1"),
x2 = msgArgs.getValueOf("x2"),
side = msgArgs.getValueOf("side"),
rname = st.nextName();

var gEntX1: borrowed GenSymEntry = getGenericTypedArrayEntry(x1, st),
gEntX2: borrowed GenSymEntry = getGenericTypedArrayEntry(x2, st);
proc sort(array: [?d] ?t, alg: string, axis: int): [d] t throws
where ((t != real) && (t!=int) && (t!=uint(64))) {
throw new Error("sort does not support type %s".format(type2str(t)));
}

if side != "left" && side != "right" {
throw getErrorWithContext(
msg="Unrecognized side: %s".format(side),
lineNumber=getLineNumber(),
pn,
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
// https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#array_api.searchsorted
@arkouda.registerCommand
proc searchSorted(x1: [?d1] ?t, x2: [?d2] ?t2, side: string): [d2] int throws
where (t == real) && (t2 == real) && (d1.rank == 1) {

// TODO: add support for Float32
if gEntX1.dtype != DType.Float64 || gEntX2.dtype != DType.Float64 {
throw getErrorWithContext(
msg="searchsorted only supports Float64 arrays",
lineNumber=getLineNumber(),
pn,
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
if side != "left" && side != "right" {
throw new Error("searchSorted side must be a string with value 'left' or 'right'.");
}

sortLogger.debug(
getModuleName(),pn,getLineNumber(),
"cmd: %s, x1: %s, x2: %s, side: %s, rname: %s, dtype: %?, nd: %i".format(
cmd, x1, x2, side, rname, gEntX1.dtype, nd
)
);

const e1 = toSymEntry(gEntX1, real, 1),
e2 = toSymEntry(gEntX2, real, nd);
var ret = makeDistArray((...e2.a.domain.shape), int);
var ret = makeDistArray((...x2.shape), int);

proc doSearch(const ref a1: [] real, const ref a2: [?d] real, cmp) {
forall idx in ret.domain {
Expand All @@ -189,16 +117,22 @@ module SortMsg
}

select side {
when "left" do doSearch(e1.a, e2.a, new leftCmp());
when "right" do doSearch(e1.a, e2.a, new rightCmp());
when "left" do doSearch(x1, x2, new leftCmp());
when "right" do doSearch(x1, x2, new rightCmp());
otherwise do halt("unreachable");
}

st.addEntry(rname, createSymEntry(ret));
const repMsg = "created " + st.attrib(rname);
sortLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return ret;
}

proc searchSorted(x1: [?d1] ?t, x2: [?d2] ?t2, side: string): [d2] int throws
where (d1.rank != 1){
throw new Error("searchSorted only arrays x1 of dimension 1.");
}

return new MsgTuple(repMsg, MsgType.NORMAL);
proc searchSorted(x1: [?d1] ?t, x2: [?d2] ?t2, side: string): [d2] int throws
where ((t != real) || (t2 != real)) && (d1.rank == 1){
throw new Error("searchSorted only supports float64 type.");
}

record leftCmp: relativeComparator {
Expand All @@ -214,5 +148,4 @@ module SortMsg
else return 1;
}
}

}// end module SortMsg
Loading

0 comments on commit 07fe429

Please sign in to comment.