@@ -24,7 +24,9 @@ use scylla::statement::{Consistency, SerialConsistency};
2424use std:: collections:: HashMap ;
2525use std:: convert:: TryInto ;
2626use std:: future:: Future ;
27+ use std:: net:: IpAddr ;
2728use std:: os:: raw:: { c_char, c_int, c_uint} ;
29+ use std:: str:: FromStr ;
2830use std:: sync:: Arc ;
2931use std:: time:: Duration ;
3032
@@ -50,6 +52,8 @@ const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_millis(5000);
5052const DEFAULT_KEEPALIVE_INTERVAL : Duration = Duration :: from_secs ( 30 ) ;
5153// - keepalive timeout is 60 secs
5254const DEFAULT_KEEPALIVE_TIMEOUT : Duration = Duration :: from_secs ( 60 ) ;
55+ // - default local ip address is arbitrary
56+ const DEFAULT_LOCAL_IP_ADDRESS : Option < IpAddr > = None ;
5357
5458const DRIVER_NAME : & str = "ScyllaDB Cpp-Rust Driver" ;
5559const DRIVER_VERSION : & str = env ! ( "CARGO_PKG_VERSION" ) ;
@@ -219,6 +223,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster
219223 . connection_timeout ( DEFAULT_CONNECT_TIMEOUT )
220224 . keepalive_interval ( DEFAULT_KEEPALIVE_INTERVAL )
221225 . keepalive_timeout ( DEFAULT_KEEPALIVE_TIMEOUT )
226+ . local_ip_address ( DEFAULT_LOCAL_IP_ADDRESS )
222227 } ;
223228
224229 BoxFFI :: into_ptr ( Box :: new ( CassCluster {
@@ -482,6 +487,52 @@ pub unsafe extern "C" fn cass_cluster_set_port(
482487 CassError :: CASS_OK
483488}
484489
490+ #[ unsafe( no_mangle) ]
491+ pub unsafe extern "C" fn cass_cluster_set_local_address (
492+ cluster_raw : CassBorrowedExclusivePtr < CassCluster , CMut > ,
493+ ip : * const c_char ,
494+ ) -> CassError {
495+ // Safety: We assume that string is null-terminated.
496+ unsafe { cass_cluster_set_local_address_n ( cluster_raw, ip, strlen ( ip) ) }
497+ }
498+
499+ #[ unsafe( no_mangle) ]
500+ pub unsafe extern "C" fn cass_cluster_set_local_address_n (
501+ cluster_raw : CassBorrowedExclusivePtr < CassCluster , CMut > ,
502+ ip : * const c_char ,
503+ ip_length : size_t ,
504+ ) -> CassError {
505+ let Some ( cluster) = BoxFFI :: as_mut_ref ( cluster_raw) else {
506+ tracing:: error!( "Provided null cluster pointer to cass_cluster_set_local_address_n!" ) ;
507+ return CassError :: CASS_ERROR_LIB_BAD_PARAMS ;
508+ } ;
509+
510+ // Semantics from cpp-driver - if pointer is null or length is 0, use the
511+ // arbitrary address (INADDR_ANY, or in6addr_any).
512+ let local_addr: Option < IpAddr > = if ip. is_null ( ) || ip_length == 0 {
513+ None
514+ } else {
515+ // SAFETY: We assume that user provides valid pointer and length.
516+ match unsafe { ptr_to_cstr_n ( ip, ip_length) } {
517+ Some ( ip_str) => match IpAddr :: from_str ( ip_str) {
518+ Ok ( addr) => Some ( addr) ,
519+ Err ( err) => {
520+ tracing:: error!( "Failed to parse ip address <{}>: {}" , ip_str, err) ;
521+ return CassError :: CASS_ERROR_LIB_BAD_PARAMS ;
522+ }
523+ } ,
524+ None => {
525+ tracing:: error!( "Provided non-utf8 ip string to cass_cluster_set_local_address_n!" ) ;
526+ return CassError :: CASS_ERROR_LIB_BAD_PARAMS ;
527+ }
528+ }
529+ } ;
530+
531+ cluster. session_builder . config . local_ip_address = local_addr;
532+
533+ CassError :: CASS_OK
534+ }
535+
485536#[ unsafe( no_mangle) ]
486537pub unsafe extern "C" fn cass_cluster_set_credentials (
487538 cluster : CassBorrowedExclusivePtr < CassCluster , CMut > ,
@@ -975,12 +1026,101 @@ mod tests {
9751026 exec_profile:: { cass_execution_profile_free, cass_execution_profile_new} ,
9761027 } ;
9771028 use assert_matches:: assert_matches;
1029+ use std:: net:: { Ipv4Addr , Ipv6Addr } ;
9781030 use std:: {
9791031 collections:: HashSet ,
9801032 convert:: { TryFrom , TryInto } ,
9811033 os:: raw:: c_char,
9821034 } ;
9831035
1036+ #[ test]
1037+ fn test_local_ip_address ( ) {
1038+ unsafe {
1039+ let mut cluster_raw = cass_cluster_new ( ) ;
1040+
1041+ // Check default address
1042+ {
1043+ let cluster = BoxFFI :: as_ref ( cluster_raw. borrow ( ) ) . unwrap ( ) ;
1044+ assert ! ( cluster. session_builder. config. local_ip_address. is_none( ) ) ;
1045+ }
1046+
1047+ // null ip pointer
1048+ {
1049+ assert_cass_error_eq ! (
1050+ cass_cluster_set_local_address( cluster_raw. borrow_mut( ) , std:: ptr:: null( ) ) ,
1051+ CassError :: CASS_OK
1052+ ) ;
1053+
1054+ let cluster = BoxFFI :: as_ref ( cluster_raw. borrow ( ) ) . unwrap ( ) ;
1055+ assert ! ( cluster. session_builder. config. local_ip_address. is_none( ) ) ;
1056+ }
1057+
1058+ // empty string
1059+ {
1060+ assert_cass_error_eq ! (
1061+ cass_cluster_set_local_address( cluster_raw. borrow_mut( ) , c"" . as_ptr( ) ) ,
1062+ CassError :: CASS_OK
1063+ ) ;
1064+
1065+ let cluster = BoxFFI :: as_ref ( cluster_raw. borrow ( ) ) . unwrap ( ) ;
1066+ assert ! ( cluster. session_builder. config. local_ip_address. is_none( ) ) ;
1067+ }
1068+
1069+ // valid ipv4 address
1070+ {
1071+ assert_cass_error_eq ! (
1072+ cass_cluster_set_local_address( cluster_raw. borrow_mut( ) , c"1.2.3.4" . as_ptr( ) ) ,
1073+ CassError :: CASS_OK
1074+ ) ;
1075+
1076+ let cluster = BoxFFI :: as_ref ( cluster_raw. borrow ( ) ) . unwrap ( ) ;
1077+ assert_eq ! (
1078+ cluster. session_builder. config. local_ip_address,
1079+ Some ( Ipv4Addr :: new( 1 , 2 , 3 , 4 ) . into( ) )
1080+ ) ;
1081+ }
1082+
1083+ // valid ipv6 address
1084+ {
1085+ assert_cass_error_eq ! (
1086+ cass_cluster_set_local_address(
1087+ cluster_raw. borrow_mut( ) ,
1088+ c"2001:db8::8a2e:370:7334" . as_ptr( )
1089+ ) ,
1090+ CassError :: CASS_OK
1091+ ) ;
1092+
1093+ let cluster = BoxFFI :: as_ref ( cluster_raw. borrow ( ) ) . unwrap ( ) ;
1094+ assert_eq ! (
1095+ cluster. session_builder. config. local_ip_address,
1096+ Some ( Ipv6Addr :: new( 0x2001 , 0x0db8 , 0 , 0 , 0 , 0x8a2e , 0x0370 , 0x7334 , ) . into( ) )
1097+ ) ;
1098+ }
1099+
1100+ // non-numeric address
1101+ {
1102+ assert_cass_error_eq ! (
1103+ cass_cluster_set_local_address( cluster_raw. borrow_mut( ) , c"foo" . as_ptr( ) ) ,
1104+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1105+ ) ;
1106+ }
1107+
1108+ // non-valid-utf8 slice
1109+ {
1110+ let non_utf8_slice: & [ u8 ] = & [ 0xF0 , 0x28 , 0x8C , 0x28 , 0x00 ] ;
1111+ assert_cass_error_eq ! (
1112+ cass_cluster_set_local_address(
1113+ cluster_raw. borrow_mut( ) ,
1114+ non_utf8_slice. as_ptr( ) as * const c_char
1115+ ) ,
1116+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1117+ ) ;
1118+ }
1119+
1120+ cass_cluster_free ( cluster_raw) ;
1121+ }
1122+ }
1123+
9841124 #[ test]
9851125 #[ ntest:: timeout( 100 ) ]
9861126 fn test_load_balancing_config ( ) {
0 commit comments