diff --git a/python/python/benchmarks/test_search.py b/python/python/benchmarks/test_search.py index 8a6f2758fb..0014bc4be8 100644 --- a/python/python/benchmarks/test_search.py +++ b/python/python/benchmarks/test_search.py @@ -58,6 +58,7 @@ def create_table(num_rows, offset) -> pa.Table: "vector": vectors, "filterable": filterable, "category": categories, + "category_no_index": categories, "genres": genres, } ) @@ -77,9 +78,11 @@ def create_base_dataset(data_dir: Path) -> lance.LanceDataset: rows_remaining -= next_batch_length table = create_table(next_batch_length, offset) if offset == 0: - dataset = lance.write_dataset(table, tmp_path) + dataset = lance.write_dataset(table, tmp_path, use_legacy_format=False) else: - dataset = lance.write_dataset(table, tmp_path, mode="append") + dataset = lance.write_dataset( + table, tmp_path, mode="append", use_legacy_format=False + ) offset += next_batch_length dataset.create_index( @@ -479,3 +482,26 @@ def test_label_list_index_prefilter(test_dataset, benchmark, filter: str): prefilter=True, filter=filter, ) + + +@pytest.mark.benchmark(group="late_materialization") +@pytest.mark.parametrize( + "use_index", + (False, True), + ids=["no_index", "with_index"], +) +def test_late_materialization(test_dataset, benchmark, use_index): + column = "category" if use_index else "category_no_index" + print( + test_dataset.scanner( + columns=["vector"], + filter=f"{column} = 0", + batch_size=32, + ).explain_plan(True) + ) + benchmark( + test_dataset.to_table, + columns=["vector"], + filter=f"{column} = 0", + batch_size=32, + ) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 8c8e521d2c..f448fcc8bb 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -16,6 +16,7 @@ import lance import lance.fragment +import numpy as np import pandas as pd import pandas.testing as tm import polars as pl @@ -1974,3 +1975,14 @@ def test_v2_dataset(tmp_path: Path): fragment = list(dataset.get_fragments())[0] assert "minor_version: 3" not in format_fragment(fragment.metadata, dataset) + + +def test_late_materialization_batch_size(tmp_path: Path): + table = pa.table({"filter": np.arange(32 * 32), "values": np.arange(32 * 32)}) + dataset = lance.write_dataset( + table, tmp_path, use_legacy_format=False, max_rows_per_file=10000 + ) + for batch in dataset.to_batches( + columns=["values"], filter="filter % 2 == 0", batch_size=32 + ): + assert batch.num_rows == 32 diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index a42c4edc33..591c979202 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -909,7 +909,13 @@ impl DecodeBatchScheduler { mut schedule_action: impl FnMut(Result), ) { let rows_requested = ranges.iter().map(|r| r.end - r.start).sum::(); - trace!("Scheduling ranges {:?} ({} rows)", ranges, rows_requested); + trace!( + "Scheduling {} ranges across {}..{} ({} rows)", + ranges.len(), + ranges.first().unwrap().start, + ranges.last().unwrap().end, + rows_requested + ); let mut context = SchedulerContext::new(io); let maybe_root_job = self.root_scheduler.schedule_ranges(ranges, filter); diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index fb29d76167..2a4990096b 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -13,6 +13,7 @@ use async_recursion::async_recursion; use datafusion::functions_aggregate::count::count_udaf; use datafusion::logical_expr::{lit, Expr}; use datafusion::physical_expr::PhysicalSortExpr; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::expressions; use datafusion::physical_plan::projection::ProjectionExec as DFProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -1584,9 +1585,10 @@ impl Scanner { projection: &Schema, batch_readahead: usize, ) -> Result> { + let coalesced = Arc::new(CoalesceBatchesExec::new(input, self.get_batch_size())); Ok(Arc::new(TakeExec::try_new( self.dataset.clone(), - input, + coalesced, Arc::new(projection.clone()), batch_readahead, )?))