Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

errors: pager API errors refactor #1160

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 51 additions & 46 deletions scylla/src/client/pager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ use crate::cluster::{ClusterState, NodeRef};
#[allow(deprecated)]
use crate::cql_to_rust::{FromRow, FromRowError};
use crate::deserialize::DeserializeOwnedRow;
use crate::errors::{ProtocolError, RequestError};
use crate::errors::{QueryError, RequestAttemptError};
use crate::errors::{RequestAttemptError, RequestError};
use crate::frame::response::result;
use crate::network::Connection;
use crate::observability::driver_tracing::RequestSpan;
use crate::observability::history::{self, HistoryListener};
use crate::observability::metrics::Metrics;
use crate::policies::load_balancing::{self, RoutingInfo};
use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
use crate::prepared_statement::PartitionKeyError;
use crate::response::query_result::ColumnSpecs;
use crate::response::{NonErrorQueryResponse, QueryResponse};
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
Expand Down Expand Up @@ -79,9 +79,7 @@ mod checked_channel_sender {
use tokio::sync::mpsc;
use uuid::Uuid;

use crate::errors::QueryError;

use super::ReceivedPage;
use super::{NextPageError, ReceivedPage};

/// A value whose existence proves that there was an attempt
/// to send an item of type T through a channel.
Expand All @@ -106,7 +104,7 @@ mod checked_channel_sender {
}
}

type ResultPage = Result<ReceivedPage, QueryError>;
type ResultPage = Result<ReceivedPage, NextPageError>;

impl ProvingSender<ResultPage> {
pub(crate) async fn send_empty_page(
Expand All @@ -127,12 +125,12 @@ mod checked_channel_sender {

use checked_channel_sender::{ProvingSender, SendAttemptedProof};

type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError>>;
type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, NextPageError>>;

// PagerWorker works in the background to fetch pages
// QueryPager receives them through a channel
struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,

// Closure used to perform a single page query
// AsyncFn(Arc<Connection>, Option<Arc<[u8]>>) -> Result<QueryResponse, RequestAttemptError>
Expand Down Expand Up @@ -267,7 +265,10 @@ where
}

self.log_request_error(&last_error);
let (proof, _) = self.sender.send(Err(last_error.into_query_error())).await;
let (proof, _) = self
.sender
.send(Err(NextPageError::RequestFailure(last_error)))
.await;
proof
}

Expand Down Expand Up @@ -477,7 +478,7 @@ where
/// any complicated logic related to retries, it just fetches pages from
/// a single connection.
struct SingleConnectionPagerWorker<Fetcher> {
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
fetcher: Fetcher,
}

Expand All @@ -490,21 +491,22 @@ where
match self.do_work().await {
Ok(proof) => proof,
Err(err) => {
let (proof, _) = self.sender.send(Err(err)).await;
let (proof, _) = self
.sender
.send(Err(NextPageError::RequestFailure(
RequestError::LastAttemptError(err),
)))
.await;
proof
}
}
}

