diff --git a/datafusion-postgres/src/client.rs b/datafusion-postgres/src/client.rs new file mode 100644 index 0000000..7c1bab0 --- /dev/null +++ b/datafusion-postgres/src/client.rs @@ -0,0 +1,54 @@ +use pgwire::api::ClientInfo; + +// Metadata keys for session-level settings +const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms"; +const METADATA_TIMEZONE: &str = "timezone"; + +/// Get statement timeout from client metadata +pub fn get_statement_timeout(client: &C) -> Option +where + C: ClientInfo + ?Sized, +{ + client + .metadata() + .get(METADATA_STATEMENT_TIMEOUT) + .and_then(|s| s.parse::().ok()) + .map(std::time::Duration::from_millis) +} + +/// Set statement timeout in client metadata +pub fn set_statement_timeout(client: &mut C, timeout: Option) +where + C: ClientInfo + ?Sized, +{ + let metadata = client.metadata_mut(); + if let Some(duration) = timeout { + metadata.insert( + METADATA_STATEMENT_TIMEOUT.to_string(), + duration.as_millis().to_string(), + ); + } else { + metadata.remove(METADATA_STATEMENT_TIMEOUT); + } +} + +/// Get statement timeout from client metadata +pub fn get_timezone(client: &C) -> Option<&str> +where + C: ClientInfo + ?Sized, +{ + client.metadata().get(METADATA_TIMEZONE).map(|s| s.as_str()) +} + +/// Set statement timeout in client metadata +pub fn set_timezone(client: &mut C, timezone: Option<&str>) +where + C: ClientInfo + ?Sized, +{ + let metadata = client.metadata_mut(); + if let Some(timezone) = timezone { + metadata.insert(METADATA_TIMEZONE.to_string(), timezone.to_string()); + } else { + metadata.remove(METADATA_TIMEZONE); + } +} diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 0a8ef0d..bc43b4f 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -8,61 +8,30 @@ use datafusion::error::DataFusionError; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; use datafusion::sql::parser::Statement; -use log::{info, warn}; +use datafusion::sql::sqlparser; +use log::info; use pgwire::api::auth::noop::NoopStartupHandler; use pgwire::api::auth::StartupHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ - DescribePortalResponse, DescribeResponse, DescribeStatementResponse, FieldFormat, FieldInfo, - QueryResponse, Response, Tag, + DescribePortalResponse, DescribeResponse, DescribeStatementResponse, Response, Tag, }; use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::response::TransactionStatus; -use tokio::sync::Mutex; use crate::auth::AuthManager; +use crate::client; +use crate::hooks::set_show::SetShowHook; +use crate::hooks::QueryHook; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; -use datafusion::sql::sqlparser; use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; -#[async_trait] -pub trait QueryHook: Send + Sync { - /// called in simple query handler to return response directly - async fn handle_simple_query( - &self, - statement: &sqlparser::ast::Statement, - session_context: &SessionContext, - client: &(dyn ClientInfo + Send + Sync), - ) -> Option>; - - /// called at extended query parse phase, for generating `LogicalPlan`from statement - async fn handle_extended_parse_query( - &self, - statement: &sqlparser::ast::Statement, - session_context: &SessionContext, - client: &(dyn ClientInfo + Send + Sync), - ) -> Option>; - - /// called at extended query execute phase, for query execution - async fn handle_extended_query( - &self, - statement: &sqlparser::ast::Statement, - logical_plan: &LogicalPlan, - params: &ParamValues, - session_context: &SessionContext, - client: &(dyn ClientInfo + Send + Sync), - ) -> Option>; -} - -// Metadata keys for session-level settings -const METADATA_STATEMENT_TIMEOUT: &str = "statement_timeout_ms"; - /// Simple startup handler that does no authentication /// For production, use DfAuthSource with proper pgwire authentication handlers pub struct SimpleStartupHandler; @@ -75,12 +44,18 @@ pub struct HandlerFactory { } impl HandlerFactory { - pub fn new( + pub fn new(session_context: Arc, auth_manager: Arc) -> Self { + let session_service = + Arc::new(DfSessionService::new(session_context, auth_manager.clone())); + HandlerFactory { session_service } + } + + pub fn new_with_hooks( session_context: Arc, auth_manager: Arc, query_hooks: Vec>, ) -> Self { - let session_service = Arc::new(DfSessionService::new( + let session_service = Arc::new(DfSessionService::new_with_hooks( session_context, auth_manager.clone(), query_hooks, @@ -122,7 +97,6 @@ impl ErrorHandler for LoggingErrorHandler { pub struct DfSessionService { session_context: Arc, parser: Arc, - timezone: Arc>, auth_manager: Arc, query_hooks: Vec>, } @@ -131,6 +105,14 @@ impl DfSessionService { pub fn new( session_context: Arc, auth_manager: Arc, + ) -> DfSessionService { + let hooks: Vec> = vec![Arc::new(SetShowHook)]; + Self::new_with_hooks(session_context, auth_manager, hooks) + } + + pub fn new_with_hooks( + session_context: Arc, + auth_manager: Arc, query_hooks: Vec>, ) -> DfSessionService { let parser = Arc::new(Parser { @@ -141,40 +123,11 @@ impl DfSessionService { DfSessionService { session_context, parser, - timezone: Arc::new(Mutex::new("UTC".to_string())), auth_manager, query_hooks, } } - /// Get statement timeout from client metadata - fn get_statement_timeout(client: &C) -> Option - where - C: ClientInfo, - { - client - .metadata() - .get(METADATA_STATEMENT_TIMEOUT) - .and_then(|s| s.parse::().ok()) - .map(std::time::Duration::from_millis) - } - - /// Set statement timeout in client metadata - fn set_statement_timeout(client: &mut C, timeout: Option) - where - C: ClientInfo, - { - let metadata = client.metadata_mut(); - if let Some(duration) = timeout { - metadata.insert( - METADATA_STATEMENT_TIMEOUT.to_string(), - duration.as_millis().to_string(), - ); - } else { - metadata.remove(METADATA_STATEMENT_TIMEOUT); - } - } - /// Check if the current user has permission to execute a query async fn check_query_permission(&self, client: &C, query: &str) -> PgWireResult<()> where @@ -250,107 +203,6 @@ impl DfSessionService { ResourceType::All } - fn mock_show_response(name: &str, value: &str) -> PgWireResult { - let fields = vec![FieldInfo::new( - name.to_string(), - None, - None, - Type::VARCHAR, - FieldFormat::Text, - )]; - - let row = { - let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone())); - encoder.encode_field(&Some(value))?; - encoder.finish() - }; - - let row_stream = futures::stream::once(async move { row }); - Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) - } - - async fn try_respond_set_statements( - &self, - client: &mut C, - query_lower: &str, - ) -> PgWireResult> - where - C: ClientInfo, - { - if query_lower.starts_with("set") { - if query_lower.starts_with("set time zone") { - let parts: Vec<&str> = query_lower.split_whitespace().collect(); - if parts.len() >= 4 { - let tz = parts[3].trim_matches('"'); - let mut timezone = self.timezone.lock().await; - *timezone = tz.to_string(); - Ok(Some(Response::Execution(Tag::new("SET")))) - } else { - Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "42601".to_string(), - "Invalid SET TIME ZONE syntax".to_string(), - ), - ))) - } - } else if query_lower.starts_with("set statement_timeout") { - let parts: Vec<&str> = query_lower.split_whitespace().collect(); - if parts.len() >= 3 { - let timeout_str = parts[2].trim_matches('"').trim_matches('\''); - - let timeout = if timeout_str == "0" || timeout_str.is_empty() { - None - } else { - // Parse timeout value (supports ms, s, min formats) - let timeout_ms = if timeout_str.ends_with("ms") { - timeout_str.trim_end_matches("ms").parse::() - } else if timeout_str.ends_with("s") { - timeout_str - .trim_end_matches("s") - .parse::() - .map(|s| s * 1000) - } else if timeout_str.ends_with("min") { - timeout_str - .trim_end_matches("min") - .parse::() - .map(|m| m * 60 * 1000) - } else { - // Default to milliseconds - timeout_str.parse::() - }; - - match timeout_ms { - Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)), - _ => None, - } - }; - - Self::set_statement_timeout(client, timeout); - Ok(Some(Response::Execution(Tag::new("SET")))) - } else { - Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "42601".to_string(), - "Invalid SET statement_timeout syntax".to_string(), - ), - ))) - } - } else { - // pass SET query to datafusion - if let Err(e) = self.session_context.sql(query_lower).await { - warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored"); - } - - // Always return SET success - Ok(Some(Response::Execution(Tag::new("SET")))) - } - } else { - Ok(None) - } - } - async fn try_respond_transaction_statements( &self, client: &C, @@ -401,65 +253,6 @@ impl DfSessionService { _ => Ok(None), } } - - async fn try_respond_show_statements( - &self, - client: &C, - query_lower: &str, - ) -> PgWireResult> - where - C: ClientInfo, - { - if query_lower.starts_with("show ") { - match query_lower.strip_suffix(";").unwrap_or(query_lower) { - "show time zone" => { - let timezone = self.timezone.lock().await.clone(); - let resp = Self::mock_show_response("TimeZone", &timezone)?; - Ok(Some(Response::Query(resp))) - } - "show server_version" => { - let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?; - Ok(Some(Response::Query(resp))) - } - "show transaction_isolation" => { - let resp = - Self::mock_show_response("transaction_isolation", "read uncommitted")?; - Ok(Some(Response::Query(resp))) - } - "show catalogs" => { - let catalogs = self.session_context.catalog_names(); - let value = catalogs.join(", "); - let resp = Self::mock_show_response("Catalogs", &value)?; - Ok(Some(Response::Query(resp))) - } - "show search_path" => { - let default_schema = "public"; - let resp = Self::mock_show_response("search_path", default_schema)?; - Ok(Some(Response::Query(resp))) - } - "show statement_timeout" => { - let timeout = Self::get_statement_timeout(client); - let timeout_str = match timeout { - Some(duration) => format!("{}ms", duration.as_millis()), - None => "0".to_string(), - }; - let resp = Self::mock_show_response("statement_timeout", &timeout_str)?; - Ok(Some(Response::Query(resp))) - } - "show transaction isolation level" => { - let resp = Self::mock_show_response("transaction_isolation", "read_committed")?; - Ok(Some(Response::Query(resp))) - } - _ => { - info!("Unsupported show statement: {query_lower}"); - let resp = Self::mock_show_response("unsupported_show_statement", "")?; - Ok(Some(Response::Query(resp))) - } - } - } else { - Ok(None) - } - } } #[async_trait] @@ -520,20 +313,6 @@ impl SimpleQueryHandler for DfSessionService { } } - if let Some(resp) = self - .try_respond_set_statements(client, &query_lower) - .await? - { - return Ok(vec![resp]); - } - - if let Some(resp) = self - .try_respond_show_statements(client, &query_lower) - .await? - { - return Ok(vec![resp]); - } - // Check if we're in a failed transaction and block non-transaction // commands if client.transaction_status() == TransactionStatus::Error { @@ -547,7 +326,7 @@ impl SimpleQueryHandler for DfSessionService { } let df_result = { - let timeout = Self::get_statement_timeout(client); + let timeout = client::get_statement_timeout(client); if let Some(timeout_duration) = timeout { tokio::time::timeout(timeout_duration, self.session_context.sql(&query)) .await @@ -696,10 +475,6 @@ impl ExtendedQueryHandler for DfSessionService { .await?; } - if let Some(resp) = self.try_respond_set_statements(client, &query).await? { - return Ok(resp); - } - if let Some(resp) = self .try_respond_transaction_statements(client, &query) .await? @@ -707,10 +482,6 @@ impl ExtendedQueryHandler for DfSessionService { return Ok(resp); } - if let Some(resp) = self.try_respond_show_statements(client, &query).await? { - return Ok(resp); - } - // Check if we're in a failed transaction and block non-transaction // commands if client.transaction_status() == TransactionStatus::Error { @@ -743,7 +514,7 @@ impl ExtendedQueryHandler for DfSessionService { .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let dataframe = { - let timeout = Self::get_statement_timeout(client); + let timeout = client::get_statement_timeout(client); if let Some(timeout_duration) = timeout { tokio::time::timeout( timeout_duration, @@ -845,7 +616,6 @@ impl Parser { // show statement may not be supported by datafusion if sql_trimmed.starts_with("show") { - // Return a dummy plan for transaction commands - they'll be handled by transaction handler let show_schema = Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); let df_schema = show_schema.to_dfschema()?; @@ -926,127 +696,10 @@ fn ordered_param_types(types: &HashMap>) -> Vec, - } - - impl MockClient { - fn new() -> Self { - Self { - metadata: HashMap::new(), - } - } - } - - impl ClientInfo for MockClient { - fn socket_addr(&self) -> std::net::SocketAddr { - "127.0.0.1:5432".parse().unwrap() - } - - fn is_secure(&self) -> bool { - false - } - - fn protocol_version(&self) -> pgwire::messages::ProtocolVersion { - pgwire::messages::ProtocolVersion::PROTOCOL3_0 - } - - fn set_protocol_version(&mut self, _version: pgwire::messages::ProtocolVersion) {} - - fn pid_and_secret_key(&self) -> (i32, pgwire::messages::startup::SecretKey) { - (0, pgwire::messages::startup::SecretKey::I32(0)) - } - - fn set_pid_and_secret_key( - &mut self, - _pid: i32, - _secret_key: pgwire::messages::startup::SecretKey, - ) { - } - - fn state(&self) -> pgwire::api::PgWireConnectionState { - pgwire::api::PgWireConnectionState::ReadyForQuery - } - - fn set_state(&mut self, _new_state: pgwire::api::PgWireConnectionState) {} - - fn transaction_status(&self) -> pgwire::messages::response::TransactionStatus { - pgwire::messages::response::TransactionStatus::Idle - } - - fn set_transaction_status( - &mut self, - _new_status: pgwire::messages::response::TransactionStatus, - ) { - } - - fn metadata(&self) -> &HashMap { - &self.metadata - } - - fn metadata_mut(&mut self) -> &mut HashMap { - &mut self.metadata - } - - fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> { - None - } - } - - #[tokio::test] - async fn test_statement_timeout_set_and_show() { - let session_context = Arc::new(SessionContext::new()); - let auth_manager = Arc::new(AuthManager::new()); - let service = DfSessionService::new(session_context, auth_manager, vec![]); - let mut client = MockClient::new(); - - // Test setting timeout to 5000ms - let set_response = service - .try_respond_set_statements(&mut client, "set statement_timeout '5000ms'") - .await - .unwrap(); - assert!(set_response.is_some()); - - // Verify the timeout was set in client metadata - let timeout = DfSessionService::get_statement_timeout(&client); - assert_eq!(timeout, Some(Duration::from_millis(5000))); - - // Test SHOW statement_timeout - let show_response = service - .try_respond_show_statements(&client, "show statement_timeout") - .await - .unwrap(); - assert!(show_response.is_some()); - } - - #[tokio::test] - async fn test_statement_timeout_disable() { - let session_context = Arc::new(SessionContext::new()); - let auth_manager = Arc::new(AuthManager::new()); - let service = DfSessionService::new(session_context, auth_manager, vec![]); - let mut client = MockClient::new(); - - // Set timeout first - service - .try_respond_set_statements(&mut client, "set statement_timeout '1000ms'") - .await - .unwrap(); - - // Disable timeout with 0 - service - .try_respond_set_statements(&mut client, "set statement_timeout '0'") - .await - .unwrap(); - - let timeout = DfSessionService::get_statement_timeout(&client); - assert_eq!(timeout, None); - } + use super::*; + use crate::testing::MockClient; struct TestHook; @@ -1056,7 +709,7 @@ mod tests { &self, statement: &sqlparser::ast::Statement, _ctx: &SessionContext, - _client: &(dyn ClientInfo + Sync + Send), + _client: &mut (dyn ClientInfo + Sync + Send), ) -> Option> { if statement.to_string().contains("magic") { Some(Ok(Response::EmptyQuery)) @@ -1080,9 +733,9 @@ mod tests { _logical_plan: &LogicalPlan, _params: &ParamValues, _session_context: &SessionContext, - _client: &(dyn ClientInfo + Send + Sync), + _client: &mut (dyn ClientInfo + Send + Sync), ) -> Option> { - todo!(); + None } } @@ -1090,7 +743,7 @@ mod tests { async fn test_query_hooks() { let hook = TestHook; let ctx = SessionContext::new(); - let client = MockClient::new(); + let mut client = MockClient::new(); // Parse a statement that contains "magic" let parser = PostgresCompatibilityParser::new(); @@ -1098,7 +751,7 @@ mod tests { let stmt = &statements[0]; // Hook should intercept - let result = hook.handle_simple_query(stmt, &ctx, &client).await; + let result = hook.handle_simple_query(stmt, &ctx, &mut client).await; assert!(result.is_some()); // Parse a normal statement @@ -1106,7 +759,7 @@ mod tests { let stmt = &statements[0]; // Hook should not intercept - let result = hook.handle_simple_query(stmt, &ctx, &client).await; + let result = hook.handle_simple_query(stmt, &ctx, &mut client).await; assert!(result.is_none()); } } diff --git a/datafusion-postgres/src/hooks/mod.rs b/datafusion-postgres/src/hooks/mod.rs new file mode 100644 index 0000000..6df8d6e --- /dev/null +++ b/datafusion-postgres/src/hooks/mod.rs @@ -0,0 +1,40 @@ +pub mod set_show; + +use async_trait::async_trait; + +use datafusion::common::ParamValues; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::SessionContext; +use datafusion::sql::sqlparser::ast::Statement; +use pgwire::api::results::Response; +use pgwire::api::ClientInfo; +use pgwire::error::PgWireResult; + +#[async_trait] +pub trait QueryHook: Send + Sync { + /// called in simple query handler to return response directly + async fn handle_simple_query( + &self, + statement: &Statement, + session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option>; + + /// called at extended query parse phase, for generating `LogicalPlan`from statement + async fn handle_extended_parse_query( + &self, + sql: &Statement, + session_context: &SessionContext, + client: &(dyn ClientInfo + Send + Sync), + ) -> Option>; + + /// called at extended query execute phase, for query execution + async fn handle_extended_query( + &self, + statement: &Statement, + logical_plan: &LogicalPlan, + params: &ParamValues, + session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option>; +} diff --git a/datafusion-postgres/src/hooks/set_show.rs b/datafusion-postgres/src/hooks/set_show.rs new file mode 100644 index 0000000..5141187 --- /dev/null +++ b/datafusion-postgres/src/hooks/set_show.rs @@ -0,0 +1,325 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::{ParamValues, ToDFSchema}; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::SessionContext; +use datafusion::sql::sqlparser::ast::Statement; +use log::{info, warn}; +use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; +use pgwire::api::ClientInfo; +use pgwire::error::{PgWireError, PgWireResult}; +use postgres_types::Type; + +use crate::client; +use crate::QueryHook; + +#[derive(Debug)] +pub struct SetShowHook; + +#[async_trait] +impl QueryHook for SetShowHook { + /// called in simple query handler to return response directly + async fn handle_simple_query( + &self, + statement: &Statement, + session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option> { + match statement { + Statement::Set { .. } => { + let query = statement.to_string(); + let query_lower = query.to_lowercase(); + + try_respond_set_statements(client, &query_lower, session_context).await + } + Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => { + let query = statement.to_string(); + let query_lower = query.to_lowercase(); + + try_respond_show_statements(client, &query_lower, session_context).await + } + _ => None, + } + } + + async fn handle_extended_parse_query( + &self, + stmt: &Statement, + _session_context: &SessionContext, + _client: &(dyn ClientInfo + Send + Sync), + ) -> Option> { + let sql_lower = stmt.to_string().to_lowercase(); + let sql_trimmed = sql_lower.trim(); + + if sql_trimmed.starts_with("show") { + let show_schema = + Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)])); + let result = show_schema + .to_dfschema() + .map(|df_schema| { + LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(df_schema), + }) + }) + .map_err(|e| PgWireError::ApiError(Box::new(e))); + Some(result) + } else if sql_trimmed.starts_with("set") { + let show_schema = Arc::new(Schema::new(Vec::::new())); + let result = show_schema + .to_dfschema() + .map(|df_schema| { + LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(df_schema), + }) + }) + .map_err(|e| PgWireError::ApiError(Box::new(e))); + Some(result) + } else { + None + } + } + + async fn handle_extended_query( + &self, + statement: &Statement, + _logical_plan: &LogicalPlan, + _params: &ParamValues, + session_context: &SessionContext, + client: &mut (dyn ClientInfo + Send + Sync), + ) -> Option> { + match statement { + Statement::Set { .. } => { + let query = statement.to_string(); + let query_lower = query.to_lowercase(); + + try_respond_set_statements(client, &query_lower, session_context).await + } + Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => { + let query = statement.to_string(); + let query_lower = query.to_lowercase(); + + try_respond_show_statements(client, &query_lower, session_context).await + } + _ => None, + } + } +} + +fn mock_show_response(name: &str, value: &str) -> PgWireResult { + let fields = vec![FieldInfo::new( + name.to_string(), + None, + None, + Type::VARCHAR, + FieldFormat::Text, + )]; + + let row = { + let mut encoder = DataRowEncoder::new(Arc::new(fields.clone())); + encoder.encode_field(&Some(value))?; + encoder.finish() + }; + + let row_stream = futures::stream::once(async move { row }); + Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) +} + +async fn try_respond_set_statements( + client: &mut C, + query_lower: &str, + session_context: &SessionContext, +) -> Option> +where + C: ClientInfo + Send + Sync + ?Sized, +{ + if query_lower.starts_with("set") { + let result = if query_lower.starts_with("set time zone") { + let parts: Vec<&str> = query_lower.split_whitespace().collect(); + if parts.len() >= 4 { + let tz = parts[3].trim_matches('"'); + client::set_timezone(client, Some(tz)); + Ok(Response::Execution(Tag::new("SET"))) + } else { + Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42601".to_string(), + "Invalid SET TIME ZONE syntax".to_string(), + ), + ))) + } + } else if query_lower.starts_with("set statement_timeout") { + let parts: Vec<&str> = query_lower.split_whitespace().collect(); + if parts.len() >= 3 { + let timeout_str = parts[2].trim_matches('"').trim_matches('\''); + + let timeout = if timeout_str == "0" || timeout_str.is_empty() { + None + } else { + // Parse timeout value (supports ms, s, min formats) + let timeout_ms = if timeout_str.ends_with("ms") { + timeout_str.trim_end_matches("ms").parse::() + } else if timeout_str.ends_with("s") { + timeout_str + .trim_end_matches("s") + .parse::() + .map(|s| s * 1000) + } else if timeout_str.ends_with("min") { + timeout_str + .trim_end_matches("min") + .parse::() + .map(|m| m * 60 * 1000) + } else { + // Default to milliseconds + timeout_str.parse::() + }; + + match timeout_ms { + Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)), + _ => None, + } + }; + + client::set_statement_timeout(client, timeout); + Ok(Response::Execution(Tag::new("SET"))) + } else { + Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42601".to_string(), + "Invalid SET statement_timeout syntax".to_string(), + ), + ))) + } + } else { + // pass SET query to datafusion + if let Err(e) = session_context.sql(query_lower).await { + warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored"); + } + + // Always return SET success + Ok(Response::Execution(Tag::new("SET"))) + }; + + Some(result) + } else { + None + } +} + +async fn try_respond_show_statements( + client: &C, + query_lower: &str, + session_context: &SessionContext, +) -> Option> +where + C: ClientInfo + ?Sized, +{ + if query_lower.starts_with("show ") { + let result = match query_lower.strip_suffix(";").unwrap_or(query_lower) { + "show time zone" => { + let timezone = client::get_timezone(client).unwrap_or("UTC"); + mock_show_response("TimeZone", timezone).map(Response::Query) + } + "show server_version" => { + mock_show_response("server_version", "15.0 (DataFusion)").map(Response::Query) + } + "show transaction_isolation" => { + mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query) + } + "show catalogs" => { + let catalogs = session_context.catalog_names(); + let value = catalogs.join(", "); + mock_show_response("Catalogs", &value).map(Response::Query) + } + "show search_path" => { + let default_schema = "public"; + mock_show_response("search_path", default_schema).map(Response::Query) + } + "show statement_timeout" => { + let timeout = client::get_statement_timeout(client); + let timeout_str = match timeout { + Some(duration) => format!("{}ms", duration.as_millis()), + None => "0".to_string(), + }; + mock_show_response("statement_timeout", &timeout_str).map(Response::Query) + } + "show transaction isolation level" => { + mock_show_response("transaction_isolation", "read_committed").map(Response::Query) + } + _ => { + info!("Unsupported show statement: {query_lower}"); + mock_show_response("unsupported_show_statement", "").map(Response::Query) + } + }; + Some(result) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + use crate::testing::MockClient; + + #[tokio::test] + async fn test_statement_timeout_set_and_show() { + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + + // Test setting timeout to 5000ms + let set_response = try_respond_set_statements( + &mut client, + "set statement_timeout '5000ms'", + &session_context, + ) + .await; + + assert!(set_response.is_some()); + assert!(set_response.unwrap().is_ok()); + + // Verify the timeout was set in client metadata + let timeout = client::get_statement_timeout(&client); + assert_eq!(timeout, Some(Duration::from_millis(5000))); + + // Test SHOW statement_timeout + let show_response = + try_respond_show_statements(&client, "show statement_timeout", &session_context).await; + + assert!(show_response.is_some()); + assert!(show_response.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_statement_timeout_disable() { + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + + // Set timeout first + let resp = try_respond_set_statements( + &mut client, + "set statement_timeout '1000ms'", + &session_context, + ) + .await; + assert!(resp.is_some()); + assert!(resp.unwrap().is_ok()); + + // Disable timeout with 0 + let resp = + try_respond_set_statements(&mut client, "set statement_timeout '0'", &session_context) + .await; + assert!(resp.is_some()); + assert!(resp.unwrap().is_ok()); + + let timeout = client::get_statement_timeout(&client); + assert_eq!(timeout, None); + } +} diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index cdfe6dd..4ced1fc 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -1,12 +1,15 @@ +pub mod auth; +pub(crate) mod client; mod handlers; +pub mod hooks; +#[cfg(any(test, debug_assertions))] +pub mod testing; use std::fs::File; use std::io::{BufReader, Error as IOError, ErrorKind}; use std::sync::Arc; use datafusion::prelude::SessionContext; - -pub mod auth; use getset::{Getters, Setters, WithSetters}; use log::{info, warn}; use pgwire::api::PgWireServerHandlers; @@ -20,7 +23,8 @@ use tokio_rustls::TlsAcceptor; use crate::auth::AuthManager; use handlers::HandlerFactory; -pub use handlers::{DfSessionService, Parser, QueryHook}; +pub use handlers::{DfSessionService, Parser}; +pub use hooks::QueryHook; /// re-exports pub use arrow_pg; @@ -85,7 +89,25 @@ pub async fn serve( auth_manager: Arc, ) -> Result<(), std::io::Error> { // Create the handler factory with authentication - let factory = Arc::new(HandlerFactory::new(session_context, auth_manager, vec![])); + let factory = Arc::new(HandlerFactory::new(session_context, auth_manager)); + + serve_with_handlers(factory, opts).await +} + +/// Serve the Datafusion `SessionContext` with Postgres protocol, using custom +/// query processing hooks. +pub async fn serve_with_hooks( + session_context: Arc, + opts: &ServerOptions, + auth_manager: Arc, + hooks: Vec>, +) -> Result<(), std::io::Error> { + // Create the handler factory with authentication + let factory = Arc::new(HandlerFactory::new_with_hooks( + session_context, + auth_manager, + hooks, + )); serve_with_handlers(factory, opts).await } diff --git a/datafusion-postgres/tests/common/mod.rs b/datafusion-postgres/src/testing.rs similarity index 94% rename from datafusion-postgres/tests/common/mod.rs rename to datafusion-postgres/src/testing.rs index 6c646ff..b4a51b5 100644 --- a/datafusion-postgres/tests/common/mod.rs +++ b/datafusion-postgres/src/testing.rs @@ -2,7 +2,6 @@ use std::{collections::HashMap, sync::Arc}; use datafusion::prelude::SessionContext; use datafusion_pg_catalog::pg_catalog::setup_pg_catalog; -use datafusion_postgres::{auth::AuthManager, DfSessionService}; use futures::Sink; use pgwire::{ api::{ClientInfo, ClientPortalStore, PgWireConnectionState, METADATA_USER}, @@ -11,6 +10,8 @@ use pgwire::{ }, }; +use crate::{auth::AuthManager, DfSessionService}; + pub fn setup_handlers() -> DfSessionService { let session_context = SessionContext::new(); setup_pg_catalog( @@ -20,11 +21,7 @@ pub fn setup_handlers() -> DfSessionService { ) .expect("Failed to setup sesession context"); - DfSessionService::new( - Arc::new(session_context), - Arc::new(AuthManager::new()), - vec![], - ) + DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new())) } #[derive(Debug, Default)] diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index f99817e..6602b6d 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -1,8 +1,7 @@ -mod common; - -use common::*; use pgwire::api::query::SimpleQueryHandler; +use datafusion_postgres::testing::*; + const DBEAVER_QUERIES: &[&str] = &[ "SET extra_float_digits = 3", "SET application_name = 'PostgreSQL JDBC Driver'", diff --git a/datafusion-postgres/tests/metabase.rs b/datafusion-postgres/tests/metabase.rs index 3c15700..3e9b096 100644 --- a/datafusion-postgres/tests/metabase.rs +++ b/datafusion-postgres/tests/metabase.rs @@ -1,8 +1,7 @@ -mod common; - -use common::*; use pgwire::api::query::SimpleQueryHandler; +use datafusion_postgres::testing::*; + const METABASE_QUERIES: &[&str] = &[ "SET extra_float_digits = 2", "SET application_name = 'Metabase v0.55.1 [f8f63fdf-d8f8-4573-86ea-4fe4a9548041]'", diff --git a/datafusion-postgres/tests/pgcli.rs b/datafusion-postgres/tests/pgcli.rs index dc59e9f..cc15f1d 100644 --- a/datafusion-postgres/tests/pgcli.rs +++ b/datafusion-postgres/tests/pgcli.rs @@ -1,8 +1,7 @@ -mod common; - -use common::*; use pgwire::api::query::SimpleQueryHandler; +use datafusion_postgres::testing::*; + const PGCLI_QUERIES: &[&str] = &[ "SELECT 1", "show time zone", diff --git a/datafusion-postgres/tests/psql.rs b/datafusion-postgres/tests/psql.rs index d88649f..e960757 100644 --- a/datafusion-postgres/tests/psql.rs +++ b/datafusion-postgres/tests/psql.rs @@ -1,8 +1,7 @@ -mod common; - -use common::*; use pgwire::api::query::SimpleQueryHandler; +use datafusion_postgres::testing::*; + const PSQL_QUERIES: &[&str] = &[ "SELECT c.oid, n.nspname,