Skip to content

Commit

Permalink
Implement physical plan serialization for json Copy plans
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Jul 25, 2024
1 parent 20b298e commit 5ffda27
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 13 deletions.
3 changes: 2 additions & 1 deletion datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ use object_store::{GetResultPayload, ObjectMeta, ObjectStore};
#[derive(Default)]
/// Factory struct used to create [JsonFormat]
pub struct JsonFormatFactory {
options: Option<JsonOptions>,
/// the options carried by format factory
pub options: Option<JsonOptions>,
}

impl JsonFormatFactory {
Expand Down
5 changes: 5 additions & 0 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ message ParquetFormat {

message AvroFormat {}

message NdJsonFormat {
JsonOptions options = 1;
}


message PrimaryKeyConstraint{
repeated uint64 indices = 1;
}
Expand Down
91 changes: 91 additions & 0 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4642,6 +4642,97 @@ impl<'de> serde::Deserialize<'de> for Map {
deserializer.deserialize_struct("datafusion_common.Map", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for NdJsonFormat {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut len = 0;
if self.options.is_some() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.NdJsonFormat", len)?;
if let Some(v) = self.options.as_ref() {
struct_ser.serialize_field("options", v)?;
}
struct_ser.end()
}
}
impl<'de> serde::Deserialize<'de> for NdJsonFormat {
#[allow(deprecated)]
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
"options",
];

#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Options,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GeneratedVisitor;

impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = GeneratedField;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "expected one of: {:?}", &FIELDS)
}

#[allow(unused_variables)]
fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E>
where
E: serde::de::Error,
{
match value {
"options" => Ok(GeneratedField::Options),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(GeneratedVisitor)
}
}
struct GeneratedVisitor;
impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = NdJsonFormat;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("struct datafusion_common.NdJsonFormat")
}

