Skip to content

Commit

Permalink
Add optional serde support to datafusion-proto (apache#2889)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jul 13, 2022
1 parent eed77a2 commit f13ee06
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 29 deletions.
10 changes: 8 additions & 2 deletions datafusion/proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repository = "https://github.com/apache/arrow-datafusion"
readme = "README.md"
authors = ["Apache Arrow <dev@arrow.apache.org>"]
license = "Apache-2.0"
keywords = [ "arrow", "query", "sql" ]
keywords = ["arrow", "query", "sql"]
edition = "2021"
rust-version = "1.58"

Expand All @@ -33,18 +33,24 @@ name = "datafusion_proto"
path = "src/lib.rs"

[features]
default = []
serde = ["pbjson", "pbjson-build", "dep:serde"]

[dependencies]
arrow = { version = "18.0.0" }
datafusion = { path = "../core", version = "10.0.0" }
datafusion-common = { path = "../common", version = "10.0.0" }
datafusion-expr = { path = "../expr", version = "10.0.0" }
prost = "0.10"

serde = { version = "1.0", optional = true }
pbjson = { version = "0.3", optional = true }
pbjson-types = { version = "0.3", optional = true }

[dev-dependencies]
doc-comment = "0.3"
tokio = "1.18"
serde_json = "1.0"

[build-dependencies]
tonic-build = { version = "0.7" }
pbjson-build = { version = "0.3", optional = true }
34 changes: 33 additions & 1 deletion datafusion/proto/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,43 @@
// specific language governing permissions and limitations
// under the License.

type Error = Box<dyn std::error::Error>;
type Result<T, E = Error> = std::result::Result<T, E>;

fn main() -> Result<(), String> {
// for use in docker build where file changes can be wonky
println!("cargo:rerun-if-env-changed=FORCE_REBUILD");

println!("cargo:rerun-if-changed=proto/datafusion.proto");

build()?;

Ok(())
}

#[cfg(feature = "serde")]
fn build() -> Result<(), String> {
let descriptor_path = std::path::PathBuf::from(std::env::var("OUT_DIR").unwrap())
.join("proto_descriptor.bin");

tonic_build::configure()
.file_descriptor_set_path(&descriptor_path)
.compile_well_known_types(true)
.extern_path(".google.protobuf", "::pbjson_types")
.compile(&["proto/datafusion.proto"], &["proto"])
.map_err(|e| format!("protobuf compilation failed: {}", e))?;

let descriptor_set = std::fs::read(descriptor_path).unwrap();
pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)
.unwrap()
.build(&[".datafusion"])
.map_err(|e| format!("pbjson compilation failed: {}", e))?;

Ok(())
}

#[cfg(not(feature = "serde"))]
fn build() -> Result<(), String> {
tonic_build::configure()
.compile(&["proto/datafusion.proto"], &["proto"])
.map_err(|e| format!("protobuf compilation failed: {}", e))
Expand Down
68 changes: 42 additions & 26 deletions datafusion/proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ use datafusion_common::DataFusionError;
#[allow(clippy::all)]
pub mod protobuf {
include!(concat!(env!("OUT_DIR"), "/datafusion.rs"));

#[cfg(feature = "serde")]
include!(concat!(env!("OUT_DIR"), "/datafusion.serde.rs"));
}

pub mod bytes;
Expand Down Expand Up @@ -75,19 +78,32 @@ mod roundtrip_tests {
use std::fmt::Formatter;
use std::sync::Arc;

#[cfg(feature = "serde")]
fn roundtrip_serde_test(proto: &protobuf::LogicalExprNode) {
let string = serde_json::to_string(proto).unwrap();
let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap();
assert_eq!(proto, &back);
}

#[cfg(not(feature = "serde"))]
fn roundtrip_serde_test(_proto: &protobuf::LogicalExprNode) {}

// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test
// equality.
macro_rules! roundtrip_expr_test {
($initial_struct:ident, $ctx:ident) => {
let proto: protobuf::LogicalExprNode = (&$initial_struct).try_into().unwrap();
fn roundtrip_expr_test<T, E>(initial_struct: T, ctx: SessionContext)
where
for<'a> &'a T: TryInto<protobuf::LogicalExprNode, Error = E> + Debug,
E: Debug,
{
let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap();
let round_trip: Expr = parse_expr(&proto, &ctx).unwrap();

let round_trip: Expr = parse_expr(&proto, &$ctx).unwrap();
assert_eq!(
format!("{:?}", &initial_struct),
format!("{:?}", round_trip)
);

assert_eq!(
format!("{:?}", $initial_struct),
format!("{:?}", round_trip)
);
};
roundtrip_serde_test(&proto);
}

fn new_box_field(name: &str, dt: DataType, nullable: bool) -> Box<Field> {
Expand Down Expand Up @@ -807,23 +823,23 @@ mod roundtrip_tests {
let test_expr = Expr::Not(Box::new(lit(1.0_f32)));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_is_null() {
let test_expr = Expr::IsNull(Box::new(col("id")));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_is_not_null() {
let test_expr = Expr::IsNotNull(Box::new(col("id")));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -836,7 +852,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -848,7 +864,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -859,7 +875,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -871,15 +887,15 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_negative() {
let test_expr = Expr::Negative(Box::new(lit(1.0_f32)));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -891,15 +907,15 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_wildcard() {
let test_expr = Expr::Wildcard;

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -909,7 +925,7 @@ mod roundtrip_tests {
args: vec![col("col")],
};
let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -921,7 +937,7 @@ mod roundtrip_tests {
};

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand Down Expand Up @@ -975,7 +991,7 @@ mod roundtrip_tests {
let mut ctx = SessionContext::new();
ctx.register_udaf(dummy_agg);

roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -1000,7 +1016,7 @@ mod roundtrip_tests {
let mut ctx = SessionContext::new();
ctx.register_udf(udf);

roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
Expand All @@ -1012,22 +1028,22 @@ mod roundtrip_tests {
]));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_rollup() {
let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}

#[test]
fn roundtrip_cube() {
let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));

let ctx = SessionContext::new();
roundtrip_expr_test!(test_expr, ctx);
roundtrip_expr_test(test_expr, ctx);
}
}

0 comments on commit f13ee06

Please sign in to comment.