Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,152 @@ def test_take_rowid_rowaddr(tmp_path: Path):
assert sample_dataset.num_columns == 2


@pytest.mark.parametrize(
"column_name",
[
"_rowid",
"_rowaddr",
"_rowoffset",
"_row_created_at_version",
"_row_last_updated_at_version",
],
)
def test_take_system_columns_values(tmp_path: Path, column_name: str):
"""Test that system columns return correct values in take."""
table = pa.table({"a": range(100), "b": range(100, 200)})
base_dir = tmp_path / "test_take_system_columns_values"
# Use max_rows_per_file to create multiple fragments
lance.write_dataset(table, base_dir, max_rows_per_file=25)
dataset = lance.dataset(base_dir)

indices = [0, 5, 10, 50, 99]
result = dataset.take(indices, columns=[column_name, "a"])
assert result.num_rows == len(indices)
assert result.schema.names == [column_name, "a"]

col_values = result.column(column_name).to_pylist()
a_values = result.column("a").to_pylist()

# Verify column type is UInt64
assert result.column(column_name).type == pa.uint64()

# Verify data column values
assert a_values == indices

# Verify system column values based on column type
if column_name == "_rowid":
# Without stable row IDs, _rowid equals _rowaddr (not the index).
# Row address = (fragment_id << 32) | row_offset_within_fragment
# With max_rows_per_file=25: frag0=0-24, frag1=25-49, frag2=50-74, frag3=75-99
expected_rowids = [
(0 << 32) | 0, # index 0: fragment 0, offset 0
(0 << 32) | 5, # index 5: fragment 0, offset 5
(0 << 32) | 10, # index 10: fragment 0, offset 10
(2 << 32) | 0, # index 50: fragment 2, offset 0
(3 << 32) | 24, # index 99: fragment 3, offset 24
]
assert col_values == expected_rowids
elif column_name in ("_row_created_at_version", "_row_last_updated_at_version"):
# All rows created/updated at version 1
assert col_values == [1] * len(indices)
# _rowaddr and _rowoffset values depend on fragment layout


def test_take_system_columns_column_ordering(tmp_path: Path):
"""Test that column ordering is preserved when using system columns."""
table = pa.table({"a": range(50), "b": range(50, 100)})
base_dir = tmp_path / "test_take_column_ordering"
lance.write_dataset(table, base_dir)
dataset = lance.dataset(base_dir)

indices = [0, 1, 2]

# Test different orderings with all system columns
result = dataset.take(indices, columns=["_rowid", "a", "_rowaddr"])
assert result.schema.names == ["_rowid", "a", "_rowaddr"]

result = dataset.take(indices, columns=["a", "_rowaddr", "_rowid"])
assert result.schema.names == ["a", "_rowaddr", "_rowid"]

result = dataset.take(indices, columns=["_rowaddr", "_rowid", "b", "a"])
assert result.schema.names == ["_rowaddr", "_rowid", "b", "a"]

# Test with version columns
result = dataset.take(
indices,
columns=[
"_row_created_at_version",
"a",
"_row_last_updated_at_version",
"_rowid",
],
)
assert result.schema.names == [
"_row_created_at_version",
"a",
"_row_last_updated_at_version",
"_rowid",
]

# Test with all system columns in mixed order
result = dataset.take(
indices,
columns=[
"_rowoffset",
"_row_last_updated_at_version",
"b",
"_rowaddr",
"_row_created_at_version",
"a",
"_rowid",
],
)
assert result.schema.names == [
"_rowoffset",
"_row_last_updated_at_version",
"b",
"_rowaddr",
"_row_created_at_version",
"a",
"_rowid",
]


def test_take_version_system_columns(tmp_path: Path):
"""Test _row_created_at_version and _row_last_updated_at_version columns."""
table = pa.table({"a": range(50)})
base_dir = tmp_path / "test_take_version_columns"
lance.write_dataset(table, base_dir, enable_stable_row_ids=True)
dataset = lance.dataset(base_dir)

# Initial version is 1
initial_version = dataset.version

indices = [0, 10, 25]
result = dataset.take(
indices,
columns=["a", "_row_created_at_version", "_row_last_updated_at_version"],
)

assert result.num_rows == 3
created_at = result.column("_row_created_at_version").to_pylist()
updated_at = result.column("_row_last_updated_at_version").to_pylist()

# All rows were created and last updated at the initial version
assert created_at == [initial_version] * 3
assert updated_at == [initial_version] * 3

# Now update some rows by overwriting
table2 = pa.table({"a": range(50, 100)})
lance.write_dataset(table2, base_dir, mode="append")
dataset = lance.dataset(base_dir)

# New rows should have version 2
result = dataset.take([50, 60], columns=["_row_created_at_version"])
created_at = result.column("_row_created_at_version").to_pylist()
assert created_at == [dataset.version] * 2


@pytest.mark.parametrize("indices", [[], [1, 1], [1, 1, 20, 20, 21], [21, 0, 21, 1, 0]])
def test_take_duplicate_index(tmp_path: Path, indices: List[int]):
table = pa.table({"x": range(24)})
Expand Down
2 changes: 1 addition & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ impl Dataset {
Arc::new(
self.ds
.schema()
.project(&columns)
.project_preserve_system_columns(&columns)
.map_err(|err| PyValueError::new_err(err.to_string()))?,
)
} else {
Expand Down
82 changes: 77 additions & 5 deletions rust/lance-core/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ use lance_arrow::*;
use snafu::location;

use super::field::{BlobVersion, Field, OnTypeMismatch, SchemaCompareOptions};
use crate::{Error, Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD, WILDCARD};
use crate::{
Error, Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_CREATED_AT_VERSION, ROW_CREATED_AT_VERSION_FIELD,
ROW_ID, ROW_ID_FIELD, ROW_LAST_UPDATED_AT_VERSION, ROW_LAST_UPDATED_AT_VERSION_FIELD,
ROW_OFFSET, ROW_OFFSET_FIELD, WILDCARD,
};

/// Lance Schema.
#[derive(Default, Debug, Clone, DeepSizeOf)]
Expand Down Expand Up @@ -221,7 +225,12 @@ impl Schema {
}
}

