diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 42ce6dec9..71894529d 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -487,6 +487,11 @@ impl Client { self.connection.block_on(self.client.batch_execute(query)) } + /// Check the connection is alive and wait for the confirmation. + pub fn check_connection(&mut self) -> Result<(), Error> { + self.connection.block_on(self.client.check_connection()) + } + /// Begins a new database transaction. /// /// The transaction will roll back by default - use the `commit` method to commit it. diff --git a/postgres/src/test.rs b/postgres/src/test.rs index 0fd404574..4e5b49761 100644 --- a/postgres/src/test.rs +++ b/postgres/src/test.rs @@ -508,3 +508,24 @@ fn check_send() { is_send::(); is_send::>(); } + +#[test] +fn is_closed() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + assert!(!client.is_closed()); + client.check_connection().unwrap(); + + let row = client.query_one("select pg_backend_pid()", &[]).unwrap(); + let pid: i32 = row.get(0); + + { + let mut client2 = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + client2 + .query("SELECT pg_terminate_backend($1)", &[&pid]) + .unwrap(); + } + + assert!(!client.is_closed()); + client.check_connection().unwrap_err(); + assert!(client.is_closed()); +} diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index b38bbba37..2474c2cbd 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -531,6 +531,12 @@ impl Client { simple_query::batch_execute(self.inner(), query).await } + /// Check the connection is alive and wait for the confirmation. + pub async fn check_connection(&self) -> Result<(), Error> { + // sync is a very quick message to test the connection health. + query::sync(self.inner()).await + } + /// Begins a new database transaction. /// /// The transaction will roll back by default - use the `commit` method to commit it. diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 414335955..62680a01a 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -298,14 +298,7 @@ where self.parameters.get(name).map(|s| &**s) } - /// Polls for asynchronous messages from the server. - /// - /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to - /// examine those messages should use this method to drive the connection rather than its `Future` implementation. - /// - /// Return values of `None` or `Some(Err(_))` are "terminal"; callers should not invoke this method again after - /// receiving one of those values. - pub fn poll_message( + fn poll_message_inner( &mut self, cx: &mut Context<'_>, ) -> Poll>> { @@ -323,6 +316,26 @@ where }, } } + + /// Polls for asynchronous messages from the server. + /// + /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to + /// examine those messages should use this method to drive the connection rather than its `Future` implementation. + /// + /// Return values of `None` or `Some(Err(_))` are "terminal"; callers should not invoke this method again after + /// receiving one of those values. + pub fn poll_message( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.poll_message_inner(cx) { + nominal @ (Poll::Pending | Poll::Ready(Some(Ok(_)))) => nominal, + terminal @ (Poll::Ready(None) | Poll::Ready(Some(Err(_)))) => { + self.receiver.close(); + terminal + } + } + } } impl Future for Connection diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 2fcb22d57..8d2aff889 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -323,3 +323,13 @@ impl RowStream { self.rows_affected } } + +pub async fn sync(client: &InnerClient) -> Result<(), Error> { + let buf = Bytes::from_static(b"S\0\0\0\x04"); + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ReadyForQuery(_) => Ok(()), + _ => Err(Error::unexpected_message()), + } +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 9a6aa26fe..0a83536c2 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,6 +147,12 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } +#[tokio::test] +async fn sync() { + let client = connect("user=postgres").await; + client.check_connection().await.unwrap(); +} + #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await;