From 9b958487cb39a5b5f40f280690ca454194fa61c8 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Fri, 2 Jun 2023 21:49:19 -0400 Subject: [PATCH] Use table schema if it exists when casting batches --- rust/src/operations/write.rs | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/rust/src/operations/write.rs b/rust/src/operations/write.rs index 2e7399c8d6..3535ee0864 100644 --- a/rust/src/operations/write.rs +++ b/rust/src/operations/write.rs @@ -222,9 +222,7 @@ pub(crate) async fn write_execution_plan( .and_then(|meta| meta.schema.get_invariants().ok()) .unwrap_or_default(); - println!("here"); - //let schema = snapshot.arrow_schema() - // .unwrap_or(plan.schema().clone()); + let schema = snapshot.arrow_schema().unwrap_or(plan.schema()); let checker = DeltaDataChecker::new(invariants); @@ -232,8 +230,7 @@ pub(crate) async fn write_execution_plan( let mut tasks = vec![]; for i in 0..plan.output_partitioning().partition_count() { let inner_plan = plan.clone(); - //let inner_schema = schema.clone(); - let schema = inner_plan.schema().clone(); + let inner_schema = schema.clone(); let task_ctx = Arc::new(TaskContext::from(&state)); let config = WriterConfig::new( inner_plan.schema(), @@ -250,7 +247,7 @@ pub(crate) async fn write_execution_plan( while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; checker_stream.check_batch(&batch).await?; - let arr = cast_record_batch(&batch, schema.clone())?; + let arr = cast_record_batch(&batch, inner_schema.clone())?; writer.write(&arr).await?; } writer.close().await @@ -490,9 +487,9 @@ mod tests { use super::*; use crate::operations::DeltaOps; use crate::writer::test_utils::{get_delta_schema, get_record_batch}; - use arrow_array::{StringArray, Int32Array}; use arrow::datatypes::Field; use arrow::datatypes::Schema as ArrowSchema; + use arrow_array::{Int32Array, StringArray}; use serde_json::json; #[tokio::test] @@ -535,7 +532,6 @@ mod tests { assert_eq!(table.get_file_uris().count(), 1) } - #[tokio::test] async fn test_write_different_types() { // Ensure write fails when data of a different type from the table is provided. @@ -547,10 +543,7 @@ mod tests { let batch = RecordBatch::try_new( Arc::clone(&schema), - vec![Arc::new(Int32Array::from(vec![ - Some(0), - None, - ]))], + vec![Arc::new(Int32Array::from(vec![Some(0), None]))], ) .unwrap(); let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap(); @@ -567,11 +560,11 @@ mod tests { Some("Test123".to_owned()), None, ]))], - ).unwrap(); + ) + .unwrap(); let res = DeltaOps::from(table).write(vec![batch]).await; assert!(res.is_err()) - } #[tokio::test]