async fn do_work(&mut self) -> Result<PageSendAttemptedProof, QueryError> {
async fn do_work(&mut self) -> Result<PageSendAttemptedProof, RequestAttemptError> {
let mut paging_state = PagingState::start();
loop {
let result = (self.fetcher)(paging_state)
.await
.map_err(RequestAttemptError::into_query_error)?;
let response = result
.into_non_error_query_response()
.map_err(RequestAttemptError::into_query_error)?;
let result = (self.fetcher)(paging_state).await?;
let response = result.into_non_error_query_response()?;
match response.response {
NonErrorResponse::Result(result::Result::Rows((rows, paging_state_response))) => {
let (proof, send_result) = self
Expand Down Expand Up @@ -539,10 +541,9 @@ where
return Ok(proof);
}
_ => {
return Err(ProtocolError::UnexpectedResponse(
return Err(RequestAttemptError::UnexpectedResponse(
response.response.to_response_kind(),
)
.into());
));
}
}
}
Expand All @@ -565,7 +566,7 @@ where
/// is not the intended target type.
pub struct QueryPager {
current_page: RawRowLendingIterator,
page_receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
page_receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
tracing_ids: Vec<Uuid>,
}

Expand All @@ -583,7 +584,7 @@ impl QueryPager {
/// borrows from self.
///
/// This is cancel-safe.
async fn next(&mut self) -> Option<Result<ColumnIterator, QueryError>> {
async fn next(&mut self) -> Option<Result<ColumnIterator, NextRowError>> {
let res = std::future::poll_fn(|cx| Pin::new(&mut *self).poll_fill_page(cx)).await;
match res {
Some(Ok(())) => {}
Expand All @@ -596,15 +597,15 @@ impl QueryPager {
self.current_page
.next()
.unwrap()
.map_err(|err| NextRowError::RowDeserializationError(err).into()),
.map_err(NextRowError::RowDeserializationError),
)
}

/// Tries to acquire a non-empty page, if current page is exhausted.
fn poll_fill_page<'r>(
mut self: Pin<&'r mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), QueryError>>> {
) -> Poll<Option<Result<(), NextRowError>>> {
if !self.is_current_page_exhausted() {
return Poll::Ready(Some(Ok(())));
}
Expand All @@ -627,14 +628,11 @@ impl QueryPager {
fn poll_next_page<'r>(
mut self: Pin<&'r mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), QueryError>>> {
) -> Poll<Option<Result<(), NextRowError>>> {
let mut s = self.as_mut();

let received_page = ready_some_ok!(Pin::new(&mut s.page_receiver).poll_recv(cx));

// TODO: see my other comment next to QueryError::NextRowError
// This is the place where conversion happens. To fix this, we need to refactor error types in iterator API.
// The `page_receiver`'s error type should be narrowed from QueryError to some other error type.
let raw_rows_with_deserialized_metadata =
received_page.rows.deserialize_metadata().map_err(|err| {
NextRowError::NextPageError(NextPageError::ResultMetadataParseError(err))
Expand Down Expand Up @@ -689,8 +687,8 @@ impl QueryPager {
execution_profile: Arc<ExecutionProfileInner>,
cluster_data: Arc<ClusterState>,
metrics: Arc<Metrics>,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let consistency = query
.config
Expand Down Expand Up @@ -768,8 +766,8 @@ impl QueryPager {

pub(crate) async fn new_for_prepared_statement(
config: PreparedIteratorConfig,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let consistency = config
.prepared
Expand Down Expand Up @@ -803,7 +801,9 @@ impl QueryPager {
) {
Ok(res) => res.unzip(),
Err(err) => {
let (proof, _res) = ProvingSender::from(sender).send(Err(err)).await;
let (proof, _res) = ProvingSender::from(sender)
.send(Err(NextPageError::PartitionKeyError(err)))
.await;
return proof;
}
};
Expand Down Expand Up @@ -889,8 +889,8 @@ impl QueryPager {
connection: Arc<Connection>,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let page_size = query.get_validated_page_size();

Expand Down Expand Up @@ -919,8 +919,8 @@ impl QueryPager {
connection: Arc<Connection>,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let page_size = prepared.get_validated_page_size();

Expand All @@ -946,8 +946,8 @@ impl QueryPager {

async fn new_from_worker_future(
worker_task: impl Future<Output = PageSendAttemptedProof> + Send + 'static,
mut receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
) -> Result<Self, QueryError> {
mut receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
) -> Result<Self, NextRowError> {
tokio::task::spawn(worker_task);

// This unwrap is safe because:
Expand Down Expand Up @@ -1035,14 +1035,14 @@ impl<RowT> Stream for TypedRowStream<RowT>
where
RowT: DeserializeOwnedRow,
{
type Item = Result<RowT, QueryError>;
type Item = Result<RowT, NextRowError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next_fut = async {
self.raw_row_lending_stream.next().await.map(|res| {
res.and_then(|column_iterator| {
<RowT as DeserializeRow>::deserialize(column_iterator)
.map_err(|err| NextRowError::RowDeserializationError(err).into())
.map_err(NextRowError::RowDeserializationError)
})
})
};
Expand All @@ -1057,12 +1057,17 @@ where
#[derive(Error, Debug, Clone)]
#[non_exhaustive]
pub enum NextPageError {
/// PK extraction and/or token calculation error. Applies only for prepared statements.
#[error("Failed to extract PK and compute token required for routing: {0}")]
PartitionKeyError(#[from] PartitionKeyError),

/// Failed to run a request responsible for fetching new page.
#[error(transparent)]
RequestFailure(#[from] RequestError),

/// Failed to deserialize result metadata associated with next page response.
#[error("Failed to deserialize result metadata associated with next page response: {0}")]
ResultMetadataParseError(#[from] ResultMetadataAndRowsCountParseError),
// TODO: This should also include a variant representing an error that occurred during
// query that fetches the next page. However, as of now, it would require that we include QueryError here.
// This would introduce a cyclic dependency: QueryError -> NextRowError -> NextPageError -> QueryError.
Comment on lines 1068 to -1065
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit: "iterator: narrow error type of internal items "
❓ You removed this comment, and included "RequestFailure". This does not introduce a cycle because you used "RequestError" instead of "QueryError", right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

}

/// An error returned by async iterator API.
Expand Down Expand Up @@ -1172,7 +1177,7 @@ mod legacy {
pub enum LegacyNextRowError {
/// Query to fetch next page has failed
#[error(transparent)]
QueryError(#[from] QueryError),
NextRowError(#[from] NextRowError),

/// Parsing values in row as given types failed
#[error(transparent)]
Expand Down
8 changes: 6 additions & 2 deletions scylla/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::policies::host_filter::HostFilter;
use crate::policies::load_balancing::{self, RoutingInfo};
use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
use crate::policies::speculative_execution;
use crate::prepared_statement::PreparedStatement;
use crate::prepared_statement::{PartitionKeyError, PreparedStatement};
use crate::query::Query;
#[allow(deprecated)]
use crate::response::legacy_query_result::LegacyQueryResult;
Expand Down Expand Up @@ -1235,6 +1235,7 @@ where
self.metrics.clone(),
)
.await
.map_err(QueryError::from)
} else {
// Making QueryPager::new_for_query work with values is too hard (if even possible)
// so instead of sending one prepare to a specific connection on each iterator query,
Expand All @@ -1249,6 +1250,7 @@ where
metrics: self.metrics.clone(),
})
.await
.map_err(QueryError::from)
}
}

Expand Down Expand Up @@ -1394,7 +1396,8 @@ where
let paging_state_ref = &paging_state;

let (partition_key, token) = prepared
.extract_partition_key_and_calculate_token(prepared.get_partitioner_name(), values_ref)?
.extract_partition_key_and_calculate_token(prepared.get_partitioner_name(), values_ref)
.map_err(PartitionKeyError::into_query_error)?
.unzip();

let execution_profile = prepared
Expand Down Expand Up @@ -1503,6 +1506,7 @@ where
metrics: self.metrics.clone(),
})
.await
.map_err(QueryError::from)
}

async fn do_batch(
Expand Down
Loading
Loading