diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b534ed224001..cd24c7f081de 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -77,7 +77,7 @@ jobs: CARGO_TARGET_DIR: "/github/home/target" - name: Check Workspace builds with all features run: | - cargo check --workspace --benches --features avro,jit,scheduler + cargo check --workspace --benches --features avro,jit,scheduler,json env: CARGO_HOME: "/github/home/.cargo" CARGO_TARGET_DIR: "/github/home/target" @@ -121,7 +121,7 @@ jobs: run: | export ARROW_TEST_DATA=$(pwd)/testing/data export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data - cargo test --features avro,jit,scheduler + cargo test --features avro,jit,scheduler,json # test datafusion-sql examples cargo run --example sql # test datafusion examples diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 81481a6ef5e6..d3f70b2c071a 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -24,7 +24,7 @@ repository = "https://github.com/apache/arrow-datafusion" readme = "README.md" authors = ["Apache Arrow "] license = "Apache-2.0" -keywords = [ "arrow", "query", "sql" ] +keywords = ["arrow", "query", "sql"] edition = "2021" rust-version = "1.58" @@ -33,18 +33,24 @@ name = "datafusion_proto" path = "src/lib.rs" [features] +default = [] +json = ["pbjson", "pbjson-build", "serde", "serde_json"] [dependencies] arrow = { version = "18.0.0" } datafusion = { path = "../core", version = "10.0.0" } datafusion-common = { path = "../common", version = "10.0.0" } datafusion-expr = { path = "../expr", version = "10.0.0" } +pbjson = { version = "0.3", optional = true } +pbjson-types = { version = "0.3", optional = true } prost = "0.10" - +serde = { version = "1.0", optional = true } +serde_json = { version = "1.0", optional = true } [dev-dependencies] doc-comment = "0.3" tokio = "1.18" [build-dependencies] -tonic-build = { version = "0.7" } +pbjson-build = { version = "0.3", optional = true } +prost-build = { version = "0.7" } diff --git a/datafusion/proto/build.rs b/datafusion/proto/build.rs index 414593eeef4a..e13ffa86a53c 100644 --- a/datafusion/proto/build.rs +++ b/datafusion/proto/build.rs @@ -15,12 +15,44 @@ // specific language governing permissions and limitations // under the License. +type Error = Box; +type Result = std::result::Result; + fn main() -> Result<(), String> { // for use in docker build where file changes can be wonky println!("cargo:rerun-if-env-changed=FORCE_REBUILD"); - println!("cargo:rerun-if-changed=proto/datafusion.proto"); - tonic_build::configure() - .compile(&["proto/datafusion.proto"], &["proto"]) + + build()?; + + Ok(()) +} + +#[cfg(feature = "json")] +fn build() -> Result<(), String> { + let descriptor_path = std::path::PathBuf::from(std::env::var("OUT_DIR").unwrap()) + .join("proto_descriptor.bin"); + + prost_build::Config::new() + .file_descriptor_set_path(&descriptor_path) + .compile_well_known_types() + .extern_path(".google.protobuf", "::pbjson_types") + .compile_protos(&["proto/datafusion.proto"], &["proto"]) + .map_err(|e| format!("protobuf compilation failed: {}", e))?; + + let descriptor_set = std::fs::read(descriptor_path).unwrap(); + pbjson_build::Builder::new() + .register_descriptors(&descriptor_set) + .unwrap() + .build(&[".datafusion"]) + .map_err(|e| format!("pbjson compilation failed: {}", e))?; + + Ok(()) +} + +#[cfg(not(feature = "json"))] +fn build() -> Result<(), String> { + prost_build::Config::new() + .compile_protos(&["proto/datafusion.proto"], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {}", e)) } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 37374b3eff5a..6ddaecae098f 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -107,6 +107,19 @@ pub fn logical_plan_to_bytes(plan: &LogicalPlan) -> Result { logical_plan_to_bytes_with_extension_codec(plan, &extension_codec) } +/// Serialize a LogicalPlan as json +#[cfg(feature = "json")] +pub fn logical_plan_to_json(plan: &LogicalPlan) -> Result { + let extension_codec = DefaultExtensionCodec {}; + let protobuf = + protobuf::LogicalPlanNode::try_from_logical_plan(plan, &extension_codec) + .map_err(|e| { + DataFusionError::Plan(format!("Error serializing plan: {}", e)) + })?; + serde_json::to_string(&protobuf) + .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {}", e))) +} + /// Serialize a LogicalPlan as bytes, using the provided extension codec pub fn logical_plan_to_bytes_with_extension_codec( plan: &LogicalPlan, @@ -121,6 +134,14 @@ pub fn logical_plan_to_bytes_with_extension_codec( Ok(buffer.into()) } +/// Deserialize a LogicalPlan from json +#[cfg(feature = "json")] +pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result { + let back: protobuf::LogicalPlanNode = serde_json::from_str(json).unwrap(); + let extension_codec = DefaultExtensionCodec {}; + back.try_into_logical_plan(ctx, &extension_codec) +} + /// Deserialize a LogicalPlan from bytes pub fn logical_plan_from_bytes( bytes: &[u8], @@ -183,6 +204,31 @@ mod test { Expr::from_bytes(b"Leet").unwrap(); } + #[test] + #[cfg(feature = "json")] + fn plan_to_json() { + use datafusion_common::DFSchema; + use datafusion_expr::logical_plan::EmptyRelation; + + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }); + let actual = logical_plan_to_json(&plan).unwrap(); + let expected = r#"{"emptyRelation":{}}"#.to_string(); + assert_eq!(actual, expected); + } + + #[test] + #[cfg(feature = "json")] + fn json_to_plan() { + let input = r#"{"emptyRelation":{}}"#.to_string(); + let ctx = SessionContext::new(); + let actual = logical_plan_from_json(&input, &ctx).unwrap(); + let result = matches!(actual, LogicalPlan::EmptyRelation(_)); + assert!(result, "Should parse empty relation"); + } + #[test] fn udf_roundtrip_with_registry() { let ctx = context_with_udf(); diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 0aa00bc75051..4683abd7ac51 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -23,6 +23,9 @@ use datafusion_common::DataFusionError; #[allow(clippy::all)] pub mod protobuf { include!(concat!(env!("OUT_DIR"), "/datafusion.rs")); + + #[cfg(feature = "json")] + include!(concat!(env!("OUT_DIR"), "/datafusion.serde.rs")); } pub mod bytes; @@ -75,19 +78,32 @@ mod roundtrip_tests { use std::fmt::Formatter; use std::sync::Arc; + #[cfg(feature = "json")] + fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { + let string = serde_json::to_string(proto).unwrap(); + let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap(); + assert_eq!(proto, &back); + } + + #[cfg(not(feature = "json"))] + fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} + // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test // equality. - macro_rules! roundtrip_expr_test { - ($initial_struct:ident, $ctx:ident) => { - let proto: protobuf::LogicalExprNode = (&$initial_struct).try_into().unwrap(); + fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) + where + for<'a> &'a T: TryInto + Debug, + E: Debug, + { + let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); + let round_trip: Expr = parse_expr(&proto, &ctx).unwrap(); - let round_trip: Expr = parse_expr(&proto, &$ctx).unwrap(); + assert_eq!( + format!("{:?}", &initial_struct), + format!("{:?}", round_trip) + ); - assert_eq!( - format!("{:?}", $initial_struct), - format!("{:?}", round_trip) - ); - }; + roundtrip_json_test(&proto); } fn new_box_field(name: &str, dt: DataType, nullable: bool) -> Box { @@ -807,7 +823,7 @@ mod roundtrip_tests { let test_expr = Expr::Not(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -815,7 +831,7 @@ mod roundtrip_tests { let test_expr = Expr::IsNull(Box::new(col("id"))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -823,7 +839,7 @@ mod roundtrip_tests { let test_expr = Expr::IsNotNull(Box::new(col("id"))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -836,7 +852,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -848,7 +864,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -859,7 +875,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -871,7 +887,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -879,7 +895,7 @@ mod roundtrip_tests { let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -891,7 +907,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -899,7 +915,7 @@ mod roundtrip_tests { let test_expr = Expr::Wildcard; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -909,7 +925,7 @@ mod roundtrip_tests { args: vec![col("col")], }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -921,7 +937,7 @@ mod roundtrip_tests { }; let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -975,7 +991,7 @@ mod roundtrip_tests { let mut ctx = SessionContext::new(); ctx.register_udaf(dummy_agg); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1000,7 +1016,7 @@ mod roundtrip_tests { let mut ctx = SessionContext::new(); ctx.register_udf(udf); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1012,7 +1028,7 @@ mod roundtrip_tests { ])); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1020,7 +1036,7 @@ mod roundtrip_tests { let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } #[test] @@ -1028,6 +1044,6 @@ mod roundtrip_tests { let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); let ctx = SessionContext::new(); - roundtrip_expr_test!(test_expr, ctx); + roundtrip_expr_test(test_expr, ctx); } }