Skip to content

Commit

Permalink
Fix transaction not being rolled back on Client::transaction() `Fut…
Browse files Browse the repository at this point in the history
…ure` dropped before completion
  • Loading branch information
ilslv committed Oct 28, 2021
1 parent 0adcf58 commit f6189a9
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 4 deletions.
41 changes: 38 additions & 3 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::codec::BackendMessages;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::connection::{Request, RequestMessages};
use crate::copy_out::CopyOutStream;
Expand All @@ -19,7 +19,7 @@ use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::{backend::Message, frontend};
use postgres_types::BorrowToSql;
use std::collections::HashMap;
use std::fmt;
Expand Down Expand Up @@ -488,7 +488,42 @@ impl Client {
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
self.batch_execute("BEGIN").await?;
struct RollbackIfNotDone<'me> {
client: &'me Client,
done: bool,
}

impl<'a> Drop for RollbackIfNotDone<'a> {
fn drop(&mut self) {
if self.done {
return;
}

let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}

// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
self.batch_execute("BEGIN").await?;
cleaner.done = true;
}

Ok(Transaction::new(self))
}

Expand Down
122 changes: 121 additions & 1 deletion tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
use bytes::{Bytes, BytesMut};
use futures::channel::mpsc;
use futures::{
future, join, pin_mut, stream, try_join, FutureExt, SinkExt, StreamExt, TryStreamExt,
future, join, pin_mut, stream, try_join, Future, FutureExt, SinkExt, StreamExt, TryStreamExt,
};
use pin_project_lite::pin_project;
use std::fmt::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time;
Expand All @@ -22,6 +25,35 @@ mod parse;
mod runtime;
mod types;

pin_project! {
/// Polls `F` at most `polls_left` times returning `Some(F::Output)` if
/// [`Future`] returned [`Poll::Ready`] or [`None`] otherwise.
struct Cancellable<F> {
#[pin]
fut: F,
polls_left: usize,
}
}

impl<F: Future> Future for Cancellable<F> {
type Output = Option<F::Output>;

fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.fut.poll(ctx) {
Poll::Ready(r) => Poll::Ready(Some(r)),
Poll::Pending => {
*this.polls_left = this.polls_left.saturating_sub(1);
if *this.polls_left == 0 {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
}

async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
let config = s.parse::<Config>().unwrap();
Expand All @@ -35,6 +67,20 @@ async fn connect(s: &str) -> Client {
client
}

async fn current_transaction_id(client: &Client) -> i64 {
client
.query("SELECT txid_current()", &[])
.await
.unwrap()
.pop()
.unwrap()
.get::<_, i64>("txid_current")
}

async fn in_transaction(client: &Client) -> bool {
current_transaction_id(client).await == current_transaction_id(client).await
}

#[tokio::test]
async fn plain_password_missing() {
connect_raw("user=pass_user dbname=postgres")
Expand Down Expand Up @@ -377,6 +423,80 @@ async fn transaction_rollback() {
assert_eq!(rows.len(), 0);
}

#[tokio::test]
async fn transaction_future_cancellation() {
let mut client = connect("user=postgres").await;

for i in 0.. {
let done = {
let txn = client.transaction();
let fut = Cancellable {
fut: txn,
polls_left: i,
};
fut.await
.map(|res| res.expect("transaction failed"))
.is_some()
};

assert!(!in_transaction(&client).await);

if done {
break;
}
}
}

#[tokio::test]
async fn transaction_commit_future_cancellation() {
let mut client = connect("user=postgres").await;

for i in 0.. {
let done = {
let txn = client.transaction().await.unwrap();
let commit = txn.commit();
let fut = Cancellable {
fut: commit,
polls_left: i,
};
fut.await
.map(|res| res.expect("transaction failed"))
.is_some()
};

assert!(!in_transaction(&client).await);

if done {
break;
}
}
}

#[tokio::test]
async fn transaction_rollback_future_cancellation() {
let mut client = connect("user=postgres").await;

for i in 0.. {
let done = {
let txn = client.transaction().await.unwrap();
let rollback = txn.rollback();
let fut = Cancellable {
fut: rollback,
polls_left: i,
};
fut.await
.map(|res| res.expect("transaction failed"))
.is_some()
};

assert!(!in_transaction(&client).await);

if done {
break;
}
}
}

#[tokio::test]
async fn transaction_rollback_drop() {
let mut client = connect("user=postgres").await;
Expand Down

0 comments on commit f6189a9

Please sign in to comment.