Skip to content

Commit

Permalink
[CHORE]: prepare for nulls first/last kernels (#3301)
Browse files Browse the repository at this point in the history
This pr adds nulls_first to all sort functions in preparation for
implementing the nulls first/last kernels. This will be followed up by
adding the actual nulls first/last implementations.
  • Loading branch information
universalmind303 authored Nov 20, 2024
1 parent a9bf7c0 commit 7922d2d
Show file tree
Hide file tree
Showing 43 changed files with 521 additions and 166 deletions.
18 changes: 9 additions & 9 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ...
# expr.list namespace
# ---
def explode(expr: PyExpr) -> PyExpr: ...
def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ...
def list_sort(expr: PyExpr, desc: PyExpr, nulls_first: PyExpr) -> PyExpr: ...
def list_value_counts(expr: PyExpr) -> PyExpr: ...
def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ...
def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ...
Expand Down Expand Up @@ -1360,8 +1360,8 @@ class PySeries:
def take(self, idx: PySeries) -> PySeries: ...
def slice(self, start: int, end: int) -> PySeries: ...
def filter(self, mask: PySeries) -> PySeries: ...
def sort(self, descending: bool) -> PySeries: ...
def argsort(self, descending: bool) -> PySeries: ...
def sort(self, descending: bool, nulls_first: bool) -> PySeries: ...
def argsort(self, descending: bool, nulls_first: bool) -> PySeries: ...
def hash(self, seed: PySeries | None = None) -> PySeries: ...
def minhash(
self,
Expand Down Expand Up @@ -1462,7 +1462,7 @@ class PySeries:
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def list_slice(self, start: PySeries, end: PySeries | None = None) -> PySeries: ...
def list_sort(self, desc: PySeries) -> PySeries: ...
def list_sort(self, desc: PySeries, nulls_first: PySeries) -> PySeries: ...
def map_get(self, key: PySeries) -> PySeries: ...
def if_else(self, other: PySeries, predicate: PySeries) -> PySeries: ...
def is_null(self) -> PySeries: ...
Expand All @@ -1480,8 +1480,8 @@ class PyTable:
def eval_expression_list(self, exprs: list[PyExpr]) -> PyTable: ...
def take(self, idx: PySeries) -> PyTable: ...
def filter(self, exprs: list[PyExpr]) -> PyTable: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyTable: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyTable: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ...
def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyTable: ...
def pivot(
self,
Expand Down Expand Up @@ -1559,8 +1559,8 @@ class PyMicroPartition:
def eval_expression_list(self, exprs: list[PyExpr]) -> PyMicroPartition: ...
def take(self, idx: PySeries) -> PyMicroPartition: ...
def filter(self, exprs: list[PyExpr]) -> PyMicroPartition: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyMicroPartition: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ...
def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyMicroPartition: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ...
def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyMicroPartition: ...
def hash_join(
self,
Expand Down Expand Up @@ -1727,7 +1727,7 @@ class LogicalPlanBuilder:
variable_name: str,
value_name: str,
) -> LogicalPlanBuilder: ...
def sort(self, sort_by: list[PyExpr], descending: list[bool]) -> LogicalPlanBuilder: ...
def sort(self, sort_by: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> LogicalPlanBuilder: ...
def hash_repartition(
self,
partition_by: list[PyExpr],
Expand Down
4 changes: 3 additions & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,8 +1583,10 @@ def sort(
by = [
by,
]

sort_by = self.__column_input_to_expression(by)
builder = self._builder.sort(sort_by=sort_by, descending=desc)

builder = self._builder.sort(sort_by=sort_by, descending=desc, nulls_first=desc)
return DataFrame(builder)

@DataframePublicAPI
Expand Down
7 changes: 6 additions & 1 deletion daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,15 +920,20 @@ class ReduceToQuantiles(ReduceInstruction):
num_quantiles: int
sort_by: ExpressionsProjection
descending: list[bool]
nulls_first: list[bool] | None = None

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._reduce_to_quantiles(inputs)

def _reduce_to_quantiles(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
merged = MicroPartition.concat(inputs)

nulls_first = self.nulls_first if self.nulls_first is not None else self.descending

# Skip evaluation of expressions by converting to Column Expression, since evaluation was done in Sample
merged_sorted = merged.sort(self.sort_by.to_column_expressions(), descending=self.descending)
merged_sorted = merged.sort(
self.sort_by.to_column_expressions(), descending=self.descending, nulls_first=nulls_first
)

result = merged_sorted.quantiles(self.num_quantiles)
return [result]
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,7 @@ def sort(
child_plan: InProgressPhysicalPlan[PartitionT],
sort_by: ExpressionsProjection,
descending: list[bool],
nulls_first: list[bool],
num_partitions: int,
) -> InProgressPhysicalPlan[PartitionT]:
"""Sort the result of `child_plan` according to `sort_info`."""
Expand Down Expand Up @@ -1565,6 +1566,7 @@ def sort(
num_quantiles=num_partitions,
sort_by=sort_by,
descending=descending,
nulls_first=nulls_first,
),
)
.finalize_partition_task_single_output(stage_id=stage_id_reduce)
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,15 @@ def sort(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
sort_by: list[PyExpr],
descending: list[bool],
nulls_first: list[bool],
num_partitions: int,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in sort_by])
return physical_plan.sort(
child_plan=input,
sort_by=expr_projection,
descending=descending,
nulls_first=nulls_first,
num_partitions=num_partitions,
)

Expand Down
8 changes: 6 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3116,7 +3116,7 @@ def max(self) -> Expression:
"""
return Expression._from_pyexpr(native.list_max(self._expr))

def sort(self, desc: bool | Expression = False) -> Expression:
def sort(self, desc: bool | Expression = False, nulls_first: bool | Expression | None = None) -> Expression:
"""Sorts the inner lists of a list column.
Example:
Expand Down Expand Up @@ -3145,7 +3145,11 @@ def sort(self, desc: bool | Expression = False) -> Expression:
"""
if isinstance(desc, bool):
desc = Expression._to_expression(desc)
return Expression._from_pyexpr(_list_sort(self._expr, desc._expr))
if nulls_first is None:
nulls_first = desc
elif isinstance(nulls_first, bool):
nulls_first = Expression._to_expression(nulls_first)
return Expression._from_pyexpr(_list_sort(self._expr, desc._expr, nulls_first._expr))


class ExpressionStructNamespace(ExpressionNamespace):
Expand Down
13 changes: 11 additions & 2 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,20 @@ def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> L
builder = self._builder.sample(fraction, with_replacement, seed)
return LogicalPlanBuilder(builder)

def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> LogicalPlanBuilder:
def sort(
self,
sort_by: list[Expression],
descending: list[bool] | bool = False,
nulls_first: list[bool] | bool | None = None,
) -> LogicalPlanBuilder:
sort_by_pyexprs = [expr._expr for expr in sort_by]
if not isinstance(descending, list):
descending = [descending] * len(sort_by_pyexprs)
builder = self._builder.sort(sort_by_pyexprs, descending)
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first] * len(sort_by_pyexprs)
builder = self._builder.sort(sort_by_pyexprs, descending, nulls_first)
return LogicalPlanBuilder(builder)

def hash_repartition(self, num_partitions: int | None, partition_by: list[Expression]) -> LogicalPlanBuilder:
Expand Down
22 changes: 15 additions & 7 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,20 @@ def slice(self, start: int, end: int) -> Series:

return Series._from_pyseries(self._series.slice(start, end))

def argsort(self, descending: bool = False) -> Series:
def argsort(self, descending: bool = False, nulls_first: bool | None = None) -> Series:
if not isinstance(descending, bool):
raise TypeError(f"expected `descending` to be bool, got {type(descending)}")
if nulls_first is None:
nulls_first = descending

return Series._from_pyseries(self._series.argsort(descending))
return Series._from_pyseries(self._series.argsort(descending, nulls_first))

def sort(self, descending: bool = False) -> Series:
def sort(self, descending: bool = False, nulls_first: bool | None = None) -> Series:
if not isinstance(descending, bool):
raise TypeError(f"expected `descending` to be bool, got {type(descending)}")

return Series._from_pyseries(self._series.sort(descending))
if nulls_first is None:
nulls_first = descending
return Series._from_pyseries(self._series.sort(descending, nulls_first))

def hash(self, seed: Series | None = None) -> Series:
if not isinstance(seed, Series) and seed is not None:
Expand Down Expand Up @@ -962,10 +965,15 @@ def length(self) -> Series:
def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))

def sort(self, desc: bool | Series = False) -> Series:
def sort(self, desc: bool | Series = False, nulls_first: bool | Series | None = None) -> Series:
if isinstance(desc, bool):
desc = Series.from_pylist([desc], name="desc")
return Series._from_pyseries(self._series.list_sort(desc._series))
if nulls_first is None:
nulls_first = desc
elif isinstance(nulls_first, bool):
nulls_first = Series.from_pylist([nulls_first], name="nulls_first")

return Series._from_pyseries(self._series.list_sort(desc._series, nulls_first._series))


class SeriesMapNamespace(SeriesNamespace):
Expand Down
46 changes: 42 additions & 4 deletions daft/table/micropartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ def filter(self, exprs: ExpressionsProjection) -> MicroPartition:
pyexprs = [e._expr for e in exprs]
return MicroPartition._from_pymicropartition(self._micropartition.filter(pyexprs))

def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> MicroPartition:
def sort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> MicroPartition:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -189,7 +194,21 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] |
)
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return MicroPartition._from_pymicropartition(self._micropartition.sort(pyexprs, descending))
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first for _ in pyexprs]
elif isinstance(nulls_first, list):
if len(nulls_first) != len(sort_keys):
raise ValueError(
f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(nulls_first)} instead of {len(sort_keys)}"
)
else:
nulls_first = [bool(x) for x in nulls_first]
else:
raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}")
return MicroPartition._from_pymicropartition(self._micropartition.sort(pyexprs, descending, nulls_first))

def sample(
self,
Expand Down Expand Up @@ -349,7 +368,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str)
# Compute methods (MicroPartition -> Series)
###

def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series:
def argsort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> Series:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -364,7 +388,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool
)
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending))
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first for _ in pyexprs]
elif isinstance(nulls_first, list):
if len(nulls_first) != len(sort_keys):
raise ValueError(
f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(nulls_first)} instead of {len(sort_keys)}"
)
else:
nulls_first = [bool(x) for x in nulls_first]
else:
raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}")
return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending, nulls_first))

def __reduce__(self) -> tuple:
names = self.column_names()
Expand Down
44 changes: 40 additions & 4 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,12 @@ def filter(self, exprs: ExpressionsProjection) -> Table:
pyexprs = [e._expr for e in exprs]
return Table._from_pytable(self._table.filter(pyexprs))

def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Table:
def sort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> Table:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -256,7 +261,19 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] |
)
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Table._from_pytable(self._table.sort(pyexprs, descending))
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first for _ in pyexprs]
elif isinstance(nulls_first, list):
if len(nulls_first) != len(sort_keys):
raise ValueError(
f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(nulls_first)} instead of {len(sort_keys)}"
)
else:
nulls_first = [bool(x) for x in nulls_first]
return Table._from_pytable(self._table.sort(pyexprs, descending, nulls_first))

def sample(
self,
Expand Down Expand Up @@ -378,7 +395,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str)
# Compute methods (Table -> Series)
###

def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series:
def argsort(
self,
sort_keys: ExpressionsProjection,
descending: bool | list[bool] | None = None,
nulls_first: bool | list[bool] | None = None,
) -> Series:
assert all(isinstance(e, Expression) for e in sort_keys)
pyexprs = [e._expr for e in sort_keys]
if descending is None:
Expand All @@ -393,7 +415,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool
)
else:
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Series._from_pyseries(self._table.argsort(pyexprs, descending))
if nulls_first is None:
nulls_first = descending
elif isinstance(nulls_first, bool):
nulls_first = [nulls_first for _ in pyexprs]
elif isinstance(nulls_first, list):
if len(nulls_first) != len(sort_keys):
raise ValueError(
f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in,"
f"got {len(nulls_first)} instead of {len(sort_keys)}"
)
else:
nulls_first = [bool(x) for x in nulls_first]
else:
raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}")
return Series._from_pyseries(self._table.argsort(pyexprs, descending, nulls_first))

def __reduce__(self) -> tuple:
names = self.column_names()
Expand Down
Loading

0 comments on commit 7922d2d

Please sign in to comment.