From c9bfeae8edc08bdbd05e1a84b870eccee5192f96 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Sun, 21 Jul 2024 16:23:45 -0700 Subject: [PATCH 1/4] Implement physical plan serialization for COPY plans CsvLogicalExtensionCodec --- .../examples/custom_file_format.rs | 4 + .../core/src/datasource/file_format/arrow.rs | 6 +- .../core/src/datasource/file_format/avro.rs | 11 ++ .../core/src/datasource/file_format/csv.rs | 15 +- .../core/src/datasource/file_format/json.rs | 12 ++ .../core/src/datasource/file_format/mod.rs | 16 +- .../src/datasource/file_format/parquet.rs | 11 ++ .../proto/src/logical_plan/file_formats.rs | 146 ++++++++++++++++-- datafusion/proto/src/logical_plan/mod.rs | 7 +- .../tests/cases/roundtrip_logical_plan.rs | 43 ++++-- 10 files changed, 242 insertions(+), 29 deletions(-) diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index bdb702375c94..c3f59e0409f4 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -166,6 +166,10 @@ impl FileFormatFactory for TSVFileFactory { fn default(&self) -> std::sync::Arc { todo!() } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for TSVFileFactory { diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 6bcbd4347682..8b6a8800119d 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -66,7 +66,7 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; -#[derive(Default)] +#[derive(Default, Debug)] /// Factory struct used to create [ArrowFormat] pub struct ArrowFormatFactory; @@ -89,6 +89,10 @@ impl FileFormatFactory for ArrowFormatFactory { fn default(&self) -> Arc { Arc::new(ArrowFormat) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for ArrowFormatFactory { diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index f4f9adcba7ed..5190bdbe153a 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::collections::HashMap; +use std::fmt; use std::sync::Arc; use arrow::datatypes::Schema; @@ -64,6 +65,16 @@ impl FileFormatFactory for AvroFormatFactory { fn default(&self) -> Arc { Arc::new(AvroFormat) } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl fmt::Debug for AvroFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AvroFormatFactory").finish() + } } impl GetExt for AvroFormatFactory { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 185f50883b2c..8a3cfa153606 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -58,7 +58,8 @@ use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore} #[derive(Default)] /// Factory struct used to create [CsvFormatFactory] pub struct CsvFormatFactory { - options: Option, + /// the options for csv file read + pub options: Option, } impl CsvFormatFactory { @@ -75,6 +76,14 @@ impl CsvFormatFactory { } } +impl fmt::Debug for CsvFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CsvFormatFactory") + .field("options", &self.options) + .finish() + } +} + impl FileFormatFactory for CsvFormatFactory { fn create( &self, @@ -103,6 +112,10 @@ impl FileFormatFactory for CsvFormatFactory { fn default(&self) -> Arc { Arc::new(CsvFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for CsvFormatFactory { diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 007b084f504d..9de9c3d7d871 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -102,6 +102,10 @@ impl FileFormatFactory for JsonFormatFactory { fn default(&self) -> Arc { Arc::new(JsonFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for JsonFormatFactory { @@ -111,6 +115,14 @@ impl GetExt for JsonFormatFactory { } } +impl fmt::Debug for JsonFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JsonFormatFactory") + .field("options", &self.options) + .finish() + } +} + /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 1aa93a106aff..da2b0b56fc7e 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -49,11 +49,11 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; use file_compression_type::FileCompressionType; use object_store::{ObjectMeta, ObjectStore}; - +use std::fmt::Debug; /// Factory for creating [`FileFormat`] instances based on session and command level options /// /// Users can provide their own `FileFormatFactory` to support arbitrary file formats -pub trait FileFormatFactory: Sync + Send + GetExt { +pub trait FileFormatFactory: Sync + Send + GetExt + Debug { /// Initialize a [FileFormat] and configure based on session and command level options fn create( &self, @@ -63,6 +63,10 @@ pub trait FileFormatFactory: Sync + Send + GetExt { /// Initialize a [FileFormat] with all options set to default values fn default(&self) -> Arc; + + /// Returns the table source as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; } /// This trait abstracts all the file format specific implementations @@ -138,6 +142,7 @@ pub trait FileFormat: Send + Sync + fmt::Debug { /// The former trait is a superset of the latter trait, which includes execution time /// relevant methods. [FileType] is only used in logical planning and only implements /// the subset of methods required during logical planning. +#[derive(Debug)] pub struct DefaultFileType { file_format_factory: Arc, } @@ -149,6 +154,11 @@ impl DefaultFileType { file_format_factory, } } + + /// get a [FileFormatFactory] struct + pub fn get_format_factory(&self) -> Arc { + self.file_format_factory.clone() + } } impl FileType for DefaultFileType { @@ -159,7 +169,7 @@ impl FileType for DefaultFileType { impl Display for DefaultFileType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.file_format_factory.default().fmt(f) + write!(f, "{:?}", self.file_format_factory) } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index d4e77b911c9f..3250b59fa1d1 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -140,6 +140,10 @@ impl FileFormatFactory for ParquetFormatFactory { fn default(&self) -> Arc { Arc::new(ParquetFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for ParquetFormatFactory { @@ -149,6 +153,13 @@ impl GetExt for ParquetFormatFactory { } } +impl fmt::Debug for ParquetFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetFormatFactory") + .field("ParquetFormatFactory", &self.options) + .finish() + } +} /// The Apache Parquet `FileFormat` implementation #[derive(Debug, Default)] pub struct ParquetFormat { diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 09e36a650b9f..69028a6a9592 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -18,19 +18,121 @@ use std::sync::Arc; use datafusion::{ + config::CsvOptions, datasource::file_format::{ arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, parquet::ParquetFormatFactory, FileFormatFactory, }, prelude::SessionContext, }; -use datafusion_common::{not_impl_err, TableReference}; +use datafusion_common::{ + exec_err, not_impl_err, parsers::CompressionTypeVariant, DataFusionError, + TableReference, +}; +use prost::Message; + +use crate::protobuf::CsvOptions as CsvOptionsProto; use super::LogicalExtensionCodec; #[derive(Debug)] pub struct CsvLogicalExtensionCodec; +impl CsvOptionsProto { + fn from_factory(factory: &CsvFormatFactory) -> Self { + if let Some(options) = &factory.options { + CsvOptionsProto { + has_header: options.has_header.map_or(vec![], |v| vec![v as u8]), + delimiter: vec![options.delimiter], + quote: vec![options.quote], + escape: options.escape.map_or(vec![], |v| vec![v]), + double_quote: options.double_quote.map_or(vec![], |v| vec![v as u8]), + compression: options.compression as i32, + schema_infer_max_rec: options.schema_infer_max_rec as u64, + date_format: options.date_format.clone().unwrap_or_default(), + datetime_format: options.datetime_format.clone().unwrap_or_default(), + timestamp_format: options.timestamp_format.clone().unwrap_or_default(), + timestamp_tz_format: options + .timestamp_tz_format + .clone() + .unwrap_or_default(), + time_format: options.time_format.clone().unwrap_or_default(), + null_value: options.null_value.clone().unwrap_or_default(), + comment: options.comment.map_or(vec![], |v| vec![v]), + } + } else { + CsvOptionsProto::default() + } + } +} + +impl From<&CsvOptionsProto> for CsvOptions { + fn from(proto: &CsvOptionsProto) -> Self { + CsvOptions { + has_header: if !proto.has_header.is_empty() { + Some(proto.has_header[0] != 0) + } else { + None + }, + delimiter: proto.delimiter.first().copied().unwrap_or(b','), + quote: proto.quote.first().copied().unwrap_or(b'"'), + escape: if !proto.escape.is_empty() { + Some(proto.escape[0]) + } else { + None + }, + double_quote: if !proto.double_quote.is_empty() { + Some(proto.double_quote[0] != 0) + } else { + None + }, + compression: match proto.compression { + 0 => CompressionTypeVariant::GZIP, + 1 => CompressionTypeVariant::BZIP2, + 2 => CompressionTypeVariant::XZ, + 3 => CompressionTypeVariant::ZSTD, + _ => CompressionTypeVariant::UNCOMPRESSED, + }, + schema_infer_max_rec: proto.schema_infer_max_rec as usize, + date_format: if proto.date_format.is_empty() { + None + } else { + Some(proto.date_format.clone()) + }, + datetime_format: if proto.datetime_format.is_empty() { + None + } else { + Some(proto.datetime_format.clone()) + }, + timestamp_format: if proto.timestamp_format.is_empty() { + None + } else { + Some(proto.timestamp_format.clone()) + }, + timestamp_tz_format: if proto.timestamp_tz_format.is_empty() { + None + } else { + Some(proto.timestamp_tz_format.clone()) + }, + time_format: if proto.time_format.is_empty() { + None + } else { + Some(proto.time_format.clone()) + }, + null_value: if proto.null_value.is_empty() { + None + } else { + Some(proto.null_value.clone()) + }, + comment: if !proto.comment.is_empty() { + Some(proto.comment[0]) + } else { + None + }, + } + } +} + // TODO! This is a placeholder for now and needs to be implemented for real. impl LogicalExtensionCodec for CsvLogicalExtensionCodec { fn try_decode( @@ -73,17 +175,41 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { fn try_decode_file_format( &self, - __buf: &[u8], - __ctx: &SessionContext, + buf: &[u8], + _ctx: &SessionContext, ) -> datafusion_common::Result> { - Ok(Arc::new(CsvFormatFactory::new())) + let proto = CsvOptionsProto::decode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode CsvOptionsProto: {:?}", + e + )) + })?; + let options: CsvOptions = (&proto).into(); + Ok(Arc::new(CsvFormatFactory { + options: Some(options), + })) } fn try_encode_file_format( &self, - __buf: &[u8], - __node: Arc, + buf: &mut Vec, + node: Arc, ) -> datafusion_common::Result<()> { + let options = + if let Some(csv_factory) = node.as_any().downcast_ref::() { + csv_factory.options.clone().unwrap_or_default() + } else { + return exec_err!("{}", "Unsupported FileFormatFactory type".to_string()); + }; + + let proto = CsvOptionsProto::from_factory(&CsvFormatFactory { + options: Some(options), + }); + + proto.encode(buf).map_err(|e| { + DataFusionError::Execution(format!("Failed to encode CsvOptions: {:?}", e)) + })?; + Ok(()) } } @@ -141,7 +267,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -201,7 +327,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -261,7 +387,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -321,7 +447,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 2a963fb13ccf..5427f34e8e07 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -131,7 +131,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_file_format( &self, - _buf: &[u8], + _buf: &mut Vec, _node: Arc, ) -> Result<()> { Ok(()) @@ -1666,10 +1666,9 @@ impl AsLogicalPlan for LogicalPlanNode { input, extension_codec, )?; - - let buf = Vec::new(); + let mut buf = Vec::new(); extension_codec - .try_encode_file_format(&buf, file_type_to_format(file_type)?)?; + .try_encode_file_format(&mut buf, file_type_to_format(file_type)?)?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3476d5d042cc..42a8fc380713 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,12 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::collections::HashMap; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; - use arrow::array::{ ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, }; @@ -29,11 +23,16 @@ use arrow::datatypes::{ IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; use prost::Message; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::vec; use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; -use datafusion::datasource::file_format::format_as_file_type; use datafusion::datasource::file_format::parquet::ParquetFormatFactory; +use datafusion::datasource::file_format::{format_as_file_type, DefaultFileType}; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::session_state::SessionStateBuilder; @@ -379,7 +378,9 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.dictionary_page_size_limit = 444; parquet_format.global.max_row_group_size = 555; - let file_type = format_as_file_type(Arc::new(ParquetFormatFactory::new())); + let file_type = format_as_file_type(Arc::new( + ParquetFormatFactory::new_with_options(parquet_format), + )); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -394,7 +395,6 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let logical_round_trip = logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.parquet", copy_to.output_url); @@ -457,7 +457,9 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.time_format = Some("HH:mm:ss".to_string()); csv_format.null_value = Some("NIL".to_string()); - let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new_with_options( + csv_format.clone(), + ))); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -478,6 +480,27 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { assert_eq!("test.csv", copy_to.output_url); assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); + + let file_type = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + let format_factory = file_type.get_format_factory(); + let csv_factory = format_factory + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let csv_config = csv_factory.options.as_ref().unwrap(); + assert_eq!(csv_format.delimiter, csv_config.delimiter); + assert_eq!(csv_format.date_format, csv_config.date_format); + assert_eq!(csv_format.datetime_format, csv_config.datetime_format); + assert_eq!(csv_format.timestamp_format, csv_config.timestamp_format); + assert_eq!(csv_format.time_format, csv_config.time_format); + assert_eq!(csv_format.null_value, csv_config.null_value) } _ => panic!(), } From 0008fd725eb20460c38adec84bae901b3cb26d09 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Sun, 21 Jul 2024 16:27:44 -0700 Subject: [PATCH 2/4] fix check --- datafusion-examples/examples/custom_file_format.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index c3f59e0409f4..8612a1cc4430 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -131,7 +131,7 @@ impl FileFormat for TSVFileFormat { } } -#[derive(Default)] +#[derive(Default, Debug)] /// Factory for creating TSV file formats /// /// This factory is a wrapper around the CSV file format factory From 4b65293d4578501dcdfb11b67f34337dfdf2aa12 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Mon, 22 Jul 2024 16:41:54 -0700 Subject: [PATCH 3/4] optimize code --- datafusion/core/src/datasource/file_format/mod.rs | 6 +++--- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index da2b0b56fc7e..500f20af474f 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -155,9 +155,9 @@ impl DefaultFileType { } } - /// get a [FileFormatFactory] struct - pub fn get_format_factory(&self) -> Arc { - self.file_format_factory.clone() + /// get a reference to the inner [FileFormatFactory] struct + pub fn as_format_factory(&self) -> &Arc { + &self.file_format_factory } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 42a8fc380713..f722d4cd86f8 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -488,7 +488,7 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { .downcast_ref::() .unwrap(); - let format_factory = file_type.get_format_factory(); + let format_factory = file_type.as_format_factory(); let csv_factory = format_factory .as_ref() .as_any() From 76e9676585266a89e795fa7c08a7401667bbb59d Mon Sep 17 00:00:00 2001 From: Lordworms Date: Mon, 22 Jul 2024 16:55:18 -0700 Subject: [PATCH 4/4] optimize code --- datafusion/proto/src/logical_plan/file_formats.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 69028a6a9592..2c4085b88869 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -59,6 +59,9 @@ impl CsvOptionsProto { time_format: options.time_format.clone().unwrap_or_default(), null_value: options.null_value.clone().unwrap_or_default(), comment: options.comment.map_or(vec![], |v| vec![v]), + newlines_in_values: options + .newlines_in_values + .map_or(vec![], |v| vec![v as u8]), } } else { CsvOptionsProto::default() @@ -129,6 +132,11 @@ impl From<&CsvOptionsProto> for CsvOptions { } else { None }, + newlines_in_values: if proto.newlines_in_values.is_empty() { + None + } else { + Some(proto.newlines_in_values[0] != 0) + }, } } }