From b5f6c83f2345bde03f5c2a15eb3d4f058c7af99a Mon Sep 17 00:00:00 2001 From: suremarc <8771538+suremarc@users.noreply.github.com> Date: Thu, 7 Sep 2023 12:36:05 -0500 Subject: [PATCH 1/7] change Streaming to Peekable> --- arrow-flight/src/sql/server.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 102d97105a2e..fa04ab8b7cac 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -19,7 +19,7 @@ use std::pin::Pin; -use futures::Stream; +use futures::{stream::Peekable, Stream}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; @@ -366,7 +366,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { /// Implementors may override to handle additional calls to do_put() async fn do_put_fallback( &self, - _request: Request>, + _request: Request>>, message: Any, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented(format!( @@ -379,7 +379,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_statement_update has no default implementation", @@ -390,7 +390,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, + _request: Request>>, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented( "do_put_prepared_statement_query has no default implementation", @@ -401,7 +401,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update has no default implementation", @@ -412,7 +412,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_substrait_plan( &self, _query: CommandStatementSubstraitPlan, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan has no default implementation", @@ -688,9 +688,10 @@ where async fn do_put( &self, - mut request: Request>, + request: Request>, ) -> Result, Status> { - let cmd = request.get_mut().message().await?.unwrap(); + let mut request = request.map(futures::StreamExt::peekable); + let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?; let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; match Command::try_from(message).map_err(arrow_error_to_status)? { From e52373deef837bce15194cf4090ffafa71dab500 Mon Sep 17 00:00:00 2001 From: suremarc <8771538+suremarc@users.noreply.github.com> Date: Thu, 7 Sep 2023 12:50:27 -0500 Subject: [PATCH 2/7] add explanatory comment --- arrow-flight/src/sql/server.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index fa04ab8b7cac..9406eb220341 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -690,8 +690,15 @@ where &self, request: Request>, ) -> Result, Status> { + // See issue #4658: https://github.com/apache/arrow-rs/issues/4658 + // To dispatch to the correct `do_put` method, we cannot discard the first message, + // as it may contain the Arrow schema, which the `do_put` handler may need. + // To allow the first message to be reused by the `do_put` handler, + // we wrap this stream in a `Peekable` one, which allows us to peek at + // the first message without discarding it. let mut request = request.map(futures::StreamExt::peekable); let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?; + let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; match Command::try_from(message).map_err(arrow_error_to_status)? { From 352555f4fab06bda3a1e146c2431b75de24546c9 Mon Sep 17 00:00:00 2001 From: suremarc <8771538+suremarc@users.noreply.github.com> Date: Thu, 7 Sep 2023 15:54:02 -0500 Subject: [PATCH 3/7] working test --- arrow-flight/examples/flight_sql_server.rs | 9 +- arrow-flight/src/bin/flight_sql_client.rs | 104 +++++++++-- arrow-flight/src/sql/client.rs | 54 +++++- arrow-flight/tests/flight_sql_client_cli.rs | 185 +++++++++++++++----- 4 files changed, 289 insertions(+), 63 deletions(-) diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 1e99957390d8..4c01146976c2 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -17,6 +17,7 @@ use base64::prelude::BASE64_STANDARD; use base64::Engine; +use futures::stream::Peekable; use futures::{stream, Stream, TryStreamExt}; use once_cell::sync::Lazy; use prost::Message; @@ -602,7 +603,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request>>, ) -> Result { Ok(FAKE_UPDATE_RESULT) } @@ -610,7 +611,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan not implemented", @@ -620,7 +621,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, + _request: Request>>, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented( "do_put_prepared_statement_query not implemented", @@ -630,7 +631,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented", diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs index 20c8062f899e..d7b02414c5cc 100644 --- a/arrow-flight/src/bin/flight_sql_client.rs +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::{sync::Arc, time::Duration}; +use std::{error::Error, sync::Arc, time::Duration}; -use arrow_array::RecordBatch; -use arrow_cast::pretty::pretty_format_batches; +use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; +use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; use arrow_flight::{ sql::client::FlightSqlServiceClient, utils::flight_data_to_batches, FlightData, + FlightInfo, }; use arrow_schema::{ArrowError, Schema}; -use clap::Parser; +use clap::{Parser, Subcommand}; use futures::TryStreamExt; use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; use tracing_log::log::info; @@ -98,8 +99,20 @@ struct Args { #[clap(flatten)] client_args: ClientArgs, - /// SQL query. - query: String, + #[clap(subcommand)] + cmd: Command, +} + +#[derive(Debug, Subcommand)] +enum Command { + StatementQuery { + query: String, + }, + PreparedStatementQuery { + query: String, + #[clap(short, value_parser = parse_key_val)] + params: Vec<(String, String)>, + }, } #[tokio::main] @@ -108,12 +121,50 @@ async fn main() { setup_logging(); let mut client = setup_client(args.client_args).await.expect("setup client"); - let info = client - .execute(args.query, None) + let flight_info = match args.cmd { + Command::StatementQuery { query } => client + .execute(query, None) + .await + .expect("execute statement"), + Command::PreparedStatementQuery { query, params } => { + let mut prepared_stmt = client + .prepare(query, None) + .await + .expect("prepare statement"); + + if !params.is_empty() { + prepared_stmt + .set_parameters( + construct_record_batch_from_params( + ¶ms, + prepared_stmt + .parameter_schema() + .expect("get parameter schema"), + ) + .expect("construct parameters"), + ) + .expect("bind parameters") + } + + prepared_stmt + .execute() + .await + .expect("execute prepared statement") + } + }; + + let batches = execute_flight(&mut client, flight_info) .await - .expect("prepare statement"); - info!("got flight info"); + .expect("read flight data"); + let res = pretty_format_batches(batches.as_slice()).expect("format results"); + println!("{res}"); +} + +async fn execute_flight( + client: &mut FlightSqlServiceClient, + info: FlightInfo, +) -> Result, ArrowError> { let schema = Arc::new(Schema::try_from(info.clone()).expect("valid schema")); let mut batches = Vec::with_capacity(info.endpoint.len() + 1); batches.push(RecordBatch::new_empty(schema)); @@ -134,8 +185,27 @@ async fn main() { } info!("received data"); - let res = pretty_format_batches(batches.as_slice()).expect("format results"); - println!("{res}"); + Ok(batches) +} + +fn construct_record_batch_from_params( + params: &[(String, String)], + parameter_schema: &Schema, +) -> Result { + let mut items = Vec::<(&String, ArrayRef)>::new(); + + for (name, value) in params { + let field = parameter_schema.field_with_name(name)?; + let value_as_array = StringArray::new_scalar(value); + let casted = cast_with_options( + value_as_array.get().0, + field.data_type(), + &CastOptions::default(), + )?; + items.push((name, casted)) + } + + RecordBatch::try_from_iter(items) } fn setup_logging() { @@ -203,3 +273,13 @@ async fn setup_client( Ok(client) } + +/// Parse a single key-value pair +fn parse_key_val( + s: &str, +) -> Result<(String, String), Box> { + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) +} diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 4b1f38ebcbb7..d1dcfd6b7fcd 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -24,6 +24,8 @@ use std::collections::HashMap; use std::str::FromStr; use tonic::metadata::AsciiMetadataKey; +use crate::encode::FlightDataEncoderBuilder; +use crate::error::FlightError; use crate::flight_service_client::FlightServiceClient; use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT}; use crate::sql::{ @@ -32,8 +34,8 @@ use crate::sql::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate, - DoPutUpdateResult, ProstMessageExt, SqlInfo, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, + CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo, }; use crate::{ Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, @@ -439,9 +441,12 @@ impl PreparedStatement { /// Executes the prepared statement query on the server. pub async fn execute(&mut self) -> Result { + self.write_bind_params().await?; + let cmd = CommandPreparedStatementQuery { prepared_statement_handle: self.handle.clone(), }; + let result = self .flight_sql_client .get_flight_info_for_command(cmd) @@ -451,7 +456,9 @@ impl PreparedStatement { /// Executes the prepared statement update query on the server. pub async fn execute_update(&mut self) -> Result { - let cmd = CommandPreparedStatementQuery { + self.write_bind_params().await?; + + let cmd = CommandPreparedStatementUpdate { prepared_statement_handle: self.handle.clone(), }; let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); @@ -492,6 +499,36 @@ impl PreparedStatement { Ok(()) } + /// Submit parameters to the server, if any have been set on this prepared statement instance + async fn write_bind_params(&mut self) -> Result<(), ArrowError> { + if let Some(ref params_batch) = self.parameter_binding { + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle: self.handle.clone(), + }; + + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let flight_stream_builder = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .with_schema(params_batch.schema()); + let flight_data = flight_stream_builder + .build(futures::stream::iter( + self.parameter_binding.clone().map(Ok), + )) + .try_collect::>() + .await + .map_err(flight_error_to_arrow_error)?; + + self.flight_sql_client + .do_put(stream::iter(flight_data)) + .await? + .try_collect::>() + .await + .map_err(status_to_arrow_error)?; + } + + Ok(()) + } + /// Close the prepared statement, so that this PreparedStatement can not used /// anymore and server can free up any resources. pub async fn close(mut self) -> Result<(), ArrowError> { @@ -515,6 +552,17 @@ fn status_to_arrow_error(status: tonic::Status) -> ArrowError { ArrowError::IpcError(format!("{status:?}")) } +fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { + match err { + FlightError::Arrow(e) => e, + FlightError::NotYetImplemented(s) => ArrowError::NotYetImplemented(s), + FlightError::Tonic(status) => status_to_arrow_error(status), + FlightError::ProtocolError(e) => ArrowError::IpcError(e), + FlightError::DecodeError(s) => ArrowError::IpcError(s), + FlightError::ExternalError(e) => ArrowError::ExternalError(e), + } +} + // A polymorphic structure to natively represent different types of data contained in `FlightData` pub enum ArrowFlightData { RecordBatch(RecordBatch), diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index 912bcc75a9df..36c744f1cf76 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -19,6 +19,7 @@ use std::{net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; use arrow_flight::{ + decode::FlightRecordBatchStream, flight_service_server::{FlightService, FlightServiceServer}, sql::{ server::FlightSqlService, ActionBeginSavepointRequest, @@ -36,11 +37,13 @@ use arrow_flight::{ }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, Ticket, + HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket, }; +use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; use assert_cmd::Command; -use futures::Stream; +use bytes::Bytes; +use futures::{stream::Peekable, Stream, StreamExt, TryStreamExt}; use prost::Message; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{Request, Response, Status, Streaming}; @@ -63,6 +66,7 @@ async fn test_simple() { .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) + .arg("statement-query") .arg(QUERY) .assert() .success() @@ -87,10 +91,56 @@ async fn test_simple() { ); } +const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; +const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_do_put_prepared_statement() { + let test_server = FlightSqlServiceImpl {}; + let fixture = TestFixture::new(&test_server).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("prepared-statement-query") + .arg(PREPARED_QUERY) + .args(["-p", "$1=string"]) + .args(["-p", "$2=64"]) + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+-----------+\ + \n| field_string | field_int |\ + \n+--------------+-----------+\ + \n| Hello | 42 |\ + \n| lovely | |\ + \n| FlightSQL! | 1337 |\ + \n+--------------+-----------+", + ); +} + /// All tests must complete within this many seconds or else the test server is shutdown const DEFAULT_TIMEOUT_SECONDS: u64 = 30; -#[derive(Clone)] +#[derive(Clone, Default)] pub struct FlightSqlServiceImpl {} impl FlightSqlServiceImpl { @@ -116,6 +166,59 @@ impl FlightSqlServiceImpl { ]; RecordBatch::try_new(Arc::new(schema), cols) } + + fn create_fake_prepared_stmt( + ) -> Result { + let handle = PREPARED_STATEMENT_HANDLE.to_string(); + let schema = Schema::new(vec![ + Field::new("field_string", DataType::Utf8, false), + Field::new("field_int", DataType::Int64, true), + ]); + + let parameter_schema = Schema::new(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]); + + Ok(ActionCreatePreparedStatementResult { + prepared_statement_handle: handle.into(), + dataset_schema: serialize_schema(&schema)?, + parameter_schema: serialize_schema(¶meter_schema)?, + }) + } + + fn fake_flight_info(&self) -> Result { + let batch = Self::fake_result()?; + + Ok(FlightInfo::new() + .try_with_schema(&batch.schema()) + .expect("encoding schema") + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_1"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_2"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_total_records(batch.num_rows() as i64) + .with_total_bytes(batch.get_array_memory_size() as i64) + .with_ordered(false)) + } +} + +fn serialize_schema(schema: &Schema) -> Result { + Ok(IpcMessage::try_from(SchemaAsIpc::new(schema, &IpcWriteOptions::default()))?.0) } #[tonic::async_trait] @@ -164,45 +267,21 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result, Status> { assert_eq!(query.query, QUERY); - let batch = Self::fake_result().unwrap(); - - let info = FlightInfo::new() - .try_with_schema(&batch.schema()) - .expect("encoding schema") - .with_endpoint( - FlightEndpoint::new().with_ticket(Ticket::new( - FetchResults { - handle: String::from("part_1"), - } - .as_any() - .encode_to_vec(), - )), - ) - .with_endpoint( - FlightEndpoint::new().with_ticket(Ticket::new( - FetchResults { - handle: String::from("part_2"), - } - .as_any() - .encode_to_vec(), - )), - ) - .with_total_records(batch.num_rows() as i64) - .with_total_bytes(batch.get_array_memory_size() as i64) - .with_ordered(false); - - let resp = Response::new(info); + let resp = Response::new(self.fake_flight_info().unwrap()); Ok(resp) } async fn get_flight_info_prepared_statement( &self, - _cmd: CommandPreparedStatementQuery, + cmd: CommandPreparedStatementQuery, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_prepared_statement not implemented", - )) + assert_eq!( + cmd.prepared_statement_handle, + PREPARED_STATEMENT_HANDLE.as_bytes() + ); + let resp = Response::new(self.fake_flight_info().unwrap()); + Ok(resp) } async fn get_flight_info_substrait_plan( @@ -426,7 +505,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_statement_update not implemented", @@ -436,7 +515,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan not implemented", @@ -446,17 +525,36 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, + request: Request>>, ) -> Result::DoPutStream>, Status> { - Err(Status::unimplemented( - "do_put_prepared_statement_query not implemented", + // just make sure decoding the parameters works + let parameters = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect::>() + .await?; + + for (left, right) in parameters[0].schema().all_fields().iter().zip(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]) { + if left.name() != right.name() || left.data_type() != right.data_type() { + return Err(Status::invalid_argument(format!( + "Parameters did not match parameter schema\ngot {}", + parameters[0].schema(), + ))); + } + } + + Ok(Response::new( + futures::stream::once(async { Ok(PutResult::default()) }).boxed(), )) } async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request>>, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented", @@ -468,9 +566,8 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: ActionCreatePreparedStatementRequest, _request: Request, ) -> Result { - Err(Status::unimplemented( - "do_action_create_prepared_statement not implemented", - )) + Self::create_fake_prepared_stmt() + .map_err(|e| Status::internal(format!("Unable to serialize schema: {e}"))) } async fn do_action_close_prepared_statement( From c82dd1c028c54fb1f321bfda818b79f6fc2af41f Mon Sep 17 00:00:00 2001 From: suremarc <8771538+suremarc@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:01:05 -0500 Subject: [PATCH 4/7] trigger pre-commit hooks? From 7f06bc5ca977bcab87be6298e8f1083b50bf93d9 Mon Sep 17 00:00:00 2001 From: Matthew Cramerus <8771538+suremarc@users.noreply.github.com> Date: Fri, 8 Sep 2023 09:28:37 -0500 Subject: [PATCH 5/7] Update arrow-flight/src/sql/client.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- arrow-flight/src/sql/client.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index d1dcfd6b7fcd..b691d305e67c 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -555,11 +555,7 @@ fn status_to_arrow_error(status: tonic::Status) -> ArrowError { fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { match err { FlightError::Arrow(e) => e, - FlightError::NotYetImplemented(s) => ArrowError::NotYetImplemented(s), - FlightError::Tonic(status) => status_to_arrow_error(status), - FlightError::ProtocolError(e) => ArrowError::IpcError(e), - FlightError::DecodeError(s) => ArrowError::IpcError(s), - FlightError::ExternalError(e) => ArrowError::ExternalError(e), + e => ArrowError::ExternalError(Box::new(e)) } } From 80855ecfe200b2496e38f4f4cfb0fabcdd136632 Mon Sep 17 00:00:00 2001 From: suremarc <8771538+suremarc@users.noreply.github.com> Date: Fri, 8 Sep 2023 09:41:03 -0500 Subject: [PATCH 6/7] remove unnecessary multi-thread annotation --- arrow-flight/tests/flight_sql_client_cli.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index 36c744f1cf76..c92e15af1711 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -50,7 +50,7 @@ use tonic::{Request, Response, Status, Streaming}; const QUERY: &str = "SELECT * FROM table;"; -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[tokio::test] async fn test_simple() { let test_server = FlightSqlServiceImpl {}; let fixture = TestFixture::new(&test_server).await; @@ -94,7 +94,7 @@ async fn test_simple() { const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[tokio::test] async fn test_do_put_prepared_statement() { let test_server = FlightSqlServiceImpl {}; let fixture = TestFixture::new(&test_server).await; From 6cfc5c4a9d2230cae7dbe3b85e088009c87c55aa Mon Sep 17 00:00:00 2001 From: suremarc <8771538+suremarc@users.noreply.github.com> Date: Fri, 15 Sep 2023 14:49:44 -0500 Subject: [PATCH 7/7] rework api --- arrow-flight/examples/flight_sql_server.rs | 10 +- arrow-flight/src/sql/client.rs | 2 +- arrow-flight/src/sql/server.rs | 100 ++++++++++++++++++-- arrow-flight/tests/flight_sql_client_cli.rs | 17 ++-- 4 files changed, 108 insertions(+), 21 deletions(-) diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 4c01146976c2..d1aeae6f0a6c 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. +use arrow_flight::sql::server::PeekableFlightDataStream; use base64::prelude::BASE64_STANDARD; use base64::Engine; -use futures::stream::Peekable; use futures::{stream, Stream, TryStreamExt}; use once_cell::sync::Lazy; use prost::Message; @@ -603,7 +603,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>>, + _request: Request, ) -> Result { Ok(FAKE_UPDATE_RESULT) } @@ -611,7 +611,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan not implemented", @@ -621,7 +621,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>>, + _request: Request, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented( "do_put_prepared_statement_query not implemented", @@ -631,7 +631,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented", diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index b691d305e67c..2d382cf2ca20 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -555,7 +555,7 @@ fn status_to_arrow_error(status: tonic::Status) -> ArrowError { fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { match err { FlightError::Arrow(e) => e, - e => ArrowError::ExternalError(Box::new(e)) + e => ArrowError::ExternalError(Box::new(e)), } } diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 9406eb220341..a158ed77f54d 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -19,7 +19,7 @@ use std::pin::Pin; -use futures::{stream::Peekable, Stream}; +use futures::{stream::Peekable, Stream, StreamExt}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; @@ -366,7 +366,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { /// Implementors may override to handle additional calls to do_put() async fn do_put_fallback( &self, - _request: Request>>, + _request: Request, message: Any, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented(format!( @@ -379,7 +379,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_statement_update has no default implementation", @@ -390,7 +390,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>>, + _request: Request, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented( "do_put_prepared_statement_query has no default implementation", @@ -401,7 +401,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update has no default implementation", @@ -412,7 +412,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { async fn do_put_substrait_plan( &self, _query: CommandStatementSubstraitPlan, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan has no default implementation", @@ -696,7 +696,7 @@ where // To allow the first message to be reused by the `do_put` handler, // we wrap this stream in a `Peekable` one, which allows us to peek at // the first message without discarding it. - let mut request = request.map(futures::StreamExt::peekable); + let mut request = request.map(PeekableFlightDataStream::new); let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?; let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd) @@ -965,3 +965,89 @@ fn decode_error_to_status(err: prost::DecodeError) -> Status { fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status { Status::internal(format!("{err:?}")) } + +/// A wrapper around [`Streaming`] that allows "peeking" at the +/// message at the front of the stream without consuming it. +/// This is needed because sometimes the first message in the stream will contain +/// a [`FlightDescriptor`] in addition to potentially any data, and the dispatch logic +/// must inspect this information. +/// +/// # Example +/// +/// [`PeekableFlightDataStream::peek`] can be used to peek at the first message without +/// discarding it; otherwise, `PeekableFlightDataStream` can be used as a regular stream. +/// See the following example: +/// +/// ```no_run +/// use arrow_array::RecordBatch; +/// use arrow_flight::decode::FlightRecordBatchStream; +/// use arrow_flight::FlightDescriptor; +/// use arrow_flight::error::FlightError; +/// use arrow_flight::sql::server::PeekableFlightDataStream; +/// use tonic::{Request, Status}; +/// use futures::TryStreamExt; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Status> { +/// let request: Request = todo!(); +/// let stream: PeekableFlightDataStream = request.into_inner(); +/// +/// // The first message contains the flight descriptor and the schema. +/// // Read the flight descriptor without discarding the schema: +/// let flight_descriptor: FlightDescriptor = stream +/// .peek() +/// .await +/// .cloned() +/// .transpose()? +/// .and_then(|data| data.flight_descriptor) +/// .expect("first message should contain flight descriptor"); +/// +/// // Pass the stream through a decoder +/// let batches: Vec = FlightRecordBatchStream::new_from_flight_data( +/// request.into_inner().map_err(|e| e.into()), +/// ) +/// .try_collect() +/// .await?; +/// } +/// ``` +pub struct PeekableFlightDataStream { + inner: Peekable>, +} + +impl PeekableFlightDataStream { + fn new(stream: Streaming) -> Self { + Self { + inner: stream.peekable(), + } + } + + /// Convert this stream into a `Streaming`. + /// Any messages observed through [`Self::peek`] will be lost + /// after the conversion. + pub fn into_inner(self) -> Streaming { + self.inner.into_inner() + } + + /// Convert this stream into a `Peekable>`. + /// Preserves the state of the stream, so that calls to [`Self::peek`] + /// and [`Self::poll_next`] are the same. + pub fn into_peekable(self) -> Peekable> { + self.inner + } + + /// Peek at the head of this stream without advancing it. + pub async fn peek(&mut self) -> Option<&Result> { + Pin::new(&mut self.inner).peek().await + } +} + +impl Stream for PeekableFlightDataStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_next_unpin(cx) + } +} diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index c92e15af1711..221e776218c3 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -22,9 +22,10 @@ use arrow_flight::{ decode::FlightRecordBatchStream, flight_service_server::{FlightService, FlightServiceServer}, sql::{ - server::FlightSqlService, ActionBeginSavepointRequest, - ActionBeginSavepointResult, ActionBeginTransactionRequest, - ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, + server::{FlightSqlService, PeekableFlightDataStream}, + ActionBeginSavepointRequest, ActionBeginSavepointResult, + ActionBeginTransactionRequest, ActionBeginTransactionResult, + ActionCancelQueryRequest, ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, @@ -43,7 +44,7 @@ use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; use assert_cmd::Command; use bytes::Bytes; -use futures::{stream::Peekable, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use prost::Message; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{Request, Response, Status, Streaming}; @@ -505,7 +506,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_statement_update not implemented", @@ -515,7 +516,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan not implemented", @@ -525,7 +526,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - request: Request>>, + request: Request, ) -> Result::DoPutStream>, Status> { // just make sure decoding the parameters works let parameters = FlightRecordBatchStream::new_from_flight_data( @@ -554,7 +555,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented",