From a9564b472ed59ff4ebd73bb54c54715412185565 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 14 Nov 2024 16:35:15 -0800 Subject: [PATCH] chore: wire up all nulls_first logic --- daft/daft/__init__.pyi | 18 +- daft/dataframe/dataframe.py | 4 +- daft/expressions/expressions.py | 8 +- daft/logical/builder.py | 13 +- daft/series.py | 22 ++- daft/table/micropartition.py | 31 +++- daft/table/table.py | 44 ++++- src/daft-core/src/array/ops/list.rs | 40 ++++- src/daft-core/src/array/ops/sort.rs | 162 +++++++++++++----- src/daft-core/src/python/series.rs | 7 +- .../src/series/array_impl/data_array.rs | 4 +- .../src/series/array_impl/logical_array.rs | 4 +- .../src/series/array_impl/nested_array.rs | 2 +- src/daft-core/src/series/ops/list.rs | 12 +- src/daft-core/src/series/ops/sort.rs | 4 +- src/daft-functions/src/list/sort.rs | 17 +- src/daft-local-execution/src/pipeline.rs | 3 +- src/daft-local-execution/src/sinks/sort.rs | 6 +- src/daft-logical-plan/src/builder.rs | 7 +- src/daft-logical-plan/src/ops/sort.rs | 10 +- .../optimization/rules/push_down_filter.rs | 4 +- src/daft-micropartition/src/ops/sort.rs | 11 +- src/daft-micropartition/src/python.rs | 7 +- src/daft-physical-plan/src/ops/sort.rs | 10 +- src/daft-sql/src/modules/list.rs | 4 +- src/daft-sql/src/planner.rs | 1 - src/daft-table/src/ops/groups.rs | 7 +- src/daft-table/src/ops/joins/mod.rs | 8 + src/daft-table/src/ops/sort.rs | 2 +- src/daft-table/src/python.rs | 14 +- 30 files changed, 361 insertions(+), 125 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 3598a4042d..da0c4bbeea 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1308,7 +1308,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ... # expr.list namespace # --- def explode(expr: PyExpr) -> PyExpr: ... -def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... +def list_sort(expr: PyExpr, desc: PyExpr, nulls_first: PyExpr) -> PyExpr: ... def list_value_counts(expr: PyExpr) -> PyExpr: ... def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ... def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ... @@ -1354,8 +1354,8 @@ class PySeries: def take(self, idx: PySeries) -> PySeries: ... def slice(self, start: int, end: int) -> PySeries: ... def filter(self, mask: PySeries) -> PySeries: ... - def sort(self, descending: bool) -> PySeries: ... - def argsort(self, descending: bool) -> PySeries: ... + def sort(self, descending: bool, nulls_first: bool) -> PySeries: ... + def argsort(self, descending: bool, nulls_first: bool) -> PySeries: ... def hash(self, seed: PySeries | None = None) -> PySeries: ... def minhash( self, @@ -1456,7 +1456,7 @@ class PySeries: def list_count(self, mode: CountMode) -> PySeries: ... def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ... def list_slice(self, start: PySeries, end: PySeries | None = None) -> PySeries: ... - def list_sort(self, desc: PySeries) -> PySeries: ... + def list_sort(self, desc: PySeries, nulls_first: PySeries) -> PySeries: ... def map_get(self, key: PySeries) -> PySeries: ... def if_else(self, other: PySeries, predicate: PySeries) -> PySeries: ... def is_null(self) -> PySeries: ... @@ -1474,8 +1474,8 @@ class PyTable: def eval_expression_list(self, exprs: list[PyExpr]) -> PyTable: ... def take(self, idx: PySeries) -> PyTable: ... def filter(self, exprs: list[PyExpr]) -> PyTable: ... - def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyTable: ... - def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ... + def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyTable: ... + def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ... def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyTable: ... def pivot( self, @@ -1553,8 +1553,8 @@ class PyMicroPartition: def eval_expression_list(self, exprs: list[PyExpr]) -> PyMicroPartition: ... def take(self, idx: PySeries) -> PyMicroPartition: ... def filter(self, exprs: list[PyExpr]) -> PyMicroPartition: ... - def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyMicroPartition: ... - def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ... + def sort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PyMicroPartition: ... + def argsort(self, sort_keys: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> PySeries: ... def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyMicroPartition: ... def hash_join( self, @@ -1721,7 +1721,7 @@ class LogicalPlanBuilder: variable_name: str, value_name: str, ) -> LogicalPlanBuilder: ... - def sort(self, sort_by: list[PyExpr], descending: list[bool]) -> LogicalPlanBuilder: ... + def sort(self, sort_by: list[PyExpr], descending: list[bool], nulls_first: list[bool]) -> LogicalPlanBuilder: ... def hash_repartition( self, partition_by: list[PyExpr], diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index a79443e327..6383efc556 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1525,6 +1525,7 @@ def sort( self, by: Union[ColumnInputType, List[ColumnInputType]], desc: Union[bool, List[bool]] = False, + nulls_first: Union[bool, List[bool]] = False, ) -> "DataFrame": """Sorts DataFrame globally @@ -1583,8 +1584,9 @@ def sort( by = [ by, ] + sort_by = self.__column_input_to_expression(by) - builder = self._builder.sort(sort_by=sort_by, descending=desc) + builder = self._builder.sort(sort_by=sort_by, descending=desc, nulls_first=nulls_first) return DataFrame(builder) @DataframePublicAPI diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 44bbc302e8..855daff484 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -3112,7 +3112,7 @@ def max(self) -> Expression: """ return Expression._from_pyexpr(native.list_max(self._expr)) - def sort(self, desc: bool | Expression = False) -> Expression: + def sort(self, desc: bool | Expression = False, nulls_first: bool | Expression | None = None) -> Expression: """Sorts the inner lists of a list column. Example: @@ -3141,7 +3141,11 @@ def sort(self, desc: bool | Expression = False) -> Expression: """ if isinstance(desc, bool): desc = Expression._to_expression(desc) - return Expression._from_pyexpr(_list_sort(self._expr, desc._expr)) + if nulls_first is None: + nulls_first = desc + elif isinstance(nulls_first, bool): + nulls_first = Expression._to_expression(nulls_first) + return Expression._from_pyexpr(_list_sort(self._expr, desc._expr, nulls_first._expr)) class ExpressionStructNamespace(ExpressionNamespace): diff --git a/daft/logical/builder.py b/daft/logical/builder.py index ed85a55517..412e400cfd 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -203,11 +203,20 @@ def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> L builder = self._builder.sample(fraction, with_replacement, seed) return LogicalPlanBuilder(builder) - def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> LogicalPlanBuilder: + def sort( + self, + sort_by: list[Expression], + descending: list[bool] | bool = False, + nulls_first: list[bool] | bool | None = None, + ) -> LogicalPlanBuilder: sort_by_pyexprs = [expr._expr for expr in sort_by] if not isinstance(descending, list): descending = [descending] * len(sort_by_pyexprs) - builder = self._builder.sort(sort_by_pyexprs, descending) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first] * len(sort_by_pyexprs) + builder = self._builder.sort(sort_by_pyexprs, descending, nulls_first) return LogicalPlanBuilder(builder) def hash_repartition(self, num_partitions: int | None, partition_by: list[Expression]) -> LogicalPlanBuilder: diff --git a/daft/series.py b/daft/series.py index 7053d8668e..95788d5f7f 100644 --- a/daft/series.py +++ b/daft/series.py @@ -256,17 +256,20 @@ def slice(self, start: int, end: int) -> Series: return Series._from_pyseries(self._series.slice(start, end)) - def argsort(self, descending: bool = False) -> Series: + def argsort(self, descending: bool = False, nulls_first: bool | None = None) -> Series: if not isinstance(descending, bool): raise TypeError(f"expected `descending` to be bool, got {type(descending)}") + if nulls_first is None: + nulls_first = descending - return Series._from_pyseries(self._series.argsort(descending)) + return Series._from_pyseries(self._series.argsort(descending, nulls_first)) - def sort(self, descending: bool = False) -> Series: + def sort(self, descending: bool = False, nulls_first: bool | None = None) -> Series: if not isinstance(descending, bool): raise TypeError(f"expected `descending` to be bool, got {type(descending)}") - - return Series._from_pyseries(self._series.sort(descending)) + if nulls_first is None: + nulls_first = descending + return Series._from_pyseries(self._series.sort(descending, nulls_first)) def hash(self, seed: Series | None = None) -> Series: if not isinstance(seed, Series) and seed is not None: @@ -962,10 +965,15 @@ def length(self) -> Series: def get(self, idx: Series, default: Series) -> Series: return Series._from_pyseries(self._series.list_get(idx._series, default._series)) - def sort(self, desc: bool | Series = False) -> Series: + def sort(self, desc: bool | Series = False, nulls_first: bool | Series | None = None) -> Series: if isinstance(desc, bool): desc = Series.from_pylist([desc], name="desc") - return Series._from_pyseries(self._series.list_sort(desc._series)) + if nulls_first is None: + nulls_first = desc + elif isinstance(nulls_first, bool): + nulls_first = Series.from_pylist([nulls_first], name="nulls_first") + + return Series._from_pyseries(self._series.list_sort(desc._series, nulls_first._series)) class SeriesMapNamespace(SeriesNamespace): diff --git a/daft/table/micropartition.py b/daft/table/micropartition.py index 1a0443dffc..c3297c9a6c 100644 --- a/daft/table/micropartition.py +++ b/daft/table/micropartition.py @@ -174,7 +174,12 @@ def filter(self, exprs: ExpressionsProjection) -> MicroPartition: pyexprs = [e._expr for e in exprs] return MicroPartition._from_pymicropartition(self._micropartition.filter(pyexprs)) - def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None, nulls_first: bool | list[bool] | None=None) -> MicroPartition: + def sort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> MicroPartition: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -187,7 +192,6 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | f"Expected length of `descending` to be the same length as `sort_keys` since a list was passed in," f"got {len(descending)} instead of {len(sort_keys)}" ) - else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") if nulls_first is None: @@ -364,7 +368,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str) # Compute methods (MicroPartition -> Series) ### - def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series: + def argsort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> Series: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -379,7 +388,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool ) else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") - return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending)) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first for _ in pyexprs] + elif isinstance(nulls_first, list): + if len(nulls_first) != len(sort_keys): + raise ValueError( + f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in," + f"got {len(nulls_first)} instead of {len(sort_keys)}" + ) + else: + nulls_first = [bool(x) for x in nulls_first] + else: + raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}") + return Series._from_pyseries(self._micropartition.argsort(pyexprs, descending, nulls_first)) def __reduce__(self) -> tuple: names = self.column_names() diff --git a/daft/table/table.py b/daft/table/table.py index 9ab769b337..66757cce2a 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -241,7 +241,12 @@ def filter(self, exprs: ExpressionsProjection) -> Table: pyexprs = [e._expr for e in exprs] return Table._from_pytable(self._table.filter(pyexprs)) - def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Table: + def sort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> Table: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -256,7 +261,19 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | ) else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") - return Table._from_pytable(self._table.sort(pyexprs, descending)) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first for _ in pyexprs] + elif isinstance(nulls_first, list): + if len(nulls_first) != len(sort_keys): + raise ValueError( + f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in," + f"got {len(nulls_first)} instead of {len(sort_keys)}" + ) + else: + nulls_first = [bool(x) for x in nulls_first] + return Table._from_pytable(self._table.sort(pyexprs, descending, nulls_first)) def sample( self, @@ -378,7 +395,12 @@ def add_monotonically_increasing_id(self, partition_num: int, column_name: str) # Compute methods (Table -> Series) ### - def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] | None = None) -> Series: + def argsort( + self, + sort_keys: ExpressionsProjection, + descending: bool | list[bool] | None = None, + nulls_first: bool | list[bool] | None = None, + ) -> Series: assert all(isinstance(e, Expression) for e in sort_keys) pyexprs = [e._expr for e in sort_keys] if descending is None: @@ -393,7 +415,21 @@ def argsort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool ) else: raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}") - return Series._from_pyseries(self._table.argsort(pyexprs, descending)) + if nulls_first is None: + nulls_first = descending + elif isinstance(nulls_first, bool): + nulls_first = [nulls_first for _ in pyexprs] + elif isinstance(nulls_first, list): + if len(nulls_first) != len(sort_keys): + raise ValueError( + f"Expected length of `nulls_first` to be the same length as `sort_keys` since a list was passed in," + f"got {len(nulls_first)} instead of {len(sort_keys)}" + ) + else: + nulls_first = [bool(x) for x in nulls_first] + else: + raise TypeError(f"Expected a bool, list[bool] or None for `nulls_first` but got {type(nulls_first)}") + return Series._from_pyseries(self._table.argsort(pyexprs, descending, nulls_first)) def __reduce__(self) -> tuple: names = self.column_names() diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 660df6bb49..2e60efa550 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -205,16 +205,18 @@ fn list_sort_helper( flat_child: &Series, offsets: &OffsetsBuffer, desc_iter: impl Iterator, + nulls_first_iter: impl Iterator, validity: impl Iterator, ) -> DaftResult> { desc_iter + .zip(nulls_first_iter) .zip(validity) .enumerate() - .map(|(i, (desc, valid))| { + .map(|(i, ((desc, nulls_first), valid))| { let start = *offsets.get(i).unwrap() as usize; let end = *offsets.get(i + 1).unwrap() as usize; if valid { - flat_child.slice(start, end)?.sort(desc) + flat_child.slice(start, end)?.sort(desc, nulls_first) } else { Ok(Series::full_null( flat_child.name(), @@ -237,11 +239,11 @@ fn list_sort_helper_fixed_size( .zip(nulls_first_iter) .zip(validity) .enumerate() - .map(|(i, ((desc,, valid))| { + .map(|(i, ((desc, nulls_first), valid))| { let start = i * fixed_size; let end = (i + 1) * fixed_size; if valid { - flat_child.slice(start, end)?.sort(desc) + flat_child.slice(start, end)?.sort(desc, nulls_first) } else { Ok(Series::full_null( flat_child.name(), @@ -570,29 +572,45 @@ impl ListArray { } // Sorts the lists within a list column - pub fn list_sort(&self, desc: &BooleanArray) -> DaftResult { + pub fn list_sort(&self, desc: &BooleanArray, nulls_first: &BooleanArray) -> DaftResult { let offsets = self.offsets(); let child_series = if desc.len() == 1 { let desc_iter = repeat(desc.get(0).unwrap()).take(self.len()); + let nulls_first_iter = repeat(nulls_first.get(0).unwrap()).take(self.len()); if let Some(validity) = self.validity() { - list_sort_helper(&self.flat_child, offsets, desc_iter, validity.iter())? + list_sort_helper( + &self.flat_child, + offsets, + desc_iter, + nulls_first_iter, + validity.iter(), + )? } else { list_sort_helper( &self.flat_child, offsets, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } } else { let desc_iter = desc.as_arrow().values_iter(); + let nulls_first_iter = nulls_first.as_arrow().values_iter(); if let Some(validity) = self.validity() { - list_sort_helper(&self.flat_child, offsets, desc_iter, validity.iter())? + list_sort_helper( + &self.flat_child, + offsets, + desc_iter, + nulls_first_iter, + validity.iter(), + )? } else { list_sort_helper( &self.flat_child, offsets, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } @@ -791,16 +809,18 @@ impl FixedSizeListArray { } // Sorts the lists within a list column - pub fn list_sort(&self, desc: &BooleanArray) -> DaftResult { + pub fn list_sort(&self, desc: &BooleanArray, nulls_first: &BooleanArray) -> DaftResult { let fixed_size = self.fixed_element_len(); let child_series = if desc.len() == 1 { let desc_iter = repeat(desc.get(0).unwrap()).take(self.len()); + let nulls_first_iter = repeat(nulls_first.get(0).unwrap()).take(self.len()); if let Some(validity) = self.validity() { list_sort_helper_fixed_size( &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, validity.iter(), )? } else { @@ -808,16 +828,19 @@ impl FixedSizeListArray { &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } } else { let desc_iter = desc.as_arrow().values_iter(); + let nulls_first_iter = nulls_first.as_arrow().values_iter(); if let Some(validity) = self.validity() { list_sort_helper_fixed_size( &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, validity.iter(), )? } else { @@ -825,6 +848,7 @@ impl FixedSizeListArray { &self.flat_child, fixed_size, desc_iter, + nulls_first_iter, repeat(true).take(self.len()), )? } diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index fa3ccad594..22bfec02ab 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -2,7 +2,7 @@ use arrow2::{ array::ord::{self, DynComparator}, types::Index, }; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use super::{arrow2::sort::primitive::common::multi_column_idx_sort, as_arrow::AsArrow}; #[cfg(feature = "python")] @@ -62,11 +62,16 @@ where T: DaftIntegerType, ::Native: Ord, { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first != descending { + return Err(DaftError::NotImplemented( + "nulls_first is not implemented".to_string(), + )); + } let arrow_array = self.as_arrow(); let result = @@ -83,11 +88,17 @@ where &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first.iter().zip(descending).any(|(a, b)| *a != *b) { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -134,10 +145,10 @@ where Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -154,11 +165,16 @@ where } impl Float32Array { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first != descending { + return Err(DaftError::NotImplemented( + "nulls_first is not implemented".to_string(), + )); + } let arrow_array = self.as_arrow(); let result = @@ -175,11 +191,17 @@ impl Float32Array { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first.iter().zip(descending).any(|(a, b)| a != b) { + return Err(DaftError::NotImplemented( + "nulls_first is not implemented".to_string(), + )); + } let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -226,10 +248,10 @@ impl Float32Array { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -246,11 +268,16 @@ impl Float32Array { } impl Float64Array { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first != descending { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } let arrow_array = self.as_arrow(); let result = @@ -267,11 +294,18 @@ impl Float64Array { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first.iter().zip(descending).any(|(a, b)| *a != *b) { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } + let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -318,10 +352,10 @@ impl Float64Array { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -338,11 +372,16 @@ impl Float64Array { } impl Decimal128Array { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first != descending { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } let arrow_array = self.as_arrow(); let result = @@ -359,11 +398,17 @@ impl Decimal128Array { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first.iter().zip(descending).any(|(a, b)| *a != *b) { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } let arrow_array = self.as_arrow(); let first_desc = *descending.first().unwrap(); @@ -410,10 +455,10 @@ impl Decimal128Array { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let arrow_array = self.as_arrow(); @@ -430,7 +475,7 @@ impl Decimal128Array { } impl NullArray { - pub fn argsort(&self, _descending: bool) -> DaftResult> + pub fn argsort(&self, _descending: bool, _nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, @@ -442,11 +487,17 @@ impl NullArray { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first.iter().zip(descending).any(|(a, b)| *a != *b) { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } let first_desc = *descending.first().unwrap(); let others_cmp = build_multi_array_compare(others, &descending[1..])?; @@ -466,20 +517,20 @@ impl NullArray { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { Ok(self.clone()) } } impl BooleanArray { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort(&self, descending: bool, nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = @@ -492,11 +543,17 @@ impl BooleanArray { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first.iter().zip(descending).any(|(a, b)| *a != *b) { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } let first_desc = *descending.first().unwrap(); let others_cmp = build_multi_array_compare(others, &descending[1..])?; @@ -547,10 +604,10 @@ impl BooleanArray { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = arrow2::compute::sort::sort(self.data(), &options, None)?; @@ -562,14 +619,18 @@ impl BooleanArray { macro_rules! impl_binary_like_sort { ($da:ident) => { impl $da { - pub fn argsort(&self, descending: bool) -> DaftResult> + pub fn argsort( + &self, + descending: bool, + nulls_first: bool, + ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = arrow2::compute::sort::sort_to_indices::( @@ -585,11 +646,17 @@ macro_rules! impl_binary_like_sort { &self, others: &[Series], descending: &[bool], + nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, { + if nulls_first.iter().zip(descending).any(|(a, b)| *a != *b) { + return Err(DaftError::not_implemented( + "nulls first is not yet implemented", + )); + } let first_desc = *descending.first().unwrap(); let others_cmp = build_multi_array_compare(others, &descending[1..])?; @@ -635,10 +702,10 @@ macro_rules! impl_binary_like_sort { Ok(DataArray::::from((self.name(), Box::new(result)))) } - pub fn sort(&self, descending: bool) -> DaftResult { + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { let options = arrow2::compute::sort::SortOptions { descending, - nulls_first: descending, + nulls_first, }; let result = arrow2::compute::sort::sort(self.data(), &options, None)?; @@ -653,7 +720,7 @@ impl_binary_like_sort!(BinaryArray); impl_binary_like_sort!(Utf8Array); impl FixedSizeBinaryArray { - pub fn argsort(&self, _descending: bool) -> DaftResult> + pub fn argsort(&self, _descending: bool, _nulls_first: bool) -> DaftResult> where I: DaftIntegerType, ::Native: arrow2::types::Index, @@ -664,6 +731,7 @@ impl FixedSizeBinaryArray { &self, _others: &[Series], _descending: &[bool], + _nulls_first: &[bool], ) -> DaftResult> where I: DaftIntegerType, @@ -671,120 +739,120 @@ impl FixedSizeBinaryArray { { todo!("impl argsort_multikey for FixedSizeBinaryArray") } - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedSizeBinaryArray") } } impl FixedSizeListArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedSizeListArray") } } impl ListArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for ListArray") } } impl MapArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for MapArray") } } impl StructArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for StructArray") } } impl ExtensionArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for ExtensionArray") } } impl IntervalArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for IntervalArray") } } #[cfg(feature = "python")] impl PythonArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for python array") } } impl DateArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl TimeArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl DurationArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl TimestampArray { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; + pub fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + let new_array = self.physical.sort(descending, nulls_first)?; Ok(Self::new(self.field.clone(), new_array)) } } impl EmbeddingArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for EmbeddingArray") } } impl ImageArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for ImageArray") } } impl FixedShapeImageArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedShapeImageArray") } } impl TensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for TensorArray") } } impl SparseTensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for SparseTensorArray") } } impl FixedShapeSparseTensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedShapeSparseTensorArray") } } impl FixedShapeTensorArray { - pub fn sort(&self, _descending: bool) -> DaftResult { + pub fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { todo!("impl sort for FixedShapeTensorArray") } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 3aa78a66c0..f8f1b3002e 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -696,8 +696,11 @@ impl PySeries { Ok(self.series.list_slice(&start.series, &end.series)?.into()) } - pub fn list_sort(&self, desc: &Self) -> PyResult { - Ok(self.series.list_sort(&desc.series)?.into()) + pub fn list_sort(&self, desc: &Self, nulls_first: &Self) -> PyResult { + Ok(self + .series + .list_sort(&desc.series, &nulls_first.series)? + .into()) } pub fn map_get(&self, key: &Self) -> PyResult { diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index 1efa60f3a7..a00a51d219 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -118,8 +118,8 @@ macro_rules! impl_series_like_for_data_array { fn slice(&self, start: usize, end: usize) -> DaftResult { Ok(self.0.slice(start, end)?.into_series()) } - fn sort(&self, descending: bool) -> DaftResult { - Ok(self.0.sort(descending)?.into_series()) + fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + Ok(self.0.sort(descending, nulls_first)?.into_series()) } fn str_value(&self, idx: usize) -> DaftResult { self.0.str_value(idx) diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index 85316f0ec4..07df8923f2 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -116,8 +116,8 @@ macro_rules! impl_series_like_for_logical_array { Ok($da::new(self.0.field.clone(), new_array).into_series()) } - fn sort(&self, descending: bool) -> DaftResult { - Ok(self.0.sort(descending)?.into_series()) + fn sort(&self, descending: bool, nulls_first: bool) -> DaftResult { + Ok(self.0.sort(descending, nulls_first)?.into_series()) } fn str_value(&self, idx: usize) -> DaftResult { diff --git a/src/daft-core/src/series/array_impl/nested_array.rs b/src/daft-core/src/series/array_impl/nested_array.rs index 1bd618e616..be60ee2f26 100644 --- a/src/daft-core/src/series/array_impl/nested_array.rs +++ b/src/daft-core/src/series/array_impl/nested_array.rs @@ -121,7 +121,7 @@ macro_rules! impl_series_like_for_nested_arrays { Ok(self.0.not_null()?.into_series()) } - fn sort(&self, _descending: bool) -> DaftResult { + fn sort(&self, _descending: bool, _nulls_first: bool) -> DaftResult { Err(DaftError::ValueError(format!( "Cannot sort a {}", stringify!($da) diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 81a4788067..7ed940ca3d 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -159,14 +159,16 @@ impl Series { } } - pub fn list_sort(&self, desc: &Self) -> DaftResult { + pub fn list_sort(&self, desc: &Self, nulls_first: &Self) -> DaftResult { let desc_arr = desc.bool()?; + let nulls_first = nulls_first.bool()?; match self.data_type() { - DataType::List(_) => Ok(self.list()?.list_sort(desc_arr)?.into_series()), - DataType::FixedSizeList(..) => { - Ok(self.fixed_size_list()?.list_sort(desc_arr)?.into_series()) - } + DataType::List(_) => Ok(self.list()?.list_sort(desc_arr, nulls_first)?.into_series()), + DataType::FixedSizeList(..) => Ok(self + .fixed_size_list()? + .list_sort(desc_arr, nulls_first)? + .into_series()), dt => Err(DaftError::TypeError(format!( "List sort not implemented for {}", dt diff --git a/src/daft-core/src/series/ops/sort.rs b/src/daft-core/src/series/ops/sort.rs index 6212a34c25..2fd3007faa 100644 --- a/src/daft-core/src/series/ops/sort.rs +++ b/src/daft-core/src/series/ops/sort.rs @@ -15,7 +15,7 @@ impl Series { let series = self.as_physical()?; with_match_comparable_daft_types!(series.data_type(), |$T| { let downcasted = series.downcast::<<$T as DaftDataType>::ArrayType>()?; - Ok(downcasted.argsort::(descending)?.into_series()) + Ok(downcasted.argsort::(descending, nulls_first)?.into_series()) }) } @@ -51,7 +51,7 @@ impl Series { let first = sort_keys.first().unwrap().as_physical()?; with_match_comparable_daft_types!(first.data_type(), |$T| { let downcasted = first.downcast::<<$T as DaftDataType>::ArrayType>()?; - let result = downcasted.argsort_multikey::(&sort_keys[1..], descending)?; + let result = downcasted.argsort_multikey::(&sort_keys[1..], descending, nulls_first)?; Ok(result.into_series()) }) } diff --git a/src/daft-functions/src/list/sort.rs b/src/daft-functions/src/list/sort.rs index 2d1ef45afb..b4b5a77814 100644 --- a/src/daft-functions/src/list/sort.rs +++ b/src/daft-functions/src/list/sort.rs @@ -21,7 +21,7 @@ impl ScalarUDF for ListSort { fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { - [data, desc] => match (data.to_field(schema), desc.to_field(schema)) { + [data, desc, _nulls_first] => match (data.to_field(schema), desc.to_field(schema)) { (Ok(field), Ok(desc_field)) => match (&field.dtype, &desc_field.dtype) { ( l @ (DataType::List(_) | DataType::FixedSizeList(_, _)), @@ -34,7 +34,7 @@ impl ScalarUDF for ListSort { (Err(e), _) | (_, Err(e)) => Err(e), }, _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", + "Expected 3 input args, got {}", inputs.len() ))), } @@ -42,9 +42,9 @@ impl ScalarUDF for ListSort { fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [data, desc] => data.list_sort(desc), + [data, desc, nulls_first] => data.list_sort(desc, nulls_first), _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", + "Expected 3 input args, got {}", inputs.len() ))), } @@ -52,9 +52,10 @@ impl ScalarUDF for ListSort { } #[must_use] -pub fn list_sort(input: ExprRef, desc: Option) -> ExprRef { +pub fn list_sort(input: ExprRef, desc: Option, nulls_first: Option) -> ExprRef { let desc = desc.unwrap_or_else(|| lit(false)); - ScalarFunction::new(ListSort {}, vec![input, desc]).into() + let nulls_first = nulls_first.unwrap_or_else(|| desc.clone()); + ScalarFunction::new(ListSort {}, vec![input, desc, nulls_first]).into() } #[cfg(feature = "python")] @@ -66,6 +67,6 @@ use { #[cfg(feature = "python")] #[pyfunction] #[pyo3(name = "list_sort")] -pub fn py_list_sort(expr: PyExpr, desc: PyExpr) -> PyResult { - Ok(list_sort(expr.into(), Some(desc.into())).into()) +pub fn py_list_sort(expr: PyExpr, desc: PyExpr, nulls_first: PyExpr) -> PyResult { + Ok(list_sort(expr.into(), Some(desc.into()), Some(nulls_first.into())).into()) } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index e213e77e0a..44fed780ac 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -314,9 +314,10 @@ pub fn physical_plan_to_pipeline( input, sort_by, descending, + nulls_first, .. }) => { - let sort_sink = SortSink::new(sort_by.clone(), descending.clone()); + let sort_sink = SortSink::new(sort_by.clone(), descending.clone(), nulls_first.clone()); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; BlockingSinkNode::new(Arc::new(sort_sink), child_node).boxed() } diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 83c933d1ec..c6ba5f991d 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -41,13 +41,15 @@ impl BlockingSinkState for SortState { pub struct SortSink { sort_by: Vec, descending: Vec, + nulls_first: Vec, } impl SortSink { - pub fn new(sort_by: Vec, descending: Vec) -> Self { + pub fn new(sort_by: Vec, descending: Vec, nulls_first: Vec) -> Self { Self { sort_by, descending, + nulls_first, } } } @@ -80,7 +82,7 @@ impl BlockingSink for SortSink { state.finalize() }); let concated = MicroPartition::concat(parts)?; - let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); + let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending, &self.nulls_first)?); Ok(Some(sorted.into())) } diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index f2d1a1f03b..c251537ff2 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -705,7 +705,12 @@ impl PyLogicalPlanBuilder { .into()) } - pub fn sort(&self, sort_by: Vec, descending: Vec, nulls_first: Vec) -> PyResult { + pub fn sort( + &self, + sort_by: Vec, + descending: Vec, + nulls_first: Vec, + ) -> PyResult { Ok(self .builder .sort(pyexprs_to_exprs(sort_by), descending, nulls_first)? diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index c6d2e9009f..046089c556 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -62,7 +62,15 @@ impl Sort { .sort_by .iter() .zip(self.descending.iter()) - .map(|(sb, d)| format!("({}, {})", sb, if *d { "descending" } else { "ascending" },)) + .zip(self.nulls_first.iter()) + .map(|((sb, d), nf)| { + format!( + "({}, {}, {})", + sb, + if *d { "descending" } else { "ascending" }, + if *nf { "nulls first" } else { "nulls last" } + ) + }) .join(", "); res.push(format!("Sort: Sort by = {}", pairs)); res diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 9e38a0f8c9..2cf2ea14ad 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -567,7 +567,9 @@ mod tests { } else { scan_plan.filter(pred)? }; - let expected = expected_filter_scan.sort(sort_by, descending, nulls_first)?.build(); + let expected = expected_filter_scan + .sort(sort_by, descending, nulls_first)? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } diff --git a/src/daft-micropartition/src/ops/sort.rs b/src/daft-micropartition/src/ops/sort.rs index cf2a54aca1..f2be0bb38e 100644 --- a/src/daft-micropartition/src/ops/sort.rs +++ b/src/daft-micropartition/src/ops/sort.rs @@ -32,16 +32,21 @@ impl MicroPartition { } } - pub fn argsort(&self, sort_keys: &[ExprRef], descending: &[bool]) -> DaftResult { + pub fn argsort( + &self, + sort_keys: &[ExprRef], + descending: &[bool], + nulls_first: &[bool], + ) -> DaftResult { let io_stats = IOStatsContext::new("MicroPartition::argsort"); let tables = self.concat_or_get(io_stats)?; match tables.as_slice() { [] => { let empty_table = Table::empty(Some(self.schema.clone()))?; - empty_table.argsort(sort_keys, descending) + empty_table.argsort(sort_keys, descending, nulls_first) } - [single] => single.argsort(sort_keys, descending), + [single] => single.argsort(sort_keys, descending, nulls_first), _ => unreachable!(), } } diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 1c3d0cb070..7de380dd93 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -207,6 +207,7 @@ impl PyMicroPartition { py: Python, sort_keys: Vec, descending: Vec, + nulls_first: Vec, ) -> PyResult { let converted_exprs: Vec = sort_keys .into_iter() @@ -215,7 +216,11 @@ impl PyMicroPartition { py.allow_threads(|| { Ok(self .inner - .argsort(converted_exprs.as_slice(), descending.as_slice())? + .argsort( + converted_exprs.as_slice(), + descending.as_slice(), + nulls_first.as_slice(), + )? .into()) }) } diff --git a/src/daft-physical-plan/src/ops/sort.rs b/src/daft-physical-plan/src/ops/sort.rs index 9ba2809907..777f61043d 100644 --- a/src/daft-physical-plan/src/ops/sort.rs +++ b/src/daft-physical-plan/src/ops/sort.rs @@ -39,7 +39,15 @@ impl Sort { .sort_by .iter() .zip(self.descending.iter()) - .map(|(sb, d)| format!("({}, {})", sb, if *d { "descending" } else { "ascending" },)) + .zip(self.nulls_first.iter()) + .map(|((sb, d), nf)| { + format!( + "({}, {}, {})", + sb, + if *d { "descending" } else { "ascending" }, + if *nf { "nulls first" } else { "nulls last" } + ) + }) .join(", "); res.push(format!("Sort: Sort by = {}", pairs)); res.push(format!("Num partitions = {}", self.num_partitions)); diff --git a/src/daft-sql/src/modules/list.rs b/src/daft-sql/src/modules/list.rs index bd6db25990..8523863aac 100644 --- a/src/daft-sql/src/modules/list.rs +++ b/src/daft-sql/src/modules/list.rs @@ -304,7 +304,7 @@ impl SQLFunction for SQLListSort { match inputs { [input] => { let input = planner.plan_function_arg(input)?; - Ok(daft_functions::list::sort(input, None)) + Ok(daft_functions::list::sort(input, None, None)) } [input, order] => { let input = planner.plan_function_arg(input)?; @@ -323,7 +323,7 @@ impl SQLFunction for SQLListSort { } _ => unsupported_sql_err!("invalid order for list_sort"), }; - Ok(daft_functions::list::sort(input, Some(order))) + Ok(daft_functions::list::sort(input, Some(order), None)) } _ => unsupported_sql_err!( "invalid arguments for list_sort. Expected list_sort(expr, ASC|DESC)" diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index cd1109341d..8299a69b66 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -692,7 +692,6 @@ impl SQLPlanner { unsupported_sql_err!("WITH FILL"); } let expr = self.plan_expr(&order_by_expr.expr)?; - desc.push(!order_by_expr.asc.unwrap_or(true)); exprs.push(expr); } diff --git a/src/daft-table/src/ops/groups.rs b/src/daft-table/src/ops/groups.rs index 580d2c4288..7239c5a35c 100644 --- a/src/daft-table/src/ops/groups.rs +++ b/src/daft-table/src/ops/groups.rs @@ -54,8 +54,11 @@ impl Table { // ) // Begin by doing the argsort. - let argsort_series = - Series::argsort_multikey(self.columns.as_slice(), &vec![false; self.columns.len()])?; + let argsort_series = Series::argsort_multikey( + self.columns.as_slice(), + &vec![false; self.columns.len()], + &vec![false; self.columns.len()], + )?; let argsort_array = argsort_series.downcast::()?; // The result indices. diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 97c034a9df..f66de100ec 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -130,6 +130,10 @@ impl Table { .take(left_on.len()) .collect::>() .as_slice(), + std::iter::repeat(false) + .take(left_on.len()) + .collect::>() + .as_slice(), )?; if right_on.is_empty() { return Err(DaftError::ValueError( @@ -142,6 +146,10 @@ impl Table { .take(right_on.len()) .collect::>() .as_slice(), + std::iter::repeat(false) + .take(right_on.len()) + .collect::>() + .as_slice(), )?; return left.sort_merge_join(&right, left_on, right_on, true); diff --git a/src/daft-table/src/ops/sort.rs b/src/daft-table/src/ops/sort.rs index b56157ef06..540013a02e 100644 --- a/src/daft-table/src/ops/sort.rs +++ b/src/daft-table/src/ops/sort.rs @@ -30,7 +30,7 @@ impl Table { } if sort_keys.len() == 1 { self.eval_expression(sort_keys.first().unwrap())? - .argsort(*descending.first().unwrap(), nulls_first.first().copied() + .argsort(*descending.first().unwrap(), *nulls_first.first().unwrap()) } else { let expr_result = self.eval_expression_list(sort_keys)?; Series::argsort_multikey(expr_result.columns.as_slice(), descending, nulls_first) diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 54e9db65f6..c2c0cc622a 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -54,6 +54,7 @@ impl PyTable { py: Python, sort_keys: Vec, descending: Vec, + nulls_first: Vec, ) -> PyResult { let converted_exprs: Vec = sort_keys .into_iter() @@ -62,7 +63,11 @@ impl PyTable { py.allow_threads(|| { Ok(self .table - .sort(converted_exprs.as_slice(), descending.as_slice())? + .sort( + converted_exprs.as_slice(), + descending.as_slice(), + nulls_first.as_slice(), + )? .into()) }) } @@ -72,6 +77,7 @@ impl PyTable { py: Python, sort_keys: Vec, descending: Vec, + nulls_first: Vec, ) -> PyResult { let converted_exprs: Vec = sort_keys .into_iter() @@ -80,7 +86,11 @@ impl PyTable { py.allow_threads(|| { Ok(self .table - .argsort(converted_exprs.as_slice(), descending.as_slice())? + .argsort( + converted_exprs.as_slice(), + descending.as_slice(), + nulls_first.as_slice(), + )? .into()) }) }