Skip to content

Commit

Permalink
feat: track transaction state
Browse files Browse the repository at this point in the history
  • Loading branch information
sunng87 committed Sep 18, 2024
1 parent 6e93585 commit 1883661
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/api/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ where
)));
let mut message_stream = stream::iter(messages.into_iter().map(Ok));
client.send_all(&mut message_stream).await.unwrap();
client.set_state(PgWireConnectionState::ReadyForQuery);
client.set_state(PgWireConnectionState::ReadyForQuery(
TransactionStatus::Idle,
));
}

pub mod cleartext;
Expand Down
5 changes: 4 additions & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::sync::Arc;

pub use postgres_types::Type;

use crate::messages::response::TransactionStatus;

pub mod auth;
pub mod copy;
pub mod portal;
Expand All @@ -21,7 +23,8 @@ pub enum PgWireConnectionState {
#[default]
AwaitingStartup,
AuthenticationInProgress,
ReadyForQuery,
// in transaction or not
ReadyForQuery(TransactionStatus),
QueryInProgress,
CopyInProgress(bool),
AwaitingSync,
Expand Down
62 changes: 54 additions & 8 deletions src/api/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ pub trait SimpleQueryHandler: Send + Sync {
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
// make sure client is ready for query
let mut transaction_status = match client.state() {
PgWireConnectionState::ReadyForQuery(transaction_status) => transaction_status,
_ => return Err(PgWireError::NotReadyForQuery),
};

client.set_state(super::PgWireConnectionState::QueryInProgress);
let query_string = query.query;

if is_empty_query(&query_string) {
client
.feed(PgWireBackendMessage::EmptyQueryResponse(EmptyQueryResponse))
Expand All @@ -65,10 +72,21 @@ pub trait SimpleQueryHandler: Send + Sync {
Response::Execution(tag) => {
send_execution_response(client, tag).await?;
}
Response::TransactionStart(tag) => {
send_execution_response(client, tag).await?;
transaction_status = TransactionStatus::Transaction;
}
Response::TransactionEnd(tag) => {
send_execution_response(client, tag).await?;
transaction_status = TransactionStatus::Idle;
}
Response::Error(e) => {
client
.feed(PgWireBackendMessage::ErrorResponse((*e).into()))
.await?;
if transaction_status == TransactionStatus::Transaction {
transaction_status = TransactionStatus::Error;
}
}
Response::CopyIn(result) => {
copy::send_copy_in_response(client, result).await?;
Expand All @@ -92,8 +110,10 @@ pub trait SimpleQueryHandler: Send + Sync {
// to send a `ReadyForQuery` message or reset the connection state
// back to `ReadyForQuery`. This is the responsibility of of the
// `on_copy_done` / `on_copy_fail`.
send_ready_for_query(client, TransactionStatus::Idle).await?;
client.set_state(super::PgWireConnectionState::ReadyForQuery);
send_ready_for_query(client, transaction_status).await?;
client.set_state(super::PgWireConnectionState::ReadyForQuery(
transaction_status,
));
};

Ok(())
Expand Down Expand Up @@ -180,6 +200,14 @@ pub trait ExtendedQueryHandler: Send + Sync {
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
// make sure client is ready for query
let mut transaction_status = match client.state() {
PgWireConnectionState::ReadyForQuery(transaction_status) => transaction_status,
_ => return Err(PgWireError::NotReadyForQuery),
};

client.set_state(super::PgWireConnectionState::QueryInProgress);

let portal_name = message.name.as_deref().unwrap_or(DEFAULT_NAME);
if let Some(portal) = client.portal_store().get_portal(portal_name) {
match self
Expand All @@ -197,6 +225,15 @@ pub trait ExtendedQueryHandler: Send + Sync {
Response::Execution(tag) => {
send_execution_response(client, tag).await?;
}
Response::TransactionStart(tag) => {
send_execution_response(client, tag).await?;
transaction_status = TransactionStatus::Transaction;
}
Response::TransactionEnd(tag) => {
send_execution_response(client, tag).await?;
transaction_status = TransactionStatus::Idle;
}

Response::Error(err) => {
client
.send(PgWireBackendMessage::ErrorResponse((*err).into()))
Expand All @@ -216,6 +253,12 @@ pub trait ExtendedQueryHandler: Send + Sync {
}
}

if !matches!(client.state(), PgWireConnectionState::CopyInProgress(_)) {
client.set_state(super::PgWireConnectionState::ReadyForQuery(
transaction_status,
));
};

Ok(())
} else {
Err(PgWireError::PortalNotFound(portal_name.to_owned()))
Expand Down Expand Up @@ -266,12 +309,15 @@ pub trait ExtendedQueryHandler: Send + Sync {
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
client
.send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
TransactionStatus::Idle,
)))
.await?;
client.flush().await?;
if let PgWireConnectionState::ReadyForQuery(status) = client.state() {
client
.send(PgWireBackendMessage::ReadyForQuery(ReadyForQuery::new(
status,
)))
.await?;
client.flush().await?;
}

Ok(())
}

Expand Down
2 changes: 2 additions & 0 deletions src/api/results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ pub enum Response<'a> {
EmptyQuery,
Query(QueryResponse<'a>),
Execution(Tag),
TransactionStart(Tag),
TransactionEnd(Tag),
Error(Box<ErrorInfo>),
CopyIn(CopyResponse),
CopyOut(CopyResponse),
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub enum PgWireError {
UnsupportedCertificateSignatureAlgorithm,
#[error("Username is required")]
UserNameRequired,
#[error("Connection is not ready for query")]
NotReadyForQuery,

#[error(transparent)]
ApiError(#[from] Box<dyn std::error::Error + 'static + Send + Sync>),
Expand Down
13 changes: 10 additions & 3 deletions src/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ where
PgWireConnectionState::AwaitingSync => {
if let PgWireFrontendMessage::Sync(sync) = message {
extended_query_handler.on_sync(socket, sync).await?;
socket.set_state(PgWireConnectionState::ReadyForQuery);
// TODO: confirm if we need to track transaction state there
socket.set_state(PgWireConnectionState::ReadyForQuery(
TransactionStatus::Idle,
));
}
}
PgWireConnectionState::CopyInProgress(is_extended_query) => {
Expand All @@ -140,7 +143,9 @@ where
// query, we should leave the CopyInProgress state
// before returning the error in order to resume normal
// operation after handling it in process_error.
socket.set_state(PgWireConnectionState::ReadyForQuery);
socket.set_state(PgWireConnectionState::ReadyForQuery(
TransactionStatus::Idle,
));
}
match result {
Ok(_) => {
Expand All @@ -166,7 +171,9 @@ where
// we should leave the CopyInProgress state
// before returning the error in order to resume normal
// operation after handling it in process_error.
socket.set_state(PgWireConnectionState::ReadyForQuery);
socket.set_state(PgWireConnectionState::ReadyForQuery(
TransactionStatus::Idle,
));
}
return Err(error);
}
Expand Down

0 comments on commit 1883661

Please sign in to comment.