diff --git a/examples/bench.rs b/examples/bench.rs index e48b9085..70986127 100644 --- a/examples/bench.rs +++ b/examples/bench.rs @@ -6,6 +6,7 @@ use futures::StreamExt; use tokio::net::TcpListener; use pgwire::api::auth::noop::NoopStartupHandler; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response}; use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type}; @@ -73,6 +74,7 @@ pub async fn main() { PlaceholderExtendedQueryHandler, ))); let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); + let noop_copy_handler = Arc::new(NoopCopyHandler); let server_addr = "127.0.0.1:5433"; let listener = TcpListener::bind(server_addr).await.unwrap(); @@ -82,6 +84,8 @@ pub async fn main() { let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); + let copy_handler_ref = noop_copy_handler.clone(); + tokio::spawn(async move { process_socket( incoming_socket.0, @@ -89,6 +93,7 @@ pub async fn main() { authenticator_ref, processor_ref, placeholder_ref, + copy_handler_ref, ) .await }); diff --git a/examples/duckdb.rs b/examples/duckdb.rs index d2caf07f..5f04494b 100644 --- a/examples/duckdb.rs +++ b/examples/duckdb.rs @@ -8,6 +8,7 @@ use futures::stream; use futures::Stream; use pgwire::api::auth::md5pass::{hash_md5_password, MakeMd5PasswordAuthStartupHandler}; use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ @@ -350,6 +351,7 @@ pub async fn main() { Arc::new(parameters), )); let processor = Arc::new(MakeDuckDBBackend::new()); + let noop_copy_handler = Arc::new(NoopCopyHandler); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); @@ -358,6 +360,8 @@ pub async fn main() { let incoming_socket = listener.accept().await.unwrap(); let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); + let copy_handler_ref = noop_copy_handler.clone(); + tokio::spawn(async move { process_socket( incoming_socket.0, @@ -365,6 +369,7 @@ pub async fn main() { authenticator_ref, processor_ref.clone(), processor_ref, + copy_handler_ref, ) .await }); diff --git a/examples/gluesql.rs b/examples/gluesql.rs index 8e955456..b81b714b 100644 --- a/examples/gluesql.rs +++ b/examples/gluesql.rs @@ -6,6 +6,7 @@ use tokio::net::TcpListener; use gluesql::prelude::*; use pgwire::api::auth::noop::NoopStartupHandler; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type}; @@ -170,6 +171,7 @@ pub async fn main() { PlaceholderExtendedQueryHandler, ))); let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); + let noop_copy_handler = Arc::new(NoopCopyHandler); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); @@ -179,6 +181,8 @@ pub async fn main() { let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); + let copy_handler_ref = noop_copy_handler.clone(); + tokio::spawn(async move { process_socket( incoming_socket.0, @@ -186,6 +190,7 @@ pub async fn main() { authenticator_ref, processor_ref, placeholder_ref, + copy_handler_ref, ) .await }); diff --git a/examples/scram.rs b/examples/scram.rs index 6e4ad1a9..5c4b0024 100644 --- a/examples/scram.rs +++ b/examples/scram.rs @@ -12,6 +12,7 @@ use tokio_rustls::TlsAcceptor; use pgwire::api::auth::scram::{gen_salted_password, MakeSASLScramAuthStartupHandler}; use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{Response, Tag}; @@ -83,6 +84,7 @@ pub async fn main() { let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( PlaceholderExtendedQueryHandler, ))); + let noop_copy_handler = Arc::new(NoopCopyHandler); let mut authenticator = MakeSASLScramAuthStartupHandler::new( Arc::new(DummyAuthDB), Arc::new(DefaultServerParameterProvider::default()), @@ -103,6 +105,8 @@ pub async fn main() { let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); + let copy_handler_ref = noop_copy_handler.clone(); + tokio::spawn(async move { process_socket( incoming_socket.0, @@ -110,6 +114,7 @@ pub async fn main() { authenticator_ref, processor_ref, placeholder_ref, + copy_handler_ref, ) .await }); diff --git a/examples/secure_server.rs b/examples/secure_server.rs index 7af5d47d..8bfe3ece 100644 --- a/examples/secure_server.rs +++ b/examples/secure_server.rs @@ -11,6 +11,7 @@ use tokio_rustls::rustls::ServerConfig; use tokio_rustls::TlsAcceptor; use pgwire::api::auth::noop::NoopStartupHandler; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type}; @@ -84,6 +85,7 @@ pub async fn main() { PlaceholderExtendedQueryHandler, ))); let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); + let noop_copy_handler = Arc::new(NoopCopyHandler); let server_addr = "127.0.0.1:5433"; let tls_acceptor = Arc::new(setup_tls().unwrap()); @@ -96,6 +98,7 @@ pub async fn main() { let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); + let copy_handler_ref = noop_copy_handler.clone(); tokio::spawn(async move { process_socket( incoming_socket.0, @@ -103,6 +106,7 @@ pub async fn main() { authenticator_ref, processor_ref, placeholder_ref, + copy_handler_ref, ) .await }); diff --git a/examples/server.rs b/examples/server.rs index f7ea7031..f51350c3 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -6,6 +6,7 @@ use futures::{stream, Sink, SinkExt, StreamExt}; use tokio::net::TcpListener; use pgwire::api::auth::noop::NoopStartupHandler; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type}; @@ -76,6 +77,7 @@ pub async fn main() { PlaceholderExtendedQueryHandler, ))); let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); + let noop_copy_handler = Arc::new(NoopCopyHandler); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); @@ -85,6 +87,7 @@ pub async fn main() { let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); + let copy_handler_ref = noop_copy_handler.clone(); tokio::spawn(async move { process_socket( incoming_socket.0, @@ -92,6 +95,7 @@ pub async fn main() { authenticator_ref, processor_ref, placeholder_ref, + copy_handler_ref, ) .await }); diff --git a/examples/sqlite.rs b/examples/sqlite.rs index a085cc70..8c45c24d 100644 --- a/examples/sqlite.rs +++ b/examples/sqlite.rs @@ -3,8 +3,10 @@ use std::sync::{Arc, Mutex}; use async_trait::async_trait; use futures::stream; use futures::Stream; + use pgwire::api::auth::md5pass::{hash_md5_password, MakeMd5PasswordAuthStartupHandler}; use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ @@ -306,6 +308,7 @@ pub async fn main() { Arc::new(parameters), )); let processor = Arc::new(MakeSqliteBackend::new()); + let noop_copy_handler = Arc::new(NoopCopyHandler); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); @@ -314,6 +317,8 @@ pub async fn main() { let incoming_socket = listener.accept().await.unwrap(); let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); + let copy_handler_ref = noop_copy_handler.clone(); + tokio::spawn(async move { process_socket( incoming_socket.0, @@ -321,6 +326,7 @@ pub async fn main() { authenticator_ref, processor_ref.clone(), processor_ref, + copy_handler_ref, ) .await }); diff --git a/src/api/copy.rs b/src/api/copy.rs new file mode 100644 index 00000000..09a71ad0 --- /dev/null +++ b/src/api/copy.rs @@ -0,0 +1,101 @@ +use async_trait::async_trait; +use futures::sink::{Sink, SinkExt}; +use std::fmt::Debug; + +use crate::error::{PgWireError, PgWireResult}; +use crate::messages::copy::{ + CopyBothResponse, CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, +}; +use crate::messages::PgWireBackendMessage; + +use super::ClientInfo; + +/// handler for copy messages +#[async_trait] +pub trait CopyHandler: Send + Sync { + async fn on_copy_data(&self, _client: &mut C, _copy_data: CopyData) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, + { + Ok(()) + } + + async fn on_copy_done(&self, _client: &mut C, _done: CopyDone) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, + { + Ok(()) + } + + async fn on_copy_fail(&self, _client: &mut C, _fail: CopyFail) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, + { + Ok(()) + } +} + +pub async fn send_copy_in_response( + client: &mut C, + overall_format: i8, + columns: usize, + column_formats: Vec, +) -> PgWireResult<()> +where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, +{ + let resp = CopyInResponse::new(overall_format, columns as i16, column_formats); + client + .send(PgWireBackendMessage::CopyInResponse(resp)) + .await?; + Ok(()) +} + +pub async fn send_copy_out_response( + client: &mut C, + overall_format: i8, + columns: usize, + column_formats: Vec, +) -> PgWireResult<()> +where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, +{ + let resp = CopyOutResponse::new(overall_format, columns as i16, column_formats); + client + .send(PgWireBackendMessage::CopyOutResponse(resp)) + .await?; + Ok(()) +} + +pub async fn send_copy_both_response( + client: &mut C, + overall_format: i8, + columns: usize, + column_formats: Vec, +) -> PgWireResult<()> +where + C: ClientInfo + Sink + Unpin + Send + Sync, + C::Error: Debug, + PgWireError: From<>::Error>, +{ + let resp = CopyBothResponse::new(overall_format, columns as i16, column_formats); + client + .send(PgWireBackendMessage::CopyBothResponse(resp)) + .await?; + Ok(()) +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct NoopCopyHandler; + +impl CopyHandler for NoopCopyHandler {} diff --git a/src/api/mod.rs b/src/api/mod.rs index b6594b6d..70be7098 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -7,6 +7,7 @@ use std::sync::Arc; pub use postgres_types::Type; pub mod auth; +pub mod copy; pub mod portal; pub mod query; pub mod results; @@ -22,6 +23,7 @@ pub enum PgWireConnectionState { AuthenticationInProgress, ReadyForQuery, QueryInProgress, + CopyInProgress, AwaitingSync, } diff --git a/src/tokio.rs b/src/tokio.rs index d7be1922..5a19c0b0 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -10,6 +10,7 @@ use tokio_rustls::TlsAcceptor; use tokio_util::codec::{Decoder, Encoder, Framed}; use crate::api::auth::StartupHandler; +use crate::api::copy::CopyHandler; use crate::api::query::ExtendedQueryHandler; use crate::api::query::SimpleQueryHandler; use crate::api::{ClientInfo, ClientPortalStore, DefaultClient, PgWireConnectionState}; @@ -93,18 +94,20 @@ impl ClientPortalStore for Framed> { } } -async fn process_message( +async fn process_message( message: PgWireFrontendMessage, socket: &mut Framed>, authenticator: Arc, query_handler: Arc, extended_query_handler: Arc, + copy_handler: Arc, ) -> PgWireResult<()> where S: AsyncRead + AsyncWrite + Unpin + Send + Sync, A: StartupHandler, Q: SimpleQueryHandler, EQ: ExtendedQueryHandler, + C: CopyHandler, { match socket.codec().client_info.state() { PgWireConnectionState::AwaitingStartup @@ -146,6 +149,15 @@ where PgWireFrontendMessage::Close(close) => { extended_query_handler.on_close(socket, close).await?; } + PgWireFrontendMessage::CopyData(copy_data) => { + copy_handler.on_copy_data(socket, copy_data).await?; + } + PgWireFrontendMessage::CopyDone(copy_done) => { + copy_handler.on_copy_done(socket, copy_done).await?; + } + PgWireFrontendMessage::CopyFail(copy_fail) => { + copy_handler.on_copy_fail(socket, copy_fail).await?; + } _ => {} } } @@ -235,17 +247,19 @@ async fn peek_for_sslrequest( Ok(ssl) } -pub async fn process_socket( +pub async fn process_socket( tcp_socket: TcpStream, tls_acceptor: Option>, startup_handler: Arc, query_handler: Arc, extended_query_handler: Arc, + copy_handler: Arc, ) -> Result<(), IOError> where A: StartupHandler, Q: SimpleQueryHandler, EQ: ExtendedQueryHandler, + C: CopyHandler, { let addr = tcp_socket.peer_addr()?; tcp_socket.set_nodelay(true)?; @@ -266,6 +280,7 @@ where startup_handler.clone(), query_handler.clone(), extended_query_handler.clone(), + copy_handler.clone(), ) .await { @@ -290,6 +305,7 @@ where startup_handler.clone(), query_handler.clone(), extended_query_handler.clone(), + copy_handler.clone(), ) .await { diff --git a/tests-integration/test-server/src/main.rs b/tests-integration/test-server/src/main.rs index a086ecc0..536cf41b 100644 --- a/tests-integration/test-server/src/main.rs +++ b/tests-integration/test-server/src/main.rs @@ -4,8 +4,10 @@ use std::time::{Duration, SystemTime}; use async_trait::async_trait; use futures::stream; use futures::StreamExt; + use pgwire::api::auth::scram::{gen_salted_password, MakeSASLScramAuthStartupHandler}; use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; +use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ @@ -213,6 +215,7 @@ pub async fn main() { ); authenticator.set_iterations(ITERATIONS); let processor = Arc::new(MakeDummyDatabase); + let noop_copy_handler = Arc::new(NoopCopyHandler); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); @@ -221,6 +224,8 @@ pub async fn main() { let incoming_socket = listener.accept().await.unwrap(); let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); + let copy_handler_ref = noop_copy_handler.clone(); + tokio::spawn(async move { process_socket( incoming_socket.0, @@ -228,6 +233,7 @@ pub async fn main() { authenticator_ref, processor_ref.clone(), processor_ref, + copy_handler_ref, ) .await });