fn do_project<T: AsRef<str>>(&self, columns: &[T], err_on_missing: bool) -> Result<Self> {
fn do_project<T: AsRef<str>>(
&self,
columns: &[T],
err_on_missing: bool,
preserve_system_columns: bool,
) -> Result<Self> {
let mut candidates: Vec<Field> = vec![];
for col in columns {
let split = parse_field_path(col.as_ref())?;
Expand All @@ -234,7 +243,30 @@ impl Schema {
} else {
candidates.push(projected_field)
}
} else if err_on_missing && first != ROW_ID && first != ROW_ADDR {
} else if crate::is_system_column(first) {
if preserve_system_columns {
if first == ROW_ID {
candidates.push(Field::try_from(ROW_ID_FIELD.clone())?);
} else if first == ROW_ADDR {
candidates.push(Field::try_from(ROW_ADDR_FIELD.clone())?);
} else if first == ROW_OFFSET {
candidates.push(Field::try_from(ROW_OFFSET_FIELD.clone())?);
} else if first == ROW_CREATED_AT_VERSION {
candidates.push(Field::try_from(ROW_CREATED_AT_VERSION_FIELD.clone())?);
} else if first == ROW_LAST_UPDATED_AT_VERSION {
candidates
.push(Field::try_from(ROW_LAST_UPDATED_AT_VERSION_FIELD.clone())?);
} else {
return Err(Error::Schema {
message: format!(
"System column {} is currently not supported in projection",
first
),
location: location!(),
});
}
}
} else if err_on_missing {
return Err(Error::Schema {
message: format!("Column {} does not exist", col.as_ref()),
location: location!(),
Expand All @@ -255,12 +287,17 @@ impl Schema {
/// let projected = schema.project(&["col1", "col2.sub_col3.field4"])?;
/// ```
pub fn project<T: AsRef<str>>(&self, columns: &[T]) -> Result<Self> {
self.do_project(columns, true)
self.do_project(columns, true, false)
}

/// Project the columns over the schema, dropping unrecognized columns
pub fn project_or_drop<T: AsRef<str>>(&self, columns: &[T]) -> Result<Self> {
self.do_project(columns, false)
self.do_project(columns, false, false)
}

/// Project the columns over the schema, preserving system columns.
pub fn project_preserve_system_columns<T: AsRef<str>>(&self, columns: &[T]) -> Result<Self> {
self.do_project(columns, true, true)
}

/// Check that the top level fields don't contain `.` in their names
Expand Down Expand Up @@ -1832,6 +1869,41 @@ mod tests {
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
}

#[test]
fn test_schema_projection_preserving_system_columns() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let projected = schema
.project_preserve_system_columns(&["b.f1", "b.f3", "_rowid", "c"])
.unwrap();

let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("_rowid", DataType::UInt64, true),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
}

#[test]
fn test_schema_project_by_ids() {
let arrow_schema = ArrowSchema::new(vec![
Expand Down
24 changes: 21 additions & 3 deletions rust/lance-datafusion/src/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ impl ProjectionPlan {
let mut with_row_id = false;
let mut with_row_addr = false;
let mut must_add_row_offset = false;
let mut with_row_last_updated_at_version = false;
let mut with_row_created_at_version = false;

for field in projection.fields.iter() {
if lance_core::is_system_column(&field.name) {
Expand All @@ -273,10 +275,14 @@ impl ProjectionPlan {
must_add_row_offset = true;
} else if field.name == ROW_ADDR {
with_row_addr = true;
} else if field.name == ROW_OFFSET {
with_row_addr = true;
must_add_row_offset = true;
} else if field.name == ROW_LAST_UPDATED_AT_VERSION {
with_row_last_updated_at_version = true;
} else if field.name == ROW_CREATED_AT_VERSION {
with_row_created_at_version = true;
}
// Note: Other system columns like _rowoffset are computed differently
// and shouldn't appear in the schema at this point
} else {
// Regular data column - validate it exists in base schema
if base.schema().field(&field.name).is_none() {
Expand All @@ -301,6 +307,8 @@ impl ProjectionPlan {
.with_blob_version(blob_version);
physical_projection.with_row_id = with_row_id;
physical_projection.with_row_addr = with_row_addr;
physical_projection.with_row_last_updated_at_version = with_row_last_updated_at_version;
physical_projection.with_row_created_at_version = with_row_created_at_version;

// Build output expressions preserving the original order (including system columns)
let exprs = projection
Expand Down Expand Up @@ -431,8 +439,18 @@ impl ProjectionPlan {
#[instrument(skip_all, level = "debug")]
pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
let src = Arc::new(OneShotExec::from_batch(batch));
let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;

// Need to add ROW_OFFSET to get filterable schema
let extra_columns = vec![
ArrowField::new(ROW_ADDR, DataType::UInt64, true),
ArrowField::new(ROW_OFFSET, DataType::UInt64, true),
];
let mut filterable_schema = self.physical_projection.to_schema();
filterable_schema = filterable_schema.merge(&ArrowSchema::new(extra_columns))?;

let physical_exprs = self.to_physical_exprs(&(&filterable_schema).into())?;
let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);

// Run dummy plan to execute projection, do not log the plan run
let stream = execute_plan(
projection,
Expand Down
Loading