From 83ff99dedd6fe8a53941751f09711df70ed4706a Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 6 Dec 2023 17:28:04 -0600 Subject: [PATCH 01/15] a working poc --- Cargo.lock | 28 ++ crates/glaredb/src/server.rs | 6 +- crates/rpcsrv/Cargo.toml | 1 + crates/rpcsrv/src/flight_handler.rs | 557 ++++++++++++++++++++++++++++ crates/rpcsrv/src/handler.rs | 3 +- crates/rpcsrv/src/lib.rs | 2 +- crates/sqlexec/src/context/local.rs | 9 + crates/sqlexec/src/engine.rs | 34 +- crates/sqlexec/src/session.rs | 70 +++- 9 files changed, 690 insertions(+), 20 deletions(-) create mode 100644 crates/rpcsrv/src/flight_handler.rs diff --git a/Cargo.lock b/Cargo.lock index a2dec94d8..cd67a1066 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -351,6 +351,33 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-flight" +version = "47.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd938ea4a0e8d0db2b9f47ebba792f73f6188f4289707caeaf93a3be705e5ed5" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "base64 0.21.5", + "bytes", + "futures", + "once_cell", + "paste", + "prost", + "tokio", + "tonic", +] + [[package]] name = "arrow-ipc" version = "47.0.0" @@ -5943,6 +5970,7 @@ dependencies = [ name = "rpcsrv" version = "0.7.1" dependencies = [ + "arrow-flight", "async-trait", "bytes", "dashmap", diff --git a/crates/glaredb/src/server.rs b/crates/glaredb/src/server.rs index 4aa5df663..f72c6a8a2 100644 --- a/crates/glaredb/src/server.rs +++ b/crates/glaredb/src/server.rs @@ -4,6 +4,7 @@ use pgsrv::auth::LocalAuthenticator; use pgsrv::handler::{ProtocolHandler, ProtocolHandlerConfig}; use protogen::gen::rpcsrv::service::execution_service_server::ExecutionServiceServer; use protogen::gen::rpcsrv::simple::simple_service_server::SimpleServiceServer; +use rpcsrv::flight_handler::{FlightServiceServer, FlightSessionHandler}; use rpcsrv::handler::{RpcHandler, SimpleHandler}; use sqlexec::engine::{Engine, EngineStorageConfig}; use std::collections::HashMap; @@ -200,10 +201,13 @@ impl ComputeServer { self.disable_rpc_auth, self.integration_testing, ); + let flight_handler = FlightSessionHandler::try_new(&self.engine).await?; + tokio::spawn(async move { let mut server = Server::builder() .trace_fn(|_| debug_span!("rpc_service_request")) - .add_service(ExecutionServiceServer::new(handler)); + .add_service(ExecutionServiceServer::new(handler)) + .add_service(FlightServiceServer::new(flight_handler)); // Add in the simple interface if requested. if self.enable_simple_query_rpc { diff --git a/crates/rpcsrv/Cargo.toml b/crates/rpcsrv/Cargo.toml index fc6e10c4c..cd7dfc564 100644 --- a/crates/rpcsrv/Cargo.toml +++ b/crates/rpcsrv/Cargo.toml @@ -27,3 +27,4 @@ tonic = { workspace = true } bytes = "1.4" futures = "0.3.29" dashmap = "5.5.0" +arrow-flight = { version = "47.0.0", features = ["flight-sql-experimental"] } diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs new file mode 100644 index 000000000..50e31ba10 --- /dev/null +++ b/crates/rpcsrv/src/flight_handler.rs @@ -0,0 +1,557 @@ +use crate::errors::{Result, RpcsrvError}; + +use dashmap::DashMap; +use datafusion::{ + arrow::{ipc::writer::IpcWriteOptions, record_batch::RecordBatch}, + logical_expr::LogicalPlan, + physical_plan::SendableRecordBatchStream, +}; +use datafusion_ext::vars::SessionVars; +use datafusion_proto::protobuf::PhysicalPlanNode; +use once_cell::sync::Lazy; +use sqlexec::{ + context::remote::RemoteSessionContext, + engine::{Engine, SessionStorageConfig}, + extension_codec::GlareDBExtensionCodec, + remote::provider_cache::ProviderCache, + session::Session, + OperationInfo, +}; +use std::{ + fmt::Debug, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use uuid::Uuid; + +pub use arrow_flight::flight_service_server::FlightServiceServer; +use arrow_flight::{ + encode::FlightDataEncoderBuilder, flight_service_server::FlightService, sql::*, Action, + FlightDescriptor, FlightEndpoint, FlightInfo, IpcMessage, SchemaAsIpc, Ticket, +}; +use arrow_flight::{ + sql::{ + metadata::{SqlInfoData, SqlInfoDataBuilder}, + server::FlightSqlService, + }, + HandshakeRequest, HandshakeResponse, +}; +use datafusion_proto::physical_plan::AsExecutionPlan; +use futures::{lock::Mutex, Stream}; +use futures::{StreamExt, TryStreamExt}; +use prost::Message; +use tonic::{Request, Response, Status, Streaming}; +static INSTANCE_SQL_DATA: Lazy = Lazy::new(|| { + let mut builder = SqlInfoDataBuilder::new(); + // Server information + builder.append(SqlInfo::FlightSqlServerName, "Example Flight SQL Server"); + builder.append(SqlInfo::FlightSqlServerVersion, "1"); + // 1.3 comes from https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/Schema.fbs#L24 + builder.append(SqlInfo::FlightSqlServerArrowVersion, "1.3"); + builder.build().unwrap() +}); +macro_rules! status { + ($desc:expr, $err:expr) => { + Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!())) + }; +} + +pub struct FlightSessionHandler { + engine: Arc, + sess: Arc>, + remote_session_contexts: Arc>>, + statements: Arc>, + results: Arc>>, +} +impl FlightSessionHandler { + // todo: figure out how to close inactive sessions + async fn create_ctx(&self) -> Result { + let uuid = Uuid::new_v4(); + let session = self + .engine + .new_remote_session_context(uuid, SessionStorageConfig::default()) + .await + .map_err(|e| Status::internal(format!("Error creating session: {e}")))?; + self.remote_session_contexts + .insert(uuid.clone(), Arc::new(session)); + Ok(uuid.to_string()) + } + pub async fn try_new(engine: &Arc) -> Result { + let session = engine + .new_untracked_session_context(SessionVars::default(), SessionStorageConfig::default()) + .await + .map_err(|e| Status::internal(format!("Error creating session: {e}")))?; + + Ok(Self { + engine: engine.clone(), + statements: Arc::new(DashMap::new()), + results: Arc::new(DashMap::new()), + remote_session_contexts: Arc::new(DashMap::new()), + sess: Arc::new(Mutex::new(session)), + }) + } + async fn get_ctx(&self, req: &Request) -> Result>, Status> { + Ok(self.sess.clone()) + } +} + +#[tonic::async_trait] +impl FlightSqlService for FlightSessionHandler { + type FlightService = FlightSessionHandler; + + async fn do_handshake( + &self, + request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + todo!() + } + + async fn do_get_fallback( + &self, + request: Request, + message: Any, + ) -> Result::DoGetStream>, Status> { + if !message.is::() { + panic!("Expected ActionExecutePhysicalPlan but got {:?}", message) + } + let ctx = self.get_ctx(&request).await?; + + let ctx = ctx.lock().await; + + let plan: ActionExecutePhysicalPlan = message + .unpack() + .map_err(|e| Status::internal(format!("{e:?}")))? + .ok_or_else(|| Status::internal("Expected FetchResults but got None!"))?; + let plan = plan.plan; + let cache = ProviderCache::default(); + let codec = ctx.extension_codec(&cache); + + let plan = + PhysicalPlanNode::try_decode(&plan).map_err(|e| Status::internal(format!("{e:?}")))?; + + let plan = plan + .try_into_physical_plan(ctx.df_ctx(), ctx.df_ctx().runtime_env().as_ref(), &codec) + .unwrap(); + + let stream = ctx.execute_physical(plan).await.unwrap(); + let schema = stream.schema(); + let stream = + stream.map_err(|e| arrow_flight::error::FlightError::ExternalError(Box::new(e))); + // while let Some(batch) = stream.next().await { + // batches.push(batch.unwrap()); + // } + + // let batch_stream = futures::stream::iter(batches).map(Ok); + + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(stream) + .map_err(Status::from); + + Ok(Response::new(Box::pin(stream))) + } + + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket::new(query.encode_to_vec()); + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) + } + + async fn get_flight_info_substrait_plan( + &self, + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_substrait_plan not implemented", + )) + } + + async fn get_flight_info_prepared_statement( + &self, + cmd: CommandPreparedStatementQuery, + request: Request, + ) -> Result, Status> { + let handle = std::str::from_utf8(&cmd.prepared_statement_handle) + .map_err(|e| status!("Unable to parse handle", e))?; + let ctx = self.get_ctx(&request).await?; + let mut ctx = ctx.lock().await; + + let portal = ctx + .get_portal(handle) + .map_err(|e| status!("Unable to get portal", e))?; + let plan = portal.logical_plan().unwrap(); + let plan = plan + .clone() + .try_into_datafusion_plan() + .map_err(RpcsrvError::from)?; + let physical = ctx + .create_physical_plan(plan, &OperationInfo::default()) + .await + .map_err(|e| status!("Unable to execute portal", e))?; + // Encode the physical plan into a protobuf message. + let physical_plan = { + let node = PhysicalPlanNode::try_from_physical_plan( + physical, + &GlareDBExtensionCodec::new_encoder(), + ) + .unwrap(); + let mut buf = Vec::new(); + node.try_encode(&mut buf).unwrap(); + buf + }; + let action = ActionExecutePhysicalPlan { + plan: physical_plan, + }; + + let ticket = Ticket::new(action.as_any().encode_to_vec()); + + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + // Ideally, we'd start the execution here, but instead we defer it all to the "do_get" call. + let flight_info = FlightInfo::new() + .with_descriptor(FlightDescriptor::new_cmd(vec![])) + .with_endpoint(endpoint); + + Ok(tonic::Response::new(flight_info)) + } + + // async fn get_flight_info_catalogs( + // &self, + // query: CommandGetCatalogs, + // request: Request, + // ) -> Result, Status> { + // todo!() + // } + + // async fn get_flight_info_schemas( + // &self, + // query: CommandGetDbSchemas, + // request: Request, + // ) -> Result, Status> { + // todo!() + // } + + // async fn get_flight_info_tables( + // &self, + // query: CommandGetTables, + // request: Request, + // ) -> Result, Status> { + // todo!() + // } + + // async fn get_flight_info_table_types( + // &self, + // _query: CommandGetTableTypes, + // _request: Request, + // ) -> Result, Status> { + // Err(Status::unimplemented( + // "get_flight_info_table_types not implemented", + // )) + // } + + async fn get_flight_info_sql_info( + &self, + query: CommandGetSqlInfo, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket::new(query.encode_to_vec()); + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(query.into_builder(&INSTANCE_SQL_DATA).schema().as_ref()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) + } + + // async fn get_flight_info_primary_keys( + // &self, + // _query: CommandGetPrimaryKeys, + // _request: Request, + // ) -> Result, Status> { + + // Err(Status::unimplemented( + // "get_flight_info_primary_keys not implemented", + // )) + // } + + // async fn get_flight_info_exported_keys( + // &self, + // _query: CommandGetExportedKeys, + // _request: Request, + // ) -> Result, Status> { + + // Err(Status::unimplemented( + // "get_flight_info_exported_keys not implemented", + // )) + // } + + // async fn get_flight_info_imported_keys( + // &self, + // _query: CommandGetImportedKeys, + // _request: Request, + // ) -> Result, Status> { + + // Err(Status::unimplemented( + // "get_flight_info_imported_keys not implemented", + // )) + // } + + // async fn get_flight_info_cross_reference( + // &self, + // _query: CommandGetCrossReference, + // _request: Request, + // ) -> Result, Status> { + + // Err(Status::unimplemented( + // "get_flight_info_imported_keys not implemented", + // )) + // } + + // async fn get_flight_info_xdbc_type_info( + // &self, + // query: CommandGetXdbcTypeInfo, + // request: Request, + // ) -> Result, Status> { + + // todo!() + // } + + // // do_get + // async fn do_get_statement( + // &self, + // _ticket: TicketStatementQuery, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // Err(Status::unimplemented("do_get_statement not implemented")) + // } + + // async fn do_get_prepared_statement( + // &self, + // _query: CommandPreparedStatementQuery, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // Err(Status::unimplemented( + // "do_get_prepared_statement not implemented", + // )) + // } + + // async fn do_get_catalogs( + // &self, + // query: CommandGetCatalogs, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // todo!() + // } + + // async fn do_get_schemas( + // &self, + // query: CommandGetDbSchemas, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // todo!() + // } + + // async fn do_get_tables( + // &self, + // query: CommandGetTables, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // todo!() + // } + + // async fn do_get_table_types( + // &self, + // _query: CommandGetTableTypes, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // Err(Status::unimplemented("do_get_table_types not implemented")) + // } + + // async fn do_get_sql_info( + // &self, + // query: CommandGetSqlInfo, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // todo!() + // } + // async fn do_get_primary_keys( + // &self, + // _query: CommandGetPrimaryKeys, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // Err(Status::unimplemented("do_get_primary_keys not implemented")) + // } + + // async fn do_get_exported_keys( + // &self, + // _query: CommandGetExportedKeys, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // Err(Status::unimplemented( + // "do_get_exported_keys not implemented", + // )) + // } + + // async fn do_get_imported_keys( + // &self, + // _query: CommandGetImportedKeys, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + // Err(Status::unimplemented( + // "do_get_imported_keys not implemented", + // )) + // } + + // async fn do_get_cross_reference( + // &self, + // _query: CommandGetCrossReference, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + // Err(Status::unimplemented( + // "do_get_cross_reference not implemented", + // )) + // } + + // async fn do_get_xdbc_type_info( + // &self, + // query: CommandGetXdbcTypeInfo, + // _request: Request, + // ) -> Result::DoGetStream>, Status> { + + // todo!() + // } + + // // do_put + // async fn do_put_statement_update( + // &self, + // _ticket: CommandStatementUpdate, + // _request: Request, + // ) -> Result { + + // todo!() + // } + + // async fn do_put_substrait_plan( + // &self, + // _ticket: CommandStatementSubstraitPlan, + // _request: Request, + // ) -> Result { + + // Err(Status::unimplemented( + // "do_put_substrait_plan not implemented", + // )) + // } + + // async fn do_put_prepared_statement_query( + // &self, + // _query: CommandPreparedStatementQuery, + // _request: Request, + // ) -> Result::DoPutStream>, Status> { + + // Err(Status::unimplemented( + // "do_put_prepared_statement_query not implemented", + // )) + // } + + // async fn do_put_prepared_statement_update( + // &self, + // _query: CommandPreparedStatementUpdate, + // _request: Request, + // ) -> Result { + // Err(Status::unimplemented( + // "do_put_prepared_statement_update not implemented", + // )) + // } + + async fn do_action_create_prepared_statement( + &self, + query: ActionCreatePreparedStatementRequest, + request: Request, + ) -> Result { + let ctx = self.get_ctx(&request).await?; + let mut ctx = ctx.lock().await; + let handle = uuid::Uuid::new_v4().to_string(); + ctx.prepare_portal(&handle, &query.query) + .await + .map_err(|e| status!("Unable to prepare statement", e))?; + + let lp = ctx + .query_to_lp(&query.query) + .await + .map_err(|e| status!("Unable to parse query", e))?; + + let output_schema = lp.output_schema().unwrap(); + let message = SchemaAsIpc::new(&output_schema, &IpcWriteOptions::default()) + .try_into() + .map_err(|e| status!("Unable to serialize schema", e))?; + let IpcMessage(schema_bytes) = message; + let res = ActionCreatePreparedStatementResult { + prepared_statement_handle: handle.into(), + dataset_schema: schema_bytes, + parameter_schema: Default::default(), // TODO: parameters + }; + Ok(res) + } + + async fn do_action_close_prepared_statement( + &self, + query: ActionClosePreparedStatementRequest, + request: Request, + ) -> Result<(), Status> { + let ctx = self.get_ctx(&request).await?; + let mut ctx = ctx.lock().await; + let handle = std::str::from_utf8(&query.prepared_statement_handle) + .map_err(|e| status!("Unable to parse handle", e))?; + ctx.remove_portal(handle); + + Ok(()) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionExecutePhysicalPlan { + #[prost(bytes, tag = "1")] + pub plan: Vec, +} + +impl ProstMessageExt for ActionExecutePhysicalPlan { + fn type_url() -> &'static str { + "type.googleapis.com/glaredb.rpcsrv.ActionExecutePhysicalPlan" + } + + fn as_any(&self) -> Any { + Any { + type_url: ActionExecutePhysicalPlan::type_url().to_string(), + value: ::prost::Message::encode_to_vec(self).into(), + } + } +} diff --git a/crates/rpcsrv/src/handler.rs b/crates/rpcsrv/src/handler.rs index a1b28952a..e7ce1544f 100644 --- a/crates/rpcsrv/src/handler.rs +++ b/crates/rpcsrv/src/handler.rs @@ -39,6 +39,7 @@ use tonic::{Request, Response, Status, Streaming}; use tracing::info; use uuid::Uuid; +#[derive(Clone)] pub struct RpcHandler { /// Core db engine for creating sessions. engine: Arc, @@ -327,7 +328,7 @@ impl Stream for ExecutionResponseBatchStream { /// lifetime of a query. pub struct SimpleHandler { /// Core db engine for creating sessions. - engine: Arc, + pub engine: Arc, } impl SimpleHandler { diff --git a/crates/rpcsrv/src/lib.rs b/crates/rpcsrv/src/lib.rs index e76d562c9..90f1ca469 100644 --- a/crates/rpcsrv/src/lib.rs +++ b/crates/rpcsrv/src/lib.rs @@ -1,5 +1,5 @@ pub mod errors; pub mod handler; pub mod proxy; - +pub mod flight_handler; mod session; diff --git a/crates/sqlexec/src/context/local.rs b/crates/sqlexec/src/context/local.rs index 3454861d1..46f485507 100644 --- a/crates/sqlexec/src/context/local.rs +++ b/crates/sqlexec/src/context/local.rs @@ -481,6 +481,15 @@ impl Portal { output_fields }) } + pub fn logical_plan(&self) -> Option<&LogicalPlan> { + self.stmt.plan.as_ref() + } + pub fn input_paramaters(&self) -> Option<&HashMap>> { + self.stmt.input_paramaters() + } + pub fn output_schema(&self) -> Option<&ArrowSchema> { + self.stmt.output_schema.as_ref() + } } /// Iterator over the various fields of output schema. diff --git a/crates/sqlexec/src/engine.rs b/crates/sqlexec/src/engine.rs index 85097f795..3b467dd7d 100644 --- a/crates/sqlexec/src/engine.rs +++ b/crates/sqlexec/src/engine.rs @@ -370,16 +370,34 @@ impl Engine { self.session_counter.load(Ordering::Relaxed) } + /// Create a new tracked local session. + /// + pub async fn new_local_session_context( + &self, + vars: SessionVars, + storage: SessionStorageConfig, + ) -> Result { + let session = self.new_untracked_session_context(vars, storage).await?; + let prev = self.session_counter.fetch_add(1, Ordering::Relaxed); + debug!(session_count = prev + 1, "new session opened"); + + Ok(TrackedSession { + inner: session, + session_counter: self.session_counter.clone(), + }) + } + /// Create a new local session, initializing it with the provided session /// variables. + /// Unlike [`new_local_session_context`] this doesn't track the session. // TODO: This is _very_ easy to mess up with the vars since we implement // default (which defaults to the nil uuid), but using default would is // incorrect in any case we're running Cloud. - pub async fn new_local_session_context( + pub async fn new_untracked_session_context( &self, vars: SessionVars, storage: SessionStorageConfig, - ) -> Result { + ) -> Result { let database_id = vars.database_id(); let metastore = self.supervisor.init_client(database_id).await?; let native = self @@ -395,7 +413,7 @@ impl Engine { }, ); - let session = Session::new( + Session::new( vars, catalog, metastore.into(), @@ -403,15 +421,7 @@ impl Engine { self.tracker.clone(), self.spill_path.clone(), self.task_scheduler.clone(), - )?; - - let prev = self.session_counter.fetch_add(1, Ordering::Relaxed); - debug!(session_count = prev + 1, "new session opened"); - - Ok(TrackedSession { - inner: session, - session_counter: self.session_counter.clone(), - }) + ) } /// Create a new remote session for plan execution. diff --git a/crates/sqlexec/src/session.rs b/crates/sqlexec/src/session.rs index db1419dbe..d42938f87 100644 --- a/crates/sqlexec/src/session.rs +++ b/crates/sqlexec/src/session.rs @@ -10,6 +10,7 @@ use crate::distexec::scheduler::{OutputSink, Scheduler}; use crate::distexec::stream::create_coalescing_adapter; use crate::environment::EnvironmentReader; use crate::errors::{ExecError, Result}; +use crate::extension_codec::GlareDBExtensionCodec; use crate::parser::StatementWithExtensions; use crate::planner::logical_plan::*; use crate::planner::physical_plan::{ @@ -18,6 +19,7 @@ use crate::planner::physical_plan::{ }; use crate::remote::client::RemoteClient; use crate::remote::planner::{DDLExtensionPlanner, RemotePhysicalPlanner}; +use crate::remote::provider_cache::ProviderCache; use catalog::mutator::CatalogMutator; use catalog::session_catalog::SessionCatalog; use datafusion::arrow::datatypes::Schema; @@ -109,7 +111,45 @@ pub enum ExecutionResult { /// Credentials are dropped. DropCredentials, } +pub struct PreparedStmt { + stmt: Option, +} +impl<'a> TryFrom<&'a str> for PreparedStmt { + type Error = ExecError; + fn try_from(query: &'a str) -> Result { + let mut statements = crate::parser::parse_sql(query)?; + match statements.len() { + 0 => Err(ExecError::String("No statements in query".to_string())), + 1 => Ok(PreparedStmt { + stmt: statements.pop_front(), + }), + _ => Err(ExecError::String( + "More than one statement in query".to_string(), + )), + } + } +} + +impl<'a> TryFrom<&'a String> for PreparedStmt { + type Error = ExecError; + fn try_from(query: &'a String) -> Result { + let s: &str = query; + s.try_into() + } +} +impl TryFrom> for PreparedStmt { + type Error = ExecError; + fn try_from(stmt: Option) -> Result { + Ok(PreparedStmt { stmt }) + } +} +impl TryFrom for PreparedStmt { + type Error = ExecError; + fn try_from(stmt: StatementWithExtensions) -> Result { + Ok(PreparedStmt { stmt: Some(stmt) }) + } +} impl ExecutionResult { /// Create a result from a stream and a physical plan. /// @@ -457,13 +497,15 @@ impl Session { } /// Prepare a parsed statement for future execution. - pub async fn prepare_statement( + pub async fn prepare_statement>( &mut self, name: String, - stmt: Option, + stmt: T, params: Vec, // OIDs ) -> Result<()> { - self.ctx.prepare_statement(name, stmt, params).await + let stmt: PreparedStmt = stmt.try_into()?; + + self.ctx.prepare_statement(name, stmt.stmt, params).await } pub fn get_prepared_statement(&self, name: &str) -> Result<&PreparedStatement> { @@ -679,8 +721,7 @@ impl Session { 0 => Err(ExecError::String("No statements in query".to_string())), 1 => { let stmt = statements.pop_front().unwrap(); - self.prepare_statement(UNNAMED, Some(stmt), Vec::new()) - .await?; + self.prepare_statement(UNNAMED, stmt, Vec::new()).await?; let prepared = self.get_prepared_statement(&UNNAMED)?; let num_fields = prepared.output_fields().map(|f| f.len()).unwrap_or(0); self.bind_statement( @@ -704,4 +745,23 @@ impl Session { datafusion_ext::vars::Dialect::Prql => crate::parser::parse_prql(query), } } + pub async fn prepare_portal(&mut self, handle: &str, query: &str) -> Result<()> { + self.prepare_statement(handle.to_string(), query, Vec::new()) + .await?; + let prepared = self.get_prepared_statement(handle)?; + + let num_fields = prepared.output_fields().map(|f| f.len()).unwrap_or(0); + self.bind_statement( + handle.to_string(), + handle, + Vec::new(), + vec![Format::Binary; num_fields], + )?; + Ok(()) + } + /// Returns the extension codec used for serializing and deserializing data + /// over RPCs. + pub fn extension_codec<'a>(&self, cache: &'a ProviderCache) -> GlareDBExtensionCodec<'a> { + GlareDBExtensionCodec::new_decoder(cache, self.ctx.df_ctx().runtime_env()) + } } From 6380219589326443dff8a9a13f3656da152fee31 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Dec 2023 11:03:00 -0600 Subject: [PATCH 02/15] add better err handling and multi session support --- crates/rpcsrv/src/errors.rs | 7 + crates/rpcsrv/src/flight_handler.rs | 423 +++++++--------------------- 2 files changed, 102 insertions(+), 328 deletions(-) diff --git a/crates/rpcsrv/src/errors.rs b/crates/rpcsrv/src/errors.rs index faa27198a..9df72dd76 100644 --- a/crates/rpcsrv/src/errors.rs +++ b/crates/rpcsrv/src/errors.rs @@ -1,3 +1,5 @@ +use datafusion::arrow::error::ArrowError; + #[derive(Debug, thiserror::Error)] pub enum RpcsrvError { #[error("Invalid {0} id: {1}")] @@ -53,6 +55,11 @@ pub enum RpcsrvError { #[error("{0}")] Internal(String), + + #[error("{0}")] + ParseError(String), + #[error("{0}")] + ArrowError(ArrowError), } pub type Result = std::result::Result; diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index 50e31ba10..128ebca54 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -1,28 +1,17 @@ use crate::errors::{Result, RpcsrvError}; use dashmap::DashMap; -use datafusion::{ - arrow::{ipc::writer::IpcWriteOptions, record_batch::RecordBatch}, - logical_expr::LogicalPlan, - physical_plan::SendableRecordBatchStream, -}; +use datafusion::arrow::ipc::writer::IpcWriteOptions; use datafusion_ext::vars::SessionVars; use datafusion_proto::protobuf::PhysicalPlanNode; use once_cell::sync::Lazy; use sqlexec::{ context::remote::RemoteSessionContext, - engine::{Engine, SessionStorageConfig}, + engine::{Engine, SessionStorageConfig, TrackedSession}, extension_codec::GlareDBExtensionCodec, - remote::provider_cache::ProviderCache, - session::Session, OperationInfo, }; -use std::{ - fmt::Debug, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; +use std::{pin::Pin, sync::Arc}; use uuid::Uuid; pub use arrow_flight::flight_service_server::FlightServiceServer; @@ -38,19 +27,20 @@ use arrow_flight::{ HandshakeRequest, HandshakeResponse, }; use datafusion_proto::physical_plan::AsExecutionPlan; +use futures::TryStreamExt; use futures::{lock::Mutex, Stream}; -use futures::{StreamExt, TryStreamExt}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; + static INSTANCE_SQL_DATA: Lazy = Lazy::new(|| { let mut builder = SqlInfoDataBuilder::new(); // Server information - builder.append(SqlInfo::FlightSqlServerName, "Example Flight SQL Server"); - builder.append(SqlInfo::FlightSqlServerVersion, "1"); - // 1.3 comes from https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/Schema.fbs#L24 + builder.append(SqlInfo::FlightSqlServerName, "GlareDB Flight Server"); + builder.append(SqlInfo::FlightSqlServerVersion, env!("CARGO_PKG_VERSION")); builder.append(SqlInfo::FlightSqlServerArrowVersion, "1.3"); builder.build().unwrap() }); + macro_rules! status { ($desc:expr, $err:expr) => { Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!())) @@ -59,40 +49,51 @@ macro_rules! status { pub struct FlightSessionHandler { engine: Arc, - sess: Arc>, - remote_session_contexts: Arc>>, - statements: Arc>, - results: Arc>>, + /// The remote context is used to execute the queries in a stateless manner. + remote_ctx: Arc, + sessions: DashMap>>, } + impl FlightSessionHandler { - // todo: figure out how to close inactive sessions - async fn create_ctx(&self) -> Result { - let uuid = Uuid::new_v4(); - let session = self + async fn create_ctx(&self, handle: &str) -> Result>, Status> { + if self.sessions.contains_key(handle) { + let sess = self.sessions.get(handle).unwrap().clone(); + return Ok(sess); + } + let uuid = Uuid::parse_str(handle) + .map_err(|e| RpcsrvError::ParseError(format!("Error parsing uuid: {e}")))?; + let _ = self .engine .new_remote_session_context(uuid, SessionStorageConfig::default()) .await - .map_err(|e| Status::internal(format!("Error creating session: {e}")))?; - self.remote_session_contexts - .insert(uuid.clone(), Arc::new(session)); - Ok(uuid.to_string()) - } - pub async fn try_new(engine: &Arc) -> Result { - let session = engine - .new_untracked_session_context(SessionVars::default(), SessionStorageConfig::default()) + .map_err(RpcsrvError::from)?; + + let sess = self + .engine + .new_local_session_context(SessionVars::default(), SessionStorageConfig::default()) .await - .map_err(|e| Status::internal(format!("Error creating session: {e}")))?; + .map_err(RpcsrvError::from)?; + + let sess = Arc::new(Mutex::new(sess)); + self.sessions.insert(handle.to_string(), sess.clone()); + Ok(sess) + } + pub async fn try_new(engine: &Arc) -> Result { + let exec_ctx = engine + .new_remote_session_context(Uuid::new_v4(), SessionStorageConfig::default()) + .await?; Ok(Self { engine: engine.clone(), - statements: Arc::new(DashMap::new()), - results: Arc::new(DashMap::new()), - remote_session_contexts: Arc::new(DashMap::new()), - sess: Arc::new(Mutex::new(session)), + remote_ctx: Arc::new(exec_ctx), + sessions: DashMap::new(), }) } - async fn get_ctx(&self, req: &Request) -> Result>, Status> { - Ok(self.sess.clone()) + + async fn get_ctx(&self, handle: &str) -> Result>, Status> { + self.sessions.get(handle).map(|s| s.clone()).ok_or_else(|| { + Status::internal(format!("Unable to find session with handle {}", handle)) + }) } } @@ -102,7 +103,7 @@ impl FlightSqlService for FlightSessionHandler { async fn do_handshake( &self, - request: Request>, + _: Request>, ) -> Result< Response> + Send>>>, Status, @@ -112,40 +113,40 @@ impl FlightSqlService for FlightSessionHandler { async fn do_get_fallback( &self, - request: Request, + _: Request, message: Any, ) -> Result::DoGetStream>, Status> { if !message.is::() { - panic!("Expected ActionExecutePhysicalPlan but got {:?}", message) + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {}", + message.type_url + )))? } - let ctx = self.get_ctx(&request).await?; - - let ctx = ctx.lock().await; - let plan: ActionExecutePhysicalPlan = message .unpack() - .map_err(|e| Status::internal(format!("{e:?}")))? + .map_err(RpcsrvError::from)? .ok_or_else(|| Status::internal("Expected FetchResults but got None!"))?; - let plan = plan.plan; - let cache = ProviderCache::default(); - let codec = ctx.extension_codec(&cache); + let ActionExecutePhysicalPlan { plan, handle } = plan; - let plan = - PhysicalPlanNode::try_decode(&plan).map_err(|e| Status::internal(format!("{e:?}")))?; + let ctx = self.get_ctx(&handle).await?; + + let ctx = ctx.lock().await; + + let plan = PhysicalPlanNode::try_decode(&plan).map_err(RpcsrvError::from)?; + let codec = self.remote_ctx.extension_codec(); let plan = plan .try_into_physical_plan(ctx.df_ctx(), ctx.df_ctx().runtime_env().as_ref(), &codec) - .unwrap(); + .map_err(RpcsrvError::from)?; + + let stream = self + .remote_ctx + .execute_physical(plan) + .map_err(RpcsrvError::from)?; - let stream = ctx.execute_physical(plan).await.unwrap(); let schema = stream.schema(); let stream = stream.map_err(|e| arrow_flight::error::FlightError::ExternalError(Box::new(e))); - // while let Some(batch) = stream.next().await { - // batches.push(batch.unwrap()); - // } - - // let batch_stream = futures::stream::iter(batches).map(Ok); let stream = FlightDataEncoderBuilder::new() .with_schema(schema) @@ -184,38 +185,43 @@ impl FlightSqlService for FlightSessionHandler { async fn get_flight_info_prepared_statement( &self, cmd: CommandPreparedStatementQuery, - request: Request, + _: Request, ) -> Result, Status> { let handle = std::str::from_utf8(&cmd.prepared_statement_handle) .map_err(|e| status!("Unable to parse handle", e))?; - let ctx = self.get_ctx(&request).await?; - let mut ctx = ctx.lock().await; - let portal = ctx - .get_portal(handle) - .map_err(|e| status!("Unable to get portal", e))?; + let ctx = self.get_ctx(handle).await?; + let ctx = ctx.lock().await; + let portal = ctx.get_portal(handle).map_err(RpcsrvError::from)?; + let plan = portal.logical_plan().unwrap(); + let plan = plan .clone() .try_into_datafusion_plan() .map_err(RpcsrvError::from)?; + let physical = ctx .create_physical_plan(plan, &OperationInfo::default()) .await - .map_err(|e| status!("Unable to execute portal", e))?; + .map_err(RpcsrvError::from)?; + // Encode the physical plan into a protobuf message. let physical_plan = { let node = PhysicalPlanNode::try_from_physical_plan( physical, &GlareDBExtensionCodec::new_encoder(), ) - .unwrap(); + .map_err(RpcsrvError::from)?; + let mut buf = Vec::new(); - node.try_encode(&mut buf).unwrap(); + node.try_encode(&mut buf).map_err(RpcsrvError::from)?; buf }; + let action = ActionExecutePhysicalPlan { plan: physical_plan, + handle: handle.to_string(), }; let ticket = Ticket::new(action.as_any().encode_to_vec()); @@ -230,306 +236,65 @@ impl FlightSqlService for FlightSessionHandler { Ok(tonic::Response::new(flight_info)) } - // async fn get_flight_info_catalogs( - // &self, - // query: CommandGetCatalogs, - // request: Request, - // ) -> Result, Status> { - // todo!() - // } - - // async fn get_flight_info_schemas( - // &self, - // query: CommandGetDbSchemas, - // request: Request, - // ) -> Result, Status> { - // todo!() - // } - - // async fn get_flight_info_tables( - // &self, - // query: CommandGetTables, - // request: Request, - // ) -> Result, Status> { - // todo!() - // } - - // async fn get_flight_info_table_types( - // &self, - // _query: CommandGetTableTypes, - // _request: Request, - // ) -> Result, Status> { - // Err(Status::unimplemented( - // "get_flight_info_table_types not implemented", - // )) - // } - async fn get_flight_info_sql_info( &self, query: CommandGetSqlInfo, request: Request, ) -> Result, Status> { let flight_descriptor = request.into_inner(); - let ticket = Ticket::new(query.encode_to_vec()); - let endpoint = FlightEndpoint::new().with_ticket(ticket); let flight_info = FlightInfo::new() .try_with_schema(query.into_builder(&INSTANCE_SQL_DATA).schema().as_ref()) - .map_err(|e| status!("Unable to encode schema", e))? + .map_err(RpcsrvError::from)? .with_descriptor(flight_descriptor); Ok(tonic::Response::new(flight_info)) } - // async fn get_flight_info_primary_keys( - // &self, - // _query: CommandGetPrimaryKeys, - // _request: Request, - // ) -> Result, Status> { - - // Err(Status::unimplemented( - // "get_flight_info_primary_keys not implemented", - // )) - // } - - // async fn get_flight_info_exported_keys( - // &self, - // _query: CommandGetExportedKeys, - // _request: Request, - // ) -> Result, Status> { - - // Err(Status::unimplemented( - // "get_flight_info_exported_keys not implemented", - // )) - // } - - // async fn get_flight_info_imported_keys( - // &self, - // _query: CommandGetImportedKeys, - // _request: Request, - // ) -> Result, Status> { - - // Err(Status::unimplemented( - // "get_flight_info_imported_keys not implemented", - // )) - // } - - // async fn get_flight_info_cross_reference( - // &self, - // _query: CommandGetCrossReference, - // _request: Request, - // ) -> Result, Status> { - - // Err(Status::unimplemented( - // "get_flight_info_imported_keys not implemented", - // )) - // } - - // async fn get_flight_info_xdbc_type_info( - // &self, - // query: CommandGetXdbcTypeInfo, - // request: Request, - // ) -> Result, Status> { - - // todo!() - // } - - // // do_get - // async fn do_get_statement( - // &self, - // _ticket: TicketStatementQuery, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // Err(Status::unimplemented("do_get_statement not implemented")) - // } - - // async fn do_get_prepared_statement( - // &self, - // _query: CommandPreparedStatementQuery, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // Err(Status::unimplemented( - // "do_get_prepared_statement not implemented", - // )) - // } - - // async fn do_get_catalogs( - // &self, - // query: CommandGetCatalogs, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // todo!() - // } - - // async fn do_get_schemas( - // &self, - // query: CommandGetDbSchemas, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // todo!() - // } - - // async fn do_get_tables( - // &self, - // query: CommandGetTables, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // todo!() - // } - - // async fn do_get_table_types( - // &self, - // _query: CommandGetTableTypes, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // Err(Status::unimplemented("do_get_table_types not implemented")) - // } - - // async fn do_get_sql_info( - // &self, - // query: CommandGetSqlInfo, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // todo!() - // } - // async fn do_get_primary_keys( - // &self, - // _query: CommandGetPrimaryKeys, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // Err(Status::unimplemented("do_get_primary_keys not implemented")) - // } - - // async fn do_get_exported_keys( - // &self, - // _query: CommandGetExportedKeys, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // Err(Status::unimplemented( - // "do_get_exported_keys not implemented", - // )) - // } - - // async fn do_get_imported_keys( - // &self, - // _query: CommandGetImportedKeys, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - // Err(Status::unimplemented( - // "do_get_imported_keys not implemented", - // )) - // } - - // async fn do_get_cross_reference( - // &self, - // _query: CommandGetCrossReference, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - // Err(Status::unimplemented( - // "do_get_cross_reference not implemented", - // )) - // } - - // async fn do_get_xdbc_type_info( - // &self, - // query: CommandGetXdbcTypeInfo, - // _request: Request, - // ) -> Result::DoGetStream>, Status> { - - // todo!() - // } - - // // do_put - // async fn do_put_statement_update( - // &self, - // _ticket: CommandStatementUpdate, - // _request: Request, - // ) -> Result { - - // todo!() - // } - - // async fn do_put_substrait_plan( - // &self, - // _ticket: CommandStatementSubstraitPlan, - // _request: Request, - // ) -> Result { - - // Err(Status::unimplemented( - // "do_put_substrait_plan not implemented", - // )) - // } - - // async fn do_put_prepared_statement_query( - // &self, - // _query: CommandPreparedStatementQuery, - // _request: Request, - // ) -> Result::DoPutStream>, Status> { - - // Err(Status::unimplemented( - // "do_put_prepared_statement_query not implemented", - // )) - // } - - // async fn do_put_prepared_statement_update( - // &self, - // _query: CommandPreparedStatementUpdate, - // _request: Request, - // ) -> Result { - // Err(Status::unimplemented( - // "do_put_prepared_statement_update not implemented", - // )) - // } - + // I think it's safe to create a session for the duration of the prepared statement? async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, - request: Request, + _: Request, ) -> Result { - let ctx = self.get_ctx(&request).await?; - let mut ctx = ctx.lock().await; let handle = uuid::Uuid::new_v4().to_string(); + let ctx = self.create_ctx(&handle).await?; + let mut ctx = ctx.lock().await; + ctx.prepare_portal(&handle, &query.query) .await - .map_err(|e| status!("Unable to prepare statement", e))?; + .map_err(RpcsrvError::from)?; let lp = ctx .query_to_lp(&query.query) .await - .map_err(|e| status!("Unable to parse query", e))?; + .map_err(RpcsrvError::from)?; + + let output_schema = lp.output_schema().expect("no output schema"); - let output_schema = lp.output_schema().unwrap(); let message = SchemaAsIpc::new(&output_schema, &IpcWriteOptions::default()) .try_into() - .map_err(|e| status!("Unable to serialize schema", e))?; + .map_err(RpcsrvError::from)?; + let IpcMessage(schema_bytes) = message; let res = ActionCreatePreparedStatementResult { prepared_statement_handle: handle.into(), dataset_schema: schema_bytes, parameter_schema: Default::default(), // TODO: parameters }; + Ok(res) } async fn do_action_close_prepared_statement( &self, query: ActionClosePreparedStatementRequest, - request: Request, + _: Request, ) -> Result<(), Status> { - let ctx = self.get_ctx(&request).await?; - let mut ctx = ctx.lock().await; let handle = std::str::from_utf8(&query.prepared_statement_handle) - .map_err(|e| status!("Unable to parse handle", e))?; - ctx.remove_portal(handle); + .map_err(|e| RpcsrvError::ParseError(e.to_string()))?; + + self.sessions.remove(handle); Ok(()) } @@ -541,6 +306,8 @@ impl FlightSqlService for FlightSessionHandler { pub struct ActionExecutePhysicalPlan { #[prost(bytes, tag = "1")] pub plan: Vec, + #[prost(string, tag = "2")] + pub handle: String, } impl ProstMessageExt for ActionExecutePhysicalPlan { From bbfe3db77001c74ed1047250aa826304d8667f93 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Dec 2023 11:11:28 -0600 Subject: [PATCH 03/15] cleanup --- adbc_flight.py | 30 ++++++++++++++++++++++++++++++ crates/glaredb/src/local.rs | 2 +- crates/pgsrv/src/handler.rs | 5 +---- crates/rpcsrv/src/handler.rs | 2 +- crates/sqlexec/src/engine.rs | 34 ++++++++++++---------------------- crates/sqlexec/src/session.rs | 31 ++++++++++++++----------------- crates/testing/src/slt/test.rs | 4 +--- 7 files changed, 60 insertions(+), 48 deletions(-) create mode 100644 adbc_flight.py diff --git a/adbc_flight.py b/adbc_flight.py new file mode 100644 index 000000000..5a61c87f7 --- /dev/null +++ b/adbc_flight.py @@ -0,0 +1,30 @@ +import adbc_driver_flightsql.dbapi +import polars as pl + +with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * from '/Users/corygrinstead/Development/glaredb/testdata/csv/userdata1.csv'" + ) + res = cursor.fetch_arrow_table() + print(pl.from_arrow(res)) + + with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn2: + cursor = conn2.cursor() + cursor.execute( + "SELECT * from '/Users/corygrinstead/Development/glaredb/testdata/csv/userdata1.csv'" + ) + res = cursor.fetch_arrow_table() + print(pl.from_arrow(res)) + cursor.close() + + with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn3: + cursor = conn3.cursor() + cursor.execute( + "SELECT * from '/Users/corygrinstead/Development/glaredb/testdata/csv/userdata1.csv'" + ) + res = cursor.fetch_arrow_table() + print(pl.from_arrow(res)) + cursor.close() + cursor.close() + diff --git a/crates/glaredb/src/local.rs b/crates/glaredb/src/local.rs index d69841227..5bed84f54 100644 --- a/crates/glaredb/src/local.rs +++ b/crates/glaredb/src/local.rs @@ -207,7 +207,7 @@ impl LocalSession { let statements = self.sess.parse_query(text)?; for stmt in statements { self.sess - .prepare_statement(UNNAMED, Some(stmt), Vec::new()) + .prepare_statement(UNNAMED, stmt, Vec::new()) .await?; let prepared = self.sess.get_prepared_statement(&UNNAMED)?; let num_fields = prepared.output_fields().map(|f| f.len()).unwrap_or(0); diff --git a/crates/pgsrv/src/handler.rs b/crates/pgsrv/src/handler.rs index bc1d42d79..35a1c34e6 100644 --- a/crates/pgsrv/src/handler.rs +++ b/crates/pgsrv/src/handler.rs @@ -461,10 +461,7 @@ where const UNNAMED: String = String::new(); // Parse... - if let Err(e) = session - .prepare_statement(UNNAMED, Some(stmt), Vec::new()) - .await - { + if let Err(e) = session.prepare_statement(UNNAMED, stmt, Vec::new()).await { self.send_error(e.into()).await?; return self.ready_for_query().await; }; diff --git a/crates/rpcsrv/src/handler.rs b/crates/rpcsrv/src/handler.rs index e7ce1544f..0458ede8d 100644 --- a/crates/rpcsrv/src/handler.rs +++ b/crates/rpcsrv/src/handler.rs @@ -328,7 +328,7 @@ impl Stream for ExecutionResponseBatchStream { /// lifetime of a query. pub struct SimpleHandler { /// Core db engine for creating sessions. - pub engine: Arc, + engine: Arc, } impl SimpleHandler { diff --git a/crates/sqlexec/src/engine.rs b/crates/sqlexec/src/engine.rs index 3b467dd7d..85097f795 100644 --- a/crates/sqlexec/src/engine.rs +++ b/crates/sqlexec/src/engine.rs @@ -370,34 +370,16 @@ impl Engine { self.session_counter.load(Ordering::Relaxed) } - /// Create a new tracked local session. - /// - pub async fn new_local_session_context( - &self, - vars: SessionVars, - storage: SessionStorageConfig, - ) -> Result { - let session = self.new_untracked_session_context(vars, storage).await?; - let prev = self.session_counter.fetch_add(1, Ordering::Relaxed); - debug!(session_count = prev + 1, "new session opened"); - - Ok(TrackedSession { - inner: session, - session_counter: self.session_counter.clone(), - }) - } - /// Create a new local session, initializing it with the provided session /// variables. - /// Unlike [`new_local_session_context`] this doesn't track the session. // TODO: This is _very_ easy to mess up with the vars since we implement // default (which defaults to the nil uuid), but using default would is // incorrect in any case we're running Cloud. - pub async fn new_untracked_session_context( + pub async fn new_local_session_context( &self, vars: SessionVars, storage: SessionStorageConfig, - ) -> Result { + ) -> Result { let database_id = vars.database_id(); let metastore = self.supervisor.init_client(database_id).await?; let native = self @@ -413,7 +395,7 @@ impl Engine { }, ); - Session::new( + let session = Session::new( vars, catalog, metastore.into(), @@ -421,7 +403,15 @@ impl Engine { self.tracker.clone(), self.spill_path.clone(), self.task_scheduler.clone(), - ) + )?; + + let prev = self.session_counter.fetch_add(1, Ordering::Relaxed); + debug!(session_count = prev + 1, "new session opened"); + + Ok(TrackedSession { + inner: session, + session_counter: self.session_counter.clone(), + }) } /// Create a new remote session for plan execution. diff --git a/crates/sqlexec/src/session.rs b/crates/sqlexec/src/session.rs index d42938f87..fd6a62520 100644 --- a/crates/sqlexec/src/session.rs +++ b/crates/sqlexec/src/session.rs @@ -10,7 +10,6 @@ use crate::distexec::scheduler::{OutputSink, Scheduler}; use crate::distexec::stream::create_coalescing_adapter; use crate::environment::EnvironmentReader; use crate::errors::{ExecError, Result}; -use crate::extension_codec::GlareDBExtensionCodec; use crate::parser::StatementWithExtensions; use crate::planner::logical_plan::*; use crate::planner::physical_plan::{ @@ -19,7 +18,6 @@ use crate::planner::physical_plan::{ }; use crate::remote::client::RemoteClient; use crate::remote::planner::{DDLExtensionPlanner, RemotePhysicalPlanner}; -use crate::remote::provider_cache::ProviderCache; use catalog::mutator::CatalogMutator; use catalog::session_catalog::SessionCatalog; use datafusion::arrow::datatypes::Schema; @@ -111,16 +109,18 @@ pub enum ExecutionResult { /// Credentials are dropped. DropCredentials, } -pub struct PreparedStmt { +// this just makes the `prepare_statement` method a bit more ergonomic. +pub struct PrepareStatementArg { stmt: Option, } -impl<'a> TryFrom<&'a str> for PreparedStmt { + +impl<'a> TryFrom<&'a str> for PrepareStatementArg { type Error = ExecError; fn try_from(query: &'a str) -> Result { let mut statements = crate::parser::parse_sql(query)?; match statements.len() { 0 => Err(ExecError::String("No statements in query".to_string())), - 1 => Ok(PreparedStmt { + 1 => Ok(PrepareStatementArg { stmt: statements.pop_front(), }), _ => Err(ExecError::String( @@ -130,7 +130,7 @@ impl<'a> TryFrom<&'a str> for PreparedStmt { } } -impl<'a> TryFrom<&'a String> for PreparedStmt { +impl<'a> TryFrom<&'a String> for PrepareStatementArg { type Error = ExecError; fn try_from(query: &'a String) -> Result { let s: &str = query; @@ -138,18 +138,19 @@ impl<'a> TryFrom<&'a String> for PreparedStmt { } } -impl TryFrom> for PreparedStmt { +impl TryFrom> for PrepareStatementArg { type Error = ExecError; fn try_from(stmt: Option) -> Result { - Ok(PreparedStmt { stmt }) + Ok(PrepareStatementArg { stmt }) } } -impl TryFrom for PreparedStmt { +impl TryFrom for PrepareStatementArg { type Error = ExecError; fn try_from(stmt: StatementWithExtensions) -> Result { - Ok(PreparedStmt { stmt: Some(stmt) }) + Ok(PrepareStatementArg { stmt: Some(stmt) }) } } + impl ExecutionResult { /// Create a result from a stream and a physical plan. /// @@ -497,13 +498,13 @@ impl Session { } /// Prepare a parsed statement for future execution. - pub async fn prepare_statement>( + pub async fn prepare_statement>( &mut self, name: String, stmt: T, params: Vec, // OIDs ) -> Result<()> { - let stmt: PreparedStmt = stmt.try_into()?; + let stmt: PrepareStatementArg = stmt.try_into()?; self.ctx.prepare_statement(name, stmt.stmt, params).await } @@ -745,6 +746,7 @@ impl Session { datafusion_ext::vars::Dialect::Prql => crate::parser::parse_prql(query), } } + pub async fn prepare_portal(&mut self, handle: &str, query: &str) -> Result<()> { self.prepare_statement(handle.to_string(), query, Vec::new()) .await?; @@ -759,9 +761,4 @@ impl Session { )?; Ok(()) } - /// Returns the extension codec used for serializing and deserializing data - /// over RPCs. - pub fn extension_codec<'a>(&self, cache: &'a ProviderCache) -> GlareDBExtensionCodec<'a> { - GlareDBExtensionCodec::new_decoder(cache, self.ctx.df_ctx().runtime_env()) - } } diff --git a/crates/testing/src/slt/test.rs b/crates/testing/src/slt/test.rs index 3f1f6cccd..90b78889c 100644 --- a/crates/testing/src/slt/test.rs +++ b/crates/testing/src/slt/test.rs @@ -289,9 +289,7 @@ impl AsyncDB for TestClient { let statements = session.parse_query(sql)?; for stmt in statements { - session - .prepare_statement(UNNAMED, Some(stmt), Vec::new()) - .await?; + session.prepare_statement(UNNAMED, stmt, Vec::new()).await?; let prepared = session.get_prepared_statement(&UNNAMED)?; let num_fields = prepared.output_fields().map(|f| f.len()).unwrap_or(0); session.bind_statement( From 73560d0a5d317298b7256853b90d2351204776b2 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Dec 2023 11:12:27 -0600 Subject: [PATCH 04/15] clippy --- crates/rpcsrv/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rpcsrv/src/lib.rs b/crates/rpcsrv/src/lib.rs index 90f1ca469..68f7a2446 100644 --- a/crates/rpcsrv/src/lib.rs +++ b/crates/rpcsrv/src/lib.rs @@ -1,5 +1,5 @@ pub mod errors; +pub mod flight_handler; pub mod handler; pub mod proxy; -pub mod flight_handler; mod session; From 031352d2897828367792a1d1631027c21e2aa2d1 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Dec 2023 11:12:58 -0600 Subject: [PATCH 05/15] remove accidentally commited file --- adbc_flight.py | 30 ------------------------------ 1 file changed, 30 deletions(-) delete mode 100644 adbc_flight.py diff --git a/adbc_flight.py b/adbc_flight.py deleted file mode 100644 index 5a61c87f7..000000000 --- a/adbc_flight.py +++ /dev/null @@ -1,30 +0,0 @@ -import adbc_driver_flightsql.dbapi -import polars as pl - -with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn: - cursor = conn.cursor() - cursor.execute( - "SELECT * from '/Users/corygrinstead/Development/glaredb/testdata/csv/userdata1.csv'" - ) - res = cursor.fetch_arrow_table() - print(pl.from_arrow(res)) - - with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn2: - cursor = conn2.cursor() - cursor.execute( - "SELECT * from '/Users/corygrinstead/Development/glaredb/testdata/csv/userdata1.csv'" - ) - res = cursor.fetch_arrow_table() - print(pl.from_arrow(res)) - cursor.close() - - with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn3: - cursor = conn3.cursor() - cursor.execute( - "SELECT * from '/Users/corygrinstead/Development/glaredb/testdata/csv/userdata1.csv'" - ) - res = cursor.fetch_arrow_table() - print(pl.from_arrow(res)) - cursor.close() - cursor.close() - From 1f5c882f52197d5bc55ab0df9cca527e3a898ede Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Dec 2023 11:54:19 -0600 Subject: [PATCH 06/15] more cleanup --- adbc_flight.py | 11 +++++++++++ crates/rpcsrv/src/flight_handler.rs | 9 ++------- 2 files changed, 13 insertions(+), 7 deletions(-) create mode 100644 adbc_flight.py diff --git a/adbc_flight.py b/adbc_flight.py new file mode 100644 index 000000000..4c8cd1222 --- /dev/null +++ b/adbc_flight.py @@ -0,0 +1,11 @@ +import adbc_driver_flightsql.dbapi +import polars as pl + +with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * from './testdata/csv/userdata1.csv'" + ) + res = cursor.fetch_arrow_table() + print(pl.from_arrow(res)) + cursor.close() \ No newline at end of file diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index 128ebca54..622c2b226 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -60,13 +60,6 @@ impl FlightSessionHandler { let sess = self.sessions.get(handle).unwrap().clone(); return Ok(sess); } - let uuid = Uuid::parse_str(handle) - .map_err(|e| RpcsrvError::ParseError(format!("Error parsing uuid: {e}")))?; - let _ = self - .engine - .new_remote_session_context(uuid, SessionStorageConfig::default()) - .await - .map_err(RpcsrvError::from)?; let sess = self .engine @@ -76,6 +69,7 @@ impl FlightSessionHandler { let sess = Arc::new(Mutex::new(sess)); self.sessions.insert(handle.to_string(), sess.clone()); + Ok(sess) } @@ -83,6 +77,7 @@ impl FlightSessionHandler { let exec_ctx = engine .new_remote_session_context(Uuid::new_v4(), SessionStorageConfig::default()) .await?; + Ok(Self { engine: engine.clone(), remote_ctx: Arc::new(exec_ctx), From 61280182a96855179aee1a1439915d664febd902 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Dec 2023 11:54:47 -0600 Subject: [PATCH 07/15] remove accidentally commited file --- adbc_flight.py | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 adbc_flight.py diff --git a/adbc_flight.py b/adbc_flight.py deleted file mode 100644 index 4c8cd1222..000000000 --- a/adbc_flight.py +++ /dev/null @@ -1,11 +0,0 @@ -import adbc_driver_flightsql.dbapi -import polars as pl - -with adbc_driver_flightsql.dbapi.connect("grpc://0.0.0.0:6789") as conn: - cursor = conn.cursor() - cursor.execute( - "SELECT * from './testdata/csv/userdata1.csv'" - ) - res = cursor.fetch_arrow_table() - print(pl.from_arrow(res)) - cursor.close() \ No newline at end of file From ae54f1f9d603f3f2b8782d8cff7d19b39e83e34b Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Dec 2023 11:57:42 -0600 Subject: [PATCH 08/15] add todo comment --- crates/rpcsrv/src/flight_handler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index 622c2b226..0e926cb3a 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -103,7 +103,7 @@ impl FlightSqlService for FlightSessionHandler { Response> + Send>>>, Status, > { - todo!() + todo!("support TLS") } async fn do_get_fallback( From 1fb144994c0c49b208bec450885f512d7643cb32 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Dec 2023 10:24:27 -0600 Subject: [PATCH 09/15] better session handling --- crates/glaredb/src/server.rs | 1 - crates/rpcsrv/src/flight_handler.rs | 265 +++++++++++++++++++--------- crates/sqlexec/src/engine.rs | 33 ++-- crates/sqlexec/src/session.rs | 2 +- 4 files changed, 207 insertions(+), 94 deletions(-) diff --git a/crates/glaredb/src/server.rs b/crates/glaredb/src/server.rs index f72c6a8a2..029c79bc6 100644 --- a/crates/glaredb/src/server.rs +++ b/crates/glaredb/src/server.rs @@ -215,7 +215,6 @@ impl ComputeServer { let handler = SimpleHandler::new(self.engine.clone()); server = server.add_service(SimpleServiceServer::new(handler)); } - if let Err(e) = server.serve(addr).await { // TODO: Maybe panic instead? Revisit once we have // everything working. diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index 0e926cb3a..01a6cdb21 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -1,14 +1,15 @@ use crate::errors::{Result, RpcsrvError}; use dashmap::DashMap; -use datafusion::arrow::ipc::writer::IpcWriteOptions; +use datafusion::{arrow::ipc::writer::IpcWriteOptions, physical_plan::ExecutionPlan}; use datafusion_ext::vars::SessionVars; use datafusion_proto::protobuf::PhysicalPlanNode; use once_cell::sync::Lazy; use sqlexec::{ context::remote::RemoteSessionContext, - engine::{Engine, SessionStorageConfig, TrackedSession}, + engine::{Engine, SessionStorageConfig}, extension_codec::GlareDBExtensionCodec, + session::Session, OperationInfo, }; use std::{pin::Pin, sync::Arc}; @@ -41,54 +42,148 @@ static INSTANCE_SQL_DATA: Lazy = Lazy::new(|| { builder.build().unwrap() }); -macro_rules! status { - ($desc:expr, $err:expr) => { - Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!())) - }; -} +/// Custom header clients can use to specify the database they want to connect to. +/// the ADBC driver requires it to be passed in as `adbc.flight.sql.rpc.call_header.x-database` +const DATABASE_HEADER: &str = "x-database"; pub struct FlightSessionHandler { engine: Arc, - /// The remote context is used to execute the queries in a stateless manner. - remote_ctx: Arc, - sessions: DashMap>>, + /// TODO: currently, we aren't removing these sessions, so this will grow forever. + /// there's no close/shutdown hook, so the sessions can at most only be tied to a single transaction, not a connection. + /// We'll want to implement a time based eviction policy, or a max size. + remote_sessions: DashMap>, + // We use [`Session`] instead of [`TrackedSession`] because tracked sessions need to be + // explicitly closed, and we don't have a way to do that yet. + sessions: DashMap>>, } impl FlightSessionHandler { - async fn create_ctx(&self, handle: &str) -> Result>, Status> { - if self.sessions.contains_key(handle) { - let sess = self.sessions.get(handle).unwrap().clone(); - return Ok(sess); + pub async fn try_new(engine: &Arc) -> Result { + Ok(Self { + engine: engine.clone(), + remote_sessions: DashMap::new(), + sessions: DashMap::new(), + }) + } + + async fn get_or_create_ctx( + &self, + request: &Request, + ) -> Result<(String, Arc>), Status> { + let db_handle = request + .metadata() + .get(DATABASE_HEADER) + .map(|s| s.to_str().unwrap().to_string()) + .unwrap_or_else(|| Uuid::default().to_string()); + if self.sessions.contains_key(&db_handle) { + let sess = self.sessions.get(&db_handle).unwrap().clone(); + return Ok((db_handle, sess)); } + let db_id = Uuid::parse_str(&db_handle).map_err(|e| { + Status::internal(format!( + "Unable to parse database handle: {}", + e.to_string() + )) + })?; + + let session_vars = + SessionVars::default().with_database_id(db_id, datafusion::variable::VarType::System); let sess = self .engine - .new_local_session_context(SessionVars::default(), SessionStorageConfig::default()) + .new_untracked_session(session_vars, SessionStorageConfig::default()) .await .map_err(RpcsrvError::from)?; + let remote_sess = self + .engine + .new_remote_session_context(db_id, SessionStorageConfig::default()) + .await + .map_err(RpcsrvError::from)?; let sess = Arc::new(Mutex::new(sess)); - self.sessions.insert(handle.to_string(), sess.clone()); + self.sessions.insert(db_handle.clone(), sess.clone()); + self.remote_sessions + .insert(db_handle.clone(), Arc::new(remote_sess)); - Ok(sess) + Ok((db_handle, sess)) } - pub async fn try_new(engine: &Arc) -> Result { - let exec_ctx = engine - .new_remote_session_context(Uuid::new_v4(), SessionStorageConfig::default()) - .await?; + async fn get_exec_ctx( + &self, + request: &Request, + ) -> Result, Status> { + let db_handle = request + .metadata() + .get(DATABASE_HEADER) + .map(|s| s.to_str().unwrap().to_string()) + .unwrap_or_else(|| Uuid::default().to_string()); + + if let Some(sess) = self.remote_sessions.get(&db_handle) { + Ok(sess.clone()) + } else { + Err(Status::internal(format!( + "Unable to find session with handle {}", + db_handle + ))) + } + } - Ok(Self { - engine: engine.clone(), - remote_ctx: Arc::new(exec_ctx), - sessions: DashMap::new(), - }) + async fn get_ctx(&self, request: &Request) -> Result>, Status> { + let db_handle = request + .metadata() + .get(DATABASE_HEADER) + .map(|s| s.to_str().unwrap().to_string()) + .unwrap_or_else(|| Uuid::default().to_string()); + + self.sessions + .get(&db_handle) + .map(|s| s.clone()) + .ok_or_else(|| { + Status::internal(format!("Unable to find session with handle {}", db_handle)) + }) } - async fn get_ctx(&self, handle: &str) -> Result>, Status> { - self.sessions.get(handle).map(|s| s.clone()).ok_or_else(|| { - Status::internal(format!("Unable to find session with handle {}", handle)) - }) + async fn do_action_execute_physical_plan( + &self, + req: &Request, + query: ActionExecutePhysicalPlan, + ) -> Result::DoGetStream>, Status> { + let ctx = self.get_exec_ctx(&req).await?; + + let plan = PhysicalPlanNode::try_decode(&query.plan).map_err(RpcsrvError::from)?; + let codec = ctx.extension_codec(); + + let plan = plan + .try_into_physical_plan( + ctx.get_datafusion_context(), + ctx.get_datafusion_context().runtime_env().as_ref(), + &codec, + ) + .map_err(RpcsrvError::from)?; + + self.execute_physical_plan(req, plan).await + } + + async fn execute_physical_plan( + &self, + request: &Request, + physical_plan: Arc, + ) -> Result::DoGetStream>, Status> { + let ctx = self.get_exec_ctx(&request).await?; + let stream = ctx + .execute_physical(physical_plan) + .map_err(RpcsrvError::from)?; + + let schema = stream.schema(); + let stream = + stream.map_err(|e| arrow_flight::error::FlightError::ExternalError(Box::new(e))); + + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(stream) + .map_err(Status::from); + + Ok(Response::new(Box::pin(stream))) } } @@ -108,47 +203,52 @@ impl FlightSqlService for FlightSessionHandler { async fn do_get_fallback( &self, - _: Request, + req: Request, message: Any, ) -> Result::DoGetStream>, Status> { - if !message.is::() { - Err(Status::unimplemented(format!( - "do_get: The defined request is invalid: {}", - message.type_url - )))? + match message.type_url.as_str() { + ActionExecutePhysicalPlan::TYPE_URL => { + let action: ActionExecutePhysicalPlan = message + .unpack() + .map_err(RpcsrvError::from)? + .ok_or_else(|| { + Status::internal("Expected ActionExecutePhysicalPlan but got None!") + })?; + + return self.do_action_execute_physical_plan(&req, action).await; + } + // All non specified types should be handled as a sql query + other => { + let mut ctx = self + .engine + .new_local_session_context( + SessionVars::default(), + SessionStorageConfig::default(), + ) + .await + .map_err(RpcsrvError::from)?; + match sqlexec::parser::parse_sql(&other) { + Ok(statements) => { + let lp = ctx + .parsed_to_lp(statements) + .await + .map_err(RpcsrvError::from)?; + let physical = ctx + .create_physical_plan( + lp.try_into_datafusion_plan().map_err(RpcsrvError::from)?, + &OperationInfo::default(), + ) + .await + .map_err(RpcsrvError::from)?; + self.execute_physical_plan(&req, physical).await + } + Err(e) => Err(Status::internal(format!( + "Expected a SQL query, instead received: {}", + e.to_string() + ))), + } + } } - let plan: ActionExecutePhysicalPlan = message - .unpack() - .map_err(RpcsrvError::from)? - .ok_or_else(|| Status::internal("Expected FetchResults but got None!"))?; - let ActionExecutePhysicalPlan { plan, handle } = plan; - - let ctx = self.get_ctx(&handle).await?; - - let ctx = ctx.lock().await; - - let plan = PhysicalPlanNode::try_decode(&plan).map_err(RpcsrvError::from)?; - let codec = self.remote_ctx.extension_codec(); - - let plan = plan - .try_into_physical_plan(ctx.df_ctx(), ctx.df_ctx().runtime_env().as_ref(), &codec) - .map_err(RpcsrvError::from)?; - - let stream = self - .remote_ctx - .execute_physical(plan) - .map_err(RpcsrvError::from)?; - - let schema = stream.schema(); - let stream = - stream.map_err(|e| arrow_flight::error::FlightError::ExternalError(Box::new(e))); - - let stream = FlightDataEncoderBuilder::new() - .with_schema(schema) - .build(stream) - .map_err(Status::from); - - Ok(Response::new(Box::pin(stream))) } async fn get_flight_info_statement( @@ -179,15 +279,12 @@ impl FlightSqlService for FlightSessionHandler { async fn get_flight_info_prepared_statement( &self, - cmd: CommandPreparedStatementQuery, - _: Request, + _: CommandPreparedStatementQuery, + req: Request, ) -> Result, Status> { - let handle = std::str::from_utf8(&cmd.prepared_statement_handle) - .map_err(|e| status!("Unable to parse handle", e))?; - - let ctx = self.get_ctx(handle).await?; + let (handle, ctx) = self.get_or_create_ctx(&req).await?; let ctx = ctx.lock().await; - let portal = ctx.get_portal(handle).map_err(RpcsrvError::from)?; + let portal = ctx.get_portal(&handle).map_err(RpcsrvError::from)?; let plan = portal.logical_plan().unwrap(); @@ -250,10 +347,9 @@ impl FlightSqlService for FlightSessionHandler { async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, - _: Request, + req: Request, ) -> Result { - let handle = uuid::Uuid::new_v4().to_string(); - let ctx = self.create_ctx(&handle).await?; + let (handle, ctx) = self.get_or_create_ctx(&req).await?; let mut ctx = ctx.lock().await; ctx.prepare_portal(&handle, &query.query) @@ -284,12 +380,12 @@ impl FlightSqlService for FlightSessionHandler { async fn do_action_close_prepared_statement( &self, query: ActionClosePreparedStatementRequest, - _: Request, + req: Request, ) -> Result<(), Status> { let handle = std::str::from_utf8(&query.prepared_statement_handle) .map_err(|e| RpcsrvError::ParseError(e.to_string()))?; - - self.sessions.remove(handle); + let ctx = self.get_ctx(&req).await?; + ctx.lock().await.remove_portal(handle); Ok(()) } @@ -305,9 +401,14 @@ pub struct ActionExecutePhysicalPlan { pub handle: String, } +impl ActionExecutePhysicalPlan { + pub const TYPE_URL: &'static str = + "type.googleapis.com/glaredb.rpcsrv.ActionExecutePhysicalPlan"; +} + impl ProstMessageExt for ActionExecutePhysicalPlan { fn type_url() -> &'static str { - "type.googleapis.com/glaredb.rpcsrv.ActionExecutePhysicalPlan" + Self::TYPE_URL } fn as_any(&self) -> Any { diff --git a/crates/sqlexec/src/engine.rs b/crates/sqlexec/src/engine.rs index 85097f795..b1919429c 100644 --- a/crates/sqlexec/src/engine.rs +++ b/crates/sqlexec/src/engine.rs @@ -266,6 +266,7 @@ impl EngineStorageConfig { } /// Hold configuration and clients needed to create database sessions. +/// An engine is able to support multiple [`Session`]'s across multiple db instances pub struct Engine { /// Metastore client supervisor. supervisor: MetastoreClientSupervisor, @@ -380,6 +381,26 @@ impl Engine { vars: SessionVars, storage: SessionStorageConfig, ) -> Result { + let session = self.new_untracked_session(vars, storage).await?; + + let prev = self.session_counter.fetch_add(1, Ordering::Relaxed); + debug!(session_count = prev + 1, "new session opened"); + + Ok(TrackedSession { + inner: session, + session_counter: self.session_counter.clone(), + }) + } + + /// Create a new untracked session. + /// + /// This does not increment the session counter. + /// So any session created with this method will not prevent the engine from shutting down. + pub async fn new_untracked_session( + &self, + vars: SessionVars, + storage: SessionStorageConfig, + ) -> Result { let database_id = vars.database_id(); let metastore = self.supervisor.init_client(database_id).await?; let native = self @@ -395,7 +416,7 @@ impl Engine { }, ); - let session = Session::new( + Session::new( vars, catalog, metastore.into(), @@ -403,15 +424,7 @@ impl Engine { self.tracker.clone(), self.spill_path.clone(), self.task_scheduler.clone(), - )?; - - let prev = self.session_counter.fetch_add(1, Ordering::Relaxed); - debug!(session_count = prev + 1, "new session opened"); - - Ok(TrackedSession { - inner: session, - session_counter: self.session_counter.clone(), - }) + ) } /// Create a new remote session for plan execution. diff --git a/crates/sqlexec/src/session.rs b/crates/sqlexec/src/session.rs index fd6a62520..4b6567cf1 100644 --- a/crates/sqlexec/src/session.rs +++ b/crates/sqlexec/src/session.rs @@ -757,7 +757,7 @@ impl Session { handle.to_string(), handle, Vec::new(), - vec![Format::Binary; num_fields], + vec![Format::Text; num_fields], )?; Ok(()) } From 843941338aaa393d3aa16d838a1fea0386e9fe2e Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Dec 2023 10:29:22 -0600 Subject: [PATCH 10/15] cleanup --- crates/glaredb/src/server.rs | 3 ++- crates/rpcsrv/src/flight_handler.rs | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/glaredb/src/server.rs b/crates/glaredb/src/server.rs index 029c79bc6..4012df1a5 100644 --- a/crates/glaredb/src/server.rs +++ b/crates/glaredb/src/server.rs @@ -201,7 +201,8 @@ impl ComputeServer { self.disable_rpc_auth, self.integration_testing, ); - let flight_handler = FlightSessionHandler::try_new(&self.engine).await?; + + let flight_handler = FlightSessionHandler::new(&self.engine); tokio::spawn(async move { let mut server = Server::builder() diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index 01a6cdb21..e2759bb02 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -58,12 +58,12 @@ pub struct FlightSessionHandler { } impl FlightSessionHandler { - pub async fn try_new(engine: &Arc) -> Result { - Ok(Self { + pub fn new(engine: &Arc) -> Self { + Self { engine: engine.clone(), remote_sessions: DashMap::new(), sessions: DashMap::new(), - }) + } } async fn get_or_create_ctx( From a75efff4dcb15d3dc429ef3195f0515c9c47483c Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Dec 2023 10:47:04 -0600 Subject: [PATCH 11/15] cleanup some comments --- crates/rpcsrv/src/flight_handler.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index e2759bb02..e4e0d7379 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -321,6 +321,8 @@ impl FlightSqlService for FlightSessionHandler { let endpoint = FlightEndpoint::new().with_ticket(ticket); // Ideally, we'd start the execution here, but instead we defer it all to the "do_get" call. + // Eventually, we should asynchronously start the execution here, + // and return a `Ticket` that contains information on how to retrieve the results. let flight_info = FlightInfo::new() .with_descriptor(FlightDescriptor::new_cmd(vec![])) .with_endpoint(endpoint); @@ -343,7 +345,6 @@ impl FlightSqlService for FlightSessionHandler { Ok(tonic::Response::new(flight_info)) } - // I think it's safe to create a session for the duration of the prepared statement? async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, From 02a395d32c9ec5cec42dfd459a20d0f367f738ac Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Dec 2023 10:47:19 -0600 Subject: [PATCH 12/15] fmt --- crates/sqlexec/src/engine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/sqlexec/src/engine.rs b/crates/sqlexec/src/engine.rs index b1919429c..3d0b96488 100644 --- a/crates/sqlexec/src/engine.rs +++ b/crates/sqlexec/src/engine.rs @@ -393,7 +393,7 @@ impl Engine { } /// Create a new untracked session. - /// + /// /// This does not increment the session counter. /// So any session created with this method will not prevent the engine from shutting down. pub async fn new_untracked_session( From 92d3ac0b4e0c14a2eb40210a0fea2801276ca34d Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Dec 2023 11:01:13 -0600 Subject: [PATCH 13/15] clippy --- crates/rpcsrv/src/flight_handler.rs | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index e4e0d7379..ba5560a2f 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -80,12 +80,8 @@ impl FlightSessionHandler { return Ok((db_handle, sess)); } - let db_id = Uuid::parse_str(&db_handle).map_err(|e| { - Status::internal(format!( - "Unable to parse database handle: {}", - e.to_string() - )) - })?; + let db_id = Uuid::parse_str(&db_handle) + .map_err(|e| Status::internal(format!("Unable to parse database handle: {e}")))?; let session_vars = SessionVars::default().with_database_id(db_id, datafusion::variable::VarType::System); @@ -148,7 +144,7 @@ impl FlightSessionHandler { req: &Request, query: ActionExecutePhysicalPlan, ) -> Result::DoGetStream>, Status> { - let ctx = self.get_exec_ctx(&req).await?; + let ctx = self.get_exec_ctx(req).await?; let plan = PhysicalPlanNode::try_decode(&query.plan).map_err(RpcsrvError::from)?; let codec = ctx.extension_codec(); @@ -169,7 +165,7 @@ impl FlightSessionHandler { request: &Request, physical_plan: Arc, ) -> Result::DoGetStream>, Status> { - let ctx = self.get_exec_ctx(&request).await?; + let ctx = self.get_exec_ctx(request).await?; let stream = ctx .execute_physical(physical_plan) .map_err(RpcsrvError::from)?; @@ -227,7 +223,7 @@ impl FlightSqlService for FlightSessionHandler { ) .await .map_err(RpcsrvError::from)?; - match sqlexec::parser::parse_sql(&other) { + match sqlexec::parser::parse_sql(other) { Ok(statements) => { let lp = ctx .parsed_to_lp(statements) @@ -243,8 +239,7 @@ impl FlightSqlService for FlightSessionHandler { self.execute_physical_plan(&req, physical).await } Err(e) => Err(Status::internal(format!( - "Expected a SQL query, instead received: {}", - e.to_string() + "Expected a SQL query, instead received: {e}" ))), } } From 1de23c206ecf294812b975f283158e3113bb018f Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Dec 2023 12:15:28 -0600 Subject: [PATCH 14/15] refactor away remote sessions --- crates/rpcsrv/src/flight_handler.rs | 153 ++++++++-------------------- crates/rpcsrv/src/handler.rs | 1 - 2 files changed, 43 insertions(+), 111 deletions(-) diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index ba5560a2f..a75669da6 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -1,14 +1,11 @@ use crate::errors::{Result, RpcsrvError}; use dashmap::DashMap; -use datafusion::{arrow::ipc::writer::IpcWriteOptions, physical_plan::ExecutionPlan}; +use datafusion::{arrow::ipc::writer::IpcWriteOptions, logical_expr::LogicalPlan}; use datafusion_ext::vars::SessionVars; -use datafusion_proto::protobuf::PhysicalPlanNode; use once_cell::sync::Lazy; use sqlexec::{ - context::remote::RemoteSessionContext, engine::{Engine, SessionStorageConfig}, - extension_codec::GlareDBExtensionCodec, session::Session, OperationInfo, }; @@ -27,7 +24,6 @@ use arrow_flight::{ }, HandshakeRequest, HandshakeResponse, }; -use datafusion_proto::physical_plan::AsExecutionPlan; use futures::TryStreamExt; use futures::{lock::Mutex, Stream}; use prost::Message; @@ -48,10 +44,11 @@ const DATABASE_HEADER: &str = "x-database"; pub struct FlightSessionHandler { engine: Arc, - /// TODO: currently, we aren't removing these sessions, so this will grow forever. - /// there's no close/shutdown hook, so the sessions can at most only be tied to a single transaction, not a connection. - /// We'll want to implement a time based eviction policy, or a max size. - remote_sessions: DashMap>, + // since plans can be tied to any session, we can't use a single session to store them. + logical_plans: DashMap, + // TODO: currently, we aren't removing these sessions, so this will grow forever. + // there's no close/shutdown hook, so the sessions can at most only be tied to a single transaction, not a connection. + // We'll want to implement a time based eviction policy, or a max size. // We use [`Session`] instead of [`TrackedSession`] because tracked sessions need to be // explicitly closed, and we don't have a way to do that yet. sessions: DashMap>>, @@ -61,7 +58,7 @@ impl FlightSessionHandler { pub fn new(engine: &Arc) -> Self { Self { engine: engine.clone(), - remote_sessions: DashMap::new(), + logical_plans: DashMap::new(), sessions: DashMap::new(), } } @@ -74,7 +71,8 @@ impl FlightSessionHandler { .metadata() .get(DATABASE_HEADER) .map(|s| s.to_str().unwrap().to_string()) - .unwrap_or_else(|| Uuid::default().to_string()); + .unwrap_or_else(|| Uuid::new_v4().to_string()); + if self.sessions.contains_key(&db_handle) { let sess = self.sessions.get(&db_handle).unwrap().clone(); return Ok((db_handle, sess)); @@ -90,84 +88,41 @@ impl FlightSessionHandler { .new_untracked_session(session_vars, SessionStorageConfig::default()) .await .map_err(RpcsrvError::from)?; - - let remote_sess = self - .engine - .new_remote_session_context(db_id, SessionStorageConfig::default()) - .await - .map_err(RpcsrvError::from)?; let sess = Arc::new(Mutex::new(sess)); self.sessions.insert(db_handle.clone(), sess.clone()); - self.remote_sessions - .insert(db_handle.clone(), Arc::new(remote_sess)); Ok((db_handle, sess)) } - async fn get_exec_ctx( - &self, - request: &Request, - ) -> Result, Status> { - let db_handle = request - .metadata() - .get(DATABASE_HEADER) - .map(|s| s.to_str().unwrap().to_string()) - .unwrap_or_else(|| Uuid::default().to_string()); - - if let Some(sess) = self.remote_sessions.get(&db_handle) { - Ok(sess.clone()) - } else { - Err(Status::internal(format!( - "Unable to find session with handle {}", - db_handle - ))) - } - } - - async fn get_ctx(&self, request: &Request) -> Result>, Status> { - let db_handle = request - .metadata() - .get(DATABASE_HEADER) - .map(|s| s.to_str().unwrap().to_string()) - .unwrap_or_else(|| Uuid::default().to_string()); - - self.sessions - .get(&db_handle) - .map(|s| s.clone()) - .ok_or_else(|| { - Status::internal(format!("Unable to find session with handle {}", db_handle)) - }) - } - async fn do_action_execute_physical_plan( &self, req: &Request, - query: ActionExecutePhysicalPlan, + query: ActionExecuteLogicalPlan, ) -> Result::DoGetStream>, Status> { - let ctx = self.get_exec_ctx(req).await?; - - let plan = PhysicalPlanNode::try_decode(&query.plan).map_err(RpcsrvError::from)?; - let codec = ctx.extension_codec(); - - let plan = plan - .try_into_physical_plan( - ctx.get_datafusion_context(), - ctx.get_datafusion_context().runtime_env().as_ref(), - &codec, - ) - .map_err(RpcsrvError::from)?; - - self.execute_physical_plan(req, plan).await + let ActionExecuteLogicalPlan { handle } = query; + let lp = self + .logical_plans + .get(&handle) + .ok_or_else(|| Status::internal(format!("Unable to find logical plan {}", handle)))? + .clone(); + self.execute_lp(req, lp).await } - async fn execute_physical_plan( + async fn execute_lp( &self, request: &Request, - physical_plan: Arc, + lp: LogicalPlan, ) -> Result::DoGetStream>, Status> { - let ctx = self.get_exec_ctx(request).await?; + let (_, ctx) = self.get_or_create_ctx(request).await?; + let ctx = ctx.lock().await; + let plan = ctx + .create_physical_plan(lp, &OperationInfo::default()) + .await + .map_err(RpcsrvError::from)?; + let stream = ctx - .execute_physical(physical_plan) + .execute_physical(plan) + .await .map_err(RpcsrvError::from)?; let schema = stream.schema(); @@ -203,8 +158,8 @@ impl FlightSqlService for FlightSessionHandler { message: Any, ) -> Result::DoGetStream>, Status> { match message.type_url.as_str() { - ActionExecutePhysicalPlan::TYPE_URL => { - let action: ActionExecutePhysicalPlan = message + ActionExecuteLogicalPlan::TYPE_URL => { + let action: ActionExecuteLogicalPlan = message .unpack() .map_err(RpcsrvError::from)? .ok_or_else(|| { @@ -228,15 +183,11 @@ impl FlightSqlService for FlightSessionHandler { let lp = ctx .parsed_to_lp(statements) .await + .map_err(RpcsrvError::from)? + .try_into_datafusion_plan() .map_err(RpcsrvError::from)?; - let physical = ctx - .create_physical_plan( - lp.try_into_datafusion_plan().map_err(RpcsrvError::from)?, - &OperationInfo::default(), - ) - .await - .map_err(RpcsrvError::from)?; - self.execute_physical_plan(&req, physical).await + + self.execute_lp(&req, lp).await } Err(e) => Err(Status::internal(format!( "Expected a SQL query, instead received: {e}" @@ -288,26 +239,9 @@ impl FlightSqlService for FlightSessionHandler { .try_into_datafusion_plan() .map_err(RpcsrvError::from)?; - let physical = ctx - .create_physical_plan(plan, &OperationInfo::default()) - .await - .map_err(RpcsrvError::from)?; - - // Encode the physical plan into a protobuf message. - let physical_plan = { - let node = PhysicalPlanNode::try_from_physical_plan( - physical, - &GlareDBExtensionCodec::new_encoder(), - ) - .map_err(RpcsrvError::from)?; - - let mut buf = Vec::new(); - node.try_encode(&mut buf).map_err(RpcsrvError::from)?; - buf - }; + self.logical_plans.insert(handle.clone(), plan); - let action = ActionExecutePhysicalPlan { - plan: physical_plan, + let action = ActionExecuteLogicalPlan { handle: handle.to_string(), }; @@ -380,8 +314,9 @@ impl FlightSqlService for FlightSessionHandler { ) -> Result<(), Status> { let handle = std::str::from_utf8(&query.prepared_statement_handle) .map_err(|e| RpcsrvError::ParseError(e.to_string()))?; - let ctx = self.get_ctx(&req).await?; + let (_, ctx) = self.get_or_create_ctx(&req).await?; ctx.lock().await.remove_portal(handle); + self.logical_plans.remove(handle); Ok(()) } @@ -390,26 +325,24 @@ impl FlightSqlService for FlightSessionHandler { } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ActionExecutePhysicalPlan { - #[prost(bytes, tag = "1")] - pub plan: Vec, +pub struct ActionExecuteLogicalPlan { #[prost(string, tag = "2")] pub handle: String, } -impl ActionExecutePhysicalPlan { +impl ActionExecuteLogicalPlan { pub const TYPE_URL: &'static str = - "type.googleapis.com/glaredb.rpcsrv.ActionExecutePhysicalPlan"; + "type.googleapis.com/glaredb.rpcsrv.ActionExecuteLogicalPlan"; } -impl ProstMessageExt for ActionExecutePhysicalPlan { +impl ProstMessageExt for ActionExecuteLogicalPlan { fn type_url() -> &'static str { Self::TYPE_URL } fn as_any(&self) -> Any { Any { - type_url: ActionExecutePhysicalPlan::type_url().to_string(), + type_url: ActionExecuteLogicalPlan::type_url().to_string(), value: ::prost::Message::encode_to_vec(self).into(), } } diff --git a/crates/rpcsrv/src/handler.rs b/crates/rpcsrv/src/handler.rs index 0458ede8d..a1b28952a 100644 --- a/crates/rpcsrv/src/handler.rs +++ b/crates/rpcsrv/src/handler.rs @@ -39,7 +39,6 @@ use tonic::{Request, Response, Status, Streaming}; use tracing::info; use uuid::Uuid; -#[derive(Clone)] pub struct RpcHandler { /// Core db engine for creating sessions. engine: Arc, From eb82f2883629f91e1b6fe7ee4310e9d8eaae17ba Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Dec 2023 12:18:42 -0600 Subject: [PATCH 15/15] cleanup --- crates/rpcsrv/src/flight_handler.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/crates/rpcsrv/src/flight_handler.rs b/crates/rpcsrv/src/flight_handler.rs index a75669da6..46d00a99e 100644 --- a/crates/rpcsrv/src/flight_handler.rs +++ b/crates/rpcsrv/src/flight_handler.rs @@ -168,16 +168,11 @@ impl FlightSqlService for FlightSessionHandler { return self.do_action_execute_physical_plan(&req, action).await; } + // All non specified types should be handled as a sql query other => { - let mut ctx = self - .engine - .new_local_session_context( - SessionVars::default(), - SessionStorageConfig::default(), - ) - .await - .map_err(RpcsrvError::from)?; + let (_, ctx) = self.get_or_create_ctx(&req).await?; + let mut ctx = ctx.lock().await; match sqlexec::parser::parse_sql(other) { Ok(statements) => { let lp = ctx