fn visit_map<V>(self, mut map_: V) -> std::result::Result<NdJsonFormat, V::Error>
where
V: serde::de::MapAccess<'de>,
{
let mut options__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Options => {
if options__.is_some() {
return Err(serde::de::Error::duplicate_field("options"));
}
options__ = map_.next_value()?;
}
}
}
Ok(NdJsonFormat {
options: options__,
})
}
}
deserializer.deserialize_struct("datafusion_common.NdJsonFormat", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for ParquetFormat {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
Expand Down
6 changes: 6 additions & 0 deletions datafusion/proto-common/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ pub struct ParquetFormat {
pub struct AvroFormat {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct NdJsonFormat {
#[prost(message, optional, tag = "1")]
pub options: ::core::option::Option<JsonOptions>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PrimaryKeyConstraint {
#[prost(uint64, repeated, tag = "1")]
pub indices: ::prost::alloc::vec::Vec<u64>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ message ListingTableScanNode {
datafusion_common.CsvFormat csv = 10;
datafusion_common.ParquetFormat parquet = 11;
datafusion_common.AvroFormat avro = 12;
datafusion_common.NdJsonFormat json = 15;
}
repeated LogicalExprNodeCollection file_sort_order = 13;
}
Expand Down
6 changes: 6 additions & 0 deletions datafusion/proto/src/generated/datafusion_proto_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ pub struct ParquetFormat {
pub struct AvroFormat {}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct NdJsonFormat {
#[prost(message, optional, tag = "1")]
pub options: ::core::option::Option<JsonOptions>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PrimaryKeyConstraint {
#[prost(uint64, repeated, tag = "1")]
pub indices: ::prost::alloc::vec::Vec<u64>,
Expand Down
13 changes: 13 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion/proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ pub mod protobuf {
pub use datafusion_proto_common::common::proto_error;
pub use datafusion_proto_common::protobuf_common::{
ArrowOptions, ArrowType, AvroFormat, AvroOptions, CsvFormat, DfSchema,
EmptyMessage, Field, JoinSide, ParquetFormat, ScalarValue, Schema,
EmptyMessage, Field, JoinSide, NdJsonFormat, ParquetFormat, ScalarValue, Schema,
};
pub use datafusion_proto_common::{FromProtoError, ToProtoError};
}
Expand Down
69 changes: 62 additions & 7 deletions datafusion/proto/src/logical_plan/file_formats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::sync::Arc;

use datafusion::{
config::CsvOptions,
config::{CsvOptions, JsonOptions},
datasource::file_format::{
arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory,
parquet::ParquetFormatFactory, FileFormatFactory,
Expand All @@ -31,7 +31,7 @@ use datafusion_common::{
};
use prost::Message;

use crate::protobuf::CsvOptions as CsvOptionsProto;
use crate::protobuf::{CsvOptions as CsvOptionsProto, JsonOptions as JsonOptionsProto};

use super::LogicalExtensionCodec;

Expand Down Expand Up @@ -222,6 +222,34 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
}
}

impl JsonOptionsProto {
fn from_factory(factory: &JsonFormatFactory) -> Self {
if let Some(options) = &factory.options {
JsonOptionsProto {
compression: options.compression as i32,
schema_infer_max_rec: options.schema_infer_max_rec as u64,
}
} else {
JsonOptionsProto::default()
}
}
}

impl From<&JsonOptionsProto> for JsonOptions {
fn from(proto: &JsonOptionsProto) -> Self {
JsonOptions {
compression: match proto.compression {
0 => CompressionTypeVariant::GZIP,
1 => CompressionTypeVariant::BZIP2,
2 => CompressionTypeVariant::XZ,
3 => CompressionTypeVariant::ZSTD,
_ => CompressionTypeVariant::UNCOMPRESSED,
},
schema_infer_max_rec: proto.schema_infer_max_rec as usize,
}
}
}

#[derive(Debug)]
pub struct JsonLogicalExtensionCodec;

Expand Down Expand Up @@ -267,17 +295,44 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {

fn try_decode_file_format(
&self,
__buf: &[u8],
__ctx: &SessionContext,
buf: &[u8],
_ctx: &SessionContext,
) -> datafusion_common::Result<Arc<dyn FileFormatFactory>> {
Ok(Arc::new(JsonFormatFactory::new()))
let proto = JsonOptionsProto::decode(buf).map_err(|e| {
DataFusionError::Execution(format!(
"Failed to decode JsonOptionsProto: {:?}",
e
))
})?;
let options: JsonOptions = (&proto).into();
Ok(Arc::new(JsonFormatFactory {
options: Some(options),
}))
}

fn try_encode_file_format(
&self,
__buf: &mut Vec<u8>,
__node: Arc<dyn FileFormatFactory>,
buf: &mut Vec<u8>,
node: Arc<dyn FileFormatFactory>,
) -> datafusion_common::Result<()> {
let options = if let Some(json_factory) =
node.as_any().downcast_ref::<JsonFormatFactory>()
{
json_factory.options.clone().unwrap_or_default()
} else {
return Err(DataFusionError::Execution(
"Unsupported FileFormatFactory type".to_string(),
));
};

let proto = JsonOptionsProto::from_factory(&JsonFormatFactory {
options: Some(options),
});

proto.encode(buf).map_err(|e| {
DataFusionError::Execution(format!("Failed to encode JsonOptions: {:?}", e))
})?;

Ok(())
}
}
Expand Down
25 changes: 23 additions & 2 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ use datafusion::datasource::file_format::{
};
use datafusion::{
datasource::{
file_format::{avro::AvroFormat, csv::CsvFormat, FileFormat},
file_format::{
avro::AvroFormat, csv::CsvFormat, json::JsonFormat as OtherNdJsonFormat,
FileFormat,
},
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
view::ViewTable,
TableProvider,
Expand Down Expand Up @@ -395,7 +398,17 @@ impl AsLogicalPlan for LogicalPlanNode {
if let Some(options) = options {
csv = csv.with_options(options.try_into()?)
}
Arc::new(csv)},
Arc::new(csv)
},
FileFormatType::Json(protobuf::NdJsonFormat {
options
}) => {
let mut json = OtherNdJsonFormat::default();
if let Some(options) = options {
json = json.with_options(options.try_into()?)
}
Arc::new(json)
}
FileFormatType::Avro(..) => Arc::new(AvroFormat),
};

Expand Down Expand Up @@ -996,6 +1009,14 @@ impl AsLogicalPlan for LogicalPlanNode {
}));
}

if let Some(json) = any.downcast_ref::<OtherNdJsonFormat>() {
let options = json.options();
maybe_some_type =
Some(FileFormatType::Json(protobuf::NdJsonFormat {
options: Some(options.try_into()?),
}))
}

if any.is::<AvroFormat>() {
maybe_some_type =
Some(FileFormatType::Avro(protobuf::AvroFormat {}))
Expand Down
Loading

0 comments on commit 5ffda27

Please sign in to comment.