diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 422dcbda9..942c36712 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,6 +64,7 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' +wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust +host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 68b5aa6e5..1ae74a636 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -11,6 +11,7 @@ use std::str; use crate::Oid; +// top-level message tags pub const PARSE_COMPLETE_TAG: u8 = b'1'; pub const BIND_COMPLETE_TAG: u8 = b'2'; pub const CLOSE_COMPLETE_TAG: u8 = b'3'; @@ -22,6 +23,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -33,6 +35,10 @@ pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; pub const ROW_DESCRIPTION_TAG: u8 = b'T'; pub const READY_FOR_QUERY_TAG: u8 = b'Z'; +// replication message tags +pub const XLOG_DATA_TAG: u8 = b'w'; +pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; + #[derive(Debug, Copy, Clone)] pub struct Header { tag: u8, @@ -93,6 +99,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +197,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -278,6 +295,57 @@ impl Message { } } +/// An enum representing Postgres backend replication messages. +#[non_exhaustive] +pub enum ReplicationMessage { + XLogData(XLogDataBody), + PrimaryKeepAlive(PrimaryKeepAliveBody), +} + +impl ReplicationMessage { + pub fn parse(bytes: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: bytes.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let replication_message = match tag { + XLOG_DATA_TAG => { + let wal_start = buf.read_u64::()?; + let wal_end = buf.read_u64::()?; + let timestamp = buf.read_i64::()?; + let storage = buf.read_all(); + ReplicationMessage::XLogData(XLogDataBody { + wal_start, + wal_end, + timestamp, + storage, + }) + } + PRIMARY_KEEPALIVE_TAG => { + let wal_end = buf.read_u64::()?; + let timestamp = buf.read_i64::()?; + let reply = buf.read_u8()?; + ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { + wal_end, + timestamp, + reply, + }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(replication_message) + } +} + struct Buffer { bytes: Bytes, idx: usize, @@ -524,6 +592,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + storage: Bytes, + len: u16, + format: u8, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + pub struct DataRowBody { storage: Bytes, len: u16, @@ -776,6 +865,63 @@ impl RowDescriptionBody { } } +pub struct XLogDataBody { + wal_start: u64, + wal_end: u64, + timestamp: i64, + storage: Bytes, +} + +impl XLogDataBody { + #[inline] + pub fn wal_start(&self) -> u64 { + self.wal_start + } + + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + pub fn data(&self) -> &[u8] { + &self.storage + } + + #[inline] + pub fn into_bytes(self) -> Bytes { + self.storage + } +} + +pub struct PrimaryKeepAliveBody { + wal_end: u64, + timestamp: i64, + reply: u8, +} + +impl PrimaryKeepAliveBody { + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + pub fn reply(&self) -> u8 { + self.reply + } +} + pub struct Fields<'a> { buf: &'a [u8], remaining: u16, diff --git a/postgres-protocol/src/message/frontend.rs b/postgres-protocol/src/message/frontend.rs index 5d0a8ff8c..6af7d6a1e 100644 --- a/postgres-protocol/src/message/frontend.rs +++ b/postgres-protocol/src/message/frontend.rs @@ -132,6 +132,48 @@ pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { }) } +#[inline] +pub fn standby_status_update( + write_lsn: u64, + flush_lsn: u64, + apply_lsn: u64, + timestamp: i64, + reply: u8, + buf: &mut BytesMut, +) -> io::Result<()> { + buf.put_u8(b'd'); + write_body(buf, |buf| { + buf.put_u8(b'r'); + buf.put_u64(write_lsn); + buf.put_u64(flush_lsn); + buf.put_u64(apply_lsn); + buf.put_i64(timestamp); + buf.put_u8(reply); + Ok(()) + }) +} + +#[inline] +pub fn hot_standby_feedback( + timestamp: i64, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + buf: &mut BytesMut, +) -> io::Result<()> { + buf.put_u8(b'd'); + write_body(buf, |buf| { + buf.put_u8(b'h'); + buf.put_i64(timestamp); + buf.put_u32(global_xmin); + buf.put_u32(global_xmin_epoch); + buf.put_u32(catalog_xmin); + buf.put_u32(catalog_xmin_epoch); + Ok(()) + }) +} + pub struct CopyData { buf: T, len: i32, diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index f659663e1..b7fc479f3 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -45,7 +45,7 @@ futures = "0.3" log = "0.4" parking_lot = "0.11" percent-encoding = "2.0" -pin-project-lite = "0.2" +pin-project = "1.0" phf = "0.8" postgres-protocol = { version = "0.5.0", path = "../postgres-protocol" } postgres-types = { version = "0.1.2", path = "../postgres-types" } diff --git a/tokio-postgres/src/binary_copy.rs b/tokio-postgres/src/binary_copy.rs index 20064c728..cfe5bd319 100644 --- a/tokio-postgres/src/binary_copy.rs +++ b/tokio-postgres/src/binary_copy.rs @@ -5,7 +5,7 @@ use crate::{slice_iter, CopyInSink, CopyOutStream, Error}; use byteorder::{BigEndian, ByteOrder}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::{ready, SinkExt, Stream}; -use pin_project_lite::pin_project; +use pin_project::pin_project; use postgres_types::BorrowToSql; use std::convert::TryFrom; use std::io; @@ -18,16 +18,15 @@ use std::task::{Context, Poll}; const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0"; const HEADER_LEN: usize = MAGIC.len() + 4 + 4; -pin_project! { - /// A type which serializes rows into the PostgreSQL binary copy format. - /// - /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. - pub struct BinaryCopyInWriter { - #[pin] - sink: CopyInSink, - types: Vec, - buf: BytesMut, - } +/// A type which serializes rows into the PostgreSQL binary copy format. +/// +/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. +#[pin_project] +pub struct BinaryCopyInWriter { + #[pin] + sink: CopyInSink, + types: Vec, + buf: BytesMut, } impl BinaryCopyInWriter { @@ -115,14 +114,13 @@ struct Header { has_oids: bool, } -pin_project! { - /// A stream of rows deserialized from the PostgreSQL binary copy format. - pub struct BinaryCopyOutStream { - #[pin] - stream: CopyOutStream, - types: Arc>, - header: Option
, - } +/// A stream of rows deserialized from the PostgreSQL binary copy format. +#[pin_project] +pub struct BinaryCopyOutStream { + #[pin] + stream: CopyOutStream, + types: Arc>, + header: Option
, } impl BinaryCopyOutStream { diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 359a7cd16..0a8c6eec8 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -70,7 +70,10 @@ pub struct InnerClient { impl InnerClient { pub fn send(&self, messages: RequestMessages) -> Result { let (sender, receiver) = mpsc::channel(1); - let request = Request { messages, sender }; + let request = Request { + messages: messages, + sender: Some(sender), + }; self.sender .unbounded_send(request) .map_err(|_| Error::closed())?; @@ -81,6 +84,21 @@ impl InnerClient { }) } + // Send a message for the existing entry in the pipeline; don't + // create a new entry in the pipeline. This is needed for CopyBoth + // mode (i.e. streaming replication), where the client may send a + // new message that is part of the existing request. + pub fn unpipelined_send(&self, messages: RequestMessages) -> Result<(), Error> { + let request = Request { + messages: messages, + sender: None, + }; + self.sender + .unbounded_send(request) + .map_err(|_| Error::closed())?; + Ok(()) + } + pub fn typeinfo(&self) -> Option { self.state.lock().typeinfo.clone() } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index da171cc79..d53b7c04a 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -56,6 +56,16 @@ pub enum ChannelBinding { Require, } +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq)] pub enum Host { @@ -159,6 +169,7 @@ pub struct Config { pub(crate) keepalives_idle: Duration, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option, } impl Default for Config { @@ -184,6 +195,7 @@ impl Config { keepalives_idle: Duration::from_secs(2 * 60 * 60), target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + replication_mode: None, } } @@ -387,6 +399,17 @@ impl Config { self.channel_binding } + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -534,6 +557,12 @@ impl fmt::Debug for Config { } } + let replication_mode_str = match self.replication_mode { + None => "false", + Some(ReplicationMode::Physical) => "true", + Some(ReplicationMode::Logical) => "database", + }; + f.debug_struct("Config") .field("user", &self.user) .field("password", &self.password.as_ref().map(|_| Redaction {})) @@ -548,6 +577,7 @@ impl fmt::Debug for Config { .field("keepalives_idle", &self.keepalives_idle) .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &replication_mode_str.to_string()) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index d07d5a2df..15d24f9f9 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config}; +use crate::config::{self, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -124,6 +124,12 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 9c8e369f1..4b28b7604 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -25,7 +25,7 @@ pub enum RequestMessages { pub struct Request { pub messages: RequestMessages, - pub sender: mpsc::Sender, + pub sender: Option>, } pub struct Response { @@ -183,9 +183,9 @@ where match self.receiver.poll_next_unpin(cx) { Poll::Ready(Some(request)) => { trace!("polled new request"); - self.responses.push_back(Response { - sender: request.sender, - }); + if let Some(sender) = request.sender { + self.responses.push_back(Response { sender: sender }); + } Poll::Ready(Some(request.messages)) } Poll::Ready(None) => Poll::Ready(None), diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index fc712f6db..8de0447fc 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -8,7 +8,7 @@ use futures::channel::mpsc; use futures::future; use futures::{ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; -use pin_project_lite::pin_project; +use pin_project::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use postgres_protocol::message::frontend::CopyData; @@ -69,21 +69,20 @@ enum SinkState { Reading, } -pin_project! { - /// A sink for `COPY ... FROM STDIN` query data. - /// - /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is - /// not, the copy will be aborted. - pub struct CopyInSink { - #[pin] - sender: mpsc::Sender, - responses: Responses, - buf: BytesMut, - state: SinkState, - #[pin] - _p: PhantomPinned, - _p2: PhantomData, - } +/// A sink for `COPY ... FROM STDIN` query data. +/// +/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is +/// not, the copy will be aborted. +#[pin_project] +pub struct CopyInSink { + #[pin] + sender: mpsc::Sender, + responses: Responses, + buf: BytesMut, + state: SinkState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, } impl CopyInSink diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 52691b963..61a5af84c 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -5,7 +5,7 @@ use crate::{query, slice_iter, Error, Statement}; use bytes::Bytes; use futures::{ready, Stream}; use log::debug; -use pin_project_lite::pin_project; +use pin_project::pin_project; use postgres_protocol::message::backend::Message; use std::marker::PhantomPinned; use std::pin::Pin; @@ -38,13 +38,12 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { Ok(responses) } -pin_project! { - /// A stream of `COPY ... TO STDOUT` query data. - pub struct CopyOutStream { - responses: Responses, - #[pin] - _p: PhantomPinned, - } +/// A stream of `COPY ... TO STDOUT` query data. +#[pin_project] +pub struct CopyOutStream { + responses: Responses, + #[pin] + _p: PhantomPinned, } impl Stream for CopyOutStream { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 90c2b0404..9a2a5b86b 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -117,7 +117,7 @@ pub use crate::cancel_token::CancelToken; pub use crate::client::Client; -pub use crate::config::Config; +pub use crate::config::{Config, ReplicationMode}; pub use crate::connection::Connection; pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyOutStream; @@ -126,6 +126,7 @@ pub use crate::error::Error; pub use crate::generic_client::GenericClient; pub use crate::portal::Portal; pub use crate::query::RowStream; +use crate::replication_client::ReplicationClient; pub use crate::row::{Row, SimpleQueryRow}; pub use crate::simple_query::SimpleQueryStream; #[cfg(feature = "runtime")] @@ -163,6 +164,7 @@ mod maybe_tls_stream; mod portal; mod prepare; mod query; +pub mod replication_client; pub mod row; mod simple_query; #[cfg(feature = "runtime")] @@ -193,6 +195,30 @@ where config.connect(tls).await } +/// A convenience function which parses a connection string and connects to the database in replication mode. Normal queries are not permitted in replication mode. +/// +/// See the documentation for [`Config`] for details on the connection string format. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +/// +/// [`Config`]: config/struct.Config.html +#[cfg(feature = "runtime")] +pub async fn connect_replication( + config: &str, + tls: T, + mode: ReplicationMode, +) -> Result<(ReplicationClient, Connection), Error> +where + T: MakeTlsConnect, +{ + let mut config = config.parse::()?; + config.replication_mode(mode); + config + .connect(tls) + .await + .map(|(client, conn)| (ReplicationClient::new(client), conn)) +} + /// An asynchronous notification. #[derive(Clone, Debug)] pub struct Notification { diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index f139ed915..96b9cb6ec 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -6,7 +6,7 @@ use crate::{Error, Portal, Row, Statement}; use bytes::{Bytes, BytesMut}; use futures::{ready, Stream}; use log::{debug, log_enabled, Level}; -use pin_project_lite::pin_project; +use pin_project::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::fmt; @@ -188,14 +188,13 @@ where } } -pin_project! { - /// A stream of table rows. - pub struct RowStream { - statement: Statement, - responses: Responses, - #[pin] - _p: PhantomPinned, - } +/// A stream of table rows. +#[pin_project] +pub struct RowStream { + statement: Statement, + responses: Responses, + #[pin] + _p: PhantomPinned, } impl Stream for RowStream { diff --git a/tokio-postgres/src/replication_client.rs b/tokio-postgres/src/replication_client.rs new file mode 100644 index 000000000..ce3b56401 --- /dev/null +++ b/tokio-postgres/src/replication_client.rs @@ -0,0 +1,882 @@ +//! Streaming replication support. +//! +//! This module allows writing Postgres replication clients. A +//! replication client forms a special connection to the server in +//! either physical replication mode, which receives a stream of raw +//! Write-Ahead Log (WAL) records; or logical replication mode, which +//! receives a stream of data that depends on the output plugin +//! selected. All data and control messages are exchanged in CopyData +//! envelopes. +//! +//! See the [PostgreSQL protocol +//! documentation](https://www.postgresql.org/docs/current/protocol-replication.html) +//! for details of the protocol itself. +//! +//! # Physical Replication Client Example +//! ```no_run +//! extern crate tokio; +//! +//! use postgres_protocol::message::backend::ReplicationMessage; +//! use tokio::stream::StreamExt; +//! use tokio_postgres::{connect_replication, Error, NoTls, ReplicationMode}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Error> { +//! let conninfo = "host=localhost user=postgres dbname=postgres"; +//! +//! // form replication connection +//! let (mut rclient, rconnection) = +//! connect_replication(conninfo, NoTls, ReplicationMode::Physical).await?; +//! tokio::spawn(async move { +//! if let Err(e) = rconnection.await { +//! eprintln!("connection error: {}", e); +//! } +//! }); +//! +//! let identify_system = rclient.identify_system().await?; +//! +//! let mut physical_stream = rclient +//! .start_physical_replication(None, identify_system.xlogpos(), None) +//! .await?; +//! +//! while let Some(replication_message) = physical_stream.next().await { +//! match replication_message? { +//! ReplicationMessage::XLogData(xlog_data) => { +//! eprintln!("received XLogData: {:#?}", xlog_data); +//! } +//! ReplicationMessage::PrimaryKeepAlive(keepalive) => { +//! eprintln!("received PrimaryKeepAlive: {:#?}", keepalive); +//! } +//! _ => (), +//! } +//! } +//! +//! Ok(()) +//! } +//! ``` +//! +//! # Logical Replication Client Example +//! +//! This example requires the [wal2json +//! extension](https://github.com/eulerto/wal2json). +//! +//! ```no_run +//! extern crate tokio; +//! +//! use postgres_protocol::message::backend::ReplicationMessage; +//! use tokio::stream::StreamExt; +//! use tokio_postgres::{connect_replication, Error, NoTls, ReplicationMode}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Error> { +//! let conninfo = "host=localhost user=postgres dbname=postgres"; +//! +//! // form replication connection +//! let (mut rclient, rconnection) = +//! connect_replication(conninfo, NoTls, ReplicationMode::Logical).await?; +//! +//! // spawn connection to run on its own +//! tokio::spawn(async move { +//! if let Err(e) = rconnection.await { +//! eprintln!("connection error: {}", e); +//! } +//! }); +//! +//! let identify_system = rclient.identify_system().await?; +//! +//! let slot = "my_slot"; +//! let plugin = "wal2json"; +//! let options = &vec![("pretty-print", "1")]; +//! +//! let _slotdesc = rclient +//! .create_logical_replication_slot(slot, false, plugin, None) +//! .await?; +//! +//! let mut physical_stream = rclient +//! .start_logical_replication(slot, identify_system.xlogpos(), options) +//! .await?; +//! +//! while let Some(replication_message) = physical_stream.next().await { +//! match replication_message? { +//! ReplicationMessage::XLogData(xlog_data) => { +//! eprintln!("received XLogData: {:#?}", xlog_data); +//! let json = std::str::from_utf8(xlog_data.data()).unwrap(); +//! eprintln!("JSON text: {}", json); +//! } +//! ReplicationMessage::PrimaryKeepAlive(keepalive) => { +//! eprintln!("received PrimaryKeepAlive: {:#?}", keepalive); +//! } +//! _ => (), +//! } +//! } +//! +//! Ok(()) +//! } +//! ``` +//! +//! # Caveats +//! +//! It is recommended that you use a PostgreSQL server patch version +//! of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or +//! 9.5.25. Earlier patch levels have a bug that doesn't properly +//! handle pipelined requests after streaming has stopped. + +use crate::client::Responses; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::{Lsn, Type}; +use crate::{simple_query, Client, Error}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures::{ready, Stream}; +use pin_project::{pin_project, pinned_drop}; +use postgres_protocol::escape::{escape_identifier, escape_literal}; +use postgres_protocol::message::backend::{Message, ReplicationMessage, RowDescriptionBody}; +use postgres_protocol::message::frontend; +use std::io; +use std::marker::PhantomPinned; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::str::from_utf8; +use std::task::{Context, Poll}; + +/// Result of [identify_system()](ReplicationClient::identify_system()) call. +#[derive(Debug)] +pub struct IdentifySystem { + systemid: String, + timeline: u32, + xlogpos: Lsn, + dbname: Option, +} + +impl IdentifySystem { + /// The unique system identifier identifying the cluster. This can + /// be used to check that the base backup used to initialize the + /// standby came from the same cluster. + pub fn systemid(&self) -> &str { + &self.systemid + } + + /// Current timeline ID. Also useful to check that the standby is + /// consistent with the master. + pub fn timeline(&self) -> u32 { + self.timeline + } + + /// Current WAL flush location. Useful to get a known location in + /// the write-ahead log where streaming can start. + pub fn xlogpos(&self) -> Lsn { + self.xlogpos + } + + /// Database connected to or None. + pub fn dbname(&self) -> Option<&str> { + self.dbname.as_deref() + } +} + +/// Result of [timeline_history()](ReplicationClient::timeline_history()) call. +#[derive(Debug)] +pub struct TimelineHistory { + filename: PathBuf, + content: Vec, +} + +impl TimelineHistory { + /// File name of the timeline history file, e.g., + /// 00000002.history. + pub fn filename(&self) -> &Path { + self.filename.as_path() + } + + /// Contents of the timeline history file. + pub fn content(&self) -> &[u8] { + self.content.as_slice() + } +} + +/// Argument to +/// [create_logical_replication_slot()](ReplicationClient::create_logical_replication_slot). +#[derive(Debug)] +pub enum SnapshotMode { + /// Export the snapshot for use in other sessions. This option + /// can't be used inside a transaction. + ExportSnapshot, + /// Use the snapshot for logical decoding as normal but won't do + /// anything else with it. + NoExportSnapshot, + /// Use the snapshot for the current transaction executing the + /// command. This option must be used in a transaction, and + /// CREATE_REPLICATION_SLOT must be the first command run in that + /// transaction. + UseSnapshot, +} + +/// Description of slot created with +/// [create_physical_replication_slot()](ReplicationClient::create_physical_replication_slot) +/// or +/// [create_logical_replication_slot()](ReplicationClient::create_logical_replication_slot). +#[derive(Debug)] +pub struct CreateReplicationSlotResponse { + slot_name: String, + consistent_point: Lsn, + snapshot_name: Option, + output_plugin: Option, +} + +impl CreateReplicationSlotResponse { + /// The name of the newly-created replication slot. + pub fn slot_name(&self) -> &str { + &self.slot_name + } + + /// The WAL location at which the slot became consistent. This is + /// the earliest location from which streaming can start on this + /// replication slot. + pub fn consistent_point(&self) -> Lsn { + self.consistent_point + } + + /// The identifier of the snapshot exported by the command. The + /// snapshot is valid until a new command is executed on this + /// connection or the replication connection is closed. Null if + /// the created slot is physical. + pub fn snapshot_name(&self) -> Option<&str> { + self.snapshot_name.as_deref() + } + + /// The name of the output plugin used by the newly-created + /// replication slot. Null if the created slot is physical. + pub fn output_plugin(&self) -> Option<&str> { + self.output_plugin.as_deref() + } +} + +/// Response sent after streaming from a timeline that is not the +/// current timeline. +#[derive(Clone, Debug)] +pub struct ReplicationResponse { + next_tli: u64, + next_tli_startpos: Lsn, +} + +impl ReplicationResponse { + /// next timeline's ID + pub fn next_tli(&self) -> u64 { + self.next_tli + } + + /// WAL location where the switch happened + pub fn next_tli_startpos(&self) -> Lsn { + self.next_tli_startpos + } +} + +/// Represents a client connected in replication mode. +pub struct ReplicationClient { + client: Client, +} + +impl ReplicationClient { + /// Requests the server to identify itself. + pub async fn identify_system(&mut self) -> Result { + let command = "IDENTIFY_SYSTEM"; + let mut responses = self.send(command).await?; + let rowdesc = match responses.next().await? { + Message::RowDescription(m) => m, + _ => return Err(Error::unexpected_message()), + }; + let datarow = match responses.next().await? { + Message::DataRow(m) => m, + _ => return Err(Error::unexpected_message()), + }; + + let fields = rowdesc.fields().collect::>().map_err(Error::parse)?; + let ranges = datarow.ranges().collect::>().map_err(Error::parse)?; + + assert_eq!(fields.len(), 4); + assert_eq!(fields[0].type_oid(), Type::TEXT.oid()); + assert_eq!(fields[1].type_oid(), Type::INT4.oid()); + assert_eq!(fields[2].type_oid(), Type::TEXT.oid()); + assert_eq!(fields[3].type_oid(), Type::TEXT.oid()); + assert_eq!(ranges.len(), 4); + + let values: Vec> = ranges + .iter() + .map(|range| { + range + .to_owned() + .map(|r| from_utf8(&datarow.buffer()[r]).unwrap()) + }) + .collect::>(); + + Ok(IdentifySystem { + systemid: values[0].unwrap().to_string(), + timeline: values[1].unwrap().parse::().unwrap(), + xlogpos: Lsn::from(values[2].unwrap()), + dbname: values[3].map(String::from), + }) + } + + /// Requests the server to send the current setting of a run-time + /// parameter. This is similar to the SQL command + /// [SHOW](https://www.postgresql.org/docs/current/sql-show.html). + pub async fn show(&mut self, name: &str) -> Result { + let command = format!("SHOW {}", escape_identifier(name)); + let mut responses = self.send(&command).await?; + let rowdesc = match responses.next().await? { + Message::RowDescription(m) => m, + _ => return Err(Error::unexpected_message()), + }; + let datarow = match responses.next().await? { + Message::DataRow(m) => m, + _ => return Err(Error::unexpected_message()), + }; + + let fields = rowdesc.fields().collect::>().map_err(Error::parse)?; + let ranges = datarow.ranges().collect::>().map_err(Error::parse)?; + + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].type_oid(), Type::TEXT.oid()); + assert_eq!(ranges.len(), 1); + + let val = from_utf8(&datarow.buffer()[ranges[0].to_owned().unwrap()]).unwrap(); + + Ok(String::from(val)) + } + + /// Requests the server to send over the timeline history file for + /// the given timeline ID. + pub async fn timeline_history(&mut self, timeline_id: u32) -> Result { + let command = format!("TIMELINE_HISTORY {}", timeline_id); + let mut responses = self.send(&command).await?; + + let rowdesc = match responses.next().await? { + Message::RowDescription(m) => m, + _ => return Err(Error::unexpected_message()), + }; + let datarow = match responses.next().await? { + Message::DataRow(m) => m, + _ => return Err(Error::unexpected_message()), + }; + + let fields = rowdesc.fields().collect::>().map_err(Error::parse)?; + let ranges = datarow.ranges().collect::>().map_err(Error::parse)?; + + // The TIMELINE_HISTORY command sends a misleading + // RowDescriptor which is different depending on the server + // version, so we ignore it aside from checking for the right + // number of fields. Both fields are documented to be raw + // bytes. + // + // Both fields are documented to return raw bytes. + assert_eq!(fields.len(), 2); + assert_eq!(ranges.len(), 2); + + // Practically speaking, the filename is ASCII, so it's OK to + // treat it as UTF-8, and convert it to a PathBuf. If this + // assumption is violated, generate a useful error message. + let filename_bytes = &datarow.buffer()[ranges[0].to_owned().unwrap()]; + let filename_str = from_utf8(filename_bytes) + .map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "Timeline history filename is invalid UTF-8", + ) + }) + .map_err(Error::parse)?; + let filename_path = PathBuf::from(filename_str); + + // The file contents are typically ASCII, but we treat it as + // binary because it can contain text in an unknown + // encoding. For instance, the restore point name will be in + // the server encoding (it will not be converted to the client + // encoding before being sent); and the file can also be + // edited by the user to contain arbitrary comments in an + // unknown encoding. + let content_bytes = &datarow.buffer()[ranges[1].to_owned().unwrap()]; + + Ok(TimelineHistory { + filename: filename_path, + content: Vec::from(content_bytes), + }) + } + + /// Create logical replication slot. See [Replication + /// Slots](https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS). + /// + /// Arguments: + /// + /// * `slot_name`: The name of the slot to create. Must be a valid + /// replication slot name (see [Querying and Manipulating + /// Replication + /// Slots](https://www.postgresql.org/docs/13/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION)). + /// * `temporary`: Specify that this replication slot is a + /// temporary one. Temporary slots are not saved to disk and are + /// automatically dropped on error or when the session has + /// finished. + /// * `reserve_wal`: Specify that this physical replication slot + /// reserves WAL immediately. Otherwise, WAL is only reserved + /// upon connection from a streaming replication client. + pub async fn create_physical_replication_slot( + &mut self, + slot_name: &str, + temporary: bool, + reserve_wal: bool, + ) -> Result { + let temporary_str = if temporary { " TEMPORARY" } else { "" }; + let reserve_wal_str = if reserve_wal { " RESERVE_WAL" } else { "" }; + let command = format!( + "CREATE_REPLICATION_SLOT {}{} PHYSICAL{}", + escape_identifier(slot_name), + temporary_str, + reserve_wal_str + ); + let mut responses = self.send(&command).await?; + + let rowdesc = match responses.next().await? { + Message::RowDescription(m) => m, + _ => return Err(Error::unexpected_message()), + }; + let datarow = match responses.next().await? { + Message::DataRow(m) => m, + _ => return Err(Error::unexpected_message()), + }; + + let fields = rowdesc.fields().collect::>().map_err(Error::parse)?; + let ranges = datarow.ranges().collect::>().map_err(Error::parse)?; + + assert_eq!(fields.len(), 4); + assert_eq!(fields[0].type_oid(), Type::TEXT.oid()); + assert_eq!(fields[0].format(), 0); + assert_eq!(fields[1].type_oid(), Type::TEXT.oid()); + assert_eq!(fields[1].format(), 0); + assert_eq!(fields[2].type_oid(), Type::TEXT.oid()); + assert_eq!(fields[2].format(), 0); + assert_eq!(fields[3].type_oid(), Type::TEXT.oid()); + assert_eq!(fields[3].format(), 0); + assert_eq!(ranges.len(), 4); + + let values: Vec> = ranges + .iter() + .map(|range| { + range + .to_owned() + .map(|r| from_utf8(&datarow.buffer()[r]).unwrap()) + }) + .collect::>(); + + Ok(CreateReplicationSlotResponse { + slot_name: values[0].unwrap().to_string(), + consistent_point: Lsn::from(values[1].unwrap()), + snapshot_name: values[2].map(String::from), + output_plugin: values[3].map(String::from), + }) + } + + /// Create logical replication slot. See [Replication + /// Slots](https://www.postgresql.org/docs/current/warm-standby.html#STREAMING-REPLICATION-SLOTS). + /// + /// Arguments: + /// + /// * `slot_name`: The name of the slot to create. Must be a valid + /// replication slot name (see [Querying and Manipulating + /// Replication + /// Slots](https://www.postgresql.org/docs/13/warm-standby.html#STREAMING-REPLICATION-SLOTS-MANIPULATION)). + /// * `temporary`: Specify that this replication slot is a + /// temporary one. Temporary slots are not saved to disk and are + /// automatically dropped on error or when the session has + /// finished. + /// * `plugin_name`: The name of the output plugin used for + /// logical decoding (see [Logical Decoding Output + /// Plugins](https://www.postgresql.org/docs/current/logicaldecoding-output-plugin.html)). + /// * `snapshot_mode`: Decides what to do with the snapshot + /// created during logical slot initialization. + pub async fn create_logical_replication_slot( + &mut self, + slot_name: &str, + temporary: bool, + plugin_name: &str, + snapshot_mode: Option, + ) -> Result { + let temporary_str = if temporary { " TEMPORARY" } else { "" }; + let snapshot_str = snapshot_mode.map_or("", |mode| match mode { + SnapshotMode::ExportSnapshot => " EXPORT_SNAPSHOT", + SnapshotMode::NoExportSnapshot => " NOEXPORT_SNAPSHOT", + SnapshotMode::UseSnapshot => " USE_SNAPSHOT", + }); + let command = format!( + "CREATE_REPLICATION_SLOT {}{} LOGICAL {}{}", + escape_identifier(slot_name), + temporary_str, + escape_identifier(plugin_name), + snapshot_str + ); + let mut responses = self.send(&command).await?; + + let rowdesc = match responses.next().await? { + Message::RowDescription(m) => m, + _ => return Err(Error::unexpected_message()), + }; + let datarow = match responses.next().await? { + Message::DataRow(m) => m, + _ => return Err(Error::unexpected_message()), + }; + + let fields = rowdesc.fields().collect::>().map_err(Error::parse)?; + let ranges = datarow.ranges().collect::>().map_err(Error::parse)?; + + assert_eq!(fields.len(), 4); + + let values: Vec> = ranges + .iter() + .map(|range| { + range + .to_owned() + .map(|r| from_utf8(&datarow.buffer()[r]).unwrap()) + }) + .collect::>(); + + Ok(CreateReplicationSlotResponse { + slot_name: values[0].unwrap().to_string(), + consistent_point: Lsn::from(values[1].unwrap()), + snapshot_name: values[2].map(String::from), + output_plugin: values[3].map(String::from), + }) + } + + /// Drops a replication slot, freeing any reserved server-side + /// resources. If the slot is a logical slot that was created in a + /// database other than the database the walsender is connected + /// to, this command fails. + pub async fn drop_replication_slot( + &mut self, + slot_name: &str, + wait: bool, + ) -> Result<(), Error> { + let wait_str = if wait { " WAIT" } else { "" }; + let command = format!( + "DROP_REPLICATION_SLOT {}{}", + escape_identifier(slot_name), + wait_str + ); + let _ = self.send(&command).await?; + Ok(()) + } + + /// Begin physical replication, consuming the replication client + /// and producing a replication stream. + /// + /// Arguments: + /// + /// * `slot_name`: If a slot's name is provided via slot_name, it + /// will be updated as replication progresses so that the server + /// knows which WAL segments, and if hot_standby_feedback is on + /// which transactions, are still needed by the standby. + /// * `lsn`: The starting WAL location. + /// * `timeline_id`: If specified, streaming starts on timeline + /// tli; otherwise, the server's current timeline is selected. + pub async fn start_physical_replication<'a>( + &'a mut self, + slot_name: Option<&str>, + lsn: Lsn, + timeline_id: Option, + ) -> Result>>, Error> { + let slot = match slot_name { + Some(name) => format!(" SLOT {}", escape_identifier(name)), + None => String::from(""), + }; + let timeline = match timeline_id { + Some(id) => format!(" TIMELINE {}", id), + None => String::from(""), + }; + let command = format!( + "START_REPLICATION{} PHYSICAL {}{}", + slot, + String::from(lsn), + timeline + ); + + Ok(self.start_replication(command).await?) + } + + /// Begin logical replication, consuming the replication client and producing a replication stream. + /// + /// Arguments: + /// + /// * `slot_name`: If a slot's name is provided via slot_name, it + /// will be updated as replication progresses so that the server + /// knows which WAL segments, and if hot_standby_feedback is on + /// which transactions, are still needed by the standby. + /// * `lsn`: The starting WAL location. + /// * `options`: (name, value) pairs of options passed to the + /// slot's logical decoding plugin. + pub async fn start_logical_replication<'a>( + &'a mut self, + slot_name: &str, + lsn: Lsn, + options: &[(&str, &str)], + ) -> Result>>, Error> { + let slot = format!(" SLOT {}", escape_identifier(slot_name)); + let options_string = if !options.is_empty() { + format!( + " ({})", + options + .iter() + .map(|pair| format!("{} {}", escape_identifier(pair.0), escape_literal(pair.1))) + .collect::>() + .as_slice() + .join(", ") + ) + } else { + String::from("") + }; + let command = format!( + "START_REPLICATION{} LOGICAL {}{}", + slot, + String::from(lsn), + options_string + ); + + Ok(self.start_replication(command).await?) + } + + /// Send update to server. + pub async fn standby_status_update( + &mut self, + write_lsn: Lsn, + flush_lsn: Lsn, + apply_lsn: Lsn, + ts: i64, + reply: u8, + ) -> Result<(), Error> { + let iclient = self.client.inner(); + let mut buf = BytesMut::new(); + let _ = frontend::standby_status_update( + write_lsn.into(), + flush_lsn.into(), + apply_lsn.into(), + ts as i64, + reply, + &mut buf, + ); + let _ = iclient.send(RequestMessages::Single(FrontendMessage::Raw(buf.freeze())))?; + Ok(()) + } + + // Private methods + + pub(crate) fn new(client: Client) -> ReplicationClient { + ReplicationClient { client: client } + } + + // send command to the server, but finish any unfinished replication stream, first + async fn send(&mut self, command: &str) -> Result { + let iclient = self.client.inner(); + let buf = simple_query::encode(iclient, command)?; + let responses = iclient.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + Ok(responses) + } + + async fn start_replication<'a>( + &'a mut self, + command: String, + ) -> Result>>, Error> { + let mut copyboth_received = false; + let mut replication_response: Option = None; + let mut responses = self.send(&command).await?; + + // Before we construct the ReplicationStream, we must know + // whether the server entered copy mode or not. Otherwise, if + // the ReplicationStream were to be dropped, we wouldn't know + // whether to send a CopyDone message or not (and it would be + // bad to try to receive and process the responses during the + // destructor). + + // If the timeline selected is the current one, the server + // will always enter copy mode. If the timeline is historic, + // and if there is no work to do, the server will skip copy + // mode and immediately send a response tuple. + match responses.next().await? { + Message::CopyBothResponse(_) => { + copyboth_received = true; + } + Message::RowDescription(rowdesc) => { + // Never entered copy mode, so don't bother returning + // a stream, just process the response. + replication_response = + Some(recv_replication_response(&mut responses, rowdesc).await?); + } + _ => return Err(Error::unexpected_message()), + } + + Ok(Box::pin(ReplicationStream { + rclient: self, + responses: responses, + copyboth_received: copyboth_received, + copydone_sent: false, + copydone_received: false, + replication_response: replication_response, + _phantom_pinned: PhantomPinned, + })) + } + + fn send_copydone(&mut self) -> Result<(), Error> { + let iclient = self.client.inner(); + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + iclient.unpipelined_send(RequestMessages::Single(FrontendMessage::Raw(buf.freeze())))?; + + Ok(()) + } +} + +/// A stream of data from a `START_REPLICATION` command. All control +/// and data messages will be in +/// [CopyData](postgres_protocol::message::backend::Message::CopyData). +/// +/// Intended to be used with the [next()](tokio::stream::StreamExt::next) method. +/// +/// If the timeline specified with +/// [start_physical_replication()](ReplicationClient::start_physical_replication) +/// or +/// [start_logical_replication()](ReplicationClient::start_logical_replication()) +/// is the current timeline, the stream is indefinite, and must be +/// terminated with +/// [stop_replication()](ReplicationStream::stop_replication()) (which +/// will not return a response tuple); or by dropping the +/// [ReplicationStream](ReplicationStream). +/// +/// If the timeline is not the current timeline, the stream will +/// terminate when the end of the timeline is reached, and +/// [stop_replication()](ReplicationStream::stop_replication()) will +/// return a response tuple. +#[pin_project(PinnedDrop)] +pub struct ReplicationStream<'a> { + rclient: &'a mut ReplicationClient, + responses: Responses, + copyboth_received: bool, + copydone_sent: bool, + copydone_received: bool, + replication_response: Option, + #[pin] + _phantom_pinned: PhantomPinned, +} + +impl ReplicationStream<'_> { + /// Stop replication stream and return the replication client object. + pub async fn stop_replication( + mut self: Pin>, + ) -> Result, Error> { + let this = self.as_mut().project(); + + if this.replication_response.is_some() { + return Ok(this.replication_response.clone()); + } + + // we must be in copy mode; shut it down + assert!(*this.copyboth_received); + if !*this.copydone_sent { + this.rclient.send_copydone()?; + *this.copydone_sent = true; + } + + // If server didn't already shut down copy, drain remaining + // CopyData and the CopyDone. + if !*this.copydone_received { + loop { + match this.responses.next().await? { + Message::CopyData(_) => (), + Message::CopyDone => { + *this.copydone_received = true; + break; + } + _ => return Err(Error::unexpected_message()), + } + } + } + + match this.responses.next().await? { + Message::RowDescription(rowdesc) => { + *this.replication_response = + Some(recv_replication_response(this.responses, rowdesc).await?); + } + Message::CommandComplete(_) => (), + _ => return Err(Error::unexpected_message()), + } + + Ok(this.replication_response.clone()) + } +} + +impl Stream for ReplicationStream<'_> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + // if we already got a replication response tuple, we're done + if this.replication_response.is_some() { + return Poll::Ready(None); + } + + // we are in copy mode + assert!(*this.copyboth_received); + assert!(!*this.copydone_sent); + assert!(!*this.copydone_received); + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => { + let r = ReplicationMessage::parse(&body.into_bytes()); + Poll::Ready(Some(r.map_err(Error::parse))) + } + Message::CopyDone => { + *this.copydone_received = true; + this.rclient.send_copydone()?; + *this.copydone_sent = true; + Poll::Ready(None) + } + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} + +#[pinned_drop] +impl PinnedDrop for ReplicationStream<'_> { + fn drop(mut self: Pin<&mut Self>) { + let this = self.project(); + if *this.copyboth_received && !*this.copydone_sent { + this.rclient.send_copydone().unwrap(); + *this.copydone_sent = true; + } + } +} + +// Read a replication response tuple from the server. This function +// assumes that the caller has already consumed the RowDescription +// from the stream. +async fn recv_replication_response( + responses: &mut Responses, + rowdesc: RowDescriptionBody, +) -> Result { + let fields = rowdesc.fields().collect::>().map_err(Error::parse)?; + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].type_oid(), Type::INT8.oid()); + assert_eq!(fields[0].format(), 0); + assert_eq!(fields[1].type_oid(), Type::TEXT.oid()); + assert_eq!(fields[1].format(), 0); + + match responses.next().await? { + Message::DataRow(datarow) => { + let ranges = datarow.ranges().collect::>().map_err(Error::parse)?; + assert_eq!(ranges.len(), 2); + + let timeline = &datarow.buffer()[ranges[0].to_owned().unwrap()]; + let switch = &datarow.buffer()[ranges[1].to_owned().unwrap()]; + Ok(ReplicationResponse { + next_tli: from_utf8(timeline).unwrap().parse::().unwrap(), + next_tli_startpos: Lsn::from(from_utf8(switch).unwrap()), + }) + } + _ => Err(Error::unexpected_message()), + } +} diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 82ac35664..1bf5da8ca 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -6,7 +6,7 @@ use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures::{ready, Stream}; use log::debug; -use pin_project_lite::pin_project; +use pin_project::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::marker::PhantomPinned; @@ -45,21 +45,20 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result { +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) }) } -pin_project! { - /// A stream of simple query results. - pub struct SimpleQueryStream { - responses: Responses, - columns: Option>, - #[pin] - _p: PhantomPinned, - } +/// A stream of simple query results. +#[pin_project] +pub struct SimpleQueryStream { + responses: Responses, + columns: Option>, + #[pin] + _p: PhantomPinned, } impl Stream for SimpleQueryStream { diff --git a/tokio-postgres/src/types.rs b/tokio-postgres/src/types.rs index b2e15d059..47f4b7875 100644 --- a/tokio-postgres/src/types.rs +++ b/tokio-postgres/src/types.rs @@ -4,3 +4,45 @@ #[doc(inline)] pub use postgres_types::*; + +use std::fmt; + +/// Log Sequence Number for PostgreSQL Write-Ahead Log (transaction log). +#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] +pub struct Lsn(u64); + +impl From<&str> for Lsn { + fn from(lsn_str: &str) -> Self { + let split: Vec<&str> = lsn_str.split('/').collect(); + assert_eq!(split.len(), 2); + let (hi, lo) = ( + u64::from_str_radix(split[0], 16).unwrap(), + u64::from_str_radix(split[1], 16).unwrap(), + ); + Lsn((hi << 32) | lo) + } +} + +impl From for Lsn { + fn from(lsn_u64: u64) -> Self { + Lsn(lsn_u64) + } +} + +impl From for u64 { + fn from(lsn: Lsn) -> u64 { + lsn.0 + } +} + +impl From for String { + fn from(lsn: Lsn) -> String { + format!("{:X}/{:X}", lsn.0 >> 32, lsn.0 & 0x00000000ffffffff) + } +} + +impl fmt::Debug for Lsn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Lsn").field(&String::from(*self)).finish() + } +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index bf6d72d3e..18976496a 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -18,6 +18,7 @@ use tokio_postgres::{ mod binary_copy; mod parse; +mod replication; #[cfg(feature = "runtime")] mod runtime; mod types; diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs new file mode 100644 index 000000000..2323a123d --- /dev/null +++ b/tokio-postgres/tests/test/replication.rs @@ -0,0 +1,230 @@ +use postgres_protocol::message::backend::ReplicationMessage; +use tokio::stream::StreamExt; +use tokio_postgres::replication_client::ReplicationClient; +use tokio_postgres::Client; +use tokio_postgres::{connect, connect_replication, NoTls, ReplicationMode}; + +const LOGICAL_BEGIN_TAG: u8 = b'B'; +const LOGICAL_COMMIT_TAG: u8 = b'C'; +const LOGICAL_INSERT_TAG: u8 = b'I'; + +// Tests missing for timeline_history(). For a timeline history to be +// available, it requires a point-in-time-recovery or a standby +// promotion; neither of which is done in the current test setup. + +// test for: +// - identify_system +// - show +// - slot create/drop +// - physical replication +#[tokio::test] +async fn physical_replication() { + let (sclient, mut rclient) = setup(ReplicationMode::Physical).await; + + simple_exec(&sclient, "drop table if exists test_physical_replication").await; + simple_exec(&sclient, "create table test_physical_replication(i int)").await; + + let identify_system = rclient.identify_system().await.unwrap(); + assert_eq!(identify_system.dbname(), None); + let show_port = rclient.show("port").await.unwrap(); + assert_eq!(show_port, "5433"); + + let slot = "test_physical_slot"; + let _ = rclient.drop_replication_slot(slot, false).await.unwrap(); + let slotdesc = rclient + .create_physical_replication_slot(slot, false, false) + .await + .unwrap(); + assert_eq!(slotdesc.slot_name(), slot); + assert_eq!(slotdesc.snapshot_name(), None); + assert_eq!(slotdesc.output_plugin(), None); + + let mut physical_stream = rclient + .start_physical_replication(None, identify_system.xlogpos(), None) + .await + .unwrap(); + + let _nrows = sclient + .execute("insert into test_physical_replication values(1)", &[]) + .await + .unwrap(); + + let mut got_xlogdata = false; + while let Some(replication_message) = physical_stream.next().await { + if let ReplicationMessage::XLogData(_) = replication_message.unwrap() { + got_xlogdata = true; + break; + } + } + + assert!(got_xlogdata); + + let response = physical_stream.stop_replication().await.unwrap(); + assert!(response.is_none()); + + // repeat simple command after stream is ended + let show_port = rclient.show("port").await.unwrap(); + assert_eq!(show_port, "5433"); + + simple_exec(&sclient, "drop table if exists test_physical_replication").await; +} + +// test for: +// - create/drop slot +// X standby_status_update +// - logical replication +#[tokio::test] +async fn logical_replication() { + let (sclient, mut rclient) = setup(ReplicationMode::Logical).await; + + simple_exec(&sclient, "drop table if exists test_logical_replication").await; + simple_exec(&sclient, "drop publication if exists test_logical_pub").await; + simple_exec(&sclient, "create table test_logical_replication(i int)").await; + simple_exec( + &sclient, + "create publication test_logical_pub for table test_logical_replication", + ) + .await; + + let identify_system = rclient.identify_system().await.unwrap(); + assert_eq!(identify_system.dbname().unwrap(), "postgres"); + + let slot = "test_logical_slot"; + let plugin = "pgoutput"; + let _ = rclient.drop_replication_slot(slot, false).await.unwrap(); + let slotdesc = rclient + .create_logical_replication_slot(slot, false, plugin, None) + .await + .unwrap(); + assert_eq!(slotdesc.slot_name(), slot); + assert!(slotdesc.snapshot_name().is_some()); + assert_eq!(slotdesc.output_plugin(), Some(plugin)); + + let xlog_start = identify_system.xlogpos(); + let options = &vec![ + ("proto_version", "1"), + ("publication_names", "test_logical_pub"), + ]; + + let mut logical_stream = rclient + .start_logical_replication(slot, xlog_start, options) + .await + .unwrap(); + + let _nrows = sclient + .execute("insert into test_logical_replication values(1)", &[]) + .await + .unwrap(); + + let mut got_begin = false; + let mut got_insert = false; + let mut got_commit = false; + while let Some(replication_message) = logical_stream.next().await { + if let ReplicationMessage::XLogData(msg) = replication_message.unwrap() { + match msg.data()[0] { + LOGICAL_BEGIN_TAG => { + assert!(!got_begin); + assert!(!got_insert); + assert!(!got_commit); + got_begin = true; + } + LOGICAL_INSERT_TAG => { + assert!(got_begin); + assert!(!got_insert); + assert!(!got_commit); + got_insert = true; + } + LOGICAL_COMMIT_TAG => { + assert!(got_begin); + assert!(got_insert); + assert!(!got_commit); + got_commit = true; + break; + } + _ => (), + } + } + } + + assert!(got_begin); + assert!(got_insert); + assert!(got_commit); + + simple_exec(&sclient, "drop table if exists test_logical_replication").await; + simple_exec(&sclient, "drop publication if exists test_logical_pub").await; +} + +// test for base backup +#[tokio::test] +async fn base_backup() {} + +// Test that a dropped replication stream properly returns to normal +// command processing in the ReplicationClient. +// +// This test will fail on PostgreSQL server versions earlier than the +// following patch versions: 13.2, 12.6, 11.11, 10.16, 9.6.21, +// 9.5.25. In earlier server versions, there's a bug that prevents +// pipelining requests after the client sends a CopyDone message, but +// before the server replies with a CommandComplete. +// +// Disabled until the patch is more widely available. +// #[tokio::test] +#[allow(dead_code)] +async fn drop_replication_stream() { + let (sclient, mut rclient) = setup(ReplicationMode::Physical).await; + + simple_exec(&sclient, "drop table if exists test_drop_stream").await; + simple_exec(&sclient, "create table test_drop_stream(i int)").await; + + let identify_system = rclient.identify_system().await.unwrap(); + assert_eq!(identify_system.dbname(), None); + + let mut physical_stream = rclient + .start_physical_replication(None, identify_system.xlogpos(), None) + .await + .unwrap(); + + let mut got_xlogdata = false; + while let Some(replication_message) = physical_stream.next().await { + if let ReplicationMessage::XLogData(_) = replication_message.unwrap() { + got_xlogdata = true; + break; + } + } + + assert!(got_xlogdata); + + drop(physical_stream); + + // test that simple command completes after replication stream is dropped + let show_port = rclient.show("port").await.unwrap(); + assert_eq!(show_port, "5433"); + + simple_exec(&sclient, "drop table if exists test_drop_stream").await; +} + +async fn setup(mode: ReplicationMode) -> (Client, ReplicationClient) { + let conninfo = "host=127.0.0.1 port=5433 user=postgres"; + + // form SQL connection + let (sclient, sconnection) = connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = sconnection.await { + eprintln!("connection error: {}", e); + } + }); + + // form replication connection + let (rclient, rconnection) = connect_replication(conninfo, NoTls, mode).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = rconnection.await { + eprintln!("connection error: {}", e); + } + }); + + (sclient, rclient) +} + +async fn simple_exec(sclient: &Client, command: &str) { + let _nrows = sclient.execute(command, &[]).await.unwrap(); +}