Skip to content

Commit

Permalink
chore: wire up all nulls_first logic
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Nov 15, 2024
1 parent b213fa5 commit a9564b4
Show file tree
Hide file tree
Showing 30 changed files with 361 additions and 125 deletions.
18 changes: 9 additions & 9 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ...
# expr.list namespace
# ---
def explode(expr: PyExpr) -> PyExpr: ...
def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ...
def list_sort(expr: PyExpr, desc: PyExpr, nulls_first: PyExpr) -> PyExpr: ...
def list_value_counts(expr: PyExpr) -> PyExpr: ...
def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ...
def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ...
Expand Down Expand Up @@ -1354,8 +1354,8 @@ class PySeries:
def take(self, idx: PySeries) -> PySeries: ...
def slice(self, start: int, end: int) -> PySeries: ...
def filter(self, mask: PySeries) -> PySeries: ...
def sort(self, descending: bool) -> PySeries: ...
def argsort(self, descending: bool) -> PySeries: ...
def sort(self, descending: bool, nulls_first: bool) -> PySeries: ...
def argsort(self, descending: bool, nulls_first: bool) -> PySeries: ...
def hash(self, seed: PySeries | None = None) -> PySeries: ...
def minhash(
self,
Expand Down Expand Up @@ -1456,7 +1456,7 @@ class PySeries:
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def list_slice(self, start: PySeries, end: PySeries | None = None) -> PySeries: ...
def list_sort(self, desc: PySeries) -> PySeries: ...
def list_sort(self, desc: PySeries, nulls_first: PySeries) -> PySeries: ...
def map_get(self, key: PySeries) -> PySeries: ...
def if_else(self, other: PySeries, predicate: PySeries) -> PySeries: ...
def is_null(self) -> PySeries: ...
Expand All @@ -1474,8 +1474,8 @@ class PyTable:
def eval_expression_list(self, exprs: list[PyExpr]) -> PyTable: ...
def take(self, idx: PySeries) -> PyTable: ...
def filter(self, exprs: list[PyExpr]) -> PyTable: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyTable: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyTable: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ...
def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyTable: ...
def pivot(
self,
Expand Down Expand Up @@ -1553,8 +1553,8 @@ class PyMicroPartition:
def eval_expression_list(self, exprs: list[PyExpr]) -> PyMicroPartition: ...
def take(self, idx: PySeries) -> PyMicroPartition: ...
def filter(self, exprs: list[PyExpr]) -> PyMicroPartition: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyMicroPartition: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyMicroPartition: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ...
def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyMicroPartition: ...
def hash_join(
self,
Expand Down Expand Up @@ -1721,7 +1721,7 @@ class LogicalPlanBuilder:
variable_name: str,
value_name: str,
) -> LogicalPlanBuilder: ...
def sort(self, sort_by: list[PyExpr], descending: list[bool]) -> LogicalPlanBuilder: ...
def sort(self, sort_by: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> LogicalPlanBuilder: ...
def hash_repartition(
self,
partition_by: list[PyExpr],
Expand Down
4 changes: 3 additions & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,7 @@ def sort(
self,
by: Union[ColumnInputType, List[ColumnInputType]],
desc: Union[bool, List[bool]] = False,
nulls_first: Union[bool, List[bool]] = False,
) -> "DataFrame":
"""Sorts DataFrame globally
Expand Down Expand Up @@ -1583,8 +1584,9 @@ def sort(
by = [
by,
]

sort_by = self.__column_input_to_expression(by)
builder = self._builder.sort(sort_by=sort_by, descending=desc)
builder = self._builder.sort(sort_by=sort_by, descending=desc, nulls_first=nulls_first)
return DataFrame(builder)

@DataframePublicAPI
Expand Down
8 changes: 6 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3112,7 +3112,7 @@ def max(self) -> Expression:
"""
return Expression._from_pyexpr(native.list_max(self._expr))

def sort(self, desc: bool | Expression = False) -> Expression:
def sort(self, desc: bool | Expression = False, nulls_first: bool | Expression | None = None) -> Expression:
"""Sorts the inner lists of a list column.
Example:
Expand Down Expand Up @@ -3141,7 +3141,11 @@ def sort(self, desc: bool | Expression = False) -> Expression:
"""
if isinstance(desc, bool):
desc = Expression._to_expression(desc)
return Expression._from_pyexpr(_list_sort(self._expr, desc._expr))
if nulls_first is None:
nulls_first = desc
elif isinstance(nulls_first, bool):
nulls_first = Expression._to_expression(nulls_first)
return Expression._from_pyexpr(_list_sort(self._expr, desc._expr, nulls_first._expr))


class ExpressionStructNamespace(ExpressionNamespace):
Expand Down
13 changes: 11 additions & 2 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,20 @@ def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> L
builder = self._builder.sample(fraction, with_replacement, seed)
return LogicalPlanBuilder(builder)

def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> LogicalPlanBuilder:
def sort(
self,
sort_by: list[Expression],
descending: list[bool] | bool = False,
nulls_first: list[bool] | bool | None = None,
) -> LogicalPlanBuilder:
sort_by_pyexprs = [expr._expr for expr in sort_by]
if not isinstance(descending, list):
descending = [descending] * len(sort_by_pyexprs)
builder = self._builder.sort(sort_by_pyexprs, descending)
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first] * len(sort_by_pyexprs)
builder = self._builder.sort(sort_by_pyexprs, descending, nulls_first)
return LogicalPlanBuilder(builder)

def hash_repartition(self, num_partitions: int | None, partition_by: list[Expression]) -> LogicalPlanBuilder:
Expand Down
22 changes: 15 additions & 7 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,20 @@ def slice(self, start: int, end: int) -> Series:

return Series._from_pyseries(self._series.slice(start, end))

def argsort(self, descending: bool = False) -> Series:
def argsort(self, descending: bool = False, nulls_first: bool | None = None) -> Series:
if not isinstance(descending, bool):
raise TypeError(f"expected `descending` to be bool, got {type(descending)}")
if nulls_first is None:
nulls_first = descending

return Series._from_pyseries(self._series.argsort(descending))
return Series._from_pyseries(self._series.argsort(descending, nulls_first))

def sort(self, descending: bool = False) -> Series:
def sort(self, descending: bool = False, nulls_first: bool | None = None) -> Series:
if not isinstance(descending, bool):
raise TypeError(f"expected `descending` to be bool, got {type(descending)}")

return Series._from_pyseries(self._series.sort(descending))
if nulls_first is None:
nulls_first = descending
return Series._from_pyseries(self._series.sort(descending, nulls_first))

def hash(self, seed: Series | None = None) -> Series:
if not isinstance(seed, Series) and seed is not None:
Expand Down Expand Up @@ -962,10 +965,15 @@ def length(self) -> Series:
def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))

def sort(self, desc: bool | Series = False) -> Series:
def sort(self, desc: bool | Series = False, nulls_first: bool | Series | None = None) -> Series:
if isinstance(desc, bool):
desc = Series.from_pylist([desc], name="desc")
return Series._from_pyseries(self._series.list_sort(desc._series))
if nulls_first is None:
nulls_first = desc
elif isinstance(nulls_first, bool):
nulls_first = Series.from_pylist([nulls_first], name="nulls_first")

return Series._from_pyseries(self._series.list_sort(desc._series, nulls_first._series))


class SeriesMapNamespace(SeriesNamespace):
Expand Down
31 changes: 27 additions & 4 deletions daft/table/micropartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ def filter(self, exprs: ExpressionsProjection) -> MicroPartition:
pyexprs = [e._expr for e in exprs]
return MicroPartition._from_pymicropartition(self._micropartition.filter(pyexprs))

def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None, nulls_first: bool | list[bool] | None=None) -> MicroPartition:
def sort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> MicroPartition:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -187,7 +192,6 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] |
f"Expected length of `descending` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(descending)} instead of {len(sort_keys)}"
)

