From 7922d2d810ff92b00008d877aa9a6553bc0dedab Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 19 Nov 2024 18:36:51 -0600 Subject: [PATCH] [CHORE]: prepare for nulls first/last kernels (#3301) This pr adds nulls_first to all sort functions in preparation for implementing the nulls first/last kernels. This will be followed up by adding the actual nulls first/last implementations. --- daft/daft/__init__.pyi | 18 +-- daft/dataframe/dataframe.py | 4 +- daft/execution/execution_step.py | 7 +- daft/execution/physical_plan.py | 2 + daft/execution/rust_physical_plan_shim.py | 2 + daft/expressions/expressions.py | 8 +- daft/logical/builder.py | 13 +- daft/series.py | 22 ++-- daft/table/micropartition.py | 46 ++++++- daft/table/table.py | 44 ++++++- src/daft-core/src/array/ops/list.rs | 42 +++++-- src/daft-core/src/array/ops/sort.rs | 119 +++++++++++------- src/daft-core/src/python/series.rs | 15 ++- .../src/series/array_impl/data_array.rs | 4 +- .../src/series/array_impl/logical_array.rs | 4 +- .../src/series/array_impl/nested_array.rs | 2 +- src/daft-core/src/series/ops/list.rs | 12 +- src/daft-core/src/series/ops/sort.rs | 21 ++-- src/daft-core/src/series/series_like.rs | 2 +- src/daft-core/src/utils/mod.rs | 25 ++++ src/daft-functions/src/list/sort.rs | 17 +-- src/daft-local-execution/src/pipeline.rs | 3 +- src/daft-local-execution/src/sinks/sort.rs | 10 +- src/daft-local-plan/src/plan.rs | 3 + src/daft-local-plan/src/translate.rs | 1 + src/daft-logical-plan/src/builder.rs | 18 ++- src/daft-logical-plan/src/display.rs | 6 +- src/daft-logical-plan/src/logical_plan.rs | 2 +- src/daft-logical-plan/src/ops/sort.rs | 13 +- .../optimization/rules/push_down_filter.rs | 7 +- src/daft-micropartition/src/ops/sort.rs | 20 ++- src/daft-micropartition/src/python.rs | 14 ++- src/daft-physical-plan/src/ops/sort.rs | 13 +- .../src/physical_planner/translate.rs | 4 + src/daft-physical-plan/src/plan.rs | 3 +- src/daft-scheduler/src/scheduler.rs | 2 + src/daft-sql/src/lib.rs | 2 +- src/daft-sql/src/modules/list.rs | 4 +- src/daft-sql/src/planner.rs | 84 ++++++++++--- src/daft-table/src/ops/groups.rs | 7 +- src/daft-table/src/ops/joins/mod.rs | 8 ++ src/daft-table/src/ops/sort.rs | 20 ++- src/daft-table/src/python.rs | 14 ++- 43 files changed, 521 insertions(+), 166 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index b50612920d..368af56011 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1280,7 +1280,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ... # expr.list namespace # --- def explode(expr: PyExpr) -> PyExpr: ... -def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... +def list_sort(expr: PyExpr, desc: PyExpr, nulls_first: PyExpr) -> PyExpr: ... def list_value_counts(expr: PyExpr) -> PyExpr: ... def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ... def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ... @@ -1360,8 +1360,8 @@ class PySeries: def take(self, idx: PySeries) -> PySeries: ... def slice(self, start: int, end: int) -> PySeries: ... def filter(self, mask: PySeries) -> PySeries: ... - def sort(self, descending: bool) -> PySeries: ... - def argsort(self, descending: bool) -> PySeries: ... + def sort(self, descending: bool, nulls_first: bool) -> PySeries: ... + def argsort(self, descending: bool, nulls_first: bool) -> PySeries: ... def hash(self, seed: PySeries | None = None) -> PySeries: ... def minhash( self, @@ -1462,7 +1462,7 @@ class PySeries: def list_count(self, mode: CountMode) -> PySeries: ... def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ... def list_slice(self, start: PySeries, end: PySeries | None = None) -> PySeries: ... - def list_sort(self, desc: PySeries) -> PySeries: ... + def list_sort(self, desc: PySeries, nulls_first: PySeries) -> PySeries: ... def map_get(self, key: PySeries) -> PySeries: ... def if_else(self, other: PySeries, predicate: PySeries) -> PySeries: ... def is_null(self) -> PySeries: ... @@ -1480,8 +1480,8 @@ class PyTable: def eval_expression_list(self, exprs: list[PyExpr]) -> PyTable: ... def take(self, idx: PySeries) -> PyTable: ... def filter(self, exprs: list[PyExpr]) -> PyTable: ... - def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyTable: ... - def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ... + def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyTable: ... + def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ... def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyTable: ... def pivot( self, @@ -1559,8 +1559,8 @@ class PyMicroPartition: def eval_expression_list(self, exprs: list[PyExpr]) -> PyMicroPartition: ... def take(self, idx: PySeries) -> PyMicroPartition: ... def filter(self, exprs: list[PyExpr]) -> PyMicroPartition: ... - def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyMicroPartition: ... - def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ... + def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyMicroPartition: ... + def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ... def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyMicroPartition: ... def hash_join( self, @@ -1727,7 +1727,7 @@ class LogicalPlanBuilder: variable_name: str, value_name: str, ) -> LogicalPlanBuilder: ... - def sort(self, sort_by: list[PyExpr], descending: list[bool]) -> LogicalPlanBuilder: ... + def sort(self, sort_by: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> LogicalPlanBuilder: ... def hash_repartition( self, partition_by: list[PyExpr], diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index a79443e327..31148e6896 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1583,8 +1583,10 @@ def sort( by = [ by, ] + sort_by = self.__column_input_to_expression(by) - builder = self._builder.sort(sort_by=sort_by, descending=desc) + + builder = self._builder.sort(sort_by=sort_by, descending=desc, nulls_first=desc) return DataFrame(builder) @DataframePublicAPI diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 5755f29538..94873a4bb4 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -920,6 +920,7 @@ class ReduceToQuantiles(ReduceInstruction): num_quantiles: int sort_by: ExpressionsProjection descending: list[bool] + nulls_first: list[bool] | None = None def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: return self._reduce_to_quantiles(inputs) @@ -927,8 +928,12 @@ def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: def _reduce_to_quantiles(self, inputs: list[MicroPartition]) -> list[MicroPartition]: merged = MicroPartition.concat(inputs) + nulls_first = self.nulls_first if self.nulls_first is not None else self.descending + # Skip evaluation of expressions by converting to Column Expression, since evaluation was done in Sample - merged_sorted = merged.sort(self.sort_by.to_column_expressions(), descending=self.descending) + merged_sorted = merged.sort( + self.sort_by.to_column_expressions(), descending=self.descending, nulls_first=nulls_first + ) result = merged_sorted.quantiles(self.num_quantiles) return [result] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 49cd76d247..55fa3f1f03 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -1510,6 +1510,7 @@ def sort( child_plan: InProgressPhysicalPlan[PartitionT], sort_by: ExpressionsProjection, descending: list[bool], + nulls_first: list[bool], num_partitions: int, ) -> InProgressPhysicalPlan[PartitionT]: """Sort the result of `child_plan` according to `sort_info`.""" @@ -1565,6 +1566,7 @@ def sort( num_quantiles=num_partitions, sort_by=sort_by, descending=descending, + nulls_first=nulls_first, ), ) .finalize_partition_task_single_output(stage_id=stage_id_reduce) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 8d36bcd81d..a304574894 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -202,6 +202,7 @@ def sort( input: physical_plan.InProgressPhysicalPlan[PartitionT], sort_by: list[PyExpr], descending: list[bool], + nulls_first: list[bool], num_partitions: int, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in sort_by]) @@ -209,6 +210,7 @@ def sort( child_plan=input, sort_by=expr_projection, descending=descending, + nulls_first=nulls_first, num_partitions=num_partitions, ) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index f317f4ca85..88eb885976 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -3116,7 +3116,7 @@ def max(self) -> Expression: """ return Expression._from_pyexpr(native.list_max(self._expr)) - def sort(self, desc: bool | Expression = False) -> Expression: + def sort(self, desc: bool | Expression = False, nulls_first: bool | Expression | None = None) -> Expression: """Sorts the inner lists of a list column. Example: @@ -3145,7 +3145,11 @@ def sort(self, desc: bool | Expression = False) -> Expression: """ if isinstance(desc, bool): desc = Expression._to_expression(desc) - return Expression._from_pyexpr(_list_sort(self._expr, desc._expr)) + if nulls_first is None: + nulls_first = desc + elif isinstance(nulls_first, bool): + nulls_first = Expression._to_expression(nulls_first) + return Expression._from_pyexpr(_list_sort(self._expr, desc._expr, nulls_first._expr)) class ExpressionStructNamespace(ExpressionNamespace): diff --git a/daft/logical/builder.py b/daft/logical/builder.py index ed85a55517..412e400cfd 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -203,11 +203,20 @@ def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> L builder = self._builder.sample(fraction, with_replacement, seed) return LogicalPlanBuilder(builder) - def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> LogicalPlanBuilder: + def sort( + self, + sort_by: list[Expression], + descending: list[bool] | bool = False, + nulls_first: list[bool] | bool | None = None, + ) -> LogicalPlanBuilder: sort_by_pyexprs = [expr._expr for expr in sort_by] if not isinstance(descending, list): descending = [descending] * len(sort_by_pyexprs) - builder = self._builder.sort(sort_by_pyexprs, descending) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first] * len(sort_by_pyexprs) + builder = self._builder.sort(sort_by_pyexprs, descending, nulls_first) return LogicalPlanBuilder(builder) def hash_repartition(self, num_partitions: int | None, partition_by: list[Expression]) -> LogicalPlanBuilder: diff --git a/daft/series.py b/daft/series.py index 7053d8668e..95788d5f7f 100644 --- a/daft/series.py +++ b/daft/series.py @@ -256,17 +256,20 @@ def slice(self, start: int, end: int) -> Series: return Series._from_pyseries(self._series.slice(start, end)) - def argsort(self, descending: bool = False) -> Series: + def argsort(self, descending: bool = False, nulls_first: bool | None = None) -> Series: if not isinstance(descending, bool): raise TypeError(f"expected `descending` to be bool, got {type(descending)}") + if nulls_first is None: + nulls_first = descending - return Series._from_pyseries(self._series.argsort(descending)) + return Series._from_pyseries(self._series.argsort(descending, nulls_first)) - def sort(self, descending: bool = False) -> Series: + def sort(self, descending: bool = False, nulls_first: bool | None = None) -> Series: if not isinstance(descending, bool): raise TypeError(f"expected `descending` to be bool, got {type(descending)}") - - return Series._from_pyseries(self._series.sort(descending)) + if nulls_first is None: + nulls_first = descending + return Series._from_pyseries(self._series.sort(descending, nulls_first)) def hash(self, seed: Series | None = None) -> Series: if not isinstance(seed, Series) and seed is not None: @@ -962,10 +965,15 @@ def length(self) -> Series: def get(self, idx: Series, default: Series) -> Series: return Series._from_pyseries(self._series.list_get(idx._series, default._series)) - def sort(self, desc: bool | Series = False) -> Series: + def sort(self, desc: bool | Series = False, nulls_first: bool | Series | None = None) -> Series: if isinstance(desc, bool): desc = Series.from_pylist([desc], name="desc") - return Series._from_pyseries(self._series.list_sort(desc._series)) + if nulls_first is None: + nulls_first = desc + elif isinstance(nulls_first, bool): + nulls_first = Series.from_pylist([nulls_first], name="nulls_first") + + return Series._from_pyseries(self._series.list_sort(desc._series, nulls_first._series)) class SeriesMapNamespace(SeriesNamespace): diff --git a/daft/table/micropartition.py b/daft/table/micropartition.py index 5baf3d379c..c3297c9a6c 100644 --- a/daft/table/micropartition.py +++ b/daft/table/micropartition.py @@ -174,7 +174,12 @@ def filter(self, exprs: ExpressionsProjection) -> MicroPartition: pyexprs = [e._expr for e in exprs] return MicroPartition._from_pymicropartition(self._micropartition.filter(pyexprs)) - def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> MicroPartition: + def sort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> MicroPartition: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -189,7 +194,21 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | ) else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") - return MicroPartition._from_pymicropartition(self._micropartition.sort(pyexprs, descending)) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first for _ in pyexprs] + elif isinstance(nulls_first, list): + if len(nulls_first) != len(sort_keys): + raise ValueError( + f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in," + f"got {len(nulls_first)} instead of {len(sort_keys)}" + ) + else: + nulls_first = [bool(x) for x in nulls_first] + else: + raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}") + return MicroPartition._from_pymicropartition(self._micropartition.sort(pyexprs, descending, nulls_first)) def sample( self, @@ -349,7 +368,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str) # Compute methods (MicroPartition -> Series) ### - def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series: + def argsort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> Series: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -364,7 +388,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool ) else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") - return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending)) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first for _ in pyexprs] + elif isinstance(nulls_first, list): + if len(nulls_first) != len(sort_keys): + raise ValueError( + f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in," + f"got {len(nulls_first)} instead of {len(sort_keys)}" + ) + else: + nulls_first = [bool(x) for x in nulls_first] + else: + raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}") + return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending, nulls_first)) def __reduce__(self) -> tuple: names = self.column_names() diff --git a/daft/table/table.py b/daft/table/table.py index 9ab769b337..66757cce2a 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -241,7 +241,12 @@ def filter(self, exprs: ExpressionsProjection) -> Table: pyexprs = [e._expr for e in exprs] return Table._from_pytable(self._table.filter(pyexprs)) - def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Table: + def sort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> Table: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -256,7 +261,19 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | ) else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") - return Table._from_pytable(self._table.sort(pyexprs, descending)) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first for _ in pyexprs] + elif isinstance(nulls_first, list): + if len(nulls_first) != len(sort_keys): + raise ValueError( + f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in," + f"got {len(nulls_first)} instead of {len(sort_keys)}" + ) + else: + nulls_first = [bool(x) for x in nulls_first] + return Table._from_pytable(self._table.sort(pyexprs, descending, nulls_first)) def sample( self, @@ -378,7 +395,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str) # Compute methods (Table -> Series) ### - def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series: + def argsort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> Series: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -393,7 +415,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool ) else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") - return Series._from_pyseries(self._table.argsort(pyexprs, descending)) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first for _ in pyexprs] + elif isinstance(nulls_first, list): + if len(nulls_first) != len(sort_keys): + raise ValueError( + f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in," + f"got {len(nulls_first)} instead of {len(sort_keys)}" + ) + else: + nulls_first = [bool(x) for x in nulls_first] + else: + raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}") + return Series._from_pyseries(self._table.argsort(pyexprs, descending, nulls_first)) def __reduce__(self) -> tuple: names = self.column_names() diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 080fed0ad0..2e60efa550 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -205,16 +205,18 @@ fn list_sort_helper( flat_child: &Series, offsets: &OffsetsBuffer, desc_iter: impl Iterator, + nulls_first_iter: impl Iterator, validity: impl Iterator, ) -> DaftResult> { desc_iter + .zip(nulls_first_iter) .zip(validity) .enumerate() - .map(|(i, (desc, valid))| { + .map(|(i, ((desc, nulls_first), valid))| { let start = *offsets.get(i).unwrap() as usize; let end = *offsets.get(i + 1).unwrap() as usize; if valid { - flat_child.slice(start, end)?.sort(desc) + flat_child.slice(start, end)?.sort(desc, nulls_first) } else { Ok(Series::full_null( flat_child.name(), @@ -230,16 +232,18 @@ fn list_sort_helper_fixed_size( flat_child: &Series, fixed_size: usize, desc_iter: impl Iterator, + nulls_first_iter: impl Iterator, validity: impl Iterator, ) -> DaftResult> { desc_iter + .zip(nulls_first_iter) .zip(validity) .enumerate() - .map(|(i, (desc, valid))| { + .map(|(i, ((desc, nulls_first), valid))| { let start = i * fixed_size; let end = (i + 1) * fixed_size; if valid { - flat_child.slice(start, end)?.sort(desc) + flat_child.slice(start, end)?.sort(desc, nulls_first) } else { Ok(Series::full_null( flat_child.name(), @@ -568,29 +572,45 @@ impl ListArray { } // Sorts the lists within a list column - pub fn list_sort(&self, desc: &BooleanArray) -> DaftResult { + pub fn list_sort(&self, desc: &BooleanArray, nulls_first: &BooleanArray) -> DaftResult { let offsets = self.offsets(); let child_series = if desc.len() == 1 { let desc_iter = repeat(desc.get(0).unwrap()).take(self.len()); + let nulls_first_iter = repeat(nulls_first.get(0).unwrap()).take(self.len()); if let Some(validity) = self.validity() { - list_sort_helper(&self.flat_child, offsets, desc_iter, validity.iter())? + list_sort_helper( + &self.flat_child, + offsets, + desc_iter, + nulls_first_iter, + validity.iter(), + )? } else { list_sort_helper( &self.flat_child, offsets, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } } else { let desc_iter = desc.as_arrow().values_iter(); + let nulls_first_iter = nulls_first.as_arrow().values_iter(); if let Some(validity) = self.validity() { - list_sort_helper(&self.flat_child, offsets, desc_iter, validity.iter())? + list_sort_helper( + &self.flat_child, + offsets, + desc_iter, + nulls_first_iter, + validity.iter(), + )? } else { list_sort_helper( &self.flat_child, offsets, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } @@ -789,16 +809,18 @@ impl FixedSizeListArray { } // Sorts the lists within a list column - pub fn list_sort(&self, desc: &BooleanArray) -> DaftResult { + pub fn list_sort(&self, desc: &BooleanArray, nulls_first: &BooleanArray) -> DaftResult { let fixed_size = self.fixed_element_len(); let child_series = if desc.len() == 1 { let desc_iter = repeat(desc.get(0).unwrap()).take(self.len()); + let nulls_first_iter = repeat(nulls_first.get(0).unwrap()).take(self.len()); if let Some(validity) = self.validity() { list_sort_helper_fixed_size( &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, validity.iter(), )? } else { @@ -806,16 +828,19 @@ impl FixedSizeListArray { &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } } else { let desc_iter = desc.as_arrow().values_iter(); + let nulls_first_iter = nulls_first.as_arrow().values_iter(); if let Some(validity) = self.validity() { list_sort_helper_fixed_size( &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, validity.iter(), )? } else { @@ -823,6 +848,7 @@ impl FixedSizeListArray { &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index fa3ccad594..b63b4c23f1 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -21,6 +21,7 @@ use crate::{ }, kernels::search_sorted::{build_compare_with_nulls, cmp_float}, series::Series, + utils::{ensure_nulls_first, ensure_nulls_first_arr}, }; pub fn build_multi_array_compare( @@ -62,11 +63,13 @@ where T: DaftIntegerType, ::Native: Ord, { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first(descending, nulls_first)?; + let arrow_array = self.as_arrow(); let result = @@ -83,11 +86,13 @@ where &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first_arr(descending, nulls_first)?; let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -134,10 +139,10 @@ where Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -154,11 +159,12 @@ where } impl Float32Array { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first(descending, nulls_first)?; let arrow_array = self.as_arrow(); let result = @@ -175,11 +181,14 @@ impl Float32Array { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first_arr(descending, nulls_first)?; + let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -226,10 +235,10 @@ impl Float32Array { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -246,11 +255,12 @@ impl Float32Array { } impl Float64Array { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first(descending, nulls_first)?; let arrow_array = self.as_arrow(); let result = @@ -267,11 +277,14 @@ impl Float64Array { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first_arr(descending, nulls_first)?; + let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -318,10 +331,10 @@ impl Float64Array { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -338,11 +351,12 @@ impl Float64Array { } impl Decimal128Array { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first(descending, nulls_first)?; let arrow_array = self.as_arrow(); let result = @@ -359,11 +373,13 @@ impl Decimal128Array { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first_arr(descending, nulls_first)?; let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -410,10 +426,10 @@ impl Decimal128Array { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -430,7 +446,7 @@ impl Decimal128Array { } impl NullArray { - pub fn argsort(&self, _descending: bool) -> DaftResult> + pub fn argsort(&self, _descending: bool, _nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, @@ -442,11 +458,13 @@ impl NullArray { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first_arr(descending, nulls_first)?; let first_desc = *descending.first().unwrap(); let others_cmp = build_multi_array_compare(others, &descending[1..])?; @@ -466,20 +484,20 @@ impl NullArray { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { Ok(self.clone()) } } impl BooleanArray { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = @@ -492,11 +510,13 @@ impl BooleanArray { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first_arr(descending, nulls_first)?; let first_desc = *descending.first().unwrap(); let others_cmp = build_multi_array_compare(others, &descending[1..])?; @@ -547,10 +567,10 @@ impl BooleanArray { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = arrow2::compute::sort::sort(self.data(), &options, None)?; @@ -562,14 +582,18 @@ impl BooleanArray { macro_rules! impl_binary_like_sort { ($da:ident) => { impl $da { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort( + &self, + descending: bool, + nulls_first: bool, + ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = arrow2::compute::sort::sort_to_indices::( @@ -585,11 +609,13 @@ macro_rules! impl_binary_like_sort { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + ensure_nulls_first_arr(descending, nulls_first)?; let first_desc = *descending.first().unwrap(); let others_cmp = build_multi_array_compare(others, &descending[1..])?; @@ -635,10 +661,10 @@ macro_rules! impl_binary_like_sort { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = arrow2::compute::sort::sort(self.data(), &options, None)?; @@ -653,7 +679,7 @@ impl_binary_like_sort!(BinaryArray); impl_binary_like_sort!(Utf8Array); impl FixedSizeBinaryArray { - pub fn argsort(&self, _descending: bool) -> DaftResult> + pub fn argsort(&self, _descending: bool, _nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, @@ -664,6 +690,7 @@ impl FixedSizeBinaryArray { &self, _others: &[Series], _descending: &[bool], + _nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, @@ -671,120 +698,120 @@ impl FixedSizeBinaryArray { { todo!("impl argsort_multikey for FixedSizeBinaryArray") } - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedSizeBinaryArray") } } impl FixedSizeListArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedSizeListArray") } } impl ListArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for ListArray") } } impl MapArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for MapArray") } } impl StructArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for StructArray") } } impl ExtensionArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for ExtensionArray") } } impl IntervalArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for IntervalArray") } } #[cfg(feature = "python")] impl PythonArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for python array") } } impl DateArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl TimeArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl DurationArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl TimestampArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl EmbeddingArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for EmbeddingArray") } } impl ImageArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for ImageArray") } } impl FixedShapeImageArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedShapeImageArray") } } impl TensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for TensorArray") } } impl SparseTensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for SparseTensorArray") } } impl FixedShapeSparseTensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedShapeSparseTensorArray") } } impl FixedShapeTensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedShapeTensorArray") } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 28bfcede0e..f8f1b3002e 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -299,12 +299,12 @@ impl PySeries { Ok(self.series.filter(mask.series.downcast()?)?.into()) } - pub fn sort(&self, descending: bool) -> PyResult { - Ok(self.series.sort(descending)?.into()) + pub fn sort(&self, descending: bool, nulls_first: bool) -> PyResult { + Ok(self.series.sort(descending, nulls_first)?.into()) } - pub fn argsort(&self, descending: bool) -> PyResult { - Ok(self.series.argsort(descending)?.into()) + pub fn argsort(&self, descending: bool, nulls_first: bool) -> PyResult { + Ok(self.series.argsort(descending, nulls_first)?.into()) } pub fn hash(&self, seed: Option) -> PyResult { @@ -696,8 +696,11 @@ impl PySeries { Ok(self.series.list_slice(&start.series, &end.series)?.into()) } - pub fn list_sort(&self, desc: &Self) -> PyResult { - Ok(self.series.list_sort(&desc.series)?.into()) + pub fn list_sort(&self, desc: &Self, nulls_first: &Self) -> PyResult { + Ok(self + .series + .list_sort(&desc.series, &nulls_first.series)? + .into()) } pub fn map_get(&self, key: &Self) -> PyResult { diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index 1efa60f3a7..a00a51d219 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -118,8 +118,8 @@ macro_rules! impl_series_like_for_data_array { fn slice(&self, start: usize, end: usize) -> DaftResult { Ok(self.0.slice(start, end)?.into_series()) } - fn sort(&self, descending: bool) -> DaftResult { - Ok(self.0.sort(descending)?.into_series()) + fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + Ok(self.0.sort(descending, nulls_first)?.into_series()) } fn str_value(&self, idx: usize) -> DaftResult { self.0.str_value(idx) diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index 85316f0ec4..07df8923f2 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -116,8 +116,8 @@ macro_rules! impl_series_like_for_logical_array { Ok($da::new(self.0.field.clone(), new_array).into_series()) } - fn sort(&self, descending: bool) -> DaftResult { - Ok(self.0.sort(descending)?.into_series()) + fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + Ok(self.0.sort(descending, nulls_first)?.into_series()) } fn str_value(&self, idx: usize) -> DaftResult { diff --git a/src/daft-core/src/series/array_impl/nested_array.rs b/src/daft-core/src/series/array_impl/nested_array.rs index 1bd618e616..be60ee2f26 100644 --- a/src/daft-core/src/series/array_impl/nested_array.rs +++ b/src/daft-core/src/series/array_impl/nested_array.rs @@ -121,7 +121,7 @@ macro_rules! impl_series_like_for_nested_arrays { Ok(self.0.not_null()?.into_series()) } - fn sort(&self, _descending: bool) -> DaftResult { + fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { Err(DaftError::ValueError(format!( "Cannot sort a {}", stringify!($da) diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 81a4788067..7ed940ca3d 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -159,14 +159,16 @@ impl Series { } } - pub fn list_sort(&self, desc: &Self) -> DaftResult { + pub fn list_sort(&self, desc: &Self, nulls_first: &Self) -> DaftResult { let desc_arr = desc.bool()?; + let nulls_first = nulls_first.bool()?; match self.data_type() { - DataType::List(_) => Ok(self.list()?.list_sort(desc_arr)?.into_series()), - DataType::FixedSizeList(..) => { - Ok(self.fixed_size_list()?.list_sort(desc_arr)?.into_series()) - } + DataType::List(_) => Ok(self.list()?.list_sort(desc_arr, nulls_first)?.into_series()), + DataType::FixedSizeList(..) => Ok(self + .fixed_size_list()? + .list_sort(desc_arr, nulls_first)? + .into_series()), dt => Err(DaftError::TypeError(format!( "List sort not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/sort.rs b/src/daft-core/src/series/ops/sort.rs index 48ad1288ba..b1f627c389 100644 --- a/src/daft-core/src/series/ops/sort.rs +++ b/src/daft-core/src/series/ops/sort.rs @@ -2,19 +2,26 @@ use common_error::{DaftError, DaftResult}; use crate::{ series::{array_impl::IntoSeries, Series}, + utils::{ensure_nulls_first, ensure_nulls_first_arr}, with_match_comparable_daft_types, }; impl Series { - pub fn argsort(&self, descending: bool) -> DaftResult { + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult { + ensure_nulls_first(descending, nulls_first)?; let series = self.as_physical()?; with_match_comparable_daft_types!(series.data_type(), |$T| { let downcasted = series.downcast::<<$T as DaftDataType>::ArrayType>()?; - Ok(downcasted.argsort::(descending)?.into_series()) + Ok(downcasted.argsort::(descending, nulls_first)?.into_series()) }) } - pub fn argsort_multikey(sort_keys: &[Self], descending: &[bool]) -> DaftResult { + pub fn argsort_multikey( + sort_keys: &[Self], + descending: &[bool], + nulls_first: &[bool], + ) -> DaftResult { + ensure_nulls_first_arr(descending, nulls_first)?; if sort_keys.len() != descending.len() { return Err(DaftError::ValueError(format!( "sort_keys and descending length must match, got {} vs {}", @@ -27,18 +34,18 @@ impl Series { return sort_keys .first() .unwrap() - .argsort(*descending.first().unwrap()); + .argsort(*descending.first().unwrap(), *nulls_first.first().unwrap()); } let first = sort_keys.first().unwrap().as_physical()?; with_match_comparable_daft_types!(first.data_type(), |$T| { let downcasted = first.downcast::<<$T as DaftDataType>::ArrayType>()?; - let result = downcasted.argsort_multikey::(&sort_keys[1..], descending)?; + let result = downcasted.argsort_multikey::(&sort_keys[1..], descending, nulls_first)?; Ok(result.into_series()) }) } - pub fn sort(&self, descending: bool) -> DaftResult { - self.inner.sort(descending) + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + self.inner.sort(descending, nulls_first) } } diff --git a/src/daft-core/src/series/series_like.rs b/src/daft-core/src/series/series_like.rs index 9d152693d9..cc7dc6cd63 100644 --- a/src/daft-core/src/series/series_like.rs +++ b/src/daft-core/src/series/series_like.rs @@ -29,7 +29,7 @@ pub trait SeriesLike: Send + Sync + Any + std::fmt::Debug { fn size_bytes(&self) -> DaftResult; fn is_null(&self) -> DaftResult; fn not_null(&self) -> DaftResult; - fn sort(&self, descending: bool) -> DaftResult; + fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult; fn head(&self, num: usize) -> DaftResult; fn slice(&self, start: usize, end: usize) -> DaftResult; fn take(&self, idx: &Series) -> DaftResult; diff --git a/src/daft-core/src/utils/mod.rs b/src/daft-core/src/utils/mod.rs index baf1dc66fd..031220c65f 100644 --- a/src/daft-core/src/utils/mod.rs +++ b/src/daft-core/src/utils/mod.rs @@ -1,6 +1,31 @@ +use common_error::{DaftError, DaftResult}; + pub mod arrow; pub mod display; pub mod dyn_compare; pub mod identity_hash_set; pub mod stats; pub mod supertype; + +/// Ensure that the nulls_first parameter is compatible with the descending parameter. +/// TODO: remove this function once nulls_first is implemented. +pub(crate) fn ensure_nulls_first(descending: bool, nulls_first: bool) -> DaftResult<()> { + if nulls_first != descending { + return Err(DaftError::NotImplemented( + "nulls_first is not implemented".to_string(), + )); + } + Ok(()) +} + +/// Ensure that the nulls_first parameter is compatible with the descending parameter. +/// TODO: remove this function once nulls_first is implemented. +pub(crate) fn ensure_nulls_first_arr(descending: &[bool], nulls_first: &[bool]) -> DaftResult<()> { + if nulls_first.iter().zip(descending).any(|(a, b)| a != b) { + return Err(DaftError::NotImplemented( + "nulls_first is not implemented".to_string(), + )); + } + + Ok(()) +} diff --git a/src/daft-functions/src/list/sort.rs b/src/daft-functions/src/list/sort.rs index 2d1ef45afb..b4b5a77814 100644 --- a/src/daft-functions/src/list/sort.rs +++ b/src/daft-functions/src/list/sort.rs @@ -21,7 +21,7 @@ impl ScalarUDF for ListSort { fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { - [data, desc] => match (data.to_field(schema), desc.to_field(schema)) { + [data, desc, _nulls_first] => match (data.to_field(schema), desc.to_field(schema)) { (Ok(field), Ok(desc_field)) => match (&field.dtype, &desc_field.dtype) { ( l @ (DataType::List(_) | DataType::FixedSizeList(_, _)), @@ -34,7 +34,7 @@ impl ScalarUDF for ListSort { (Err(e), _) | (_, Err(e)) => Err(e), }, _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", + "Expected 3 input args, got {}", inputs.len() ))), } @@ -42,9 +42,9 @@ impl ScalarUDF for ListSort { fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [data, desc] => data.list_sort(desc), + [data, desc, nulls_first] => data.list_sort(desc, nulls_first), _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", + "Expected 3 input args, got {}", inputs.len() ))), } @@ -52,9 +52,10 @@ impl ScalarUDF for ListSort { } #[must_use] -pub fn list_sort(input: ExprRef, desc: Option) -> ExprRef { +pub fn list_sort(input: ExprRef, desc: Option, nulls_first: Option) -> ExprRef { let desc = desc.unwrap_or_else(|| lit(false)); - ScalarFunction::new(ListSort {}, vec![input, desc]).into() + let nulls_first = nulls_first.unwrap_or_else(|| desc.clone()); + ScalarFunction::new(ListSort {}, vec![input, desc, nulls_first]).into() } #[cfg(feature = "python")] @@ -66,6 +67,6 @@ use { #[cfg(feature = "python")] #[pyfunction] #[pyo3(name = "list_sort")] -pub fn py_list_sort(expr: PyExpr, desc: PyExpr) -> PyResult { - Ok(list_sort(expr.into(), Some(desc.into())).into()) +pub fn py_list_sort(expr: PyExpr, desc: PyExpr, nulls_first: PyExpr) -> PyResult { + Ok(list_sort(expr.into(), Some(desc.into()), Some(nulls_first.into())).into()) } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index a54e6d8b85..bb98c50e99 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -291,9 +291,10 @@ pub fn physical_plan_to_pipeline( input, sort_by, descending, + nulls_first, .. }) => { - let sort_sink = SortSink::new(sort_by.clone(), descending.clone()); + let sort_sink = SortSink::new(sort_by.clone(), descending.clone(), nulls_first.clone()); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; BlockingSinkNode::new(Arc::new(sort_sink), child_node).boxed() } diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 9c2b3f3944..9aa97a2ef8 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -46,17 +46,19 @@ impl BlockingSinkState for SortState { struct SortParams { sort_by: Vec, descending: Vec, + nulls_first: Vec, } pub struct SortSink { params: Arc, } impl SortSink { - pub fn new(sort_by: Vec, descending: Vec) -> Self { + pub fn new(sort_by: Vec, descending: Vec, nulls_first: Vec) -> Self { Self { params: Arc::new(SortParams { sort_by, descending, + nulls_first, }), } } @@ -95,7 +97,11 @@ impl BlockingSink for SortSink { state.finalize() }); let concated = MicroPartition::concat(parts)?; - let sorted = Arc::new(concated.sort(¶ms.sort_by, ¶ms.descending)?); + let sorted = Arc::new(concated.sort( + ¶ms.sort_by, + ¶ms.descending, + ¶ms.nulls_first, + )?); Ok(Some(sorted)) }) .into() diff --git a/src/daft-local-plan/src/plan.rs b/src/daft-local-plan/src/plan.rs index cf1cd91a97..cefdace1e3 100644 --- a/src/daft-local-plan/src/plan.rs +++ b/src/daft-local-plan/src/plan.rs @@ -226,11 +226,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, sort_by: Vec, descending: Vec, + nulls_first: Vec, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Sort(Sort { input, sort_by, + nulls_first, descending, schema, plan_stats: PlanStats {}, @@ -429,6 +431,7 @@ pub struct Sort { pub input: LocalPhysicalPlanRef, pub sort_by: Vec, pub descending: Vec, + pub nulls_first: Vec, pub schema: SchemaRef, pub plan_stats: PlanStats, } diff --git a/src/daft-local-plan/src/translate.rs b/src/daft-local-plan/src/translate.rs index de77ce5d70..ffe03edd20 100644 --- a/src/daft-local-plan/src/translate.rs +++ b/src/daft-local-plan/src/translate.rs @@ -107,6 +107,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, sort.sort_by.clone(), sort.descending.clone(), + sort.nulls_first.clone(), )) } LogicalPlan::Join(join) => { diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index d82186c60e..c251537ff2 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -295,9 +295,14 @@ impl LogicalPlanBuilder { Ok(self.with_new_plan(logical_plan)) } - pub fn sort(&self, sort_by: Vec, descending: Vec) -> DaftResult { + pub fn sort( + &self, + sort_by: Vec, + descending: Vec, + nulls_first: Vec, + ) -> DaftResult { let logical_plan: LogicalPlan = - ops::Sort::try_new(self.plan.clone(), sort_by, descending)?.into(); + ops::Sort::try_new(self.plan.clone(), sort_by, descending, nulls_first)?.into(); Ok(self.with_new_plan(logical_plan)) } @@ -700,10 +705,15 @@ impl PyLogicalPlanBuilder { .into()) } - pub fn sort(&self, sort_by: Vec, descending: Vec) -> PyResult { + pub fn sort( + &self, + sort_by: Vec, + descending: Vec, + nulls_first: Vec, + ) -> PyResult { Ok(self .builder - .sort(pyexprs_to_exprs(sort_by), descending)? + .sort(pyexprs_to_exprs(sort_by), descending, nulls_first)? .into()) } diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 6bcdba360b..be83f5237b 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -96,7 +96,7 @@ mod test { .limit(1000, false)? .add_monotonically_increasing_id(None)? .distinct()? - .sort(vec![col("last_name")], vec![false])? + .sort(vec![col("last_name")], vec![false], vec![false])? .build(); let plan = LogicalPlanBuilder::new(subplan, None) @@ -132,7 +132,7 @@ Num partitions = 0 Output schema = text#Utf8, id#Int32"] Source5 --> Filter4 Filter4 --> Join3 -Sort6["Sort: Sort by = (col(last_name), ascending)"] +Sort6["Sort: Sort by = (col(last_name), ascending, nulls last)"] Distinct7["Distinct"] MonotonicallyIncreasingId8["MonotonicallyIncreasingId"] Limit9["Limit: 1000"] @@ -170,7 +170,7 @@ Project1 --> Limit0 .limit(1000, false)? .add_monotonically_increasing_id(None)? .distinct()? - .sort(vec![col("last_name")], vec![false])? + .sort(vec![col("last_name")], vec![false], vec![false])? .build(); let plan = LogicalPlanBuilder::new(subplan, None) diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index b9ca2e449c..01c6b510c8 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -260,7 +260,7 @@ impl LogicalPlan { Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::try_new(input.clone(), predicate.clone()).unwrap()), Self::Limit(Limit { limit, eager, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager)), Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), - Self::Sort(Sort { sort_by, descending, .. }) => Self::Sort(Sort::try_new(input.clone(), sort_by.clone(), descending.clone()).unwrap()), + Self::Sort(Sort { sort_by, descending, nulls_first, .. }) => Self::Sort(Sort::try_new(input.clone(), sort_by.clone(), descending.clone(), nulls_first.clone()).unwrap()), Self::Repartition(Repartition { repartition_spec: scheme_config, .. }) => Self::Repartition(Repartition::try_new(input.clone(), scheme_config.clone()).unwrap()), Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())), Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()), diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index 1c722a85f7..85cd8c2a64 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -14,6 +14,7 @@ pub struct Sort { pub input: Arc, pub sort_by: Vec, pub descending: Vec, + pub nulls_first: Vec, } impl Sort { @@ -21,6 +22,7 @@ impl Sort { input: Arc, sort_by: Vec, descending: Vec, + nulls_first: Vec, ) -> logical_plan::Result { if sort_by.is_empty() { return Err(DaftError::ValueError( @@ -51,6 +53,7 @@ impl Sort { input, sort_by, descending, + nulls_first, }) } @@ -62,7 +65,15 @@ impl Sort { .sort_by .iter() .zip(self.descending.iter()) - .map(|(sb, d)| format!("({}, {})", sb, if *d { "descending" } else { "ascending" },)) + .zip(self.nulls_first.iter()) + .map(|((sb, d), nf)| { + format!( + "({}, {}, {})", + sb, + if *d { "descending" } else { "ascending" }, + if *nf { "nulls first" } else { "nulls last" } + ) + }) .join(", "); res.push(format!("Sort: Sort by = {}", pairs)); res diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 7af3ffdec7..2cf2ea14ad 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -557,8 +557,9 @@ mod tests { let pred = col("a").lt(lit(2)); let sort_by = vec![col("a")]; let descending = vec![true]; + let nulls_first = vec![false]; let plan = scan_plan - .sort(sort_by.clone(), descending.clone())? + .sort(sort_by.clone(), descending.clone(), nulls_first.clone())? .filter(pred.clone())? .build(); let expected_filter_scan = if push_into_scan { @@ -566,7 +567,9 @@ mod tests { } else { scan_plan.filter(pred)? }; - let expected = expected_filter_scan.sort(sort_by, descending)?.build(); + let expected = expected_filter_scan + .sort(sort_by, descending, nulls_first)? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } diff --git a/src/daft-micropartition/src/ops/sort.rs b/src/daft-micropartition/src/ops/sort.rs index e8a1ea78b1..f2be0bb38e 100644 --- a/src/daft-micropartition/src/ops/sort.rs +++ b/src/daft-micropartition/src/ops/sort.rs @@ -9,14 +9,19 @@ use daft_table::Table; use crate::micropartition::MicroPartition; impl MicroPartition { - pub fn sort(&self, sort_keys: &[ExprRef], descending: &[bool]) -> DaftResult { + pub fn sort( + &self, + sort_keys: &[ExprRef], + descending: &[bool], + nulls_first: &[bool], + ) -> DaftResult { let io_stats = IOStatsContext::new("MicroPartition::sort"); let tables = self.concat_or_get(io_stats)?; match tables.as_slice() { [] => Ok(Self::empty(Some(self.schema.clone()))), [single] => { - let sorted = single.sort(sort_keys, descending)?; + let sorted = single.sort(sort_keys, descending, nulls_first)?; Ok(Self::new_loaded( self.schema.clone(), Arc::new(vec![sorted]), @@ -27,16 +32,21 @@ impl MicroPartition { } } - pub fn argsort(&self, sort_keys: &[ExprRef], descending: &[bool]) -> DaftResult { + pub fn argsort( + &self, + sort_keys: &[ExprRef], + descending: &[bool], + nulls_first: &[bool], + ) -> DaftResult { let io_stats = IOStatsContext::new("MicroPartition::argsort"); let tables = self.concat_or_get(io_stats)?; match tables.as_slice() { [] => { let empty_table = Table::empty(Some(self.schema.clone()))?; - empty_table.argsort(sort_keys, descending) + empty_table.argsort(sort_keys, descending, nulls_first) } - [single] => single.argsort(sort_keys, descending), + [single] => single.argsort(sort_keys, descending, nulls_first), _ => unreachable!(), } } diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 39bc7ad5c5..7de380dd93 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -184,6 +184,7 @@ impl PyMicroPartition { py: Python, sort_keys: Vec, descending: Vec, + nulls_first: Vec, ) -> PyResult { let converted_exprs: Vec = sort_keys .into_iter() @@ -192,7 +193,11 @@ impl PyMicroPartition { py.allow_threads(|| { Ok(self .inner - .sort(converted_exprs.as_slice(), descending.as_slice())? + .sort( + converted_exprs.as_slice(), + descending.as_slice(), + nulls_first.as_slice(), + )? .into()) }) } @@ -202,6 +207,7 @@ impl PyMicroPartition { py: Python, sort_keys: Vec, descending: Vec, + nulls_first: Vec, ) -> PyResult { let converted_exprs: Vec = sort_keys .into_iter() @@ -210,7 +216,11 @@ impl PyMicroPartition { py.allow_threads(|| { Ok(self .inner - .argsort(converted_exprs.as_slice(), descending.as_slice())? + .argsort( + converted_exprs.as_slice(), + descending.as_slice(), + nulls_first.as_slice(), + )? .into()) }) } diff --git a/src/daft-physical-plan/src/ops/sort.rs b/src/daft-physical-plan/src/ops/sort.rs index 0b6cda2ffc..777f61043d 100644 --- a/src/daft-physical-plan/src/ops/sort.rs +++ b/src/daft-physical-plan/src/ops/sort.rs @@ -10,6 +10,7 @@ pub struct Sort { pub input: PhysicalPlanRef, pub sort_by: Vec, pub descending: Vec, + pub nulls_first: Vec, pub num_partitions: usize, } @@ -18,12 +19,14 @@ impl Sort { input: PhysicalPlanRef, sort_by: Vec, descending: Vec, + nulls_first: Vec, num_partitions: usize, ) -> Self { Self { input, sort_by, descending, + nulls_first, num_partitions, } } @@ -36,7 +39,15 @@ impl Sort { .sort_by .iter() .zip(self.descending.iter()) - .map(|(sb, d)| format!("({}, {})", sb, if *d { "descending" } else { "ascending" },)) + .zip(self.nulls_first.iter()) + .map(|((sb, d), nf)| { + format!( + "({}, {}, {})", + sb, + if *d { "descending" } else { "ascending" }, + if *nf { "nulls first" } else { "nulls last" } + ) + }) .join(", "); res.push(format!("Sort: Sort by = {}", pairs)); res.push(format!("Num partitions = {}", self.num_partitions)); diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index b3790111f2..9bffaef97b 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -141,6 +141,7 @@ pub(super) fn translate_single_logical_node( LogicalPlan::Sort(LogicalSort { sort_by, descending, + nulls_first, .. }) => { let input_physical = physical_children.pop().expect("requires 1 input"); @@ -149,6 +150,7 @@ pub(super) fn translate_single_logical_node( input_physical, sort_by.clone(), descending.clone(), + nulls_first.clone(), num_partitions, )) .arced()) @@ -583,6 +585,7 @@ pub(super) fn translate_single_logical_node( left_physical, left_on.clone(), std::iter::repeat(false).take(left_on.len()).collect(), + std::iter::repeat(false).take(left_on.len()).collect(), num_partitions, )) .arced(); @@ -592,6 +595,7 @@ pub(super) fn translate_single_logical_node( right_physical, right_on.clone(), std::iter::repeat(false).take(right_on.len()).collect(), + std::iter::repeat(false).take(right_on.len()).collect(), num_partitions, )) .arced(); diff --git a/src/daft-physical-plan/src/plan.rs b/src/daft-physical-plan/src/plan.rs index 2d764dce28..740905b6e8 100644 --- a/src/daft-physical-plan/src/plan.rs +++ b/src/daft-physical-plan/src/plan.rs @@ -431,6 +431,7 @@ impl PhysicalPlan { Self::Project(Project::new_with_clustering_spec( input.clone(), projection.clone(), clustering_spec.clone(), ).unwrap()), + Self::ActorPoolProject(ActorPoolProject {projection, ..}) => Self::ActorPoolProject(ActorPoolProject::try_new(input.clone(), projection.clone()).unwrap()), Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::new(input.clone(), predicate.clone())), Self::Limit(Limit { limit, eager, num_partitions, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager, *num_partitions)), @@ -438,7 +439,7 @@ impl PhysicalPlan { Self::Unpivot(Unpivot { ids, values, variable_name, value_name, .. }) => Self::Unpivot(Unpivot::new(input.clone(), ids.clone(), values.clone(), variable_name, value_name)), Self::Pivot(Pivot { group_by, pivot_column, value_column, names, .. }) => Self::Pivot(Pivot::new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), names.clone())), Self::Sample(Sample { fraction, with_replacement, seed, .. }) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), - Self::Sort(Sort { sort_by, descending, num_partitions, .. }) => Self::Sort(Sort::new(input.clone(), sort_by.clone(), descending.clone(), *num_partitions)), + Self::Sort(Sort { sort_by, descending, nulls_first, num_partitions, .. }) => Self::Sort(Sort::new(input.clone(), sort_by.clone(), descending.clone(),nulls_first.clone(), *num_partitions)), Self::ShuffleExchange(ShuffleExchange { strategy, .. }) => Self::ShuffleExchange(ShuffleExchange { input: input.clone(), strategy: strategy.clone() }), Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::new(input.clone(), aggregations.clone(), groupby.clone())), Self::TabularWriteParquet(TabularWriteParquet { schema, file_info, .. }) => Self::TabularWriteParquet(TabularWriteParquet::new(schema.clone(), file_info.clone(), input.clone())), diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 05ede0e498..93873a3ab3 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -466,6 +466,7 @@ fn physical_plan_to_partition_tasks( input, sort_by, descending, + nulls_first, num_partitions, }) => { let upstream_iter = @@ -481,6 +482,7 @@ fn physical_plan_to_partition_tasks( upstream_iter, sort_by_pyexprs, descending.clone(), + nulls_first.clone(), *num_partitions, ))?; Ok(py_iter.into()) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 8203d71d3b..75bb2b07c0 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -230,7 +230,7 @@ mod tests { let expected = LogicalPlanBuilder::new(tbl_1, None) .select(vec![col("utf8")])? - .sort(vec![col("utf8")], vec![true])? + .sort(vec![col("utf8")], vec![true], vec![true])? .build(); assert_eq!(plan, expected); diff --git a/src/daft-sql/src/modules/list.rs b/src/daft-sql/src/modules/list.rs index bd6db25990..8523863aac 100644 --- a/src/daft-sql/src/modules/list.rs +++ b/src/daft-sql/src/modules/list.rs @@ -304,7 +304,7 @@ impl SQLFunction for SQLListSort { match inputs { [input] => { let input = planner.plan_function_arg(input)?; - Ok(daft_functions::list::sort(input, None)) + Ok(daft_functions::list::sort(input, None, None)) } [input, order] => { let input = planner.plan_function_arg(input)?; @@ -323,7 +323,7 @@ impl SQLFunction for SQLListSort { } _ => unsupported_sql_err!("invalid order for list_sort"), }; - Ok(daft_functions::list::sort(input, Some(order))) + Ok(daft_functions::list::sort(input, Some(order), None)) } _ => unsupported_sql_err!( "invalid arguments for list_sort. Expected list_sort(expr, ASC|DESC)" diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 5d91923235..1eb1169b5a 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -422,7 +422,7 @@ impl<'a> SQLPlanner<'a> { unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); }; - let (orderby_exprs, orderby_desc) = + let (orderby_exprs, orderby_desc, orderby_nulls_first) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; for expr in &orderby_exprs { @@ -449,7 +449,9 @@ impl<'a> SQLPlanner<'a> { rel.inner = rel.inner.select(projections)?; } - rel.inner = rel.inner.sort(orderby_exprs, orderby_desc)?; + rel.inner = rel + .inner + .sort(orderby_exprs, orderby_desc, orderby_nulls_first)?; if needs_projection { rel.inner = rel.inner.select(final_projection)?; @@ -477,10 +479,12 @@ impl<'a> SQLPlanner<'a> { // these are orderbys that are part of the final projection let mut orderbys_after_projection = Vec::new(); let mut orderbys_after_projection_desc = Vec::new(); + let mut orderbys_after_projection_nulls_first = Vec::new(); // these are orderbys that are not part of the final projection let mut orderbys_before_projection = Vec::new(); let mut orderbys_before_projection_desc = Vec::new(); + let mut orderbys_before_projection_nulls_first = Vec::new(); for p in projections { let fld = p.to_field(schema)?; @@ -518,7 +522,7 @@ impl<'a> SQLPlanner<'a> { unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); }; - let (exprs, desc) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; + let (exprs, desc, nulls_first) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; for (i, expr) in exprs.iter().enumerate() { // the orderby is ordered by a column of the projection @@ -553,12 +557,14 @@ impl<'a> SQLPlanner<'a> { // ex: SELECT count(*) as c FROM t ORDER BY count(*) orderbys_after_projection.push(col(alias.as_ref())); orderbys_after_projection_desc.push(desc[i]); + orderbys_after_projection_nulls_first.push(nulls_first[i]); } else { // its a count(*) that is not in the final projection // ex: SELECT sum(n) FROM t ORDER BY count(*); aggs.push(expr.clone()); orderbys_before_projection.push(col(fld.name.as_ref())); orderbys_before_projection_desc.push(desc[i]); + orderbys_before_projection_nulls_first.push(nulls_first[i]); } } } else if has_agg(expr) { @@ -566,9 +572,11 @@ impl<'a> SQLPlanner<'a> { // so we just need to push the column name orderbys_after_projection.push(col(fld.name.as_ref())); orderbys_after_projection_desc.push(desc[i]); + orderbys_after_projection_nulls_first.push(nulls_first[i]); } else { orderbys_after_projection.push(expr.clone()); orderbys_after_projection_desc.push(desc[i]); + orderbys_after_projection_nulls_first.push(nulls_first[i]); } // the orderby is ordered by an expr from the original schema @@ -594,6 +602,7 @@ impl<'a> SQLPlanner<'a> { }) { orderbys_after_projection.push(col(alias.as_ref())); orderbys_after_projection_desc.push(desc[i]); + orderbys_after_projection_nulls_first.push(nulls_first[i]); } else { // its an aggregate that is not part of the final projection // ex: SELECT sum(a) FROM t ORDER BY sum(b) @@ -603,6 +612,7 @@ impl<'a> SQLPlanner<'a> { // then add it to the orderbys that are not part of the final projection orderbys_before_projection.push(col(fld.name.as_ref())); orderbys_before_projection_desc.push(desc[i]); + orderbys_before_projection_nulls_first.push(nulls_first[i]); } } else { // we know it's a column of the original schema @@ -612,6 +622,7 @@ impl<'a> SQLPlanner<'a> { orderbys_before_projection.push(col(fld.name.as_ref())); orderbys_before_projection_desc.push(desc[i]); + orderbys_before_projection_nulls_first.push(nulls_first[i]); } } else { panic!("unexpected order by expr"); @@ -632,9 +643,11 @@ impl<'a> SQLPlanner<'a> { // order bys that are not in the final projection if has_orderby_before_projection { - rel.inner = rel - .inner - .sort(orderbys_before_projection, orderbys_before_projection_desc)?; + rel.inner = rel.inner.sort( + orderbys_before_projection, + orderbys_before_projection_desc, + orderbys_before_projection_nulls_first, + )?; } // apply the final projection @@ -642,9 +655,11 @@ impl<'a> SQLPlanner<'a> { // order bys that are in the final projection if has_orderby_after_projection { - rel.inner = rel - .inner - .sort(orderbys_after_projection, orderbys_after_projection_desc)?; + rel.inner = rel.inner.sort( + orderbys_after_projection, + orderbys_after_projection_desc, + orderbys_after_projection_nulls_first, + )?; } Ok(()) } @@ -652,25 +667,64 @@ impl<'a> SQLPlanner<'a> { fn plan_order_by_exprs( &self, expr: &[sqlparser::ast::OrderByExpr], - ) -> SQLPlannerResult<(Vec, Vec)> { + ) -> SQLPlannerResult<(Vec, Vec, Vec)> { if expr.is_empty() { unsupported_sql_err!("ORDER BY []"); } let mut exprs = Vec::with_capacity(expr.len()); let mut desc = Vec::with_capacity(expr.len()); + let mut nulls_first = Vec::with_capacity(expr.len()); for order_by_expr in expr { - if order_by_expr.nulls_first.is_some() { - unsupported_sql_err!("NULLS FIRST"); - } + match (order_by_expr.asc, order_by_expr.nulls_first) { + // --------------------------- + // all of these are equivalent + // --------------------------- + // ORDER BY expr + (None, None) | + // ORDER BY expr ASC + (Some(true), None) | + // ORDER BY expr NULLS LAST + (None, Some(false)) | + // ORDER BY expr ASC NULLS LAST + (Some(true), Some(false)) => { + nulls_first.push(false); + desc.push(false); + }, + // --------------------------- + + + // --------------------------- + // ORDER BY expr NULLS FIRST + (None, Some(true)) | + // ORDER BY expr ASC NULLS FIRST + (Some(true), Some(true)) => { + nulls_first.push(true); + desc.push(false); + } + // --------------------------- + + // ORDER BY expr DESC + (Some(false), None) | + // ORDER BY expr DESC NULLS FIRST + (Some(false), Some(true)) => { + nulls_first.push(true); + desc.push(true); + }, + // ORDER BY expr DESC NULLS LAST + (Some(false), Some(false)) => { + nulls_first.push(false); + desc.push(true); + } + + }; if order_by_expr.with_fill.is_some() { unsupported_sql_err!("WITH FILL"); } let expr = self.plan_expr(&order_by_expr.expr)?; - desc.push(!order_by_expr.asc.unwrap_or(true)); exprs.push(expr); } - Ok((exprs, desc)) + Ok((exprs, desc, nulls_first)) } fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult { diff --git a/src/daft-table/src/ops/groups.rs b/src/daft-table/src/ops/groups.rs index 580d2c4288..7239c5a35c 100644 --- a/src/daft-table/src/ops/groups.rs +++ b/src/daft-table/src/ops/groups.rs @@ -54,8 +54,11 @@ impl Table { // ) // Begin by doing the argsort. - let argsort_series = - Series::argsort_multikey(self.columns.as_slice(), &vec![false; self.columns.len()])?; + let argsort_series = Series::argsort_multikey( + self.columns.as_slice(), + &vec![false; self.columns.len()], + &vec![false; self.columns.len()], + )?; let argsort_array = argsort_series.downcast::()?; // The result indices. diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 97c034a9df..f66de100ec 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -130,6 +130,10 @@ impl Table { .take(left_on.len()) .collect::>() .as_slice(), + std::iter::repeat(false) + .take(left_on.len()) + .collect::>() + .as_slice(), )?; if right_on.is_empty() { return Err(DaftError::ValueError( @@ -142,6 +146,10 @@ impl Table { .take(right_on.len()) .collect::>() .as_slice(), + std::iter::repeat(false) + .take(right_on.len()) + .collect::>() + .as_slice(), )?; return left.sort_merge_join(&right, left_on, right_on, true); diff --git a/src/daft-table/src/ops/sort.rs b/src/daft-table/src/ops/sort.rs index 20e56c3b0b..540013a02e 100644 --- a/src/daft-table/src/ops/sort.rs +++ b/src/daft-table/src/ops/sort.rs @@ -5,12 +5,22 @@ use daft_dsl::ExprRef; use crate::Table; impl Table { - pub fn sort(&self, sort_keys: &[ExprRef], descending: &[bool]) -> DaftResult { - let argsort = self.argsort(sort_keys, descending)?; + pub fn sort( + &self, + sort_keys: &[ExprRef], + descending: &[bool], + nulls_first: &[bool], + ) -> DaftResult { + let argsort = self.argsort(sort_keys, descending, nulls_first)?; self.take(&argsort) } - pub fn argsort(&self, sort_keys: &[ExprRef], descending: &[bool]) -> DaftResult { + pub fn argsort( + &self, + sort_keys: &[ExprRef], + descending: &[bool], + nulls_first: &[bool], + ) -> DaftResult { if sort_keys.len() != descending.len() { return Err(DaftError::ValueError(format!( "sort_keys and descending length must match, got {} vs {}", @@ -20,10 +30,10 @@ impl Table { } if sort_keys.len() == 1 { self.eval_expression(sort_keys.first().unwrap())? - .argsort(*descending.first().unwrap()) + .argsort(*descending.first().unwrap(), *nulls_first.first().unwrap()) } else { let expr_result = self.eval_expression_list(sort_keys)?; - Series::argsort_multikey(expr_result.columns.as_slice(), descending) + Series::argsort_multikey(expr_result.columns.as_slice(), descending, nulls_first) } } } diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 54e9db65f6..c2c0cc622a 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -54,6 +54,7 @@ impl PyTable { py: Python, sort_keys: Vec, descending: Vec, + nulls_first: Vec, ) -> PyResult { let converted_exprs: Vec = sort_keys .into_iter() @@ -62,7 +63,11 @@ impl PyTable { py.allow_threads(|| { Ok(self .table - .sort(converted_exprs.as_slice(), descending.as_slice())? + .sort( + converted_exprs.as_slice(), + descending.as_slice(), + nulls_first.as_slice(), + )? .into()) }) } @@ -72,6 +77,7 @@ impl PyTable { py: Python, sort_keys: Vec, descending: Vec, + nulls_first: Vec, ) -> PyResult { let converted_exprs: Vec = sort_keys .into_iter() @@ -80,7 +86,11 @@ impl PyTable { py.allow_threads(|| { Ok(self .table - .argsort(converted_exprs.as_slice(), descending.as_slice())? + .argsort( + converted_exprs.as_slice(), + descending.as_slice(), + nulls_first.as_slice(), + )? .into()) }) }