Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flight sql do put handling, add bind parameter support to FlightSQL cli client #4797

Merged
merged 7 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow_flight::sql::server::PeekableFlightDataStream;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::{stream, Stream, TryStreamExt};
Expand Down Expand Up @@ -602,15 +603,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan not implemented",
Expand All @@ -620,7 +621,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query not implemented",
Expand All @@ -630,7 +631,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update not implemented",
Expand Down
104 changes: 92 additions & 12 deletions arrow-flight/src/bin/flight_sql_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::{sync::Arc, time::Duration};
use std::{error::Error, sync::Arc, time::Duration};

use arrow_array::RecordBatch;
use arrow_cast::pretty::pretty_format_batches;
use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions};
use arrow_flight::{
sql::client::FlightSqlServiceClient, utils::flight_data_to_batches, FlightData,
FlightInfo,
};
use arrow_schema::{ArrowError, Schema};
use clap::Parser;
use clap::{Parser, Subcommand};
use futures::TryStreamExt;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tracing_log::log::info;
Expand Down Expand Up @@ -98,8 +99,20 @@ struct Args {
#[clap(flatten)]
client_args: ClientArgs,

/// SQL query.
query: String,
#[clap(subcommand)]
cmd: Command,
}

#[derive(Debug, Subcommand)]
enum Command {
StatementQuery {
query: String,
},
PreparedStatementQuery {
query: String,
#[clap(short, value_parser = parse_key_val)]
params: Vec<(String, String)>,
},
}

#[tokio::main]
Expand All @@ -108,12 +121,50 @@ async fn main() {
setup_logging();
let mut client = setup_client(args.client_args).await.expect("setup client");

let info = client
.execute(args.query, None)
let flight_info = match args.cmd {
Command::StatementQuery { query } => client
.execute(query, None)
.await
.expect("execute statement"),
Command::PreparedStatementQuery { query, params } => {
let mut prepared_stmt = client
.prepare(query, None)
.await
.expect("prepare statement");

if !params.is_empty() {
prepared_stmt
.set_parameters(
construct_record_batch_from_params(
&params,
prepared_stmt
.parameter_schema()
.expect("get parameter schema"),
)
.expect("construct parameters"),
)
.expect("bind parameters")
}

prepared_stmt
.execute()
.await
.expect("execute prepared statement")
}
};

let batches = execute_flight(&mut client, flight_info)
.await
.expect("prepare statement");
info!("got flight info");
.expect("read flight data");

let res = pretty_format_batches(batches.as_slice()).expect("format results");
println!("{res}");
}

async fn execute_flight(
client: &mut FlightSqlServiceClient<Channel>,
info: FlightInfo,
) -> Result<Vec<RecordBatch>, ArrowError> {
let schema = Arc::new(Schema::try_from(info.clone()).expect("valid schema"));
let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
batches.push(RecordBatch::new_empty(schema));
Expand All @@ -134,8 +185,27 @@ async fn main() {
}
info!("received data");

let res = pretty_format_batches(batches.as_slice()).expect("format results");
println!("{res}");
Ok(batches)
}

fn construct_record_batch_from_params(
params: &[(String, String)],
parameter_schema: &Schema,
) -> Result<RecordBatch, ArrowError> {
let mut items = Vec::<(&String, ArrayRef)>::new();

for (name, value) in params {
let field = parameter_schema.field_with_name(name)?;
let value_as_array = StringArray::new_scalar(value);
let casted = cast_with_options(
value_as_array.get().0,
field.data_type(),
&CastOptions::default(),
)?;
items.push((name, casted))
}

RecordBatch::try_from_iter(items)
}

fn setup_logging() {
Expand Down Expand Up @@ -203,3 +273,13 @@ async fn setup_client(

Ok(client)
}

/// Parse a single key-value pair
fn parse_key_val(
s: &str,
) -> Result<(String, String), Box<dyn Error + Send + Sync + 'static>> {
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
50 changes: 47 additions & 3 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use std::collections::HashMap;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;

use crate::encode::FlightDataEncoderBuilder;
use crate::error::FlightError;
use crate::flight_service_client::FlightServiceClient;
use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
use crate::sql::{
Expand All @@ -32,8 +34,8 @@ use crate::sql::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
DoPutUpdateResult, ProstMessageExt, SqlInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
Expand Down Expand Up @@ -439,9 +441,12 @@ impl PreparedStatement<Channel> {

/// Executes the prepared statement query on the server.
pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
self.write_bind_params().await?;

let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};

let result = self
.flight_sql_client
.get_flight_info_for_command(cmd)
Expand All @@ -451,7 +456,9 @@ impl PreparedStatement<Channel> {

/// Executes the prepared statement update query on the server.
pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
let cmd = CommandPreparedStatementQuery {
self.write_bind_params().await?;

let cmd = CommandPreparedStatementUpdate {
Comment on lines -454 to +461
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also forgot to mention, I think this was a bug in the existing implementation. ExecuteUpdate should be performed with a CommandPreparedStatementUpdate command, not a CommandPreparedStatementQuery.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepared_statement_handle: self.handle.clone(),
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
Expand Down Expand Up @@ -492,6 +499,36 @@ impl PreparedStatement<Channel> {
Ok(())
}

/// Submit parameters to the server, if any have been set on this prepared statement instance
async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
if let Some(ref params_batch) = self.parameter_binding {
let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};

let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let flight_stream_builder = FlightDataEncoderBuilder::new()
.with_flight_descriptor(Some(descriptor))
.with_schema(params_batch.schema());
let flight_data = flight_stream_builder
.build(futures::stream::iter(
self.parameter_binding.clone().map(Ok),
))
.try_collect::<Vec<_>>()
.await
.map_err(flight_error_to_arrow_error)?;

self.flight_sql_client
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears consistent with the FlightSQL specification, it uses do_put to bind the parameter arguments. What isn't clear to me is if the result should be being used in some way.

This would seem to imply some sort of server-side state which I had perhaps expected FlightSQL to not rely on

Copy link
Contributor Author

@suremarc suremarc Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we are in agreement about it implying server-side state. FWIW FlightSQL also supports transactions which I think (maybe wrongly) would also require state. There was also some discussion happening about adding new RPC's for managing session state at some point (like a close RPC or something)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a fundamental flaw in FlightSQL tbh, gRPC is not a connection-oriented protocol and so the lifetime of any server state is non-deterministic... I believe @alamb plans to start a discussion to see if we can't fix this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed apache/arrow#37720 and will circulate this around

.do_put(stream::iter(flight_data))
.await?
.try_collect::<Vec<_>>()
.await
.map_err(status_to_arrow_error)?;
}

Ok(())
}

/// Close the prepared statement, so that this PreparedStatement can not used
/// anymore and server can free up any resources.
pub async fn close(mut self) -> Result<(), ArrowError> {
Expand All @@ -515,6 +552,13 @@ fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
ArrowError::IpcError(format!("{status:?}"))
}

fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
match err {
FlightError::Arrow(e) => e,
e => ArrowError::ExternalError(Box::new(e)),
}
}

// A polymorphic structure to natively represent different types of data contained in `FlightData`
pub enum ArrowFlightData {
RecordBatch(RecordBatch),
Expand Down
Loading