diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 9cf7c824a60..8bebfe2c4cb 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -685,6 +685,152 @@ def test_take_rowid_rowaddr(tmp_path: Path): assert sample_dataset.num_columns == 2 +@pytest.mark.parametrize( + "column_name", + [ + "_rowid", + "_rowaddr", + "_rowoffset", + "_row_created_at_version", + "_row_last_updated_at_version", + ], +) +def test_take_system_columns_values(tmp_path: Path, column_name: str): + """Test that system columns return correct values in take.""" + table = pa.table({"a": range(100), "b": range(100, 200)}) + base_dir = tmp_path / "test_take_system_columns_values" + # Use max_rows_per_file to create multiple fragments + lance.write_dataset(table, base_dir, max_rows_per_file=25) + dataset = lance.dataset(base_dir) + + indices = [0, 5, 10, 50, 99] + result = dataset.take(indices, columns=[column_name, "a"]) + assert result.num_rows == len(indices) + assert result.schema.names == [column_name, "a"] + + col_values = result.column(column_name).to_pylist() + a_values = result.column("a").to_pylist() + + # Verify column type is UInt64 + assert result.column(column_name).type == pa.uint64() + + # Verify data column values + assert a_values == indices + + # Verify system column values based on column type + if column_name == "_rowid": + # Without stable row IDs, _rowid equals _rowaddr (not the index). + # Row address = (fragment_id << 32) | row_offset_within_fragment + # With max_rows_per_file=25: frag0=0-24, frag1=25-49, frag2=50-74, frag3=75-99 + expected_rowids = [ + (0 << 32) | 0, # index 0: fragment 0, offset 0 + (0 << 32) | 5, # index 5: fragment 0, offset 5 + (0 << 32) | 10, # index 10: fragment 0, offset 10 + (2 << 32) | 0, # index 50: fragment 2, offset 0 + (3 << 32) | 24, # index 99: fragment 3, offset 24 + ] + assert col_values == expected_rowids + elif column_name in ("_row_created_at_version", "_row_last_updated_at_version"): + # All rows created/updated at version 1 + assert col_values == [1] * len(indices) + # _rowaddr and _rowoffset values depend on fragment layout + + +def test_take_system_columns_column_ordering(tmp_path: Path): + """Test that column ordering is preserved when using system columns.""" + table = pa.table({"a": range(50), "b": range(50, 100)}) + base_dir = tmp_path / "test_take_column_ordering" + lance.write_dataset(table, base_dir) + dataset = lance.dataset(base_dir) + + indices = [0, 1, 2] + + # Test different orderings with all system columns + result = dataset.take(indices, columns=["_rowid", "a", "_rowaddr"]) + assert result.schema.names == ["_rowid", "a", "_rowaddr"] + + result = dataset.take(indices, columns=["a", "_rowaddr", "_rowid"]) + assert result.schema.names == ["a", "_rowaddr", "_rowid"] + + result = dataset.take(indices, columns=["_rowaddr", "_rowid", "b", "a"]) + assert result.schema.names == ["_rowaddr", "_rowid", "b", "a"] + + # Test with version columns + result = dataset.take( + indices, + columns=[ + "_row_created_at_version", + "a", + "_row_last_updated_at_version", + "_rowid", + ], + ) + assert result.schema.names == [ + "_row_created_at_version", + "a", + "_row_last_updated_at_version", + "_rowid", + ] + + # Test with all system columns in mixed order + result = dataset.take( + indices, + columns=[ + "_rowoffset", + "_row_last_updated_at_version", + "b", + "_rowaddr", + "_row_created_at_version", + "a", + "_rowid", + ], + ) + assert result.schema.names == [ + "_rowoffset", + "_row_last_updated_at_version", + "b", + "_rowaddr", + "_row_created_at_version", + "a", + "_rowid", + ] + + +def test_take_version_system_columns(tmp_path: Path): + """Test _row_created_at_version and _row_last_updated_at_version columns.""" + table = pa.table({"a": range(50)}) + base_dir = tmp_path / "test_take_version_columns" + lance.write_dataset(table, base_dir, enable_stable_row_ids=True) + dataset = lance.dataset(base_dir) + + # Initial version is 1 + initial_version = dataset.version + + indices = [0, 10, 25] + result = dataset.take( + indices, + columns=["a", "_row_created_at_version", "_row_last_updated_at_version"], + ) + + assert result.num_rows == 3 + created_at = result.column("_row_created_at_version").to_pylist() + updated_at = result.column("_row_last_updated_at_version").to_pylist() + + # All rows were created and last updated at the initial version + assert created_at == [initial_version] * 3 + assert updated_at == [initial_version] * 3 + + # Now update some rows by overwriting + table2 = pa.table({"a": range(50, 100)}) + lance.write_dataset(table2, base_dir, mode="append") + dataset = lance.dataset(base_dir) + + # New rows should have version 2 + result = dataset.take([50, 60], columns=["_row_created_at_version"]) + created_at = result.column("_row_created_at_version").to_pylist() + assert created_at == [dataset.version] * 2 + + @pytest.mark.parametrize("indices", [[], [1, 1], [1, 1, 20, 20, 21], [21, 0, 21, 1, 0]]) def test_take_duplicate_index(tmp_path: Path, indices: List[int]): table = pa.table({"x": range(24)}) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 0af4b46008a..b4b4d8af050 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1299,7 +1299,7 @@ impl Dataset { Arc::new( self.ds .schema() - .project(&columns) + .project_preserve_system_columns(&columns) .map_err(|err| PyValueError::new_err(err.to_string()))?, ) } else { diff --git a/rust/lance-core/src/datatypes/schema.rs b/rust/lance-core/src/datatypes/schema.rs index 3e8340bd19b..caa31fdf9da 100644 --- a/rust/lance-core/src/datatypes/schema.rs +++ b/rust/lance-core/src/datatypes/schema.rs @@ -16,7 +16,11 @@ use lance_arrow::*; use snafu::location; use super::field::{BlobVersion, Field, OnTypeMismatch, SchemaCompareOptions}; -use crate::{Error, Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD, WILDCARD}; +use crate::{ + Error, Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_CREATED_AT_VERSION, ROW_CREATED_AT_VERSION_FIELD, + ROW_ID, ROW_ID_FIELD, ROW_LAST_UPDATED_AT_VERSION, ROW_LAST_UPDATED_AT_VERSION_FIELD, + ROW_OFFSET, ROW_OFFSET_FIELD, WILDCARD, +}; /// Lance Schema. #[derive(Default, Debug, Clone, DeepSizeOf)] @@ -221,7 +225,12 @@ impl Schema { } } - fn do_project>(&self, columns: &[T], err_on_missing: bool) -> Result { + fn do_project>( + &self, + columns: &[T], + err_on_missing: bool, + preserve_system_columns: bool, + ) -> Result { let mut candidates: Vec = vec![]; for col in columns { let split = parse_field_path(col.as_ref())?; @@ -234,7 +243,30 @@ impl Schema { } else { candidates.push(projected_field) } - } else if err_on_missing && first != ROW_ID && first != ROW_ADDR { + } else if crate::is_system_column(first) { + if preserve_system_columns { + if first == ROW_ID { + candidates.push(Field::try_from(ROW_ID_FIELD.clone())?); + } else if first == ROW_ADDR { + candidates.push(Field::try_from(ROW_ADDR_FIELD.clone())?); + } else if first == ROW_OFFSET { + candidates.push(Field::try_from(ROW_OFFSET_FIELD.clone())?); + } else if first == ROW_CREATED_AT_VERSION { + candidates.push(Field::try_from(ROW_CREATED_AT_VERSION_FIELD.clone())?); + } else if first == ROW_LAST_UPDATED_AT_VERSION { + candidates + .push(Field::try_from(ROW_LAST_UPDATED_AT_VERSION_FIELD.clone())?); + } else { + return Err(Error::Schema { + message: format!( + "System column {} is currently not supported in projection", + first + ), + location: location!(), + }); + } + } + } else if err_on_missing { return Err(Error::Schema { message: format!("Column {} does not exist", col.as_ref()), location: location!(), @@ -255,12 +287,17 @@ impl Schema { /// let projected = schema.project(&["col1", "col2.sub_col3.field4"])?; /// ``` pub fn project>(&self, columns: &[T]) -> Result { - self.do_project(columns, true) + self.do_project(columns, true, false) } /// Project the columns over the schema, dropping unrecognized columns pub fn project_or_drop>(&self, columns: &[T]) -> Result { - self.do_project(columns, false) + self.do_project(columns, false, false) + } + + /// Project the columns over the schema, preserving system columns. + pub fn project_preserve_system_columns>(&self, columns: &[T]) -> Result { + self.do_project(columns, true, true) } /// Check that the top level fields don't contain `.` in their names @@ -1832,6 +1869,41 @@ mod tests { assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema); } + #[test] + fn test_schema_projection_preserving_system_columns() { + let arrow_schema = ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ + ArrowField::new("f1", DataType::Utf8, true), + ArrowField::new("f2", DataType::Boolean, false), + ArrowField::new("f3", DataType::Float32, false), + ])), + true, + ), + ArrowField::new("c", DataType::Float64, false), + ]); + let schema = Schema::try_from(&arrow_schema).unwrap(); + let projected = schema + .project_preserve_system_columns(&["b.f1", "b.f3", "_rowid", "c"]) + .unwrap(); + + let expected_arrow_schema = ArrowSchema::new(vec![ + ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ + ArrowField::new("f1", DataType::Utf8, true), + ArrowField::new("f3", DataType::Float32, false), + ])), + true, + ), + ArrowField::new("_rowid", DataType::UInt64, true), + ArrowField::new("c", DataType::Float64, false), + ]); + assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema); + } + #[test] fn test_schema_project_by_ids() { let arrow_schema = ArrowSchema::new(vec![ diff --git a/rust/lance-datafusion/src/projection.rs b/rust/lance-datafusion/src/projection.rs index f586ac4bb20..aacb63d4118 100644 --- a/rust/lance-datafusion/src/projection.rs +++ b/rust/lance-datafusion/src/projection.rs @@ -264,6 +264,8 @@ impl ProjectionPlan { let mut with_row_id = false; let mut with_row_addr = false; let mut must_add_row_offset = false; + let mut with_row_last_updated_at_version = false; + let mut with_row_created_at_version = false; for field in projection.fields.iter() { if lance_core::is_system_column(&field.name) { @@ -273,10 +275,14 @@ impl ProjectionPlan { must_add_row_offset = true; } else if field.name == ROW_ADDR { with_row_addr = true; + } else if field.name == ROW_OFFSET { + with_row_addr = true; must_add_row_offset = true; + } else if field.name == ROW_LAST_UPDATED_AT_VERSION { + with_row_last_updated_at_version = true; + } else if field.name == ROW_CREATED_AT_VERSION { + with_row_created_at_version = true; } - // Note: Other system columns like _rowoffset are computed differently - // and shouldn't appear in the schema at this point } else { // Regular data column - validate it exists in base schema if base.schema().field(&field.name).is_none() { @@ -301,6 +307,8 @@ impl ProjectionPlan { .with_blob_version(blob_version); physical_projection.with_row_id = with_row_id; physical_projection.with_row_addr = with_row_addr; + physical_projection.with_row_last_updated_at_version = with_row_last_updated_at_version; + physical_projection.with_row_created_at_version = with_row_created_at_version; // Build output expressions preserving the original order (including system columns) let exprs = projection @@ -431,8 +439,18 @@ impl ProjectionPlan { #[instrument(skip_all, level = "debug")] pub async fn project_batch(&self, batch: RecordBatch) -> Result { let src = Arc::new(OneShotExec::from_batch(batch)); - let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?; + + // Need to add ROW_OFFSET to get filterable schema + let extra_columns = vec![ + ArrowField::new(ROW_ADDR, DataType::UInt64, true), + ArrowField::new(ROW_OFFSET, DataType::UInt64, true), + ]; + let mut filterable_schema = self.physical_projection.to_schema(); + filterable_schema = filterable_schema.merge(&ArrowSchema::new(extra_columns))?; + + let physical_exprs = self.to_physical_exprs(&(&filterable_schema).into())?; let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?); + // Run dummy plan to execute projection, do not log the plan run let stream = execute_plan( projection, diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index bf4bede7d50..2562a2b9b2c 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -19,16 +19,14 @@ use crate::dataset::transaction::translate_schema_metadata_updates; use crate::session::caches::{DSMetadataCache, ManifestKey, TransactionKey}; use crate::session::index_caches::DSIndexCache; use itertools::Itertools; -use lance_core::datatypes::{ - BlobVersion, Field, OnMissing, OnTypeMismatch, Projectable, Projection, -}; +use lance_core::datatypes::{BlobVersion, OnMissing, OnTypeMismatch, Projectable, Projection}; use lance_core::traits::DatasetTakeRows; use lance_core::utils::address::RowAddress; use lance_core::utils::tracing::{ DATASET_CLEANING_EVENT, DATASET_DELETING_EVENT, DATASET_DROPPING_COLUMN_EVENT, TRACE_DATASET_EVENTS, }; -use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID_FIELD}; +use lance_core::ROW_ADDR; use lance_datafusion::projection::ProjectionPlan; use lance_file::datatypes::populate_schema_dictionary; use lance_file::reader::FileReaderOptions; @@ -337,47 +335,9 @@ impl ProjectionRequest { .map(|s| s.as_ref().to_string()) .collect::>(); - // Separate data columns from system columns - // System columns need to be added to the schema manually since Schema::project - // doesn't include them (they're virtual columns) - let mut data_columns = Vec::new(); - let mut system_fields = Vec::new(); - - for col in &columns { - if lance_core::is_system_column(col) { - // For now we only support _rowid and _rowaddr in projections - if col == ROW_ID { - system_fields.push(Field::try_from(ROW_ID_FIELD.clone()).unwrap()); - } else if col == ROW_ADDR { - system_fields.push(Field::try_from(ROW_ADDR_FIELD.clone()).unwrap()); - } - // Note: Other system columns like _rowoffset are handled differently - } else { - data_columns.push(col.as_str()); - } - } - - // Project only the data columns - let mut schema = dataset_schema.project(&data_columns).unwrap(); - - // Add system fields in the order they appeared in the original columns list - // We need to reconstruct the proper order - let mut final_fields = Vec::new(); - for col in &columns { - if lance_core::is_system_column(col) { - // Find and add the system field - if let Some(field) = system_fields.iter().find(|f| &f.name == col) { - final_fields.push(field.clone()); - } - } else { - // Find and add the data field - if let Some(field) = schema.fields.iter().find(|f| &f.name == col) { - final_fields.push(field.clone()); - } - } - } - - schema.fields = final_fields; + let schema = dataset_schema + .project_preserve_system_columns(&columns) + .unwrap(); Self::Schema(Arc::new(schema)) } diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 8428cf619b4..985040227b8 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -1376,7 +1376,8 @@ impl FileFragment { }; // Then call take rows - self.take_rows(&row_ids, projection, false, false).await + self.take_rows(&row_ids, projection, false, false, false, false) + .await } /// Get the deletion vector for this fragment, using the cache if available. @@ -1424,13 +1425,17 @@ impl FileFragment { projection: &Schema, with_row_id: bool, with_row_address: bool, + with_row_created_at_version: bool, + with_row_last_updated_at_version: bool, ) -> Result { let reader = self .open( projection, FragReadConfig::default() .with_row_id(with_row_id) - .with_row_address(with_row_address), + .with_row_address(with_row_address) + .with_row_created_at_version(with_row_created_at_version) + .with_row_last_updated_at_version(with_row_last_updated_at_version), ) .await?; @@ -3314,7 +3319,14 @@ mod tests { // Repeated indices are repeated in result. let batch = fragment - .take_rows(&[1, 2, 4, 5, 5, 8], dataset.schema(), false, false) + .take_rows( + &[1, 2, 4, 5, 5, 8], + dataset.schema(), + false, + false, + false, + false, + ) .await .unwrap(); assert_eq!( @@ -3333,7 +3345,14 @@ mod tests { .unwrap(); assert!(fragment.metadata().deletion_file.is_some()); let batch = fragment - .take_rows(&[1, 2, 4, 5, 8], dataset.schema(), false, false) + .take_rows( + &[1, 2, 4, 5, 8], + dataset.schema(), + false, + false, + false, + false, + ) .await .unwrap(); assert_eq!( @@ -3343,7 +3362,7 @@ mod tests { // Empty indices gives empty result let batch = fragment - .take_rows(&[], dataset.schema(), false, false) + .take_rows(&[], dataset.schema(), false, false, false, false) .await .unwrap(); assert_eq!( @@ -3353,7 +3372,14 @@ mod tests { // Can get row ids let batch = fragment - .take_rows(&[1, 2, 4, 5, 8], dataset.schema(), false, true) + .take_rows( + &[1, 2, 4, 5, 8], + dataset.schema(), + false, + true, + false, + false, + ) .await .unwrap(); assert_eq!( diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index f39531f9853..00a59f15f5a 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -5,21 +5,24 @@ use std::{collections::BTreeMap, ops::Range, pin::Pin, sync::Arc}; use crate::dataset::fragment::FragReadConfig; use crate::dataset::rowids::get_row_id_index; +use crate::io::exec::AddRowOffsetExec; use crate::{Error, Result}; use arrow::{compute::concat_batches, datatypes::UInt64Type}; use arrow_array::cast::AsArray; -use arrow_array::{Array, RecordBatch, StructArray, UInt64Array}; +use arrow_array::{Array, ArrayRef, RecordBatch, StructArray, UInt64Array}; use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, NullBuffer}; use arrow_schema::Field as ArrowField; +use datafusion::common::Column; use datafusion::error::DataFusionError; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_expr::Expr; use futures::{Future, Stream, StreamExt, TryStreamExt}; use lance_arrow::RecordBatchExt; use lance_core::datatypes::Schema; use lance_core::utils::address::RowAddress; use lance_core::utils::deletion::OffsetMapper; -use lance_core::ROW_ADDR; -use lance_datafusion::projection::ProjectionPlan; +use lance_core::{ROW_ADDR, ROW_OFFSET}; +use lance_datafusion::projection::{OutputColumn, ProjectionPlan}; use snafu::location; use super::ProjectionRequest; @@ -125,12 +128,37 @@ pub async fn take( } /// Take rows by the internal ROW ids. +#[allow(clippy::needless_question_mark)] async fn do_take_rows( mut builder: TakeBuilder, projection: Arc, ) -> Result { + // If we need row addresses in output, add to projection's output expressions + let projection = if builder.with_row_address { + let mut proj = (*projection).clone(); + // Add _rowaddr to output if not already present + if !proj + .requested_output_expr + .iter() + .any(|c| c.name == ROW_ADDR) + { + proj.requested_output_expr.push(OutputColumn { + expr: Expr::Column(Column::from_name(ROW_ADDR)), + name: ROW_ADDR.to_string(), + }); + } + Arc::new(proj) + } else { + projection + }; + let with_row_id_in_projection = projection.physical_projection.with_row_id; let with_row_addr_in_projection = projection.physical_projection.with_row_addr; + let with_row_created_at_version_in_projection = + projection.physical_projection.with_row_created_at_version; + let with_row_last_updated_at_version_in_projection = projection + .physical_projection + .with_row_last_updated_at_version; let row_addrs = builder.get_row_addrs().await?.clone(); @@ -160,6 +188,8 @@ async fn do_take_rows( projection: Arc, with_row_id: bool, with_row_addresses: bool, + with_row_created_at_version: bool, + with_row_last_updated_at_version: bool, ) -> impl Future> + Send { async move { fragment @@ -168,14 +198,15 @@ async fn do_take_rows( projection.as_ref(), with_row_id, with_row_addresses, + with_row_created_at_version, + with_row_last_updated_at_version, ) .await } } let physical_schema = Arc::new(projection.physical_projection.to_bare_schema()); - - let batch = if row_addr_stats.contiguous { + let mut batch = if row_addr_stats.contiguous { // Fastest path: Can use `read_range` directly let start = row_addrs.first().expect("empty range passed to take_rows"); let fragment_id = (start >> 32) as usize; @@ -195,7 +226,9 @@ async fn do_take_rows( let read_config = FragReadConfig::default() .with_row_id(with_row_id_in_projection) - .with_row_address(with_row_addr_in_projection); + .with_row_address(with_row_addr_in_projection) + .with_row_created_at_version(with_row_created_at_version_in_projection) + .with_row_last_updated_at_version(with_row_last_updated_at_version_in_projection); let reader = fragment.open(&physical_schema, read_config).await?; reader.legacy_read_range_as_batch(range).await } else if row_addr_stats.sorted { @@ -243,6 +276,8 @@ async fn do_take_rows( physical_schema.clone(), with_row_id_in_projection, with_row_addr_in_projection, + with_row_created_at_version_in_projection, + with_row_last_updated_at_version_in_projection, ); batches.push(batch_fut); } @@ -283,6 +318,8 @@ async fn do_take_rows( physical_schema.clone(), with_row_id_in_projection, true, + with_row_created_at_version_in_projection, + with_row_last_updated_at_version_in_projection, ) }) .buffered(builder.dataset.object_store.io_parallelism()) @@ -329,8 +366,8 @@ async fn do_take_rows( Ok(reordered.into()) }?; - let batch = projection.project_batch(batch).await?; - if builder.with_row_address { + if builder.with_row_address || projection.must_add_row_offset { + // compile `ROW_ADDR` column if batch.num_rows() != row_addrs.len() { return Err(Error::NotSupported { source: format!( @@ -342,12 +379,26 @@ async fn do_take_rows( }); } - let row_addr_col = Arc::new(UInt64Array::from(row_addrs)); - let row_addr_field = ArrowField::new(ROW_ADDR, arrow::datatypes::DataType::UInt64, false); - Ok(batch.try_with_column(row_addr_field, row_addr_col)?) - } else { - Ok(batch) + let row_addr_col: ArrayRef = Arc::new(UInt64Array::from(row_addrs)); + + if projection.must_add_row_offset { + // compile and inject `ROW_OFFSET` column + let row_offset_col = + AddRowOffsetExec::compute_row_offset_array(&row_addr_col, builder.dataset).await?; + let row_offset_field = + ArrowField::new(ROW_OFFSET, arrow::datatypes::DataType::UInt64, false); + batch = batch.try_with_column(row_offset_field, row_offset_col)?; + } + + if builder.with_row_address { + // inject `ROW_ADDR` column + let row_addr_field = + ArrowField::new(ROW_ADDR, arrow::datatypes::DataType::UInt64, false); + batch = batch.try_with_column(row_addr_field, row_addr_col)?; + } } + + Ok(projection.project_batch(batch).await?) } async fn take_rows(builder: TakeBuilder) -> Result { diff --git a/rust/lance/src/io/exec/rowids.rs b/rust/lance/src/io/exec/rowids.rs index 0078c1256b7..abe8f229157 100644 --- a/rust/lance/src/io/exec/rowids.rs +++ b/rust/lance/src/io/exec/rowids.rs @@ -393,6 +393,13 @@ impl AddRowOffsetExec { input: Arc, dataset: Arc, ) -> LanceResult { + let frag_id_to_offset = Self::compute_frag_id_to_offset(dataset).await?; + Self::internal_new(input, frag_id_to_offset) + } + + async fn compute_frag_id_to_offset( + dataset: Arc, + ) -> LanceResult>> { let mut frag_id_to_offset = HashMap::new(); let mut row_offset = 0; for frag in dataset.get_fragments() { @@ -408,7 +415,15 @@ impl AddRowOffsetExec { row_offset += frag.count_rows(None).await? as u64; } - Self::internal_new(input, Arc::new(frag_id_to_offset)) + Ok(Arc::new(frag_id_to_offset)) + } + + pub async fn compute_row_offset_array( + row_addr: &ArrayRef, + dataset: Arc, + ) -> Result { + let frag_id_to_offset = Self::compute_frag_id_to_offset(dataset).await?; + Self::compute_row_offsets(row_addr, frag_id_to_offset.as_ref()) } fn compute_row_offsets(