diff --git a/Cargo.lock b/Cargo.lock index 72dd645b439..6afe9459890 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8652,6 +8652,7 @@ dependencies = [ "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-physical-expr", "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", diff --git a/Cargo.toml b/Cargo.toml index 2d3ff3c48eb..f85a4291f80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,6 +101,7 @@ datafusion-common-runtime = { version = "50" } datafusion-datasource = { version = "50", default-features = false } datafusion-execution = { version = "50" } datafusion-expr = { version = "50" } +datafusion-functions = { version = "50" } datafusion-physical-expr = { version = "50" } datafusion-physical-expr-adapter = { version = "50" } datafusion-physical-expr-common = { version = "50" } diff --git a/vortex-datafusion/Cargo.toml b/vortex-datafusion/Cargo.toml index 65e9a8d2c42..c929b1befc1 100644 --- a/vortex-datafusion/Cargo.toml +++ b/vortex-datafusion/Cargo.toml @@ -23,6 +23,7 @@ datafusion-common-runtime = { workspace = true } datafusion-datasource = { workspace = true, default-features = false } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index e1072bfa1d5..047333a2ed1 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use arrow_schema::{DataType, Schema}; use datafusion_expr::Operator as DFOperator; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; +use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr; use datafusion_physical_plan::expressions as df_expr; use itertools::Itertools; @@ -104,10 +105,47 @@ impl TryFromDataFusion for ExprRef { return Ok(if in_list.negated() { not(expr) } else { expr }); } + if let Some(scalar_fn) = df.as_any().downcast_ref::() { + return try_convert_scalar_function(scalar_fn); + } + vortex_bail!("Couldn't convert DataFusion physical {df} expression to a vortex expression") } } +/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression. +fn try_convert_scalar_function(scalar_fn: &ScalarFunctionExpr) -> VortexResult { + if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::(scalar_fn) { + let source_expr = get_field_fn + .args() + .first() + .ok_or_else(|| vortex_err!("get_field missing source expression"))? + .as_ref(); + let field_name_expr = get_field_fn + .args() + .get(1) + .ok_or_else(|| vortex_err!("get_field missing field name argument"))?; + let field_name = field_name_expr + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("get_field field name must be a literal"))? + .value() + .try_as_str() + .flatten() + .ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?; + return Ok(get_item( + field_name.to_string(), + ExprRef::try_from_df(source_expr)?, + )); + } + + tracing::debug!( + function_name = scalar_fn.name(), + "Unsupported ScalarFunctionExpr" + ); + vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name()) +} + impl TryFromDataFusion for Operator { fn try_from_df(value: &DFOperator) -> VortexResult { match value { @@ -188,6 +226,9 @@ pub(crate) fn can_be_pushed_down(df_expr: &PhysicalExprRef, schema: &Schema) -> } else if let Some(in_list) = expr.downcast_ref::() { can_be_pushed_down(in_list.expr(), schema) && in_list.list().iter().all(|e| can_be_pushed_down(e, schema)) + } else if let Some(scalar_fn) = expr.downcast_ref::() { + // Only get_field pushdown is supported. + ScalarFunctionExpr::try_downcast_func::(scalar_fn).is_some() } else { tracing::debug!(%df_expr, "DataFusion expression can't be pushed down"); false @@ -203,6 +244,12 @@ fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> b fn supported_data_types(dt: &DataType) -> bool { use DataType::*; + + // For dictionary types, check if the value type is supported. + if let Dictionary(_, value_type) = dt { + return supported_data_types(value_type.as_ref()); + } + let is_supported = dt.is_null() || dt.is_numeric() || matches!( @@ -232,9 +279,11 @@ fn supported_data_types(dt: &DataType) -> bool { mod tests { use std::sync::Arc; - use arrow_schema::{DataType, Field, Schema, TimeUnit as ArrowTimeUnit}; + use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit as ArrowTimeUnit}; + use datafusion::functions::core::getfield::GetFieldFunc; use datafusion_common::ScalarValue; - use datafusion_expr::Operator as DFOperator; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{Operator as DFOperator, ScalarUDF}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::expressions as df_expr; use insta::assert_snapshot; @@ -415,6 +464,22 @@ mod tests { false )] #[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()), false)] + // Dictionary types - should be supported if value type is supported + #[case::dict_utf8( + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true + )] + #[case::dict_int32( + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)), + true + )] + #[case::dict_unsupported( + DataType::Dictionary( + Box::new(DataType::UInt32), + Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true)))) + ), + false + )] fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) { assert_eq!(supported_data_types(&data_type), expected); } @@ -518,4 +583,53 @@ mod tests { assert!(!can_be_pushed_down(&like_expr, &test_schema)); } + + #[test] + fn test_expr_from_df_get_field() { + let struct_col = Arc::new(df_expr::Column::new("my_struct", 0)) as Arc; + let field_name = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some( + "field1".to_string(), + )))) as Arc; + let get_field_expr = ScalarFunctionExpr::new( + "get_field", + Arc::new(ScalarUDF::from(GetFieldFunc::new())), + vec![struct_col, field_name], + Arc::new(Field::new("field1", DataType::Utf8, true)), + Arc::new(ConfigOptions::new()), + ); + let result = ExprRef::try_from_df(&get_field_expr).unwrap(); + assert_snapshot!(result.display_tree().to_string(), @r" + GetItem(field1) + └── GetItem(my_struct) + └── Root + "); + } + + #[test] + fn test_can_be_pushed_down_get_field() { + let struct_fields = Fields::from(vec![ + Field::new("field1", DataType::Utf8, true), + Field::new("field2", DataType::Int32, true), + ]); + let schema = Schema::new(vec![Field::new( + "my_struct", + DataType::Struct(struct_fields), + true, + )]); + + let struct_col = Arc::new(df_expr::Column::new("my_struct", 0)) as Arc; + let field_name = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some( + "field1".to_string(), + )))) as Arc; + + let get_field_expr = Arc::new(ScalarFunctionExpr::new( + "get_field", + Arc::new(ScalarUDF::from(GetFieldFunc::new())), + vec![struct_col, field_name], + Arc::new(Field::new("field1", DataType::Utf8, true)), + Arc::new(ConfigOptions::new()), + )) as Arc; + + assert!(can_be_pushed_down(&get_field_expr, &schema)); + } } diff --git a/vortex-datafusion/src/convert/scalars.rs b/vortex-datafusion/src/convert/scalars.rs index 19a5118fd94..8558167c758 100644 --- a/vortex-datafusion/src/convert/scalars.rs +++ b/vortex-datafusion/src/convert/scalars.rs @@ -236,6 +236,7 @@ impl FromDataFusion for Scalar { Scalar::null(DType::Decimal(decimal_dtype, nullable)) } } + ScalarValue::Dictionary(_, v) => Scalar::from_df(v.as_ref()), _ => unimplemented!("Can't convert {value:?} value to a Vortex scalar"), } } diff --git a/vortex-datafusion/src/persistent/opener.rs b/vortex-datafusion/src/persistent/opener.rs index 5169da6bd8c..9fe2e2b63a9 100644 --- a/vortex-datafusion/src/persistent/opener.rs +++ b/vortex-datafusion/src/persistent/opener.rs @@ -4,7 +4,7 @@ use std::ops::Range; use std::sync::{Arc, Weak}; -use arrow_schema::{ArrowError, Field, SchemaRef}; +use arrow_schema::{ArrowError, DataType, Field, SchemaRef}; use datafusion_common::arrow::array::RecordBatch; use datafusion_common::{DataFusionError, Result as DFResult}; use datafusion_datasource::file_meta::FileMeta; @@ -61,6 +61,87 @@ pub(crate) struct VortexOpener { pub has_output_ordering: bool, } +/// Merges the data types of two fields, preferring the logical type from the +/// table field. +fn merge_field_types(physical_field: &Field, table_field: &Field) -> DataType { + match (physical_field.data_type(), table_field.data_type()) { + (DataType::Struct(phys_fields), DataType::Struct(table_fields)) => { + let merged_fields = merge_fields(phys_fields, table_fields); + DataType::Struct(merged_fields.into()) + } + (DataType::List(phys_field), DataType::List(table_field)) => { + DataType::List(Arc::new(Field::new( + phys_field.name(), + merge_field_types(phys_field, table_field), + phys_field.is_nullable(), + ))) + } + (DataType::LargeList(phys_field), DataType::LargeList(table_field)) => { + DataType::LargeList(Arc::new(Field::new( + phys_field.name(), + merge_field_types(phys_field, table_field), + phys_field.is_nullable(), + ))) + } + _ => table_field.data_type().clone(), + } +} + +/// Merges two field collections, using logical types from table_fields where available. +/// Falls back to physical field types when no matching table field is found. +fn merge_fields( + physical_fields: &arrow_schema::Fields, + table_fields: &arrow_schema::Fields, +) -> Vec { + physical_fields + .iter() + .map(|phys_field| { + table_fields + .iter() + .find(|f| f.name() == phys_field.name()) + .map(|table_field| { + Field::new( + phys_field.name(), + merge_field_types(phys_field, table_field), + phys_field.is_nullable(), + ) + }) + .unwrap_or_else(|| (**phys_field).clone()) + }) + .collect() +} + +/// Computes a logical file schema from the physical file schema and the table +/// schema. +/// +/// For each field in the physical file schema, looks up the corresponding field +/// in the table schema and uses its logical type. +fn compute_logical_file_schema( + physical_file_schema: &SchemaRef, + table_schema: &SchemaRef, +) -> SchemaRef { + let logical_fields: Vec = physical_file_schema + .fields() + .iter() + .map(|physical_field| { + table_schema + .fields() + .find(physical_field.name()) + .map(|(_, table_field)| { + Field::new( + physical_field.name(), + merge_field_types(physical_field, table_field), + physical_field.is_nullable(), + ) + .with_metadata(physical_field.metadata().clone()) + }) + .unwrap_or_else(|| (**physical_field).clone()) + }) + .collect(); + + Arc::new(arrow_schema::Schema::new(logical_fields)) +} + impl FileOpener for VortexOpener { fn open(&self, file_meta: FileMeta, file: PartitionedFile) -> DFResult { let object_store = self.object_store.clone(); @@ -143,8 +224,11 @@ impl FileOpener for VortexOpener { // for schema evolution and divergence between the table's schema and individual files. filter = filter .map(|filter| { + let logical_file_schema = + compute_logical_file_schema(&physical_file_schema, &logical_schema); + let expr = expr_adapter_factory - .create(logical_schema.clone(), physical_file_schema.clone()) + .create(logical_file_schema, physical_file_schema.clone()) .with_partition_values(partition_values) .rewrite(filter)?; @@ -302,8 +386,9 @@ fn byte_range_to_row_range(byte_range: Range, row_count: u64, total_size: u #[cfg(test)] mod tests { + use arrow_schema::Fields; use chrono::Utc; - use datafusion::arrow::array::RecordBatch; + use datafusion::arrow::array::{RecordBatch, StringArray, StructArray}; use datafusion::arrow::datatypes::{DataType, Schema}; use datafusion::arrow::util::display::FormatOptions; use datafusion::common::record_batch; @@ -563,4 +648,90 @@ mod tests { Ok(()) } + + #[tokio::test] + // This test verifies that expression rewriting doesn't fail when there is + // a nested schema mismatch between the physical file schema and logical + // table schema. + async fn test_adapter_logical_physical_struct_mismatch() -> anyhow::Result<()> { + let vx_session = Arc::new(VortexSession::default()); + let object_store = Arc::new(InMemory::new()) as Arc; + let file_path = "/path/file.vortex"; + let file_struct_fields = Fields::from(vec![ + Field::new("field1", DataType::Utf8, true), + Field::new("field2", DataType::Utf8, true), + ]); + let struct_array = StructArray::new( + file_struct_fields.clone(), + vec![ + Arc::new(StringArray::from(vec!["value1", "value2", "value3"])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + None, + ); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "my_struct", + DataType::Struct(file_struct_fields), + true, + )])), + vec![Arc::new(struct_array)], + )?; + let data_size = write_arrow_to_vortex(object_store.clone(), file_path, batch).await?; + + // Table schema has an extra utf8 field. + let table_schema = Arc::new(Schema::new(vec![Field::new( + "my_struct", + DataType::Struct(Fields::from(vec![ + Field::new( + "field1", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + ), + Field::new( + "field2", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + ), + Field::new("field3", DataType::Utf8, true), + ])), + true, + )])); + + let opener = VortexOpener { + object_store: object_store.clone(), + projection: None, + filter: Some(logical2physical( + &col("my_struct").is_not_null(), + &table_schema, + )), + file_pruning_predicate: None, + expr_adapter_factory: Some(Arc::new(DefaultPhysicalExprAdapterFactory) as _), + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + partition_fields: vec![], + file_cache: VortexFileCache::new(1, 1, vx_session), + logical_schema: table_schema, + batch_size: 100, + limit: None, + metrics: Default::default(), + layout_readers: Default::default(), + has_output_ordering: false, + }; + + // The opener should be able to open the file with a filter on the + // struct column. + let data = opener + .open( + make_meta(file_path, data_size), + PartitionedFile::new(file_path.to_string(), data_size), + )? + .await? + .try_collect::>() + .await?; + + assert_eq!(data.len(), 1); + assert_eq!(data[0].num_rows(), 3); + + Ok(()) + } }