diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 81481a6ef5e6..c24c23f4f54f 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -24,7 +24,7 @@ repository = "https://github.com/apache/arrow-datafusion" readme = "README.md" authors = ["Apache Arrow "] license = "Apache-2.0" -keywords = [ "arrow", "query", "sql" ] +keywords = ["arrow", "query", "sql"] edition = "2021" rust-version = "1.58" @@ -33,6 +33,8 @@ name = "datafusion_proto" path = "src/lib.rs" [features] +default = [] +serde = ["pbjson", "pbjson-build", "dep:serde"] [dependencies] arrow = { version = "18.0.0" } @@ -40,11 +42,15 @@ datafusion = { path = "../core", version = "10.0.0" } datafusion-common = { path = "../common", version = "10.0.0" } datafusion-expr = { path = "../expr", version = "10.0.0" } prost = "0.10" - +serde = { version = "1.0", optional = true } +pbjson = { version = "0.3", optional = true } +pbjson-types = { version = "0.3", optional = true } [dev-dependencies] doc-comment = "0.3" tokio = "1.18" +serde_json = "1.0" [build-dependencies] tonic-build = { version = "0.7" } +pbjson-build = { version = "0.3", optional = true } diff --git a/datafusion/proto/build.rs b/datafusion/proto/build.rs index 414593eeef4a..1e613f2ef9c0 100644 --- a/datafusion/proto/build.rs +++ b/datafusion/proto/build.rs @@ -15,11 +15,43 @@ // specific language governing permissions and limitations // under the License. +type Error = Box; +type Result = std::result::Result; + fn main() -> Result<(), String> { // for use in docker build where file changes can be wonky println!("cargo:rerun-if-env-changed=FORCE_REBUILD"); - println!("cargo:rerun-if-changed=proto/datafusion.proto"); + + build()?; + + Ok(()) +} + +#[cfg(feature = "serde")] +fn build() -> Result<(), String> { + let descriptor_path = std::path::PathBuf::from(std::env::var("OUT_DIR").unwrap()) + .join("proto_descriptor.bin"); + + tonic_build::configure() + .file_descriptor_set_path(&descriptor_path) + .compile_well_known_types(true) + .extern_path(".google.protobuf", "::pbjson_types") + .compile(&["proto/datafusion.proto"], &["proto"]) + .map_err(|e| format!("protobuf compilation failed: {}", e))?; + + let descriptor_set = std::fs::read(descriptor_path).unwrap(); + pbjson_build::Builder::new() + .register_descriptors(&descriptor_set) + .unwrap() + .build(&[".datafusion"]) + .map_err(|e| format!("pbjson compilation failed: {}", e))?; + + Ok(()) +} + +#[cfg(not(feature = "serde"))] +fn build() -> Result<(), String> { tonic_build::configure() .compile(&["proto/datafusion.proto"], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {}", e)) diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 0aa00bc75051..b50dab611080 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -23,6 +23,9 @@ use datafusion_common::DataFusionError; #[allow(clippy::all)] pub mod protobuf { include!(concat!(env!("OUT_DIR"), "/datafusion.rs")); + + #[cfg(feature = "serde")] + include!(concat!(env!("OUT_DIR"), "/datafusion.serde.rs")); } pub mod bytes; @@ -75,19 +78,32 @@ mod roundtrip_tests { use std::fmt::Formatter; use std::sync::Arc; + #[cfg(feature = "serde")] + fn roundtrip_serde_test(proto: &protobuf::LogicalExprNode) { + let string = serde_json::to_string(proto).unwrap(); + let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap(); + assert_eq!(proto, &back); + } + + #[cfg(not(feature = "serde"))] + fn roundtrip_serde_test(_proto: &protobuf::LogicalExprNode) {} + // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test // equality. - macro_rules! roundtrip_expr_test { - ($initial_struct:ident, $ctx:ident) => { - let proto: protobuf::LogicalExprNode = (&$initial_struct).try_into().unwrap(); + fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) + where + for<'a> &'a T: TryInto + Debug, + E: Debug, + { + let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); + let round_trip: Expr = parse_expr(&proto, &ctx).unwrap(); - let round_trip: Expr = parse_expr(&proto, &$ctx).unwrap(); + assert_eq!( + format!("{:?}", &initial_struct), + format!("{:?}", round_trip) + ); - assert_eq!( - format!("{:?}", $initial_struct), - format!("{:?}", round_trip) - ); - }; + roundtrip_serde_test(&proto); } fn new_box_field(name: &str, dt: DataType, nullable: bool) -> Box { @@ -807,7 +823,7 @@ mod roundtrip_tests { let test_expr = Expr::Not(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -815,7 +831,7 @@ mod roundtrip_tests { let test_expr = Expr::IsNull(Box::new(col("id"))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -823,7 +839,7 @@ mod roundtrip_tests { let test_expr = Expr::IsNotNull(Box::new(col("id"))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -836,7 +852,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -848,7 +864,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -859,7 +875,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -871,7 +887,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -879,7 +895,7 @@ mod roundtrip_tests { let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -891,7 +907,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -899,7 +915,7 @@ mod roundtrip_tests { let test_expr = Expr::Wildcard; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -909,7 +925,7 @@ mod roundtrip_tests { args: vec![col("col")], }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -921,7 +937,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -975,7 +991,7 @@ mod roundtrip_tests { let mut ctx = SessionContext::new(); ctx.register_udaf(dummy_agg); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1000,7 +1016,7 @@ mod roundtrip_tests { let mut ctx = SessionContext::new(); ctx.register_udf(udf); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1012,7 +1028,7 @@ mod roundtrip_tests { ])); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1020,7 +1036,7 @@ mod roundtrip_tests { let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1028,6 +1044,6 @@ mod roundtrip_tests { let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } }