Skip to content

Commit

Permalink
Merge pull request #1075 from muzarski/fix-use-keyspace-errors-manage…
Browse files Browse the repository at this point in the history
…ment-logic

errors: fix driver's logic that bases on error variants returned from query execution
  • Loading branch information
wprzytula authored Sep 20, 2024
2 parents 44a3093 + ce0fdca commit ab07be6
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 71 deletions.
62 changes: 37 additions & 25 deletions scylla/src/transport/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,31 +731,7 @@ impl ClusterWorker {
let use_keyspace_results: Vec<Result<(), QueryError>> =
join_all(use_keyspace_futures).await;

// If there was at least one Ok and the rest were IoErrors we can return Ok
// keyspace name is correct and will be used on broken connection on the next reconnect

// If there were only IoErrors then return IoError
// If there was an error different than IoError return this error - something is wrong

let mut was_ok: bool = false;
let mut io_error: Option<Arc<std::io::Error>> = None;

for result in use_keyspace_results {
match result {
Ok(()) => was_ok = true,
Err(err) => match err {
QueryError::IoError(io_err) => io_error = Some(io_err),
_ => return Err(err),
},
}
}

if was_ok {
return Ok(());
}

// We can unwrap io_error because use_keyspace_futures must be nonempty
Err(QueryError::IoError(io_error.unwrap()))
use_keyspace_result(use_keyspace_results.into_iter())
}

async fn perform_refresh(&mut self) -> Result<(), QueryError> {
Expand Down Expand Up @@ -788,3 +764,39 @@ impl ClusterWorker {
self.cluster_data.store(new_cluster_data);
}
}

/// Returns a result of use_keyspace operation, based on the query results
/// returned from given node/connection.
///
/// This function assumes that `use_keyspace_results` iterator is NON-EMPTY!
pub(crate) fn use_keyspace_result(
use_keyspace_results: impl Iterator<Item = Result<(), QueryError>>,
) -> Result<(), QueryError> {
// If there was at least one Ok and the rest were broken connection errors we can return Ok
// keyspace name is correct and will be used on broken connection on the next reconnect

// If there were only broken connection errors then return broken connection error.
// If there was an error different than broken connection error return this error - something is wrong

let mut was_ok: bool = false;
let mut broken_conn_error: Option<QueryError> = None;

for result in use_keyspace_results {
match result {
Ok(()) => was_ok = true,
Err(err) => match err {
QueryError::BrokenConnection(_) | QueryError::ConnectionPoolError(_) => {
broken_conn_error = Some(err)
}
_ => return Err(err),
},
}
}

if was_ok {
return Ok(());
}

// We can unwrap conn_broken_error because use_keyspace_results must be nonempty
Err(broken_conn_error.unwrap())
}
26 changes: 1 addition & 25 deletions scylla/src/transport/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,31 +1108,7 @@ impl PoolRefiller {
.await
.map_err(|_| QueryError::TimeoutError)?;

// If there was at least one Ok and the rest were IoErrors we can return Ok
// keyspace name is correct and will be used on broken connection on the next reconnect

// If there were only IoErrors then return IoError
// If there was an error different than IoError return this error - something is wrong

let mut was_ok: bool = false;
let mut io_error: Option<Arc<std::io::Error>> = None;

for result in use_keyspace_results {
match result {
Ok(()) => was_ok = true,
Err(err) => match err {
QueryError::IoError(io_err) => io_error = Some(io_err),
_ => return Err(err),
},
}
}

if was_ok {
return Ok(());
}

// We can unwrap io_error because use_keyspace_futures must be nonempty
Err(QueryError::IoError(io_error.unwrap()))
super::cluster::use_keyspace_result(use_keyspace_results.into_iter())
};

