Skip to content

Commit

Permalink
feat: Add Commands enum to decode prost messages to strong type (#3887)
Browse files Browse the repository at this point in the history
* feat: Add Commands enum to decode known messages to strong type

* chore: paste needs to be a dependency

* chore: rustfmt

* Add docs and use Commands

* chore: Rename to `Command`; impl TryFrom<Any>

* chore: Add `into_any` and `type_url` API

* Tweak documentation

* fixup

* clippy

* feat: Add `Command::Unknown(Any)` variant

* Updated `do_get` and `do_put` functions to use `Command` enum
* Added test for Unknown variant

* chore: placate clippy

* chore: combine errors

* chore: don't change error code

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
stuartcarnie and alamb authored Apr 3, 2023
1 parent 901c061 commit e3f212c
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 171 deletions.
1 change: 1 addition & 0 deletions arrow-flight/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ arrow-schema = { workspace = true }
base64 = { version = "0.21", default-features = false, features = ["std"] }
tonic = { version = "0.9", default-features = false, features = ["transport", "codegen", "prost"] }
bytes = { version = "1", default-features = false }
paste = { version = "1.0" }
prost = { version = "0.11", default-features = false, features = ["prost-derive"] }
tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] }
futures = { version = "0.3", default-features = false, features = ["alloc"] }
Expand Down
132 changes: 122 additions & 10 deletions arrow-flight/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use arrow_schema::ArrowError;
use bytes::Bytes;
use paste::paste;
use prost::Message;

mod gen {
Expand Down Expand Up @@ -71,22 +72,110 @@ pub trait ProstMessageExt: prost::Message + Default {
fn as_any(&self) -> Any;
}

/// Macro to coerce a token to an item, specifically
/// to build the `Commands` enum.
///
/// See: <https://danielkeep.github.io/tlborm/book/blk-ast-coercion.html>
macro_rules! as_item {
($i:item) => {
$i
};
}

macro_rules! prost_message_ext {
($($name:ty,)*) => {
$(
impl ProstMessageExt for $name {
fn type_url() -> &'static str {
concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name))
($($name:tt,)*) => {
paste! {
$(
const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name));
)*

as_item! {
/// Helper to convert to/from protobuf [`Any`]
/// to a strongly typed enum.
///
/// # Example
/// ```rust
/// # use arrow_flight::sql::{Any, CommandStatementQuery, Command};
/// let flightsql_message = CommandStatementQuery {
/// query: "SELECT * FROM foo".to_string(),
/// };
///
/// // Given a packed FlightSQL Any message
/// let any_message = Any::pack(&flightsql_message).unwrap();
///
/// // decode it to Command:
/// match Command::try_from(any_message).unwrap() {
/// Command::CommandStatementQuery(decoded) => {
/// assert_eq!(flightsql_message, decoded);
/// }
/// _ => panic!("Unexpected decoded message"),
/// }
/// ```
#[derive(Clone, Debug, PartialEq)]
pub enum Command {
$($name($name),)*

/// Any message that is not any FlightSQL command.
Unknown(Any),
}
}

fn as_any(&self) -> Any {
Any {
type_url: <$name>::type_url().to_string(),
value: self.encode_to_vec().into(),
impl Command {
/// Convert the command to [`Any`].
pub fn into_any(self) -> Any {
match self {
$(
Self::$name(cmd) => cmd.as_any(),
)*
Self::Unknown(any) => any,
}
}

/// Get the URL for the command.
pub fn type_url(&self) -> &str {
match self {
$(
Self::$name(_) => [<$name:snake:upper _TYPE_URL>],
)*
Self::Unknown(any) => any.type_url.as_str(),
}
}
}

impl TryFrom<Any> for Command {
type Error = ArrowError;

fn try_from(any: Any) -> Result<Self, Self::Error> {
match any.type_url.as_str() {
$(
[<$name:snake:upper _TYPE_URL>]
=> {
let m: $name = Message::decode(&*any.value).map_err(|err| {
ArrowError::ParseError(format!("Unable to decode Any value: {err}"))
})?;
Ok(Self::$name(m))
}
)*
_ => Ok(Self::Unknown(any)),
}
}
}
)*

$(
impl ProstMessageExt for $name {
fn type_url() -> &'static str {
[<$name:snake:upper _TYPE_URL>]
}

fn as_any(&self) -> Any {
Any {
type_url: <$name>::type_url().to_string(),
value: self.encode_to_vec().into(),
}
}
}
)*
}
};
}

Expand Down Expand Up @@ -190,4 +279,27 @@ mod tests {
let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
assert_eq!(query, unpack_query);
}

#[test]
fn test_command() {
let query = CommandStatementQuery {
query: "select 1".to_string(),
};
let any = Any::pack(&query).unwrap();
let cmd: Command = any.try_into().unwrap();

assert!(matches!(cmd, Command::CommandStatementQuery(_)));
assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL);

// Unknown variant

let any = Any {
type_url: "fake_url".to_string(),
value: Default::default(),
};

let cmd: Command = any.try_into().unwrap();
assert!(matches!(cmd, Command::Unknown(_)));
assert_eq!(cmd.type_url(), "fake_url");
}
}
Loading

0 comments on commit e3f212c

Please sign in to comment.