Skip to content

Commit

Permalink
only combine final step for limit (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Aug 5, 2024
1 parent 111cfc1 commit 7832e10
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,10 +1383,14 @@ def order_by(self, *args) -> "Self":
@detach
def limit(self, n: int) -> "Self":
query = self.clone(new_table=False)
for step in query.steps:
if isinstance(step, SQLLimit) and step.n < n:
return query
query.steps.append(SQLLimit(n))
if (
query.steps
and (last_step := query.steps[-1])
and isinstance(last_step, SQLLimit)
):
query.steps[-1] = SQLLimit(min(n, last_step.n))
else:
query.steps.append(SQLLimit(n))
return query

@detach
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,22 @@ def test_show_limit():
assert dc.limit(3).limit(2).count() == 2
dc.show(1)
assert dc.count() == 5


def test_gen_limit(catalog):
def func(key, val) -> Iterator[tuple[File, str]]:
for i in range(val):
yield File(name=""), f"{key}_{i}"

keys = ["a", "b", "c", "d"]
values = [3, 3, 3, 3]

ds = DataChain.from_values(key=keys, val=values)

assert ds.count() == 4
assert ds.gen(res=func).count() == 12
assert ds.limit(2).gen(res=func).count() == 6
assert ds.limit(2).gen(res=func).limit(1).count() == 1
assert ds.limit(3).gen(res=func).limit(2).count() == 2
assert ds.limit(2).gen(res=func).limit(3).count() == 3
assert ds.limit(3).gen(res=func).limit(10).count() == 9

0 comments on commit 7832e10

Please sign in to comment.