Skip to content

Commit 5c1a0c3

Browse files
committed
cluster: implement cass_cluster_set_local_address[_n]
In addition, set the default value and implemented unit test.
1 parent 716de88 commit 5c1a0c3

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

scylla-rust-wrapper/src/cluster.rs

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ use scylla::statement::{Consistency, SerialConsistency};
2424
use std::collections::HashMap;
2525
use std::convert::TryInto;
2626
use std::future::Future;
27+
use std::net::IpAddr;
2728
use std::os::raw::{c_char, c_int, c_uint};
29+
use std::str::FromStr;
2830
use std::sync::Arc;
2931
use std::time::Duration;
3032

@@ -50,6 +52,8 @@ const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_millis(5000);
5052
const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
5153
// - keepalive timeout is 60 secs
5254
const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(60);
55+
// - default local ip address is arbitrary
56+
const DEFAULT_LOCAL_IP_ADDRESS: Option<IpAddr> = None;
5357

5458
const DRIVER_NAME: &str = "ScyllaDB Cpp-Rust Driver";
5559
const DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -219,6 +223,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster
219223
.connection_timeout(DEFAULT_CONNECT_TIMEOUT)
220224
.keepalive_interval(DEFAULT_KEEPALIVE_INTERVAL)
221225
.keepalive_timeout(DEFAULT_KEEPALIVE_TIMEOUT)
226+
.local_ip_address(DEFAULT_LOCAL_IP_ADDRESS)
222227
};
223228

224229
BoxFFI::into_ptr(Box::new(CassCluster {
@@ -482,6 +487,52 @@ pub unsafe extern "C" fn cass_cluster_set_port(
482487
CassError::CASS_OK
483488
}
484489

490+
#[unsafe(no_mangle)]
491+
pub unsafe extern "C" fn cass_cluster_set_local_address(
492+
cluster_raw: CassBorrowedExclusivePtr<CassCluster, CMut>,
493+
ip: *const c_char,
494+
) -> CassError {
495+
// Safety: We assume that string is null-terminated.
496+
unsafe { cass_cluster_set_local_address_n(cluster_raw, ip, strlen(ip)) }
497+
}
498+
499+
#[unsafe(no_mangle)]
500+
pub unsafe extern "C" fn cass_cluster_set_local_address_n(
501+
cluster_raw: CassBorrowedExclusivePtr<CassCluster, CMut>,
502+
ip: *const c_char,
503+
ip_length: size_t,
504+
) -> CassError {
505+
let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else {
506+
tracing::error!("Provided null cluster pointer to cass_cluster_set_local_address_n!");
507+
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
508+
};
509+
510+
// Semantics from cpp-driver - if pointer is null or length is 0, use the
511+
// arbitrary address (INADDR_ANY, or in6addr_any).
512+
let local_addr: Option<IpAddr> = if ip.is_null() || ip_length == 0 {
513+
None
514+
} else {
515+
// SAFETY: We assume that user provides valid pointer and length.
516+
match unsafe { ptr_to_cstr_n(ip, ip_length) } {
517+
Some(ip_str) => match IpAddr::from_str(ip_str) {
518+
Ok(addr) => Some(addr),
519+
Err(err) => {
520+
tracing::error!("Failed to parse ip address <{}>: {}", ip_str, err);
521+
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
522+
}
523+
},
524+
None => {
525+
tracing::error!("Provided non-utf8 ip string to cass_cluster_set_local_address_n!");
526+
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
527+
}
528+
}
529+
};
530+
531+
cluster.session_builder.config.local_ip_address = local_addr;
532+
533+
CassError::CASS_OK
534+
}
535+
485536
#[unsafe(no_mangle)]
486537
pub unsafe extern "C" fn cass_cluster_set_credentials(
487538
cluster: CassBorrowedExclusivePtr<CassCluster, CMut>,
@@ -975,12 +1026,101 @@ mod tests {
9751026
exec_profile::{cass_execution_profile_free, cass_execution_profile_new},
9761027
};
9771028
use assert_matches::assert_matches;
1029+
use std::net::{Ipv4Addr, Ipv6Addr};
9781030
use std::{
9791031
collections::HashSet,
9801032
convert::{TryFrom, TryInto},
9811033
os::raw::c_char,
9821034
};
9831035

1036+
#[test]
1037+
fn test_local_ip_address() {
1038+
unsafe {
1039+
let mut cluster_raw = cass_cluster_new();
1040+
1041+
// Check default address
1042+
{
1043+
let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
1044+
assert!(cluster.session_builder.config.local_ip_address.is_none());
1045+
}
1046+
1047+
// null ip pointer
1048+
{
1049+
assert_cass_error_eq!(
1050+
cass_cluster_set_local_address(cluster_raw.borrow_mut(), std::ptr::null()),
1051+
CassError::CASS_OK
1052+
);
1053+
1054+
let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
1055+
assert!(cluster.session_builder.config.local_ip_address.is_none());
1056+
}
1057+
1058+
// empty string
1059+
{
1060+
assert_cass_error_eq!(
1061+
cass_cluster_set_local_address(cluster_raw.borrow_mut(), c"".as_ptr()),
1062+
CassError::CASS_OK
1063+
);
1064+
1065+
let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
1066+
assert!(cluster.session_builder.config.local_ip_address.is_none());
1067+
}
1068+
1069+
// valid ipv4 address
1070+
{
1071+
assert_cass_error_eq!(
1072+
cass_cluster_set_local_address(cluster_raw.borrow_mut(), c"1.2.3.4".as_ptr()),
1073+
CassError::CASS_OK
1074+
);
1075+
1076+
let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
1077+
assert_eq!(
1078+
cluster.session_builder.config.local_ip_address,
1079+
Some(Ipv4Addr::new(1, 2, 3, 4).into())
1080+
);
1081+
}
1082+
1083+
// valid ipv6 address
1084+
{
1085+
assert_cass_error_eq!(
1086+
cass_cluster_set_local_address(
1087+
cluster_raw.borrow_mut(),
1088+
c"2001:db8::8a2e:370:7334".as_ptr()
1089+
),
1090+
CassError::CASS_OK
1091+
);
1092+
1093+
let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
1094+
assert_eq!(
1095+
cluster.session_builder.config.local_ip_address,
1096+
Some(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0x8a2e, 0x0370, 0x7334,).into())
1097+
);
1098+
}
1099+
1100+
// non-numeric address
1101+
{
1102+
assert_cass_error_eq!(
1103+
cass_cluster_set_local_address(cluster_raw.borrow_mut(), c"foo".as_ptr()),
1104+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1105+
);
1106+
}
1107+
1108+
// non-valid-utf8 slice
1109+
{
1110+
let non_utf8_slice: &[u8] = &[0xF0, 0x28, 0x8C, 0x28, 0x00];
1111+
assert_cass_error_eq!(
1112+
cass_cluster_set_local_address(
1113+
cluster_raw.borrow_mut(),
1114+
non_utf8_slice.as_ptr() as *const c_char
1115+
),
1116+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1117+
);
1118+
}
1119+
1120+
cass_cluster_free(cluster_raw);
1121+
}
1122+
}
1123+
9841124
#[test]
9851125
#[ntest::timeout(100)]
9861126
fn test_load_balancing_config() {

0 commit comments

Comments
 (0)