tokio::task::spawn(async move {
Expand Down
12 changes: 7 additions & 5 deletions scylla/src/transport/downgrading_consistency_retry_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ impl RetrySession for DowngradingConsistencyRetrySession {
match query_info.error {
// Basic errors - there are some problems on this node
// Retry on a different one if possible
QueryError::IoError(_)
QueryError::BrokenConnection(_)
| QueryError::ConnectionPoolError(_)
| QueryError::DbError(DbError::Overloaded, _)
| QueryError::DbError(DbError::ServerError, _)
| QueryError::DbError(DbError::TruncateError, _) => {
Expand Down Expand Up @@ -181,12 +182,10 @@ impl RetrySession for DowngradingConsistencyRetrySession {

#[cfg(test)]
mod tests {
use std::{io::ErrorKind, sync::Arc};

use bytes::Bytes;

use crate::test_utils::setup_tracing;
use crate::transport::errors::BadQuery;
use crate::transport::errors::{BadQuery, BrokenConnectionErrorKind, ConnectionPoolError};

use super::*;

Expand Down Expand Up @@ -328,7 +327,10 @@ mod tests {
QueryError::DbError(DbError::Overloaded, String::new()),
QueryError::DbError(DbError::TruncateError, String::new()),
QueryError::DbError(DbError::ServerError, String::new()),
QueryError::IoError(Arc::new(std::io::Error::new(ErrorKind::Other, "test"))),
QueryError::BrokenConnection(
BrokenConnectionErrorKind::TooManyOrphanedStreamIds(5).into(),
),
QueryError::ConnectionPoolError(ConnectionPoolError::Initializing),
];

for &cl in CONSISTENCY_LEVELS {
Expand Down
9 changes: 0 additions & 9 deletions scylla/src/transport/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ pub enum QueryError {
#[error("Failed to deserialize ERROR response: {0}")]
CqlErrorParseError(#[from] CqlErrorParseError),

/// Input/Output error has occurred, connection broken etc.
#[error("IO Error: {0}")]
IoError(Arc<std::io::Error>),

/// Selected node's connection pool is in invalid state.
#[error("No connections in the pool: {0}")]
ConnectionPoolError(#[from] ConnectionPoolError),
Expand Down Expand Up @@ -154,7 +150,6 @@ impl From<QueryError> for NewSessionError {
QueryError::BadQuery(e) => NewSessionError::BadQuery(e),
QueryError::CqlResultParseError(e) => NewSessionError::CqlResultParseError(e),
QueryError::CqlErrorParseError(e) => NewSessionError::CqlErrorParseError(e),
QueryError::IoError(e) => NewSessionError::IoError(e),
QueryError::ConnectionPoolError(e) => NewSessionError::ConnectionPoolError(e),
QueryError::ProtocolError(m) => NewSessionError::ProtocolError(m),
QueryError::InvalidMessage(m) => NewSessionError::InvalidMessage(m),
Expand Down Expand Up @@ -207,10 +202,6 @@ pub enum NewSessionError {
#[error("Failed to deserialize ERROR response: {0}")]
CqlErrorParseError(#[from] CqlErrorParseError),

/// Input/Output error has occurred, connection broken etc.
#[error("IO Error: {0}")]
IoError(Arc<std::io::Error>),

/// Selected node's connection pool is in invalid state.
#[error("No connections in the pool: {0}")]
ConnectionPoolError(#[from] ConnectionPoolError),
Expand Down
1 change: 0 additions & 1 deletion scylla/src/transport/load_balancing/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2855,7 +2855,6 @@ mod latency_awareness {
| QueryError::CqlResultParseError(_)
| QueryError::CqlErrorParseError(_)
| QueryError::InvalidMessage(_)
| QueryError::IoError(_)
| QueryError::ProtocolError(_)
| QueryError::TimeoutError
| QueryError::RequestTimeout(_) => true,
Expand Down
14 changes: 9 additions & 5 deletions scylla/src/transport/retry_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ impl RetrySession for DefaultRetrySession {
match query_info.error {
// Basic errors - there are some problems on this node
// Retry on a different one if possible
QueryError::IoError(_)
QueryError::BrokenConnection(_)
| QueryError::ConnectionPoolError(_)
| QueryError::DbError(DbError::Overloaded, _)
| QueryError::DbError(DbError::ServerError, _)
| QueryError::DbError(DbError::TruncateError, _) => {
Expand Down Expand Up @@ -221,11 +222,11 @@ mod tests {
use super::{DefaultRetryPolicy, QueryInfo, RetryDecision, RetryPolicy};
use crate::statement::Consistency;
use crate::test_utils::setup_tracing;
use crate::transport::errors::{BadQuery, QueryError};
use crate::transport::errors::{
BadQuery, BrokenConnectionErrorKind, ConnectionPoolError, QueryError,
};
use crate::transport::errors::{DbError, WriteType};
use bytes::Bytes;
use std::io::ErrorKind;
use std::sync::Arc;

fn make_query_info(error: &QueryError, is_idempotent: bool) -> QueryInfo<'_> {
QueryInfo {
Expand Down Expand Up @@ -323,7 +324,10 @@ mod tests {
QueryError::DbError(DbError::Overloaded, String::new()),
QueryError::DbError(DbError::TruncateError, String::new()),
QueryError::DbError(DbError::ServerError, String::new()),
QueryError::IoError(Arc::new(std::io::Error::new(ErrorKind::Other, "test"))),
QueryError::BrokenConnection(
BrokenConnectionErrorKind::TooManyOrphanedStreamIds(5).into(),
),
QueryError::ConnectionPoolError(ConnectionPoolError::Initializing),
];

for error in idempotent_next_errors {
Expand Down
3 changes: 2 additions & 1 deletion scylla/src/transport/speculative_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ impl SpeculativeExecutionPolicy for PercentileSpeculativeExecutionPolicy {
fn can_be_ignored<ResT>(result: &Result<ResT, QueryError>) -> bool {
match result {
Ok(_) => false,
Err(QueryError::IoError(_)) => true,
Err(QueryError::BrokenConnection(_)) => true,
Err(QueryError::ConnectionPoolError(_)) => true,
Err(QueryError::TimeoutError) => true,
_ => false,
}
Expand Down

0 comments on commit ab07be6

Please sign in to comment.