Skip to content

Commit

Permalink
guarantee order if order_by directly precedes collect in a chain (#558)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Nov 5, 2024
1 parent d1838c1 commit 4f87dc8
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 7 deletions.
7 changes: 4 additions & 3 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,8 @@ def order_by(self, *args, descending: bool = False) -> "Self":
Order is not guaranteed when steps are added after an `order_by` statement.
I.e. when using `from_dataset` an `order_by` statement should be used if
the order of the records in the chain is important.
Using `order_by` directly before `limit` will give expected results.
Using `order_by` directly before `limit`, `collect` and `collect_flatten`
will give expected results.
See https://github.com/iterative/datachain/issues/477 for further details.
"""
if descending:
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def collect_flatten(self, *, row_factory=None):
a tuple of row values.
"""
db_signals = self._effective_signals_schema.db_signals()
with self._query.select(*db_signals).as_iterable() as rows:
with self._query.ordered_select(*db_signals).as_iterable() as rows:
if row_factory:
rows = (row_factory(db_signals, r) for r in rows)
yield from rows
Expand Down Expand Up @@ -1282,7 +1283,7 @@ def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]
chain = self.select(*cols) if cols else self
signals_schema = chain._effective_signals_schema
db_signals = signals_schema.db_signals()
with self._query.select(*db_signals).as_iterable() as rows:
with self._query.ordered_select(*db_signals).as_iterable() as rows:
for row in rows:
ret = signals_schema.row_to_features(
row, catalog=chain.session.catalog, cache=chain._settings.cache
Expand Down
31 changes: 30 additions & 1 deletion src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,27 @@ def select(self, *args, **kwargs) -> "Self":
query.steps.append(SQLSelect((*args, *named_args)))
return query

@detach
def ordered_select(self, *args, **kwargs) -> "Self":
"""
Select the given columns or expressions using a subquery whilst
maintaining query ordering (only applicable if last step was order_by).
If used with no arguments, this simply creates a subquery and
select all columns from it.
Example:
>>> ds.ordered_select(C.name, C.size * 10)
>>> ds.ordered_select(C.name, size10x=C.size * 10)
"""
named_args = [v.label(k) for k, v in kwargs.items()]
query = self.clone()
order_by = query.last_step if query.is_ordered else None
query.steps.append(SQLSelect((*args, *named_args)))
if order_by:
query.steps.append(order_by)
return query

@detach
def select_except(self, *args) -> "Self":
"""
Expand Down Expand Up @@ -1338,7 +1359,7 @@ def limit(self, n: int) -> "Self":
query = self.clone(new_table=False)
if (
query.steps
and (last_step := query.steps[-1])
and (last_step := query.last_step)
and isinstance(last_step, SQLLimit)
):
query.steps[-1] = SQLLimit(min(n, last_step.n))
Expand Down Expand Up @@ -1591,3 +1612,11 @@ def save(
finally:
self.cleanup()
return self.__class__(name=name, version=version, catalog=self.catalog)

@property
def is_ordered(self) -> bool:
return isinstance(self.last_step, SQLOrderBy)

@property
def last_step(self) -> Optional[Step]:
return self.steps[-1] if self.steps else None
24 changes: 24 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,30 @@ def test_show_no_truncate(capsys, test_session):
assert details[i] in normalized_output


@pytest.mark.parametrize("ordered_by", ["letter", "number"])
def test_show_ordered(capsys, test_session, ordered_by):
numbers = [6, 2, 3, 1, 5, 7, 4]
letters = ["u", "y", "x", "z", "v", "t", "w"]

DataChain.from_values(
number=numbers, letter=letters, session=test_session
).order_by(ordered_by).show()

captured = capsys.readouterr()
normalized_lines = [
re.sub(r"\s+", " ", line).strip() for line in captured.out.strip().split("\n")
]

ordered_entries = sorted(
zip(numbers, letters), key=lambda x: x[0 if ordered_by == "number" else 1]
)

assert normalized_lines[0].strip() == "number letter"
for i, line in enumerate(normalized_lines[1:]):
number, letter = ordered_entries[i]
assert line == f"{i} {number} {letter}"


def test_from_storage_dataset_stats(tmp_dir, test_session):
for i in range(4):
(tmp_dir / f"file{i}.txt").write_text(f"file{i}")
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_listing_generator(cloud_test_catalog, cloud_type):
entries = sorted(
[e for e in ENTRIES if e.path.startswith("cats/")], key=lambda e: e.path
)
files = sorted(dc.collect("file"), key=lambda f: f.path)
files = dc.order_by("file.path").collect("file")

for cat_file, cat_entry in zip(files, entries):
assert cat_file.source == ctc.src_uri
Expand Down
30 changes: 28 additions & 2 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,32 @@ def test_order_by_with_nested_columns(test_session, with_function):
]


def test_order_by_collect(test_session):
numbers = [6, 2, 3, 1, 5, 7, 4]
letters = ["u", "y", "x", "z", "v", "t", "w"]

dc = DataChain.from_values(number=numbers, letter=letters, session=test_session)
assert list(dc.order_by("number").collect()) == [
(1, "z"),
(2, "y"),
(3, "x"),
(4, "w"),
(5, "v"),
(6, "u"),
(7, "t"),
]

assert list(dc.order_by("letter").collect()) == [
(7, "t"),
(6, "u"),
(5, "v"),
(4, "w"),
(3, "x"),
(2, "y"),
(1, "z"),
]


@pytest.mark.parametrize("with_function", [True, False])
def test_order_by_descending(test_session, with_function):
names = ["a.txt", "c.txt", "d.txt", "a.txt", "b.txt"]
Expand Down Expand Up @@ -1852,7 +1878,7 @@ def test_union(test_session):
chain2 = DataChain.from_values(value=[3, 4], session=test_session)
chain3 = chain1 | chain2
assert chain3.count() == 4
assert sorted(chain3.collect("value")) == [1, 2, 3, 4]
assert list(chain3.order_by("value").collect("value")) == [1, 2, 3, 4]


def test_union_different_columns(test_session):
Expand Down Expand Up @@ -1887,7 +1913,7 @@ def test_union_different_column_order(test_session):
chain2 = DataChain.from_values(
name=["different", "order"], value=[9, 10], session=test_session
)
assert sorted(chain1.union(chain2).collect()) == [
assert list(chain1.union(chain2).order_by("value").collect()) == [
(1, "chain"),
(2, "more"),
(9, "different"),
Expand Down

0 comments on commit 4f87dc8

Please sign in to comment.