Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cast list items to default before write with different item names #1959

Merged
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion crates/deltalake-core/src/operations/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn cast_record_batch_columns(
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();

if let (DataType::Struct(_), DataType::Struct(child_fields)) =
(col.data_type(), f.data_type())
{
Expand All @@ -28,7 +29,7 @@ fn cast_record_batch_columns(
child_columns.clone(),
None,
)) as ArrayRef)
} else if !col.data_type().equals_datatype(f.data_type()) {
} else if is_cast_required(col.data_type(), f.data_type()) {
cast_with_options(col, f.data_type(), cast_options)
} else {
Ok(col.clone())
Expand All @@ -37,6 +38,16 @@ fn cast_record_batch_columns(
.collect::<Result<Vec<_>, _>>()
}

fn is_cast_required(a: &DataType, b: &DataType) -> bool {
match (a, b) {
(DataType::List(a_item), DataType::List(b_item)) => {
// If list item name is not the default('item') the list must be casted
!a.equals_datatype(b) || a_item.name() != b_item.name()
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
}
(_, _) => !a.equals_datatype(b),
}
}

/// Cast recordbatch to a new target_schema, by casting each column array
pub fn cast_record_batch(
batch: &RecordBatch,
Expand All @@ -51,3 +62,80 @@ pub fn cast_record_batch(
let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?;
Ok(RecordBatch::try_new(target_schema, columns)?)
}

#[cfg(test)]
mod tests {
use crate::operations::cast::{cast_record_batch, is_cast_required};
use arrow::array::ArrayData;
use arrow_array::{Array, ArrayRef, ListArray, RecordBatch};
use arrow_buffer::Buffer;
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
use std::sync::Arc;

#[test]
fn test_cast_record_batch_with_list_non_default_item() {
let array = Arc::new(make_list_array()) as ArrayRef;
let source_schema = Schema::new(vec![Field::new(
"list_column",
array.data_type().clone(),
false,
)]);
let record_batch = RecordBatch::try_new(Arc::new(source_schema), vec![array]).unwrap();

let fields = Fields::from(vec![Field::new_list(
"list_column",
Field::new("item", DataType::Int8, false),
false,
)]);
let target_schema = Arc::new(Schema::new(fields)) as SchemaRef;

let result = cast_record_batch(&record_batch, target_schema, false);

let schema = result.unwrap().schema();
let field = schema.column_with_name("list_column").unwrap().1;
if let DataType::List(list_item) = field.data_type() {
assert_eq!(list_item.name(), "item");
} else {
panic!("Not a list");
}
}

fn make_list_array() -> ListArray {
let value_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
.build()
.unwrap();

let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);

let list_data_type = DataType::List(Arc::new(Field::new("element", DataType::Int32, true)));
let list_data = ArrayData::builder(list_data_type)
.len(3)
.add_buffer(value_offsets)
.add_child_data(value_data)
.build()
.unwrap();
ListArray::from(list_data)
}

#[test]
fn test_is_cast_required_with_list() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));

assert!(!is_cast_required(&field1, &field2));
}

#[test]
fn test_is_cast_required_with_list_non_default_item() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));

assert!(is_cast_required(&field1, &field2));
}
}
Loading