else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
if nulls_first is None:
Expand Down Expand Up @@ -364,7 +368,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str)
# Compute methods (MicroPartition -> Series)
###

def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series:
def argsort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> Series:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -379,7 +388,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool
)
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending))
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first for _ in pyexprs]
elif isinstance(nulls_first, list):
if len(nulls_first) != len(sort_keys):
raise ValueError(
f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(nulls_first)} instead of {len(sort_keys)}"
)
else:
nulls_first = [bool(x) for x in nulls_first]
else:
raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}")
return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending, nulls_first))

def __reduce__(self) -> tuple:
names = self.column_names()
Expand Down
44 changes: 40 additions & 4 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,12 @@ def filter(self, exprs: ExpressionsProjection) -> Table:
pyexprs = [e._expr for e in exprs]
return Table._from_pytable(self._table.filter(pyexprs))

def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Table:
def sort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> Table:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -256,7 +261,19 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] |
)
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Table._from_pytable(self._table.sort(pyexprs, descending))
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first for _ in pyexprs]
elif isinstance(nulls_first, list):
if len(nulls_first) != len(sort_keys):
raise ValueError(
f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(nulls_first)} instead of {len(sort_keys)}"
)
else:
nulls_first = [bool(x) for x in nulls_first]
return Table._from_pytable(self._table.sort(pyexprs, descending, nulls_first))

def sample(
self,
Expand Down Expand Up @@ -378,7 +395,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str)
# Compute methods (Table -> Series)
###

def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series:
def argsort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> Series:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -393,7 +415,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool
)
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Series._from_pyseries(self._table.argsort(pyexprs, descending))
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first for _ in pyexprs]
elif isinstance(nulls_first, list):
if len(nulls_first) != len(sort_keys):
raise ValueError(
f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(nulls_first)} instead of {len(sort_keys)}"
)
else:
nulls_first = [bool(x) for x in nulls_first]
else:
raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}")
return Series._from_pyseries(self._table.argsort(pyexprs, descending, nulls_first))

def __reduce__(self) -> tuple:
names = self.column_names()
Expand Down
Loading

0 comments on commit a9564b4

Please sign in to comment.