Skip to content

cargo: update rust-driver to 0.14 #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions scylla-rust-wrapper/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions scylla-rust-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ categories = ["database"]
license = "MIT OR Apache-2.0"

[dependencies]
scylla = { version = "0.13.1", features = ["ssl"] }
scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.14.0", features = [
"ssl",
] }
tokio = { version = "1.27.0", features = ["full"] }
lazy_static = "1.4.0"
uuid = "1.1.2"
Expand All @@ -29,11 +31,11 @@ bindgen = "0.65"
chrono = "0.4.20"

[dev-dependencies]
scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.14.0" }

assert_matches = "1.5.0"
ntest = "0.9.3"
rusty-fork = "0.3.0"
scylla-proxy = { version = "0.0.4" }

[lib]
name = "scylla_cpp_driver"
crate-type = ["cdylib", "staticlib"]
Expand Down
4 changes: 4 additions & 0 deletions scylla-rust-wrapper/src/cass_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ impl From<&QueryError> for CassError {
QueryError::UnableToAllocStreamId => CassError::CASS_ERROR_LIB_NO_STREAMS,
QueryError::RequestTimeout(_) => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT,
QueryError::TranslationError(_) => CassError::CASS_ERROR_LIB_HOST_RESOLUTION,
QueryError::CqlResponseParseError(_) => CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE,
}
}
}
Expand Down Expand Up @@ -83,6 +84,9 @@ impl From<&NewSessionError> for CassError {
NewSessionError::UnableToAllocStreamId => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::RequestTimeout(_) => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT,
NewSessionError::TranslationError(_) => CassError::CASS_ERROR_LIB_HOST_RESOLUTION,
NewSessionError::CqlResponseParseError(_) => {
CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE
}
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions scylla-rust-wrapper/src/prepared.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use scylla::frame::value::MaybeUnset::Unset;
use scylla::{frame::value::MaybeUnset::Unset, transport::PagingState};
use std::sync::Arc;

use crate::{
Expand Down Expand Up @@ -28,7 +28,9 @@ pub unsafe extern "C" fn cass_prepared_bind(
Box::into_raw(Box::new(CassStatement {
statement,
bound_values: vec![Unset; bound_values_size],
paging_state: None,
paging_state: PagingState::start(),
// Cpp driver disables paging by default.
paging_enabled: false,
request_timeout_ms: None,
exec_profile: None,
}))
Expand Down
24 changes: 15 additions & 9 deletions scylla-rust-wrapper/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::metadata::{
use crate::types::*;
use crate::uuid::CassUuid;
use scylla::frame::response::result::{ColumnSpec, CqlValue};
use scylla::Bytes;
use scylla::transport::PagingStateResponse;
use std::convert::TryInto;
use std::os::raw::c_char;
use std::sync::Arc;
Expand All @@ -20,7 +20,7 @@ pub struct CassResult {
}

pub struct CassResultData {
pub paging_state: Option<Bytes>,
pub paging_state_response: PagingStateResponse,
pub col_specs: Vec<ColumnSpec>,
pub tracing_id: Option<Uuid>,
}
Expand Down Expand Up @@ -815,7 +815,7 @@ pub unsafe extern "C" fn cass_result_free(result_raw: *const CassResult) {
#[no_mangle]
pub unsafe extern "C" fn cass_result_has_more_pages(result: *const CassResult) -> cass_bool_t {
let result = ptr_to_ref(result);
result.metadata.paging_state.is_some() as cass_bool_t
(!result.metadata.paging_state_response.finished()) as cass_bool_t
}

#[no_mangle]
Expand Down Expand Up @@ -1298,12 +1298,18 @@ pub unsafe extern "C" fn cass_result_paging_state_token(

let result_from_raw = ptr_to_ref(result);

match &result_from_raw.metadata.paging_state {
Some(result_paging_state) => {
*paging_state_size = result_paging_state.len() as u64;
*paging_state = result_paging_state.as_ptr() as *const c_char;
}
None => {
match &result_from_raw.metadata.paging_state_response {
PagingStateResponse::HasMorePages { state } => match state.as_bytes_slice() {
Some(result_paging_state) => {
*paging_state_size = result_paging_state.len() as u64;
*paging_state = result_paging_state.as_ptr() as *const c_char;
}
None => {
*paging_state_size = 0;
*paging_state = std::ptr::null();
}
},
PagingStateResponse::NoMorePages => {
*paging_state_size = 0;
*paging_state = std::ptr::null();
}
Expand Down
39 changes: 27 additions & 12 deletions scylla-rust-wrapper/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use scylla::frame::types::Consistency;
use scylla::query::Query;
use scylla::transport::errors::QueryError;
use scylla::transport::execution_profile::ExecutionProfileHandle;
use scylla::transport::PagingStateResponse;
use scylla::{QueryResult, Session, SessionBuilder};
use std::collections::HashMap;
use std::future::Future;
Expand Down Expand Up @@ -205,7 +206,7 @@ pub unsafe extern "C" fn cass_session_execute_batch(
Ok(_result) => Ok(CassResultValue::QueryResult(Arc::new(CassResult {
rows: None,
metadata: Arc::new(CassResultData {
paging_state: None,
paging_state_response: PagingStateResponse::NoMorePages,
col_specs: vec![],
tracing_id: None,
}),
Expand Down Expand Up @@ -244,6 +245,7 @@ pub unsafe extern "C" fn cass_session_execute(
// DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault.
let statement_opt = ptr_to_ref(statement_raw);
let paging_state = statement_opt.paging_state.clone();
let paging_enabled = statement_opt.paging_enabled;
let bound_values = statement_opt.bound_values.clone();
let request_timeout_ms = statement_opt.request_timeout_ms;

Expand Down Expand Up @@ -274,24 +276,38 @@ pub unsafe extern "C" fn cass_session_execute(
}
}

let query_res: Result<QueryResult, QueryError> = match statement {
let query_res: Result<(QueryResult, PagingStateResponse), QueryError> = match statement {
Statement::Simple(query) => {
session
.query_paged(query.query, bound_values, paging_state)
.await
if paging_enabled {
session
.query_single_page(query.query, bound_values, paging_state)
.await
} else {
session
.query_unpaged(query.query, bound_values)
.await
.map(|result| (result, PagingStateResponse::NoMorePages))
}
}
Statement::Prepared(prepared) => {
session
.execute_paged(&prepared, bound_values, paging_state)
.await
if paging_enabled {
session
.execute_single_page(&prepared, bound_values, paging_state)
.await
} else {
session
.execute_unpaged(&prepared, bound_values)
.await
.map(|result| (result, PagingStateResponse::NoMorePages))
}
}
};

match query_res {
Ok(result) => {
Ok((result, paging_state_response)) => {
let metadata = Arc::new(CassResultData {
paging_state: result.paging_state,
col_specs: result.col_specs,
paging_state_response,
col_specs: result.col_specs().to_vec(),
tracing_id: result.tracing_id,
});
let cass_rows = create_cass_rows_from_rows(result.rows, &metadata);
Expand Down Expand Up @@ -516,7 +532,6 @@ pub unsafe extern "C" fn cass_session_prepare_n(
.map_err(|err| (CassError::from(&err), err.msg()))?;

// Set Cpp Driver default configuration for queries:
prepared.disable_paging();
prepared.set_consistency(Consistency::One);

Ok(CassResultValue::Prepared(Arc::new(prepared)))
Expand Down
56 changes: 23 additions & 33 deletions scylla-rust-wrapper/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use scylla::frame::value::MaybeUnset::{Set, Unset};
use scylla::query::Query;
use scylla::statement::prepared_statement::PreparedStatement;
use scylla::statement::SerialConsistency;
use scylla::{BufMut, Bytes, BytesMut};
use scylla::transport::{PagingState, PagingStateResponse};
use std::collections::HashMap;
use std::convert::TryInto;
use std::os::raw::{c_char, c_int};
Expand All @@ -36,7 +36,8 @@ pub struct SimpleQuery {
pub struct CassStatement {
pub statement: Statement,
pub bound_values: Vec<MaybeUnset<Option<CassCqlValue>>>,
pub paging_state: Option<Bytes>,
pub paging_state: PagingState,
pub paging_enabled: bool,
pub request_timeout_ms: Option<cass_uint64_t>,

pub(crate) exec_profile: Option<PerStatementExecProfile>,
Expand Down Expand Up @@ -145,10 +146,7 @@ pub unsafe extern "C" fn cass_statement_new_n(
None => return std::ptr::null_mut(),
};

let mut query = Query::new(query_str.to_string());

// Set Cpp Driver default configuration for queries:
query.disable_paging();
let query = Query::new(query_str.to_string());

let simple_query = SimpleQuery {
query,
Expand All @@ -158,7 +156,9 @@ pub unsafe extern "C" fn cass_statement_new_n(
Box::into_raw(Box::new(CassStatement {
statement: Statement::Simple(simple_query),
bound_values: vec![Unset; parameter_count as usize],
paging_state: None,
paging_state: PagingState::start(),
// Cpp driver disables paging by default.
paging_enabled: false,
request_timeout_ms: None,
exec_profile: None,
}))
Expand Down Expand Up @@ -191,21 +191,15 @@ pub unsafe extern "C" fn cass_statement_set_paging_size(
statement_raw: *mut CassStatement,
page_size: c_int,
) -> CassError {
// TODO: validate page_size
match &mut ptr_to_ref_mut(statement_raw).statement {
Statement::Simple(inner) => {
if page_size == -1 {
inner.query.disable_paging()
} else {
inner.query.set_page_size(page_size)
}
}
Statement::Prepared(inner) => {
if page_size == -1 {
Arc::make_mut(inner).disable_paging()
} else {
Arc::make_mut(inner).set_page_size(page_size)
}
let statement = ptr_to_ref_mut(statement_raw);
if page_size <= 0 {
// Cpp driver sets the page size flag only for positive page size provided by user.
statement.paging_enabled = false;
} else {
statement.paging_enabled = true;
match &mut statement.statement {
Statement::Simple(inner) => inner.query.set_page_size(page_size),
Statement::Prepared(inner) => Arc::make_mut(inner).set_page_size(page_size),
}
}

Expand All @@ -220,9 +214,10 @@ pub unsafe extern "C" fn cass_statement_set_paging_state(
let statement = ptr_to_ref_mut(statement);
let result = ptr_to_ref(result);

statement
.paging_state
.clone_from(&result.metadata.paging_state);
match &result.metadata.paging_state_response {
PagingStateResponse::HasMorePages { state } => statement.paging_state.clone_from(state),
PagingStateResponse::NoMorePages => statement.paging_state = PagingState::start(),
}
CassError::CASS_OK
}

Expand All @@ -235,18 +230,13 @@ pub unsafe extern "C" fn cass_statement_set_paging_state_token(
let statement_from_raw = ptr_to_ref_mut(statement);

if paging_state.is_null() {
statement_from_raw.paging_state = None;
statement_from_raw.paging_state = PagingState::start();
return CassError::CASS_ERROR_LIB_NULL_VALUE;
}

let paging_state_usize: usize = paging_state_size.try_into().unwrap();
let mut b = BytesMut::with_capacity(paging_state_usize);
let paging_state_bytes = slice::from_raw_parts(paging_state, paging_state_usize);
for byte in paging_state_bytes {
b.put_i8(*byte);
}
statement_from_raw.paging_state = Some(b.freeze());

let paging_state_bytes = slice::from_raw_parts(paging_state as *const u8, paging_state_usize);
statement_from_raw.paging_state = PagingState::new_from_raw_bytes(paging_state_bytes);
CassError::CASS_OK
}

Expand Down
Loading
Loading