From e3f212cc114743dbdc5852ddd777b19eaba3233e Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Tue, 4 Apr 2023 05:10:50 +1000 Subject: [PATCH] feat: Add Commands enum to decode prost messages to strong type (#3887) * feat: Add Commands enum to decode known messages to strong type * chore: paste needs to be a dependency * chore: rustfmt * Add docs and use Commands * chore: Rename to `Command`; impl TryFrom * chore: Add `into_any` and `type_url` API * Tweak documentation * fixup * clippy * feat: Add `Command::Unknown(Any)` variant * Updated `do_get` and `do_put` functions to use `Command` enum * Added test for Unknown variant * chore: placate clippy * chore: combine errors * chore: don't change error code --------- Co-authored-by: Andrew Lamb --- arrow-flight/Cargo.toml | 1 + arrow-flight/src/sql/mod.rs | 132 +++++++++++++++-- arrow-flight/src/sql/server.rs | 262 +++++++++++++-------------------- 3 files changed, 224 insertions(+), 171 deletions(-) diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 732c24572856..e22642b2a727 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -36,6 +36,7 @@ arrow-schema = { workspace = true } base64 = { version = "0.21", default-features = false, features = ["std"] } tonic = { version = "0.9", default-features = false, features = ["transport", "codegen", "prost"] } bytes = { version = "1", default-features = false } +paste = { version = "1.0" } prost = { version = "0.11", default-features = false, features = ["prost-derive"] } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } futures = { version = "0.3", default-features = false, features = ["alloc"] } diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 9ea74c3f35bb..2c26f2bf69b6 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -17,6 +17,7 @@ use arrow_schema::ArrowError; use bytes::Bytes; +use paste::paste; use prost::Message; mod gen { @@ -71,22 +72,110 @@ pub trait ProstMessageExt: prost::Message + Default { fn as_any(&self) -> Any; } +/// Macro to coerce a token to an item, specifically +/// to build the `Commands` enum. +/// +/// See: +macro_rules! as_item { + ($i:item) => { + $i + }; +} + macro_rules! prost_message_ext { - ($($name:ty,)*) => { - $( - impl ProstMessageExt for $name { - fn type_url() -> &'static str { - concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name)) + ($($name:tt,)*) => { + paste! { + $( + const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name)); + )* + + as_item! { + /// Helper to convert to/from protobuf [`Any`] + /// to a strongly typed enum. + /// + /// # Example + /// ```rust + /// # use arrow_flight::sql::{Any, CommandStatementQuery, Command}; + /// let flightsql_message = CommandStatementQuery { + /// query: "SELECT * FROM foo".to_string(), + /// }; + /// + /// // Given a packed FlightSQL Any message + /// let any_message = Any::pack(&flightsql_message).unwrap(); + /// + /// // decode it to Command: + /// match Command::try_from(any_message).unwrap() { + /// Command::CommandStatementQuery(decoded) => { + /// assert_eq!(flightsql_message, decoded); + /// } + /// _ => panic!("Unexpected decoded message"), + /// } + /// ``` + #[derive(Clone, Debug, PartialEq)] + pub enum Command { + $($name($name),)* + + /// Any message that is not any FlightSQL command. + Unknown(Any), } + } - fn as_any(&self) -> Any { - Any { - type_url: <$name>::type_url().to_string(), - value: self.encode_to_vec().into(), + impl Command { + /// Convert the command to [`Any`]. + pub fn into_any(self) -> Any { + match self { + $( + Self::$name(cmd) => cmd.as_any(), + )* + Self::Unknown(any) => any, + } + } + + /// Get the URL for the command. + pub fn type_url(&self) -> &str { + match self { + $( + Self::$name(_) => [<$name:snake:upper _TYPE_URL>], + )* + Self::Unknown(any) => any.type_url.as_str(), + } + } + } + + impl TryFrom for Command { + type Error = ArrowError; + + fn try_from(any: Any) -> Result { + match any.type_url.as_str() { + $( + [<$name:snake:upper _TYPE_URL>] + => { + let m: $name = Message::decode(&*any.value).map_err(|err| { + ArrowError::ParseError(format!("Unable to decode Any value: {err}")) + })?; + Ok(Self::$name(m)) + } + )* + _ => Ok(Self::Unknown(any)), } } } - )* + + $( + impl ProstMessageExt for $name { + fn type_url() -> &'static str { + [<$name:snake:upper _TYPE_URL>] + } + + fn as_any(&self) -> Any { + Any { + type_url: <$name>::type_url().to_string(), + value: self.encode_to_vec().into(), + } + } + } + )* + } }; } @@ -190,4 +279,27 @@ mod tests { let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap(); assert_eq!(query, unpack_query); } + + #[test] + fn test_command() { + let query = CommandStatementQuery { + query: "select 1".to_string(), + }; + let any = Any::pack(&query).unwrap(); + let cmd: Command = any.try_into().unwrap(); + + assert!(matches!(cmd, Command::CommandStatementQuery(_))); + assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL); + + // Unknown variant + + let any = Any { + type_url: "fake_url".to_string(), + value: Default::default(), + }; + + let cmd: Command = any.try_into().unwrap(); + assert!(matches!(cmd, Command::Unknown(_))); + assert_eq!(cmd.type_url(), "fake_url"); + } } diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 848bfb3852f5..b11fa3e3c3db 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -17,7 +17,7 @@ use std::pin::Pin; -use crate::sql::Any; +use crate::sql::{Any, Command}; use futures::Stream; use prost::Message; use tonic::{Request, Response, Status, Streaming}; @@ -315,90 +315,46 @@ where let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_statement(token, request).await; - } - if message.is::() { - let handle = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self - .get_flight_info_prepared_statement(handle, request) - .await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_catalogs(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_schemas(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_tables(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_table_types(token, request).await; + match Command::try_from(message).map_err(arrow_error_to_status)? { + Command::CommandStatementQuery(token) => { + self.get_flight_info_statement(token, request).await + } + Command::CommandPreparedStatementQuery(handle) => { + self.get_flight_info_prepared_statement(handle, request) + .await + } + Command::CommandGetCatalogs(token) => { + self.get_flight_info_catalogs(token, request).await + } + Command::CommandGetDbSchemas(token) => { + return self.get_flight_info_schemas(token, request).await + } + Command::CommandGetTables(token) => { + self.get_flight_info_tables(token, request).await + } + Command::CommandGetTableTypes(token) => { + self.get_flight_info_table_types(token, request).await + } + Command::CommandGetSqlInfo(token) => { + self.get_flight_info_sql_info(token, request).await + } + Command::CommandGetPrimaryKeys(token) => { + self.get_flight_info_primary_keys(token, request).await + } + Command::CommandGetExportedKeys(token) => { + self.get_flight_info_exported_keys(token, request).await + } + Command::CommandGetImportedKeys(token) => { + self.get_flight_info_imported_keys(token, request).await + } + Command::CommandGetCrossReference(token) => { + self.get_flight_info_cross_reference(token, request).await + } + cmd => Err(Status::unimplemented(format!( + "get_flight_info: The defined request is invalid: {}", + cmd.type_url() + ))), } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_sql_info(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_primary_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_exported_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_imported_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_cross_reference(token, request).await; - } - - Err(Status::unimplemented(format!( - "get_flight_info: The defined request is invalid: {}", - message.type_url - ))) } async fn get_schema( @@ -415,47 +371,42 @@ where let msg: Any = Message::decode(&*request.get_ref().ticket) .map_err(decode_error_to_status)?; - fn unpack(msg: Any) -> Result { - msg.unpack() - .map_err(arrow_error_to_status)? - .ok_or_else(|| Status::internal("Expected a command, but found none.")) - } - - if msg.is::() { - return self.do_get_statement(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_prepared_statement(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_catalogs(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_schemas(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_tables(unpack(msg)?, request).await; + match Command::try_from(msg).map_err(arrow_error_to_status)? { + Command::TicketStatementQuery(command) => { + self.do_get_statement(command, request).await + } + Command::CommandPreparedStatementQuery(command) => { + self.do_get_prepared_statement(command, request).await + } + Command::CommandGetCatalogs(command) => { + self.do_get_catalogs(command, request).await + } + Command::CommandGetDbSchemas(command) => { + self.do_get_schemas(command, request).await + } + Command::CommandGetTables(command) => { + self.do_get_tables(command, request).await + } + Command::CommandGetTableTypes(command) => { + self.do_get_table_types(command, request).await + } + Command::CommandGetSqlInfo(command) => { + self.do_get_sql_info(command, request).await + } + Command::CommandGetPrimaryKeys(command) => { + self.do_get_primary_keys(command, request).await + } + Command::CommandGetExportedKeys(command) => { + self.do_get_exported_keys(command, request).await + } + Command::CommandGetImportedKeys(command) => { + self.do_get_imported_keys(command, request).await + } + Command::CommandGetCrossReference(command) => { + self.do_get_cross_reference(command, request).await + } + cmd => self.do_get_fallback(request, cmd.into_any()).await, } - if msg.is::() { - return self.do_get_table_types(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_sql_info(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_primary_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_exported_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_imported_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_cross_reference(unpack(msg)?, request).await; - } - - self.do_get_fallback(request, msg).await } async fn do_put( @@ -465,44 +416,33 @@ where let cmd = request.get_mut().message().await?.unwrap(); let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - let record_count = self.do_put_statement_update(token, request).await?; - let result = DoPutUpdateResult { record_count }; - let output = futures::stream::iter(vec![Ok(PutResult { - app_metadata: result.as_any().encode_to_vec().into(), - })]); - return Ok(Response::new(Box::pin(output))); - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_put_prepared_statement_query(token, request).await; - } - if message.is::() { - let handle = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - let record_count = self - .do_put_prepared_statement_update(handle, request) - .await?; - let result = DoPutUpdateResult { record_count }; - let output = futures::stream::iter(vec![Ok(PutResult { - app_metadata: result.as_any().encode_to_vec().into(), - })]); - return Ok(Response::new(Box::pin(output))); + match Command::try_from(message).map_err(arrow_error_to_status)? { + Command::CommandStatementUpdate(command) => { + let record_count = self.do_put_statement_update(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + Command::CommandPreparedStatementQuery(command) => { + self.do_put_prepared_statement_query(command, request).await + } + Command::CommandPreparedStatementUpdate(command) => { + let record_count = self + .do_put_prepared_statement_update(command, request) + .await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + cmd => Err(Status::invalid_argument(format!( + "do_put: The defined request is invalid: {}", + cmd.type_url() + ))), } - - Err(Status::invalid_argument(format!( - "do_put: The defined request is invalid: {}", - message.type_url - ))) } async fn list_actions(