diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index 61685863e0..5659786fb8 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -364,6 +364,7 @@ def merge_columns( self, value_func: Callable[[pa.RecordBatch], pa.RecordBatch], columns: Optional[list[str]] = None, + batch_size: Optional[int] = None, ) -> Tuple[FragmentMetadata, LanceSchema]: """Add columns to this Fragment. @@ -390,7 +391,7 @@ def merge_columns( Tuple[FragmentMetadata, LanceSchema] A new fragment with the added column(s) and the final schema. """ - updater = self._fragment.updater(columns) + updater = self._fragment.updater(columns, batch_size) while True: batch = updater.next() diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 330ac6864b..5c865195ce 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -912,6 +912,33 @@ def test_merge_with_commit(tmp_path: Path): assert tbl == expected +def test_merge_batch_size(tmp_path: Path): + # Create dataset with 10 fragments with 100 rows each + table = pa.table({"a": range(1000)}) + for batch_size in [1, 10, 100, 1000]: + ds_path = str(tmp_path / str(batch_size)) + dataset = lance.write_dataset(table, ds_path, max_rows_per_file=100) + fragments = [] + + def mutate(batch): + assert batch.num_rows <= batch_size + return pa.RecordBatch.from_pydict({"b": batch.column("a")}) + + for frag in dataset.get_fragments(): + merged, schema = frag.merge_columns(mutate, batch_size=batch_size) + fragments.append(merged) + + merge = lance.LanceOperation.Merge(fragments, schema) + dataset = lance.LanceDataset.commit( + ds_path, merge, read_version=dataset.version + ) + + dataset.validate() + tbl = dataset.to_table() + expected = pa.table({"a": range(1000), "b": range(1000)}) + assert tbl == expected + + def test_merge_with_schema_holes(tmp_path: Path): # Create table with 3 cols table = pa.table({"a": range(10)}) diff --git a/python/src/fragment.rs b/python/src/fragment.rs index e384630476..1e69d3b354 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -210,10 +210,12 @@ impl FileFragment { Ok(Scanner::new(scn)) } - fn updater(&self, columns: Option>) -> PyResult { + fn updater(&self, columns: Option>, batch_size: Option) -> PyResult { let cols = columns.as_deref(); let inner = RT - .block_on(None, async { self.fragment.updater(cols, None).await })? + .block_on(None, async { + self.fragment.updater(cols, None, batch_size).await + })? .map_err(|err| PyIOError::new_err(err.to_string()))?; Ok(Updater::new(inner)) } diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index dc5c3d5dad..22acebdd67 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -1132,6 +1132,7 @@ impl FileFragment { &self, columns: Option<&[T]>, schemas: Option<(Schema, Schema)>, + batch_size: Option, ) -> Result { let mut schema = self.dataset.schema().clone(); @@ -1160,11 +1161,11 @@ impl FileFragment { let reader = reader?; let deletion_vector = deletion_vector?.unwrap_or_default(); - Updater::try_new(self.clone(), reader, deletion_vector, schemas) + Updater::try_new(self.clone(), reader, deletion_vector, schemas, batch_size) } pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result { - let mut updater = self.updater(Some(&[join_column]), None).await?; + let mut updater = self.updater(Some(&[join_column]), None, None).await?; while let Some(batch) = updater.next().await? { let batch = joiner.collect(batch[join_column].clone()).await?; @@ -2433,7 +2434,7 @@ mod tests { } let fragment = &mut dataset.get_fragment(0).unwrap(); - let mut updater = fragment.updater(Some(&["i"]), None).await.unwrap(); + let mut updater = fragment.updater(Some(&["i"]), None, None).await.unwrap(); let new_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( "double_i", DataType::Int32, @@ -2677,7 +2678,7 @@ mod tests { let fragment = dataset.get_fragments().pop().unwrap(); // Write batch_s using add_columns - let mut updater = fragment.updater(Some(&["i"]), None).await?; + let mut updater = fragment.updater(Some(&["i"]), None, None).await?; updater.next().await?; updater.update(batch_s.clone()).await?; let frag = updater.finish().await?; diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 31538c21b2..47a6aa58db 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -66,7 +66,14 @@ use snafu::{location, Location}; #[cfg(feature = "substrait")] use lance_datafusion::substrait::parse_substrait; -pub const DEFAULT_BATCH_SIZE: usize = 8192; +const BATCH_SIZE_FALLBACK: usize = 8192; +// For backwards compatibility / historical reasons we re-calculate the default batch size +// on each call +pub fn get_default_batch_size() -> Option { + std::env::var("LANCE_DEFAULT_BATCH_SIZE") + .map(|val| Some(val.parse().unwrap())) + .unwrap_or(None) +} pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4; lazy_static::lazy_static! { @@ -262,23 +269,14 @@ impl Scanner { // 64KB, this is 16K rows. For local file systems, the default block size // is just 4K, which would mean only 1K rows, which might be a little small. // So we use a default minimum of 8K rows. - std::env::var("LANCE_DEFAULT_BATCH_SIZE") - .map(|bs| { - bs.parse().unwrap_or_else(|_| { - panic!( - "The value of LANCE_DEFAULT_BATCH_SIZE ({}) is not a valid batch size", - bs - ) - }) - }) - .unwrap_or_else(|_| { - self.batch_size.unwrap_or_else(|| { - std::cmp::max( - self.dataset.object_store().block_size() / 4, - DEFAULT_BATCH_SIZE, - ) - }) + get_default_batch_size().unwrap_or_else(|| { + self.batch_size.unwrap_or_else(|| { + std::cmp::max( + self.dataset.object_store().block_size() / 4, + BATCH_SIZE_FALLBACK, + ) }) + }) } fn ensure_not_fragment_scan(&self) -> Result<()> { diff --git a/rust/lance/src/dataset/schema_evolution.rs b/rust/lance/src/dataset/schema_evolution.rs index 9a5865159e..521a22290e 100644 --- a/rust/lance/src/dataset/schema_evolution.rs +++ b/rust/lance/src/dataset/schema_evolution.rs @@ -261,7 +261,7 @@ async fn add_columns_impl( } let mut updater = fragment - .updater(read_columns_ref, schemas_ref.clone()) + .updater(read_columns_ref, schemas_ref.clone(), None) .await?; let mut batch_index = 0; diff --git a/rust/lance/src/dataset/updater.rs b/rust/lance/src/dataset/updater.rs index 1ba9ce0565..d8a6ae96bb 100644 --- a/rust/lance/src/dataset/updater.rs +++ b/rust/lance/src/dataset/updater.rs @@ -10,6 +10,7 @@ use lance_table::utils::stream::ReadBatchFutStream; use snafu::{location, Location}; use super::fragment::FragmentReader; +use super::scanner::get_default_batch_size; use super::write::{open_writer, GenericWriter}; use super::Dataset; use crate::dataset::FileFragment; @@ -59,6 +60,7 @@ impl Updater { reader: FragmentReader, deletion_vector: DeletionVector, schemas: Option<(Schema, Schema)>, + batch_size: Option, ) -> Result { let (write_schema, final_schema) = if let Some((write_schema, final_schema)) = schemas { (Some(write_schema), Some(final_schema)) @@ -66,9 +68,18 @@ impl Updater { (None, None) }; - let batch_size = reader.legacy_num_rows_in_batch(0); + let legacy_batch_size = reader.legacy_num_rows_in_batch(0); - let input_stream = reader.read_all(1024)?; + let batch_size = match (&legacy_batch_size, batch_size) { + // If this is a v1 dataset we must use the row group size of the file + (Some(num_rows), _) => *num_rows, + // If this is a v2 dataset, let the user pick the batch size + (None, Some(legacy_batch_size)) => legacy_batch_size, + // Otherwise, default to 1024 if the user didn't specify anything + (None, None) => get_default_batch_size().unwrap_or(1024) as u32, + }; + + let input_stream = reader.read_all(batch_size)?; Ok(Self { fragment, @@ -78,7 +89,7 @@ impl Updater { write_schema, final_schema, finished: false, - deletion_restorer: DeletionRestorer::new(deletion_vector, batch_size), + deletion_restorer: DeletionRestorer::new(deletion_vector, legacy_batch_size), }) } @@ -226,7 +237,7 @@ struct DeletionRestorer { current_row_id: u32, /// Number of rows in each batch, only used in legacy files for validation - batch_size: Option, + legacy_batch_size: Option, deletion_vector_iter: Option + Send>>, @@ -234,10 +245,10 @@ struct DeletionRestorer { } impl DeletionRestorer { - fn new(deletion_vector: DeletionVector, batch_size: Option) -> Self { + fn new(deletion_vector: DeletionVector, legacy_batch_size: Option) -> Self { Self { current_row_id: 0, - batch_size, + legacy_batch_size, deletion_vector_iter: Some(deletion_vector.into_sorted_iter()), last_deleted_row_id: None, } @@ -248,12 +259,12 @@ impl DeletionRestorer { } fn is_full(batch_size: Option, num_rows: u32) -> bool { - if let Some(batch_size) = batch_size { + if let Some(legacy_batch_size) = batch_size { // We should never encounter the case that `batch_size < num_rows` because // that would mean we have a v1 writer and it generated a batch with more rows // than expected - debug_assert!(batch_size >= num_rows); - batch_size == num_rows + debug_assert!(legacy_batch_size >= num_rows); + legacy_batch_size == num_rows } else { false } @@ -295,7 +306,8 @@ impl DeletionRestorer { loop { if let Some(next_deleted_id) = next_deleted_id { if next_deleted_id > last_row_id - || (next_deleted_id == last_row_id && Self::is_full(self.batch_size, num_rows)) + || (next_deleted_id == last_row_id + && Self::is_full(self.legacy_batch_size, num_rows)) { // Either the next deleted id is out of range or it is the next row but // we are full. Either way, stash it and return @@ -322,7 +334,7 @@ impl DeletionRestorer { let deleted_batch_offsets = self.deleted_batch_offsets_in_range(batch.num_rows() as u32); let batch = add_blanks(batch, &deleted_batch_offsets)?; - if let Some(batch_size) = self.batch_size { + if let Some(batch_size) = self.legacy_batch_size { // validation just in case, when the input has a fixed batch size then the // output should have the same fixed batch size (except the last batch) let is_last = self.is_exhausted(); diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 9f38611171..01d5d66138 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -701,6 +701,7 @@ impl MergeInsertJob { .updater( Some(&read_columns), Some((write_schema, dataset.schema().clone())), + None, ) .await?;