From ecfe3b44c5a5976feecf3b4f457400f982fb25c4 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:00:34 +0200 Subject: [PATCH 01/33] value: move is_type_compatible to value module --- scylla-rust-wrapper/src/binding.rs | 6 ------ scylla-rust-wrapper/src/tuple.rs | 4 ++-- scylla-rust-wrapper/src/user_type.rs | 7 +++---- scylla-rust-wrapper/src/value.rs | 15 +++++++++++++++ 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index 4c1d37e4..f2339604 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -47,12 +47,6 @@ //! It can be used for binding named parameter in CassStatement or field by name in CassUserType. //! * Functions from make_appender don't take any extra argument, as they are for use by CassCollection //! functions - values are appended to collection. -use crate::{cass_types::CassDataType, value::CassCqlValue}; - -pub fn is_compatible_type(_data_type: &CassDataType, _value: &Option) -> bool { - // TODO: cppdriver actually checks types. - true -} macro_rules! make_index_binder { ($this:ty, $consume_v:expr, $fn_by_idx:ident, $e:expr, [$($arg:ident @ $t:ty), *]) => { diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 941a2f4c..302f4d21 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -1,8 +1,8 @@ use crate::argconv::*; -use crate::binding; use crate::cass_error::CassError; use crate::cass_types::CassDataType; use crate::types::*; +use crate::value; use crate::value::CassCqlValue; use std::sync::Arc; @@ -37,7 +37,7 @@ impl CassTuple { } if let Some(inner_types) = self.get_types() { - if !binding::is_compatible_type(&inner_types[index], &v) { + if !value::is_type_compatible(&v, &inner_types[index]) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; } } diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index 5082b6ec..f7900435 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -1,9 +1,8 @@ -use crate::argconv::*; -use crate::binding::is_compatible_type; use crate::cass_error::CassError; use crate::cass_types::CassDataType; use crate::types::*; use crate::value::CassCqlValue; +use crate::{argconv::*, value}; use std::os::raw::c_char; use std::sync::Arc; @@ -20,7 +19,7 @@ impl CassUserType { if index >= self.field_values.len() { return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; } - if !is_compatible_type(&self.data_type.get_udt_type().field_types[index].1, &value) { + if !value::is_type_compatible(&value, &self.data_type.get_udt_type().field_types[index].1) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; } self.field_values[index] = value; @@ -37,7 +36,7 @@ impl CassUserType { if index >= self.field_values.len() { return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; } - if !is_compatible_type(field_type, &value) { + if !value::is_type_compatible(&value, field_type) { return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; } self.field_values[index].clone_from(&value); diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index ffd02feb..400e37da 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -17,6 +17,8 @@ use scylla::{ }; use uuid::Uuid; +use crate::cass_types::CassDataType; + /// A narrower version of rust driver's CqlValue. /// /// cpp-driver's API allows to map single rust type to @@ -60,6 +62,19 @@ pub enum CassCqlValue { // TODO: custom (?), duration and decimal } +pub fn is_type_compatible(value: &Option, typ: &CassDataType) -> bool { + match value { + Some(v) => v.is_type_compatible(typ), + None => true, + } +} + +impl CassCqlValue { + pub fn is_type_compatible(&self, _typ: &CassDataType) -> bool { + true + } +} + impl SerializeValue for CassCqlValue { fn serialize<'b>( &self, From 970b711cc33ead6d859edaf39ab2609f2bd8a5c7 Mon Sep 17 00:00:00 2001 From: muzarski Date: Thu, 18 Jul 2024 13:42:40 +0200 Subject: [PATCH 02/33] value: prepare test for simple typechecks --- scylla-rust-wrapper/src/cass_types.rs | 4 +- scylla-rust-wrapper/src/value.rs | 70 +++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 6bd51192..6dd5a2ed 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -15,7 +15,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_types.rs")); include!(concat!(env!("OUT_DIR"), "/cppdriver_data_query_error.rs")); include!(concat!(env!("OUT_DIR"), "/cppdriver_batch_types.rs")); -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct UDTDataType { // Vec to preserve the order of types pub field_types: Vec<(String, Arc)>, @@ -95,7 +95,7 @@ impl Default for UDTDataType { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum CassDataType { Value(CassValueType), UDT(UDTDataType), diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 400e37da..477a0843 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -297,3 +297,73 @@ fn serialize_udt<'b>( .finish() .map_err(|_| mk_ser_err::(BuiltinSerializationErrorKind::SizeOverflow)) } + +#[cfg(test)] +mod tests { + use crate::{ + cass_types::{CassDataType, CassValueType}, + value::{is_type_compatible, CassCqlValue}, + }; + + fn all_value_data_types() -> [CassDataType; 26] { + let from = |v_typ: CassValueType| CassDataType::Value(v_typ); + + [ + from(CassValueType::CASS_VALUE_TYPE_TINY_INT), + from(CassValueType::CASS_VALUE_TYPE_SMALL_INT), + from(CassValueType::CASS_VALUE_TYPE_INT), + from(CassValueType::CASS_VALUE_TYPE_BIGINT), + from(CassValueType::CASS_VALUE_TYPE_COUNTER), + from(CassValueType::CASS_VALUE_TYPE_TIME), + from(CassValueType::CASS_VALUE_TYPE_TIMESTAMP), + from(CassValueType::CASS_VALUE_TYPE_FLOAT), + from(CassValueType::CASS_VALUE_TYPE_DOUBLE), + from(CassValueType::CASS_VALUE_TYPE_BOOLEAN), + from(CassValueType::CASS_VALUE_TYPE_TEXT), + from(CassValueType::CASS_VALUE_TYPE_VARCHAR), + from(CassValueType::CASS_VALUE_TYPE_ASCII), + from(CassValueType::CASS_VALUE_TYPE_BLOB), + from(CassValueType::CASS_VALUE_TYPE_UUID), + from(CassValueType::CASS_VALUE_TYPE_TIMEUUID), + from(CassValueType::CASS_VALUE_TYPE_DATE), + from(CassValueType::CASS_VALUE_TYPE_INET), + from(CassValueType::CASS_VALUE_TYPE_DURATION), + from(CassValueType::CASS_VALUE_TYPE_DECIMAL), + from(CassValueType::CASS_VALUE_TYPE_VARINT), + from(CassValueType::CASS_VALUE_TYPE_TUPLE), + from(CassValueType::CASS_VALUE_TYPE_LIST), + from(CassValueType::CASS_VALUE_TYPE_SET), + from(CassValueType::CASS_VALUE_TYPE_MAP), + from(CassValueType::CASS_VALUE_TYPE_UDT), + ] + } + + #[test] + fn typecheck_simple_test() { + struct TestCase { + value: Option, + compatible_types: Vec, + } + + let test_cases = [ + // Null -> all types + TestCase { + value: None, + compatible_types: all_value_data_types().to_vec(), + }, + ]; + let all_simple_types = all_value_data_types(); + + for case in test_cases { + for typ in all_simple_types.iter() { + let result = is_type_compatible(&case.value, typ); + let expected = case.compatible_types.iter().any(|t| t == typ); + assert_eq!( + expected, result, + "Typecheck test for value {:?} and type {:?} failed. Expected result for the typecheck: {}", + case.value, typ, expected, + ); + } + } + } +} From 6bafea805894f44f6298681b279710a081377543 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:08:27 +0200 Subject: [PATCH 03/33] typecheck: i8 --- scylla-rust-wrapper/src/value.rs | 34 +++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 477a0843..442b4055 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -17,7 +17,7 @@ use scylla::{ }; use uuid::Uuid; -use crate::cass_types::CassDataType; +use crate::cass_types::{CassDataType, CassValueType}; /// A narrower version of rust driver's CqlValue. /// @@ -70,8 +70,30 @@ pub fn is_type_compatible(value: &Option, typ: &CassDataType) -> b } impl CassCqlValue { - pub fn is_type_compatible(&self, _typ: &CassDataType) -> bool { - true + pub fn is_type_compatible(&self, typ: &CassDataType) -> bool { + match self { + CassCqlValue::TinyInt(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TINY_INT + } + CassCqlValue::SmallInt(_) => todo!(), + CassCqlValue::Int(_) => todo!(), + CassCqlValue::BigInt(_) => todo!(), + CassCqlValue::Float(_) => todo!(), + CassCqlValue::Double(_) => todo!(), + CassCqlValue::Boolean(_) => todo!(), + CassCqlValue::Text(_) => todo!(), + CassCqlValue::Blob(_) => todo!(), + CassCqlValue::Uuid(_) => todo!(), + CassCqlValue::Date(_) => todo!(), + CassCqlValue::Inet(_) => todo!(), + CassCqlValue::Duration(_) => todo!(), + CassCqlValue::Decimal(_) => todo!(), + CassCqlValue::Tuple(_) => todo!(), + CassCqlValue::List(_) => todo!(), + CassCqlValue::Map(_) => todo!(), + CassCqlValue::Set(_) => todo!(), + CassCqlValue::UserDefinedType { .. } => todo!(), + } } } @@ -340,6 +362,7 @@ mod tests { #[test] fn typecheck_simple_test() { + let from = |v_typ: CassValueType| CassDataType::Value(v_typ); struct TestCase { value: Option, compatible_types: Vec, @@ -351,6 +374,11 @@ mod tests { value: None, compatible_types: all_value_data_types().to_vec(), }, + // i8 -> tinyint + TestCase { + value: Some(CassCqlValue::TinyInt(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_TINY_INT)], + }, ]; let all_simple_types = all_value_data_types(); From d57ce6c195b7649d4e7ccdcdb596b5b402dd2013 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:08:59 +0200 Subject: [PATCH 04/33] typecheck: i16 --- scylla-rust-wrapper/src/value.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 442b4055..368c380a 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -75,7 +75,9 @@ impl CassCqlValue { CassCqlValue::TinyInt(_) => { typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TINY_INT } - CassCqlValue::SmallInt(_) => todo!(), + CassCqlValue::SmallInt(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SMALL_INT + } CassCqlValue::Int(_) => todo!(), CassCqlValue::BigInt(_) => todo!(), CassCqlValue::Float(_) => todo!(), @@ -379,6 +381,11 @@ mod tests { value: Some(CassCqlValue::TinyInt(Default::default())), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_TINY_INT)], }, + // i16 -> smallint + TestCase { + value: Some(CassCqlValue::SmallInt(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_SMALL_INT)], + }, ]; let all_simple_types = all_value_data_types(); From 75e96dfe12d1531c2a45a8125c91822e24ba32b8 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:09:44 +0200 Subject: [PATCH 05/33] typecheck: i32 --- scylla-rust-wrapper/src/value.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 368c380a..f3a34257 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -78,7 +78,7 @@ impl CassCqlValue { CassCqlValue::SmallInt(_) => { typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SMALL_INT } - CassCqlValue::Int(_) => todo!(), + CassCqlValue::Int(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INT, CassCqlValue::BigInt(_) => todo!(), CassCqlValue::Float(_) => todo!(), CassCqlValue::Double(_) => todo!(), @@ -386,6 +386,11 @@ mod tests { value: Some(CassCqlValue::SmallInt(Default::default())), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_SMALL_INT)], }, + // i32 -> int + TestCase { + value: Some(CassCqlValue::Int(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_INT)], + }, ]; let all_simple_types = all_value_data_types(); From 4553bf9cce37ce0a54b59a95b058b51c5d28b1c9 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:14:33 +0200 Subject: [PATCH 06/33] typecheck: i64 --- scylla-rust-wrapper/src/value.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index f3a34257..fcb9e11c 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -79,7 +79,15 @@ impl CassCqlValue { typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SMALL_INT } CassCqlValue::Int(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INT, - CassCqlValue::BigInt(_) => todo!(), + CassCqlValue::BigInt(_) => { + matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_BIGINT + | CassValueType::CASS_VALUE_TYPE_COUNTER + | CassValueType::CASS_VALUE_TYPE_TIMESTAMP + | CassValueType::CASS_VALUE_TYPE_TIME + ) + } CassCqlValue::Float(_) => todo!(), CassCqlValue::Double(_) => todo!(), CassCqlValue::Boolean(_) => todo!(), @@ -391,6 +399,16 @@ mod tests { value: Some(CassCqlValue::Int(Default::default())), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_INT)], }, + // i64 -> bigint/counter/time/timestamp + TestCase { + value: Some(CassCqlValue::BigInt(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_BIGINT), + from(CassValueType::CASS_VALUE_TYPE_COUNTER), + from(CassValueType::CASS_VALUE_TYPE_TIME), + from(CassValueType::CASS_VALUE_TYPE_TIMESTAMP), + ], + }, ]; let all_simple_types = all_value_data_types(); From a1df63557ca92dcc8982d5fe3f95be25d16bd06a Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:15:03 +0200 Subject: [PATCH 07/33] typecheck: f32 --- scylla-rust-wrapper/src/value.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index fcb9e11c..f994a54c 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -88,7 +88,7 @@ impl CassCqlValue { | CassValueType::CASS_VALUE_TYPE_TIME ) } - CassCqlValue::Float(_) => todo!(), + CassCqlValue::Float(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_FLOAT, CassCqlValue::Double(_) => todo!(), CassCqlValue::Boolean(_) => todo!(), CassCqlValue::Text(_) => todo!(), @@ -409,6 +409,11 @@ mod tests { from(CassValueType::CASS_VALUE_TYPE_TIMESTAMP), ], }, + // f32 -> float + TestCase { + value: Some(CassCqlValue::Float(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_FLOAT)], + }, ]; let all_simple_types = all_value_data_types(); From 8fab3056a95396add293f9f2bda72d6b9ca4099e Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:15:30 +0200 Subject: [PATCH 08/33] typecheck: f64 --- scylla-rust-wrapper/src/value.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index f994a54c..9a0b60b0 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -89,7 +89,9 @@ impl CassCqlValue { ) } CassCqlValue::Float(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_FLOAT, - CassCqlValue::Double(_) => todo!(), + CassCqlValue::Double(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DOUBLE + } CassCqlValue::Boolean(_) => todo!(), CassCqlValue::Text(_) => todo!(), CassCqlValue::Blob(_) => todo!(), @@ -414,6 +416,11 @@ mod tests { value: Some(CassCqlValue::Float(Default::default())), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_FLOAT)], }, + // f64 -> double + TestCase { + value: Some(CassCqlValue::Double(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DOUBLE)], + }, ]; let all_simple_types = all_value_data_types(); From f8c7f3cd5c98638d59668d996405279470a0ec19 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:15:55 +0200 Subject: [PATCH 09/33] typecheck: bool --- scylla-rust-wrapper/src/value.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 9a0b60b0..6c8fdde7 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -92,7 +92,9 @@ impl CassCqlValue { CassCqlValue::Double(_) => { typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DOUBLE } - CassCqlValue::Boolean(_) => todo!(), + CassCqlValue::Boolean(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_BOOLEAN + } CassCqlValue::Text(_) => todo!(), CassCqlValue::Blob(_) => todo!(), CassCqlValue::Uuid(_) => todo!(), @@ -421,6 +423,11 @@ mod tests { value: Some(CassCqlValue::Double(Default::default())), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DOUBLE)], }, + // bool -> boolean + TestCase { + value: Some(CassCqlValue::Boolean(Default::default())), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_BOOLEAN)], + }, ]; let all_simple_types = all_value_data_types(); From e91a44b81631137293c45dae3919033fbe0fd811 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:18:10 +0200 Subject: [PATCH 10/33] typecheck: String Cpp-driver allows to bind c/cpp strings to: - string CQL types => ASCII/TEXT/VARCHAR - bytes CQL types => BLOB/VARINT It also allows to bind them to custom values, however we don't plan to support them in cpp-rust-driver. --- scylla-rust-wrapper/src/value.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 6c8fdde7..e78ad0a7 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -95,7 +95,16 @@ impl CassCqlValue { CassCqlValue::Boolean(_) => { typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_BOOLEAN } - CassCqlValue::Text(_) => todo!(), + CassCqlValue::Text(_) => { + matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_TEXT + | CassValueType::CASS_VALUE_TYPE_VARCHAR + | CassValueType::CASS_VALUE_TYPE_ASCII + | CassValueType::CASS_VALUE_TYPE_BLOB + | CassValueType::CASS_VALUE_TYPE_VARINT + ) + } CassCqlValue::Blob(_) => todo!(), CassCqlValue::Uuid(_) => todo!(), CassCqlValue::Date(_) => todo!(), @@ -428,6 +437,16 @@ mod tests { value: Some(CassCqlValue::Boolean(Default::default())), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_BOOLEAN)], }, + TestCase { + value: Some(CassCqlValue::Text(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_TEXT), + from(CassValueType::CASS_VALUE_TYPE_VARCHAR), + from(CassValueType::CASS_VALUE_TYPE_ASCII), + from(CassValueType::CASS_VALUE_TYPE_BLOB), + from(CassValueType::CASS_VALUE_TYPE_VARINT), + ], + }, ]; let all_simple_types = all_value_data_types(); From 1d682c88472ea7bbbccf5c0e71ea38a506aaba84 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:20:04 +0200 Subject: [PATCH 11/33] typecheck: bytes (Vec) cpp-driver allows to bind bytes to custom values as well, but we do not plan to support custom values in cpp-rust-driver. --- scylla-rust-wrapper/src/value.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index e78ad0a7..ee84f5f4 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -105,7 +105,10 @@ impl CassCqlValue { | CassValueType::CASS_VALUE_TYPE_VARINT ) } - CassCqlValue::Blob(_) => todo!(), + CassCqlValue::Blob(_) => matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_BLOB | CassValueType::CASS_VALUE_TYPE_VARINT + ), CassCqlValue::Uuid(_) => todo!(), CassCqlValue::Date(_) => todo!(), CassCqlValue::Inet(_) => todo!(), @@ -447,6 +450,14 @@ mod tests { from(CassValueType::CASS_VALUE_TYPE_VARINT), ], }, + // Vec -> blob/varint + TestCase { + value: Some(CassCqlValue::Blob(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_BLOB), + from(CassValueType::CASS_VALUE_TYPE_VARINT), + ], + }, ]; let all_simple_types = all_value_data_types(); From 9ee463c3ffe9bc2cf1827d404bf7fd27b516e250 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:24:44 +0200 Subject: [PATCH 12/33] typecheck: uuid --- scylla-rust-wrapper/src/value.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index ee84f5f4..ab117862 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -109,7 +109,10 @@ impl CassCqlValue { typ.get_value_type(), CassValueType::CASS_VALUE_TYPE_BLOB | CassValueType::CASS_VALUE_TYPE_VARINT ), - CassCqlValue::Uuid(_) => todo!(), + CassCqlValue::Uuid(_) => matches!( + typ.get_value_type(), + CassValueType::CASS_VALUE_TYPE_UUID | CassValueType::CASS_VALUE_TYPE_TIMEUUID + ), CassCqlValue::Date(_) => todo!(), CassCqlValue::Inet(_) => todo!(), CassCqlValue::Duration(_) => todo!(), @@ -458,6 +461,14 @@ mod tests { from(CassValueType::CASS_VALUE_TYPE_VARINT), ], }, + // uuid -> uuid/timeuuid + TestCase { + value: Some(CassCqlValue::Uuid(Default::default())), + compatible_types: vec![ + from(CassValueType::CASS_VALUE_TYPE_UUID), + from(CassValueType::CASS_VALUE_TYPE_TIMEUUID), + ], + }, ]; let all_simple_types = all_value_data_types(); From ff6ae5bcf798a5410f4edf52305c5d253cd7bea9 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:25:31 +0200 Subject: [PATCH 13/33] typecheck: u32 (date) --- scylla-rust-wrapper/src/value.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index ab117862..59855e24 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -113,7 +113,7 @@ impl CassCqlValue { typ.get_value_type(), CassValueType::CASS_VALUE_TYPE_UUID | CassValueType::CASS_VALUE_TYPE_TIMEUUID ), - CassCqlValue::Date(_) => todo!(), + CassCqlValue::Date(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DATE, CassCqlValue::Inet(_) => todo!(), CassCqlValue::Duration(_) => todo!(), CassCqlValue::Decimal(_) => todo!(), @@ -351,6 +351,8 @@ fn serialize_udt<'b>( #[cfg(test)] mod tests { + use scylla::frame::value::CqlDate; + use crate::{ cass_types::{CassDataType, CassValueType}, value::{is_type_compatible, CassCqlValue}, @@ -469,6 +471,11 @@ mod tests { from(CassValueType::CASS_VALUE_TYPE_TIMEUUID), ], }, + // u32 -> date + TestCase { + value: Some(CassCqlValue::Date(CqlDate(Default::default()))), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DATE)], + }, ]; let all_simple_types = all_value_data_types(); From 23a5ad91aa1c5872e320fda4d8ab34e44292d3bc Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 11:26:04 +0200 Subject: [PATCH 14/33] typecheck: inet --- scylla-rust-wrapper/src/value.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 59855e24..17031073 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -114,7 +114,7 @@ impl CassCqlValue { CassValueType::CASS_VALUE_TYPE_UUID | CassValueType::CASS_VALUE_TYPE_TIMEUUID ), CassCqlValue::Date(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DATE, - CassCqlValue::Inet(_) => todo!(), + CassCqlValue::Inet(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INET, CassCqlValue::Duration(_) => todo!(), CassCqlValue::Decimal(_) => todo!(), CassCqlValue::Tuple(_) => todo!(), @@ -351,6 +351,8 @@ fn serialize_udt<'b>( #[cfg(test)] mod tests { + use std::net::Ipv4Addr; + use scylla::frame::value::CqlDate; use crate::{ @@ -476,6 +478,13 @@ mod tests { value: Some(CassCqlValue::Date(CqlDate(Default::default()))), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DATE)], }, + // IpAddr -> inet + TestCase { + value: Some(CassCqlValue::Inet(std::net::IpAddr::V4( + Ipv4Addr::LOCALHOST, + ))), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_INET)], + }, ]; let all_simple_types = all_value_data_types(); From 506185e2fa1acbe9120f48bee3c32a679140cc25 Mon Sep 17 00:00:00 2001 From: muzarski Date: Mon, 5 Aug 2024 09:22:47 +0200 Subject: [PATCH 15/33] typecheck: CqlDuration --- scylla-rust-wrapper/src/value.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 17031073..32e61d95 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -115,7 +115,9 @@ impl CassCqlValue { ), CassCqlValue::Date(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DATE, CassCqlValue::Inet(_) => typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_INET, - CassCqlValue::Duration(_) => todo!(), + CassCqlValue::Duration(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION + } CassCqlValue::Decimal(_) => todo!(), CassCqlValue::Tuple(_) => todo!(), CassCqlValue::List(_) => todo!(), @@ -353,7 +355,7 @@ fn serialize_udt<'b>( mod tests { use std::net::Ipv4Addr; - use scylla::frame::value::CqlDate; + use scylla::frame::value::{CqlDate, CqlDuration}; use crate::{ cass_types::{CassDataType, CassValueType}, @@ -485,6 +487,15 @@ mod tests { ))), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_INET)], }, + // CqlDuration -> duration + TestCase { + value: Some(CassCqlValue::Duration(CqlDuration { + months: 0, + days: 0, + nanoseconds: 0, + })), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DURATION)], + }, ]; let all_simple_types = all_value_data_types(); From 3deb90c9addacb2e8fd843b5347dbc748ff5611c Mon Sep 17 00:00:00 2001 From: muzarski Date: Tue, 6 Aug 2024 20:22:24 +0200 Subject: [PATCH 16/33] typecheck: CqlDecimal Note: I believe that I previously mentioned that cpp-driver maps `CassDecimal` to varint as well, but this is not true. --- scylla-rust-wrapper/src/value.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 32e61d95..d3d06830 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -118,7 +118,9 @@ impl CassCqlValue { CassCqlValue::Duration(_) => { typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION } - CassCqlValue::Decimal(_) => todo!(), + CassCqlValue::Decimal(_) => { + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DECIMAL + } CassCqlValue::Tuple(_) => todo!(), CassCqlValue::List(_) => todo!(), CassCqlValue::Map(_) => todo!(), @@ -355,7 +357,7 @@ fn serialize_udt<'b>( mod tests { use std::net::Ipv4Addr; - use scylla::frame::value::{CqlDate, CqlDuration}; + use scylla::frame::value::{CqlDate, CqlDecimal, CqlDuration}; use crate::{ cass_types::{CassDataType, CassValueType}, @@ -496,6 +498,13 @@ mod tests { })), compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DURATION)], }, + // CqlDecimal -> decimal + TestCase { + value: Some(CassCqlValue::Decimal( + CqlDecimal::from_signed_be_bytes_slice_and_exponent(&[], 0), + )), + compatible_types: vec![from(CassValueType::CASS_VALUE_TYPE_DECIMAL)], + }, ]; let all_simple_types = all_value_data_types(); From 5a513cf4cdab184ac61ddd19d3a0b577d047eb47 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 15:19:33 +0200 Subject: [PATCH 17/33] types: introduce typecheck_equals method This method will be used by `value::is_type_compatible` for compound types, i.e. tuples, collections and udts. Corresponding CassCqlValue variants will hold info about the DataType of value. To perform the typecheck, we will simply compare two types - type of the value with the type provided to `is_type_compatible` function. --- scylla-rust-wrapper/src/cass_types.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 6dd5a2ed..1c350842 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -116,6 +116,25 @@ pub enum CassDataType { Custom(String), } +impl CassDataType { + /// Checks for equality during typechecks. + /// + /// This takes into account the fact that tuples/collections may be untyped. + pub fn typecheck_equals(&self, other: &CassDataType) -> bool { + match self { + CassDataType::Value(t) => *t == other.get_value_type(), + CassDataType::UDT(_) => todo!(), + CassDataType::List { .. } => todo!(), + CassDataType::Set { .. } => todo!(), + CassDataType::Map { .. } => todo!(), + CassDataType::Tuple(_) => todo!(), + CassDataType::Custom(_) => { + unimplemented!("Cpp-rust-driver does not support custom types!") + } + } + } +} + impl From for CassValueType { fn from(native_type: NativeType) -> CassValueType { match native_type { From ab55c6ab42f264f3bad920c8ccdc49f2f5b5f8b2 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 15:26:06 +0200 Subject: [PATCH 18/33] tuple_type: typecheck_equals --- scylla-rust-wrapper/src/cass_types.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 1c350842..0f6995e2 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -112,6 +112,7 @@ pub enum CassDataType { val_type: Option>, frozen: bool, }, + // Empty vector stands for untyped tuple. Tuple(Vec>), Custom(String), } @@ -127,7 +128,23 @@ impl CassDataType { CassDataType::List { .. } => todo!(), CassDataType::Set { .. } => todo!(), CassDataType::Map { .. } => todo!(), - CassDataType::Tuple(_) => todo!(), + CassDataType::Tuple(sub) => match other { + CassDataType::Tuple(other_sub) => { + // If either of tuples is untyped, skip the typecheck for subtypes. + if sub.is_empty() || other_sub.is_empty() { + return true; + } + + // If both are non-empty, check for subtypes equality. + if sub.len() != other_sub.len() { + return false; + } + sub.iter() + .zip(other_sub.iter()) + .all(|(typ, other_typ)| typ.typecheck_equals(other_typ)) + } + _ => false, + }, CassDataType::Custom(_) => { unimplemented!("Cpp-rust-driver does not support custom types!") } From 4e1b9c503ad30789fde5007d4915f52a61079f7e Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 15:35:56 +0200 Subject: [PATCH 19/33] udt_type: typecheck_equals --- scylla-rust-wrapper/src/cass_types.rs | 41 ++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 0f6995e2..1bbf3b67 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -87,6 +87,42 @@ impl UDTDataType { pub fn get_field_by_index(&self, index: usize) -> Option<&Arc> { self.field_types.get(index).map(|(_, b)| b) } + + fn typecheck_equals(&self, other: &UDTDataType) -> bool { + // See: https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L354-L386 + + if !any_string_empty_or_both_equal(&self.keyspace, &other.keyspace) { + return false; + } + if !any_string_empty_or_both_equal(&self.name, &other.name) { + return false; + } + + // A comment from cpp-driver: + //// UDT's can be considered equal as long as the mutual first fields shared + //// between them are equal. UDT's are append only as far as fields go, so a + //// newer 'version' of the UDT data type after a schema change event should be + //// treated as equivalent in this scenario, by simply looking at the first N + //// mutual fields they should share. + // + // Iterator returned from zip() is perfect for checking the first mutual fields. + for (field, other_field) in self.field_types.iter().zip(other.field_types.iter()) { + // Compare field names. + if field.0 != other_field.0 { + return false; + } + // Compare field types. + if !field.1.typecheck_equals(&other_field.1) { + return false; + } + } + + true + } +} + +fn any_string_empty_or_both_equal(s1: &str, s2: &str) -> bool { + s1.is_empty() || s2.is_empty() || s1 == s2 } impl Default for UDTDataType { @@ -124,7 +160,10 @@ impl CassDataType { pub fn typecheck_equals(&self, other: &CassDataType) -> bool { match self { CassDataType::Value(t) => *t == other.get_value_type(), - CassDataType::UDT(_) => todo!(), + CassDataType::UDT(udt) => match other { + CassDataType::UDT(other_udt) => udt.typecheck_equals(other_udt), + _ => false, + }, CassDataType::List { .. } => todo!(), CassDataType::Set { .. } => todo!(), CassDataType::Map { .. } => todo!(), From 761b0806b2157e3d49c4add9ab1fbee3ec2def84 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 16:01:10 +0200 Subject: [PATCH 20/33] list/set type: typecheck_equals --- scylla-rust-wrapper/src/cass_types.rs | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 1bbf3b67..16cd5eb4 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -136,10 +136,12 @@ pub enum CassDataType { Value(CassValueType), UDT(UDTDataType), List { + // None stands for untyped list. typ: Option>, frozen: bool, }, Set { + // None stands for untyped set. typ: Option>, frozen: bool, }, @@ -164,8 +166,21 @@ impl CassDataType { CassDataType::UDT(other_udt) => udt.typecheck_equals(other_udt), _ => false, }, - CassDataType::List { .. } => todo!(), - CassDataType::Set { .. } => todo!(), + CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => match other { + CassDataType::List { typ: other_typ, .. } + | CassDataType::Set { typ: other_typ, .. } => { + // If one of them is list, and the other is set, fail the typecheck. + if self.get_value_type() != other.get_value_type() { + return false; + } + match (typ, other_typ) { + // One of them is untyped, skip the typecheck for subtype. + (None, _) | (_, None) => true, + (Some(typ), Some(other_typ)) => typ.typecheck_equals(other_typ), + } + } + _ => false, + }, CassDataType::Map { .. } => todo!(), CassDataType::Tuple(sub) => match other { CassDataType::Tuple(other_sub) => { From c3a2632b9513872e490a37819c2958b7573fe8c9 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 16:18:53 +0200 Subject: [PATCH 21/33] map type: typecheck_equals --- scylla-rust-wrapper/src/cass_types.rs | 30 ++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 16cd5eb4..12da2384 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -146,6 +146,8 @@ pub enum CassDataType { frozen: bool, }, Map { + // None, None stands for untyped map. + // Some, None stands for a map with an untyped value type. key_type: Option>, val_type: Option>, frozen: bool, @@ -181,7 +183,33 @@ impl CassDataType { } _ => false, }, - CassDataType::Map { .. } => todo!(), + CassDataType::Map { + key_type: k, + val_type: v, + .. + } => match other { + CassDataType::Map { + key_type: k_other, + val_type: v_other, + .. + } => match ((k, v), (k_other, v_other)) { + // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218 + // In cpp-driver the types are held in a vector. + // The logic is following: + + // If either of vectors is empty, skip the typecheck. + ((None, None), _) => true, + (_, (None, None)) => true, + + // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes. + ((Some(k), None), (Some(k_other), None)) => k.typecheck_equals(k_other), + ((Some(k), Some(v)), (Some(k_other), Some(v_other))) => { + k.typecheck_equals(k_other) && v.typecheck_equals(v_other) + } + _ => false, + }, + _ => false, + }, CassDataType::Tuple(sub) => match other { CassDataType::Tuple(other_sub) => { // If either of tuples is untyped, skip the typecheck for subtypes. From cc0b34459010eae275e540043ea9337229f8ea2b Mon Sep 17 00:00:00 2001 From: muzarski Date: Thu, 18 Jul 2024 14:32:37 +0200 Subject: [PATCH 22/33] value: prepare test for complex typechecks --- scylla-rust-wrapper/src/value.rs | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index d3d06830..91a4dccd 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -355,7 +355,7 @@ fn serialize_udt<'b>( #[cfg(test)] mod tests { - use std::net::Ipv4Addr; + use std::{net::Ipv4Addr, sync::Arc}; use scylla::frame::value::{CqlDate, CqlDecimal, CqlDuration}; @@ -520,4 +520,34 @@ mod tests { } } } + + #[test] + fn typecheck_complex_test() { + struct TestCase { + value: CassCqlValue, + compatible_types: Vec>, + incompatible_types: Vec>, + } + + let run_test_cases = |test_cases: &[TestCase]| { + for case in test_cases { + for typ in case.compatible_types.iter() { + assert!( + case.value.is_type_compatible(typ), + "Typecheck failed, when it should pass. Value: {:?}, Type: {:?}", + case.value, + typ + ); + } + for typ in case.incompatible_types.iter() { + assert!( + !case.value.is_type_compatible(typ), + "Typecheck passed, when it should fail. Value: {:?}, Type: {:?}", + case.value, + typ + ) + } + } + }; + } } From 91341ecd0b76a51f7f04412b029c55b7d353567f Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 17:48:46 +0200 Subject: [PATCH 23/33] value: implement is_type_compatible for tuples --- scylla-rust-wrapper/src/tuple.rs | 5 +- scylla-rust-wrapper/src/value.rs | 124 ++++++++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 5 deletions(-) diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 302f4d21..dd8563ae 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -50,7 +50,10 @@ impl CassTuple { impl From<&CassTuple> for CassCqlValue { fn from(tuple: &CassTuple) -> Self { - CassCqlValue::Tuple(tuple.items.clone()) + CassCqlValue::Tuple { + data_type: tuple.data_type.clone(), + fields: tuple.items.clone(), + } } } diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 91a4dccd..56878d45 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -1,4 +1,4 @@ -use std::{convert::TryInto, net::IpAddr}; +use std::{convert::TryInto, net::IpAddr, sync::Arc}; use scylla::{ frame::{ @@ -47,7 +47,10 @@ pub enum CassCqlValue { Inet(IpAddr), Duration(CqlDuration), Decimal(CqlDecimal), - Tuple(Vec>), + Tuple { + data_type: Option>, + fields: Vec>, + }, List(Vec), Map(Vec<(CassCqlValue, CassCqlValue)>), Set(Vec), @@ -121,7 +124,13 @@ impl CassCqlValue { CassCqlValue::Decimal(_) => { typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_DECIMAL } - CassCqlValue::Tuple(_) => todo!(), + CassCqlValue::Tuple { data_type, .. } => { + if let Some(dt) = data_type { + return dt.typecheck_equals(typ); + } + // Untyped tuple. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TUPLE + } CassCqlValue::List(_) => todo!(), CassCqlValue::Map(_) => todo!(), CassCqlValue::Set(_) => todo!(), @@ -198,7 +207,7 @@ impl CassCqlValue { CassCqlValue::Decimal(v) => { ::serialize(v, &ColumnType::Decimal, writer) } - CassCqlValue::Tuple(fields) => serialize_tuple_like(fields.iter(), writer), + CassCqlValue::Tuple { fields, .. } => serialize_tuple_like(fields.iter(), writer), CassCqlValue::List(l) => serialize_sequence(l.len(), l.iter(), writer), CassCqlValue::Map(m) => { serialize_mapping(m.len(), m.iter().map(|p| (&p.0, &p.1)), writer) @@ -549,5 +558,112 @@ mod tests { } } }; + + // Let's make some types accessible for all test cases. + // To make sure that e.g. Tuple against UDT typecheck fails. + let data_type_float = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_FLOAT)); + let data_type_int = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_INT)); + let data_type_bool = Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN)); + let data_type_tuple = Arc::new(CassDataType::Tuple(vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + ])); + + // TUPLES + { + let data_type_untyped_tuple = Arc::new(CassDataType::Tuple(vec![])); + let data_type_small_tuple = Arc::new(CassDataType::Tuple(vec![data_type_bool.clone()])); + let data_type_nested_tuple = Arc::new(CassDataType::Tuple(vec![ + data_type_small_tuple.clone(), + data_type_int.clone(), + data_type_tuple.clone(), + ])); + let data_type_nested_untyped_tuple = Arc::new(CassDataType::Tuple(vec![ + data_type_untyped_tuple.clone(), + data_type_int.clone(), + data_type_untyped_tuple.clone(), + ])); + + let test_cases = &[ + // Untyped tuple -> created via `cass_tuple_new` + TestCase { + value: CassCqlValue::Tuple { + data_type: None, + fields: vec![], + }, + compatible_types: vec![ + data_type_untyped_tuple.clone(), + data_type_small_tuple.clone(), + data_type_tuple.clone(), + data_type_nested_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + ], + }, + // Untyped tuple -> used created an untyped tuple data type, and then + // created a tuple value via `cass_tuple_new_from_data_type`. + TestCase { + value: CassCqlValue::Tuple { + data_type: Some(data_type_untyped_tuple.clone()), + fields: vec![], + }, + compatible_types: vec![ + data_type_untyped_tuple.clone(), + data_type_small_tuple.clone(), + data_type_tuple.clone(), + data_type_nested_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + ], + }, + // Fully typed tuple. + TestCase { + value: CassCqlValue::Tuple { + data_type: Some(data_type_tuple.clone()), + fields: vec![], + }, + compatible_types: vec![ + data_type_tuple.clone(), + data_type_untyped_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_small_tuple.clone(), + data_type_nested_tuple.clone(), + data_type_nested_tuple.clone(), + ], + }, + // Nested tuple. + TestCase { + value: CassCqlValue::Tuple { + data_type: Some(data_type_nested_tuple.clone()), + fields: vec![], + }, + compatible_types: vec![ + data_type_nested_tuple.clone(), + data_type_untyped_tuple.clone(), + data_type_nested_untyped_tuple.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_small_tuple.clone(), + ], + }, + ]; + + run_test_cases(test_cases); + } } } From 371e8dc474c4a8d20990f9fbd67f1a15f61e51c9 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 17:55:43 +0200 Subject: [PATCH 24/33] value: implement is_type_compatible for udt value --- scylla-rust-wrapper/src/user_type.rs | 3 +- scylla-rust-wrapper/src/value.rs | 76 ++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index f7900435..e723c7b7 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -54,8 +54,7 @@ impl CassUserType { impl From<&CassUserType> for CassCqlValue { fn from(user_type: &CassUserType) -> Self { CassCqlValue::UserDefinedType { - keyspace: user_type.data_type.get_udt_type().keyspace.clone(), - type_name: user_type.data_type.get_udt_type().name.clone(), + data_type: user_type.data_type.clone(), fields: user_type .field_values .iter() diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 56878d45..d4dd1b26 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -55,8 +55,7 @@ pub enum CassCqlValue { Map(Vec<(CassCqlValue, CassCqlValue)>), Set(Vec), UserDefinedType { - keyspace: String, - type_name: String, + data_type: Arc, /// Order of `fields` vector must match the order of fields as defined in the UDT. The /// driver does not check it by itself, so incorrect data will be written if the order is /// wrong. @@ -134,7 +133,7 @@ impl CassCqlValue { CassCqlValue::List(_) => todo!(), CassCqlValue::Map(_) => todo!(), CassCqlValue::Set(_) => todo!(), - CassCqlValue::UserDefinedType { .. } => todo!(), + CassCqlValue::UserDefinedType { data_type, .. } => data_type.typecheck_equals(typ), } } } @@ -369,7 +368,7 @@ mod tests { use scylla::frame::value::{CqlDate, CqlDecimal, CqlDuration}; use crate::{ - cass_types::{CassDataType, CassValueType}, + cass_types::{CassDataType, CassValueType, UDTDataType}, value::{is_type_compatible, CassCqlValue}, }; @@ -570,6 +569,22 @@ mod tests { data_type_bool.clone(), ])); + let simple_fields = vec![ + ("foo".to_owned(), data_type_float.clone()), + ("bar".to_owned(), data_type_bool.clone()), + ("baz".to_owned(), data_type_int.clone()), + ]; + let ks_keyspace_name = "ks".to_owned(); + let user_udt_name = "user".to_owned(); + let empty_str = "".to_owned(); + + let data_type_udt_simple = Arc::new(CassDataType::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: user_udt_name.clone(), + frozen: false, + })); + // TUPLES { let data_type_untyped_tuple = Arc::new(CassDataType::Tuple(vec![])); @@ -602,6 +617,7 @@ mod tests { data_type_float.clone(), data_type_int.clone(), data_type_bool.clone(), + data_type_udt_simple.clone(), ], }, // Untyped tuple -> used created an untyped tuple data type, and then @@ -621,6 +637,7 @@ mod tests { data_type_float.clone(), data_type_int.clone(), data_type_bool.clone(), + data_type_udt_simple.clone(), ], }, // Fully typed tuple. @@ -640,6 +657,7 @@ mod tests { data_type_small_tuple.clone(), data_type_nested_tuple.clone(), data_type_nested_tuple.clone(), + data_type_udt_simple.clone(), ], }, // Nested tuple. @@ -659,11 +677,61 @@ mod tests { data_type_bool.clone(), data_type_tuple.clone(), data_type_small_tuple.clone(), + data_type_udt_simple.clone(), ], }, ]; run_test_cases(test_cases); } + + // UDT + { + let data_type_udt_simple_empty_keyspace = Arc::new(CassDataType::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: empty_str.to_owned(), + name: user_udt_name.clone(), + frozen: false, + })); + let data_type_udt_simple_empty_name = Arc::new(CassDataType::UDT(UDTDataType { + field_types: simple_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: empty_str.clone(), + frozen: false, + })); + + // A prefix of simple_fields. + let small_fields = vec![ + ("foo".to_owned(), data_type_float.clone()), + ("bar".to_owned(), data_type_bool.clone()), + ]; + let data_type_udt_small = Arc::new(CassDataType::UDT(UDTDataType { + field_types: small_fields.clone(), + keyspace: ks_keyspace_name.clone(), + name: user_udt_name.clone(), + frozen: false, + })); + + let test_cases = &[TestCase { + value: CassCqlValue::UserDefinedType { + data_type: data_type_udt_simple.clone(), + fields: vec![], + }, + compatible_types: vec![ + data_type_udt_simple.clone(), + data_type_udt_simple_empty_keyspace.clone(), + data_type_udt_simple_empty_name.clone(), + data_type_udt_small.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + ], + }]; + + run_test_cases(test_cases); + } } } From 17933ca848a0d9b55b8a067bd3b5e2838003b7c5 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 18:27:04 +0200 Subject: [PATCH 25/33] collection: include info about data type Added information about data type of a collection. This will be needed to implement typechecks. In addition, we can implement two missing API functions for collections: - cass_collection_new_from_data_type -> create a typed collection - cass_collection_data_type -> return a data type of a collection --- README.md | 7 --- scylla-rust-wrapper/src/collection.rs | 66 ++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5ab1ddf4..5b37bfa6 100644 --- a/README.md +++ b/README.md @@ -172,13 +172,6 @@ The driver inherits almost all the features of C/C++ and Rust drivers, such as: Collection - - cass_collection_new_from_data_type - Unimplemented - - - cass_collection_data_type - cass_collection_append_custom[_n] Unimplemented because of the same reasons as binding for statements.
Note: The driver does not check whether the type of the appended value is compatible with the type of the collection items. diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index 5964a38f..ef86e43b 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -1,14 +1,32 @@ use crate::argconv::*; use crate::cass_error::CassError; +use crate::cass_types::CassDataType; use crate::types::*; use crate::value::CassCqlValue; use std::convert::TryFrom; +use std::sync::Arc; include!(concat!(env!("OUT_DIR"), "/cppdriver_data_collection.rs")); +// These constants help us to save an allocation in case user calls `cass_collection_new` (untyped collection). +static UNTYPED_LIST_TYPE: CassDataType = CassDataType::List { + typ: None, + frozen: false, +}; +static UNTYPED_SET_TYPE: CassDataType = CassDataType::Set { + typ: None, + frozen: false, +}; +static UNTYPED_MAP_TYPE: CassDataType = CassDataType::Map { + key_type: None, + val_type: None, + frozen: false, +}; + #[derive(Clone)] pub struct CassCollection { pub collection_type: CassCollectionType, + pub data_type: Option>, pub capacity: usize, pub items: Vec, } @@ -57,18 +75,64 @@ pub unsafe extern "C" fn cass_collection_new( ) -> *mut CassCollection { let capacity = match collection_type { // Maps consist of a key and a value, so twice - // the number of CqlValue will be stored. + // the number of CassCqlValue will be stored. CassCollectionType::CASS_COLLECTION_TYPE_MAP => item_count * 2, _ => item_count, } as usize; Box::into_raw(Box::new(CassCollection { collection_type, + data_type: None, + capacity, + items: Vec::with_capacity(capacity), + })) +} + +#[no_mangle] +unsafe extern "C" fn cass_collection_new_from_data_type( + data_type: *const CassDataType, + item_count: size_t, +) -> *mut CassCollection { + let data_type = clone_arced(data_type); + let (capacity, collection_type) = match data_type.as_ref() { + CassDataType::List { .. } => (item_count, CassCollectionType::CASS_COLLECTION_TYPE_LIST), + CassDataType::Set { .. } => (item_count, CassCollectionType::CASS_COLLECTION_TYPE_SET), + // Maps consist of a key and a value, so twice + // the number of CassCqlValue will be stored. + CassDataType::Map { .. } => (item_count * 2, CassCollectionType::CASS_COLLECTION_TYPE_MAP), + _ => return std::ptr::null_mut(), + }; + let capacity = capacity as usize; + + Box::into_raw(Box::new(CassCollection { + collection_type, + data_type: Some(data_type), capacity, items: Vec::with_capacity(capacity), })) } +#[no_mangle] +unsafe extern "C" fn cass_collection_data_type( + collection: *const CassCollection, +) -> *const CassDataType { + let collection_ref = ptr_to_ref(collection); + + match &collection_ref.data_type { + Some(dt) => Arc::as_ptr(dt), + None => match collection_ref.collection_type { + CassCollectionType::CASS_COLLECTION_TYPE_LIST => &UNTYPED_LIST_TYPE, + CassCollectionType::CASS_COLLECTION_TYPE_SET => &UNTYPED_SET_TYPE, + CassCollectionType::CASS_COLLECTION_TYPE_MAP => &UNTYPED_MAP_TYPE, + // CassCollectionType is a C enum. Panic, if it's out of range. + _ => panic!( + "CassCollectionType enum value out of range: {}", + collection_ref.collection_type.0 + ), + }, + } +} + #[no_mangle] pub unsafe extern "C" fn cass_collection_free(collection: *mut CassCollection) { free_boxed(collection); From 1591511d583cffa2b13446de62497e5ce0a09522 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 18:34:32 +0200 Subject: [PATCH 26/33] value: is_type_compatible for collection values --- scylla-rust-wrapper/src/collection.rs | 20 +- scylla-rust-wrapper/src/value.rs | 309 +++++++++++++++++++++++++- 2 files changed, 312 insertions(+), 17 deletions(-) diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index ef86e43b..a818de48 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -44,10 +44,12 @@ impl TryFrom<&CassCollection> for CassCqlValue { type Error = (); fn try_from(collection: &CassCollection) -> Result { // FIXME: validate that collection items are correct + let data_type = collection.data_type.clone(); match collection.collection_type { - CassCollectionType::CASS_COLLECTION_TYPE_LIST => { - Ok(CassCqlValue::List(collection.items.clone())) - } + CassCollectionType::CASS_COLLECTION_TYPE_LIST => Ok(CassCqlValue::List { + data_type, + values: collection.items.clone(), + }), CassCollectionType::CASS_COLLECTION_TYPE_MAP => { let mut grouped_items = Vec::new(); // FIXME: validate even number of items @@ -58,11 +60,15 @@ impl TryFrom<&CassCollection> for CassCqlValue { grouped_items.push((key, value)); } - Ok(CassCqlValue::Map(grouped_items)) - } - CassCollectionType::CASS_COLLECTION_TYPE_SET => { - Ok(CassCqlValue::Set(collection.items.clone())) + Ok(CassCqlValue::Map { + data_type, + values: grouped_items, + }) } + CassCollectionType::CASS_COLLECTION_TYPE_SET => Ok(CassCqlValue::Set { + data_type, + values: collection.items.clone(), + }), _ => Err(()), } } diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index d4dd1b26..f8d29b1a 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -51,9 +51,18 @@ pub enum CassCqlValue { data_type: Option>, fields: Vec>, }, - List(Vec), - Map(Vec<(CassCqlValue, CassCqlValue)>), - Set(Vec), + List { + data_type: Option>, + values: Vec, + }, + Map { + data_type: Option>, + values: Vec<(CassCqlValue, CassCqlValue)>, + }, + Set { + data_type: Option>, + values: Vec, + }, UserDefinedType { data_type: Arc, /// Order of `fields` vector must match the order of fields as defined in the UDT. The @@ -130,9 +139,30 @@ impl CassCqlValue { // Untyped tuple. typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_TUPLE } - CassCqlValue::List(_) => todo!(), - CassCqlValue::Map(_) => todo!(), - CassCqlValue::Set(_) => todo!(), + CassCqlValue::List { data_type, .. } => { + if let Some(dt) = data_type { + dt.typecheck_equals(typ) + } else { + // Untyped list. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_LIST + } + } + CassCqlValue::Map { data_type, .. } => { + if let Some(dt) = data_type { + dt.typecheck_equals(typ) + } else { + // Untyped map. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_MAP + } + } + CassCqlValue::Set { data_type, .. } => { + if let Some(dt) = data_type { + dt.typecheck_equals(typ) + } else { + // Untyped set. + typ.get_value_type() == CassValueType::CASS_VALUE_TYPE_SET + } + } CassCqlValue::UserDefinedType { data_type, .. } => data_type.typecheck_equals(typ), } } @@ -207,11 +237,15 @@ impl CassCqlValue { ::serialize(v, &ColumnType::Decimal, writer) } CassCqlValue::Tuple { fields, .. } => serialize_tuple_like(fields.iter(), writer), - CassCqlValue::List(l) => serialize_sequence(l.len(), l.iter(), writer), - CassCqlValue::Map(m) => { - serialize_mapping(m.len(), m.iter().map(|p| (&p.0, &p.1)), writer) + CassCqlValue::List { values, .. } => { + serialize_sequence(values.len(), values.iter(), writer) + } + CassCqlValue::Map { values, .. } => { + serialize_mapping(values.len(), values.iter().map(|p| (&p.0, &p.1)), writer) + } + CassCqlValue::Set { values, .. } => { + serialize_sequence(values.len(), values.iter(), writer) } - CassCqlValue::Set(s) => serialize_sequence(s.len(), s.iter(), writer), CassCqlValue::UserDefinedType { fields, .. } => serialize_udt(fields, writer), } } @@ -585,6 +619,22 @@ mod tests { frozen: false, })); + let data_type_int_list = Arc::new(CassDataType::List { + typ: Some(data_type_int.clone()), + frozen: false, + }); + + let data_type_int_set = Arc::new(CassDataType::Set { + typ: Some(data_type_int.clone()), + frozen: false, + }); + + let data_type_bool_float_map = Arc::new(CassDataType::Map { + key_type: Some(data_type_bool.clone()), + val_type: Some(data_type_float.clone()), + frozen: false, + }); + // TUPLES { let data_type_untyped_tuple = Arc::new(CassDataType::Tuple(vec![])); @@ -618,6 +668,9 @@ mod tests { data_type_int.clone(), data_type_bool.clone(), data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), ], }, // Untyped tuple -> used created an untyped tuple data type, and then @@ -638,6 +691,9 @@ mod tests { data_type_int.clone(), data_type_bool.clone(), data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), ], }, // Fully typed tuple. @@ -658,6 +714,9 @@ mod tests { data_type_nested_tuple.clone(), data_type_nested_tuple.clone(), data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), ], }, // Nested tuple. @@ -678,6 +737,9 @@ mod tests { data_type_tuple.clone(), data_type_small_tuple.clone(), data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), ], }, ]; @@ -728,10 +790,237 @@ mod tests { data_type_int.clone(), data_type_bool.clone(), data_type_tuple.clone(), + data_type_int_list.clone(), + data_type_int_set.clone(), + data_type_bool_float_map.clone(), ], }]; run_test_cases(test_cases); } + + // COLLECTIONS + { + let data_type_untyped_list = Arc::new(CassDataType::List { + typ: None, + frozen: false, + }); + let data_type_float_list = Arc::new(CassDataType::List { + typ: Some(data_type_float.clone()), + frozen: false, + }); + + let data_type_untyped_set = Arc::new(CassDataType::Set { + typ: None, + frozen: false, + }); + let data_type_float_set = Arc::new(CassDataType::Set { + typ: Some(data_type_float.clone()), + frozen: false, + }); + + let data_type_untyped_map = Arc::new(CassDataType::Map { + key_type: None, + val_type: None, + frozen: false, + }); + let data_type_typed_key_float_map = Arc::new(CassDataType::Map { + key_type: Some(data_type_float.clone()), + val_type: None, + frozen: false, + }); + let data_type_float_int_map = Arc::new(CassDataType::Map { + key_type: Some(data_type_float.clone()), + val_type: Some(data_type_int.clone()), + frozen: false, + }); + + let test_cases = &[ + // Untyped list -> user created it via `cass_collection_new`. + TestCase { + value: CassCqlValue::List { + data_type: None, + values: vec![], + }, + compatible_types: vec![ + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Typed list. + TestCase { + value: CassCqlValue::List { + data_type: Some(data_type_float_list.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_float_list.clone(), + data_type_untyped_list.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Untyped set (via cass_collection_new). + TestCase { + value: CassCqlValue::Set { + data_type: None, + values: vec![], + }, + compatible_types: vec![ + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_float_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Typed set. + TestCase { + value: CassCqlValue::Set { + data_type: Some(data_type_float_set.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_untyped_set.clone(), + data_type_float_set.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_int_list.clone(), + data_type_float_list.clone(), + data_type_untyped_list.clone(), + data_type_int_set.clone(), + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Untyped map (via cass_collection_new). + TestCase { + value: CassCqlValue::Map { + data_type: None, + values: vec![], + }, + compatible_types: vec![ + data_type_untyped_map.clone(), + data_type_typed_key_float_map.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + ], + }, + // Only key-typed map. + TestCase { + value: CassCqlValue::Map { + data_type: Some(data_type_typed_key_float_map.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_typed_key_float_map.clone(), + data_type_untyped_map.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_float_int_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + // Fully typed map + TestCase { + value: CassCqlValue::Map { + data_type: Some(data_type_float_int_map.clone()), + values: vec![], + }, + compatible_types: vec![ + data_type_float_int_map.clone(), + data_type_untyped_map.clone(), + ], + incompatible_types: vec![ + data_type_float.clone(), + data_type_int.clone(), + data_type_bool.clone(), + data_type_tuple.clone(), + data_type_udt_simple.clone(), + data_type_float_list.clone(), + data_type_int_list.clone(), + data_type_untyped_list.clone(), + data_type_untyped_set.clone(), + data_type_float_set.clone(), + data_type_int_set.clone(), + data_type_typed_key_float_map.clone(), + data_type_bool_float_map.clone(), + ], + }, + ]; + + run_test_cases(test_cases) + } } } From 22481bebbc7f401b80053390736652bdfa16c73e Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 18:57:49 +0200 Subject: [PATCH 27/33] collection: typecheck value on append --- scylla-rust-wrapper/src/collection.rs | 329 +++++++++++++++++++++++++- 1 file changed, 327 insertions(+), 2 deletions(-) diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index a818de48..c22312c2 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -1,8 +1,8 @@ -use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::CassDataType; use crate::types::*; use crate::value::CassCqlValue; +use crate::{argconv::*, value}; use std::convert::TryFrom; use std::sync::Arc; @@ -32,8 +32,52 @@ pub struct CassCollection { } impl CassCollection { + fn typecheck_on_append(&self, value: &Option) -> CassError { + // See https://github.com/scylladb/cpp-driver/blob/master/src/collection.hpp#L100. + let index = self.items.len(); + + // Do validation only if it's a typed collection. + if let Some(data_type) = &self.data_type { + match data_type.as_ref() { + CassDataType::List { typ: subtype, .. } + | CassDataType::Set { typ: subtype, .. } => { + if let Some(subtype) = subtype { + if !value::is_type_compatible(value, subtype) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + } + } + CassDataType::Map { + key_type: k_typ, + val_type: v_typ, + .. + } => { + // Cpp-driver does the typecheck only if both map types are present... + // However, we decided not to mimic this behaviour (which is probably a bug). + // We will do the typecheck if just the key type is defined as well (half-typed maps). + if let Some(k_typ) = k_typ { + if index % 2 == 0 && !value::is_type_compatible(value, k_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + } + if let Some(v_typ) = v_typ { + if index % 2 != 0 && !value::is_type_compatible(value, v_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + } + } + _ => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } + } + + CassError::CASS_OK + } + pub fn append_cql_value(&mut self, value: Option) -> CassError { - // FIXME: Bounds check, type check + let err = self.typecheck_on_append(&value); + if err != CassError::CASS_OK { + return err; + } // There is no API to append null, so unwrap is safe self.items.push(value.unwrap()); CassError::CASS_OK @@ -163,3 +207,284 @@ make_binders!(decimal, cass_collection_append_decimal); make_binders!(collection, cass_collection_append_collection); make_binders!(tuple, cass_collection_append_tuple); make_binders!(user_type, cass_collection_append_user_type); + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + cass_error::CassError, + cass_types::{CassDataType, CassValueType}, + collection::{ + cass_collection_append_double, cass_collection_append_float, cass_collection_free, + }, + testing::assert_cass_error_eq, + }; + + use super::{ + cass_bool_t, cass_collection_append_bool, cass_collection_append_int16, + cass_collection_new, cass_collection_new_from_data_type, CassCollectionType, + }; + + #[test] + fn test_typecheck_on_append_to_collection() { + unsafe { + // untyped map (via cass_collection_new, Collection's data type is None). + { + let untyped_map = + cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_MAP, 2); + assert_cass_error_eq!( + cass_collection_append_bool(untyped_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_map, 42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_double(untyped_map, 42.42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(untyped_map, 42.42), + CassError::CASS_OK + ); + cass_collection_free(untyped_map); + } + + // untyped map (via cass_collection_new_from_data_type - collection's type is Some(untyped_map)). + { + let dt = Arc::new(CassDataType::Map { + key_type: None, + val_type: None, + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let untyped_map = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(untyped_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_map, 42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_double(untyped_map, 42.42), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(untyped_map, 42.42), + CassError::CASS_OK + ); + cass_collection_free(untyped_map); + } + + // half-typed map (key-only) + { + let dt = Arc::new(CassDataType::Map { + key_type: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + ))), + val_type: None, + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let half_typed_map = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(half_typed_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(half_typed_map, 42), + CassError::CASS_OK + ); + + // Second entry -> key typecheck failed. + assert_cass_error_eq!( + cass_collection_append_double(half_typed_map, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + // Second entry -> typecheck succesful. + assert_cass_error_eq!( + cass_collection_append_bool(half_typed_map, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_double(half_typed_map, 42.42), + CassError::CASS_OK + ); + cass_collection_free(half_typed_map); + } + + // typed map + { + let dt = Arc::new(CassDataType::Map { + key_type: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + ))), + val_type: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_SMALL_INT, + ))), + frozen: false, + }); + let dt_ptr = Arc::into_raw(dt); + let bool_to_i16_map = cass_collection_new_from_data_type(dt_ptr, 2); + + // First entry -> typecheck successful. + assert_cass_error_eq!( + cass_collection_append_bool(bool_to_i16_map, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(bool_to_i16_map, 42), + CassError::CASS_OK + ); + + // Second entry -> key typecheck failed. + assert_cass_error_eq!( + cass_collection_append_float(bool_to_i16_map, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + // Third entry -> value typecheck failed. + assert_cass_error_eq!( + cass_collection_append_bool(bool_to_i16_map, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(bool_to_i16_map, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + Arc::from_raw(dt_ptr); + cass_collection_free(bool_to_i16_map); + } + + // untyped set (via cass_collection_new, collection's type is None) + { + let untyped_set = + cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_SET, 2); + assert_cass_error_eq!( + cass_collection_append_bool(untyped_set, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_set, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_set); + } + + // untyped set (via cass_collection_new_from_data_type, collection's type is Some(untyped_set)) + { + let dt = Arc::new(CassDataType::Set { + typ: None, + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let untyped_set = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(untyped_set, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_set, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_set); + } + + // typed set + { + let dt = Arc::new(CassDataType::Set { + typ: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + ))), + frozen: false, + }); + let dt_ptr = Arc::into_raw(dt); + let bool_set = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(bool_set, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(bool_set, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + Arc::from_raw(dt_ptr); + cass_collection_free(bool_set); + } + + // untyped list (via cass_collection_new, collection's type is None) + { + let untyped_list = + cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_LIST, 2); + assert_cass_error_eq!( + cass_collection_append_bool(untyped_list, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_list, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_list); + } + + // untyped list (via cass_collection_new_from_data_type, collection's type is Some(untyped_list)) + { + let dt = Arc::new(CassDataType::Set { + typ: None, + frozen: false, + }); + + let dt_ptr = Arc::into_raw(dt); + let untyped_list = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(untyped_list, false as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_int16(untyped_list, 42), + CassError::CASS_OK + ); + cass_collection_free(untyped_list); + } + + // typed list + { + let dt = Arc::new(CassDataType::Set { + typ: Some(Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_BOOLEAN, + ))), + frozen: false, + }); + let dt_ptr = Arc::into_raw(dt); + let bool_list = cass_collection_new_from_data_type(dt_ptr, 2); + + assert_cass_error_eq!( + cass_collection_append_bool(bool_list, true as cass_bool_t), + CassError::CASS_OK + ); + assert_cass_error_eq!( + cass_collection_append_float(bool_list, 42.42), + CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE + ); + + Arc::from_raw(dt_ptr); + cass_collection_free(bool_list); + } + } + } +} From 9e2386b9f121dc0a563c12f620869a3620c41a02 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 19:31:12 +0200 Subject: [PATCH 28/33] cass_prepared: hold info about col data types Until now, we would not hold the information about column data types from PreparedMetadata. We need to hold this information to perform a typecheck during binding values to statement. We could construct the data types from column specs each time we bind a value. However, CassDataType might be a heavy nested object, and so I decided to cache it in CassPrepared. --- scylla-rust-wrapper/src/batch.rs | 2 +- scylla-rust-wrapper/src/future.rs | 3 +-- scylla-rust-wrapper/src/prepared.rs | 25 +++++++++++++++++++-- scylla-rust-wrapper/src/session.rs | 19 ++++++++++------ scylla-rust-wrapper/src/statement.rs | 33 +++++++++++++++++----------- 5 files changed, 57 insertions(+), 25 deletions(-) diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index d4890402..3cdf36ee 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -165,7 +165,7 @@ pub unsafe extern "C" fn cass_batch_add_statement( match &statement.statement { Statement::Simple(q) => state.batch.append_statement(q.query.clone()), - Statement::Prepared(p) => state.batch.append_statement((**p).clone()), + Statement::Prepared(p) => state.batch.append_statement(p.statement.clone()), }; state.bound_values.push(statement.bound_values.clone()); diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index b11e99b0..c579dd5f 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -8,7 +8,6 @@ use crate::types::*; use crate::uuid::CassUuid; use crate::RUNTIME; use futures::future; -use scylla::prepared_statement::PreparedStatement; use std::future::Future; use std::mem; use std::os::raw::c_void; @@ -20,7 +19,7 @@ pub enum CassResultValue { Empty, QueryResult(Arc), QueryError(Arc), - Prepared(Arc), + Prepared(Arc), } type CassFutureError = (CassError, String); diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 5094d652..33fbcba6 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -3,11 +3,32 @@ use std::sync::Arc; use crate::{ argconv::*, + cass_types::{get_column_type, CassDataType}, statement::{CassStatement, Statement}, }; use scylla::prepared_statement::PreparedStatement; -pub type CassPrepared = PreparedStatement; +#[derive(Debug, Clone)] +pub struct CassPrepared { + // Data types of columns from PreparedMetadata. + pub variable_col_data_types: Vec>, + pub statement: PreparedStatement, +} + +impl CassPrepared { + pub fn new_from_prepared_statement(statement: PreparedStatement) -> Self { + let variable_col_data_types = statement + .get_variable_col_specs() + .iter() + .map(|col_spec| Arc::new(get_column_type(&col_spec.typ))) + .collect(); + + Self { + variable_col_data_types, + statement, + } + } +} #[no_mangle] pub unsafe extern "C" fn cass_prepared_free(prepared_raw: *const CassPrepared) { @@ -19,7 +40,7 @@ pub unsafe extern "C" fn cass_prepared_bind( prepared_raw: *const CassPrepared, ) -> *mut CassStatement { let prepared: Arc<_> = clone_arced(prepared_raw); - let bound_values_size = prepared.get_variable_col_specs().len(); + let bound_values_size = prepared.statement.get_variable_col_specs().len(); // cloning prepared statement's arc, because creating CassStatement should not invalidate // the CassPrepared argument diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 9e54737a..397814b3 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -8,6 +8,7 @@ use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProf use crate::future::{CassFuture, CassFutureResult, CassResultValue}; use crate::metadata::create_table_metadata; use crate::metadata::{CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta}; +use crate::prepared::CassPrepared; use crate::query_result::Value::{CollectionValue, RegularValue}; use crate::query_result::{CassResult, CassResultData, CassRow, CassValue, Collection, Value}; use crate::statement::CassStatement; @@ -279,9 +280,9 @@ pub unsafe extern "C" fn cass_session_execute( match &mut statement { Statement::Simple(query) => query.query.set_execution_profile_handle(handle), - Statement::Prepared(prepared) => { - Arc::make_mut(prepared).set_execution_profile_handle(handle) - } + Statement::Prepared(prepared) => Arc::make_mut(prepared) + .statement + .set_execution_profile_handle(handle), } let query_res: Result<(QueryResult, PagingStateResponse), QueryError> = match statement { @@ -300,11 +301,11 @@ pub unsafe extern "C" fn cass_session_execute( Statement::Prepared(prepared) => { if paging_enabled { session - .execute_single_page(&prepared, bound_values, paging_state) + .execute_single_page(&prepared.statement, bound_values, paging_state) .await } else { session - .execute_unpaged(&prepared, bound_values) + .execute_unpaged(&prepared.statement, bound_values) .await .map(|result| (result, PagingStateResponse::NoMorePages)) } @@ -499,7 +500,9 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing( .await .map_err(|err| (CassError::from(&err), err.msg()))?; - Ok(CassResultValue::Prepared(Arc::new(prepared))) + Ok(CassResultValue::Prepared(Arc::new( + CassPrepared::new_from_prepared_statement(prepared), + ))) }) } @@ -542,7 +545,9 @@ pub unsafe extern "C" fn cass_session_prepare_n( // Set Cpp Driver default configuration for queries: prepared.set_consistency(Consistency::One); - Ok(CassResultValue::Prepared(Arc::new(prepared))) + Ok(CassResultValue::Prepared(Arc::new( + CassPrepared::new_from_prepared_statement(prepared), + ))) }) } diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index b1129ad9..eb9331b8 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -1,6 +1,7 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::exec_profile::PerStatementExecProfile; +use crate::prepared::CassPrepared; use crate::query_result::CassResult; use crate::retry_policy::CassRetryPolicy; use crate::types::*; @@ -9,7 +10,6 @@ use scylla::frame::types::Consistency; use scylla::frame::value::MaybeUnset; use scylla::frame::value::MaybeUnset::{Set, Unset}; use scylla::query::Query; -use scylla::statement::prepared_statement::PreparedStatement; use scylla::statement::SerialConsistency; use scylla::transport::{PagingState, PagingStateResponse}; use std::collections::HashMap; @@ -24,7 +24,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_query_error.rs")); pub enum Statement { Simple(SimpleQuery), // Arc is needed, because PreparedStatement is passed by reference to session.execute - Prepared(Arc), + Prepared(Arc), } #[derive(Clone)] @@ -83,6 +83,7 @@ impl CassStatement { match &self.statement { Statement::Prepared(prepared) => { let indices: Vec = prepared + .statement .get_variable_col_specs() .iter() .enumerate() @@ -185,7 +186,9 @@ pub unsafe extern "C" fn cass_statement_set_consistency( if let Some(consistency) = consistency_opt { match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_consistency(consistency), - Statement::Prepared(inner) => Arc::make_mut(inner).set_consistency(consistency), + Statement::Prepared(inner) => { + Arc::make_mut(inner).statement.set_consistency(consistency) + } } } @@ -205,7 +208,7 @@ pub unsafe extern "C" fn cass_statement_set_paging_size( statement.paging_enabled = true; match &mut statement.statement { Statement::Simple(inner) => inner.query.set_page_size(page_size), - Statement::Prepared(inner) => Arc::make_mut(inner).set_page_size(page_size), + Statement::Prepared(inner) => Arc::make_mut(inner).statement.set_page_size(page_size), } } @@ -253,7 +256,9 @@ pub unsafe extern "C" fn cass_statement_set_is_idempotent( ) -> CassError { match &mut ptr_to_ref_mut(statement_raw).statement { Statement::Simple(inner) => inner.query.set_is_idempotent(is_idempotent != 0), - Statement::Prepared(inner) => Arc::make_mut(inner).set_is_idempotent(is_idempotent != 0), + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_is_idempotent(is_idempotent != 0), } CassError::CASS_OK @@ -266,7 +271,7 @@ pub unsafe extern "C" fn cass_statement_set_tracing( ) -> CassError { match &mut ptr_to_ref_mut(statement_raw).statement { Statement::Simple(inner) => inner.query.set_tracing(enabled != 0), - Statement::Prepared(inner) => Arc::make_mut(inner).set_tracing(enabled != 0), + Statement::Prepared(inner) => Arc::make_mut(inner).statement.set_tracing(enabled != 0), } CassError::CASS_OK @@ -288,9 +293,9 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_retry_policy(maybe_arced_retry_policy), - Statement::Prepared(inner) => { - Arc::make_mut(inner).set_retry_policy(maybe_arced_retry_policy) - } + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_retry_policy(maybe_arced_retry_policy), } CassError::CASS_OK @@ -317,9 +322,9 @@ pub unsafe extern "C" fn cass_statement_set_serial_consistency( match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_serial_consistency(Some(consistency)), - Statement::Prepared(inner) => { - Arc::make_mut(inner).set_serial_consistency(Some(consistency)) - } + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_serial_consistency(Some(consistency)), } CassError::CASS_OK @@ -349,7 +354,9 @@ pub unsafe extern "C" fn cass_statement_set_timestamp( ) -> CassError { match &mut ptr_to_ref_mut(statement).statement { Statement::Simple(inner) => inner.query.set_timestamp(Some(timestamp)), - Statement::Prepared(inner) => Arc::make_mut(inner).set_timestamp(Some(timestamp)), + Statement::Prepared(inner) => Arc::make_mut(inner) + .statement + .set_timestamp(Some(timestamp)), } CassError::CASS_OK From 25b2c72543d60032b0bffe3835d98568805e1df8 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 19:42:09 +0200 Subject: [PATCH 29/33] statement: perform typecheck on bind --- scylla-rust-wrapper/src/statement.rs | 35 +++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index eb9331b8..dc955d79 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -1,4 +1,3 @@ -use crate::argconv::*; use crate::cass_error::CassError; use crate::exec_profile::PerStatementExecProfile; use crate::prepared::CassPrepared; @@ -6,6 +5,7 @@ use crate::query_result::CassResult; use crate::retry_policy::CassRetryPolicy; use crate::types::*; use crate::value::CassCqlValue; +use crate::{argconv::*, value}; use scylla::frame::types::Consistency; use scylla::frame::value::MaybeUnset; use scylla::frame::value::MaybeUnset::{Set, Unset}; @@ -45,12 +45,35 @@ pub struct CassStatement { impl CassStatement { fn bind_cql_value(&mut self, index: usize, value: Option) -> CassError { - if index >= self.bound_values.len() { - CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS - } else { - self.bound_values[index] = Set(value); - CassError::CASS_OK + let (bound_value, maybe_data_type) = match &self.statement { + Statement::Simple(_) => match self.bound_values.get_mut(index) { + Some(v) => (v, None), + None => return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, + }, + Statement::Prepared(p) => match ( + self.bound_values.get_mut(index), + p.variable_col_data_types.get(index), + ) { + (Some(v), Some(dt)) => (v, Some(dt)), + (None, None) => return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, + // This indicates a length mismatch between col specs table and self.bound_values. + // + // It can only occur when user provides bad `count` value in `cass_statement_reset_parameters`. + // Cpp-driver does not verify that both of these values are equal. + // I believe returning CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS is best we can do here. + _ => return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, + }, + }; + + // Perform the typecheck. + if let Some(dt) = maybe_data_type { + if !value::is_type_compatible(&value, dt) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } } + + *bound_value = Set(value); + CassError::CASS_OK } fn bind_multiple_values_by_name( From 8307c182ffd11a72cd789917a28ca9d8921b7ac7 Mon Sep 17 00:00:00 2001 From: muzarski Date: Wed, 17 Jul 2024 19:54:10 +0200 Subject: [PATCH 30/33] tuple: change constant to UNTYPED_TUPLE_TYPE --- scylla-rust-wrapper/src/tuple.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index dd8563ae..df1f9889 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -6,7 +6,7 @@ use crate::value; use crate::value::CassCqlValue; use std::sync::Arc; -static EMPTY_TUPLE_TYPE: CassDataType = CassDataType::Tuple(Vec::new()); +static UNTYPED_TUPLE_TYPE: CassDataType = CassDataType::Tuple(Vec::new()); #[derive(Clone)] pub struct CassTuple { @@ -89,7 +89,7 @@ unsafe extern "C" fn cass_tuple_free(tuple: *mut CassTuple) { unsafe extern "C" fn cass_tuple_data_type(tuple: *const CassTuple) -> *const CassDataType { match &ptr_to_ref(tuple).data_type { Some(t) => Arc::as_ptr(t), - None => &EMPTY_TUPLE_TYPE, + None => &UNTYPED_TUPLE_TYPE, } } From 794ea6d2594445cd33d51905f7eb22db40bd3ede Mon Sep 17 00:00:00 2001 From: muzarski Date: Thu, 18 Jul 2024 15:52:01 +0200 Subject: [PATCH 31/33] PR template: unit tests checkbox --- .github/pull_request_template.md | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index beed77df..250d3c01 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -9,5 +9,6 @@ - [ ] I have split my patch into logically separate commits. - [ ] All commit messages clearly explain what they change and why. - [ ] PR description sums up the changes and reasons why they should be introduced. +- [ ] I have implemented Rust unit tests for the features/changes introduced. - [ ] I have enabled appropriate tests in `.github/workflows/build.yml` in `gtest_filter`. - [ ] I have enabled appropriate tests in `.github/workflows/cassandra.yml` in `gtest_filter`. \ No newline at end of file From 0eb32ff56c4e3e709e368b496f8ab2ca572ca04f Mon Sep 17 00:00:00 2001 From: muzarski Date: Mon, 29 Jul 2024 17:34:47 +0200 Subject: [PATCH 32/33] cleanup: remove a TODO comment --- scylla-rust-wrapper/src/value.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index f8d29b1a..28b25793 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -174,7 +174,7 @@ impl SerializeValue for CassCqlValue { _typ: &ColumnType, writer: CellWriter<'b>, ) -> Result, SerializationError> { - // _typ is not used, since we do the typechecks during binding (this is still a TODO, high priority). + // _typ is not used, since we do the typechecks during binding. // This is the same approach as cpp-driver. self.do_serialize(writer) } From 8dbe9d19049837b71d967ebb54462d7da749bc1d Mon Sep 17 00:00:00 2001 From: muzarski Date: Fri, 13 Sep 2024 05:15:44 +0200 Subject: [PATCH 33/33] map: represent valid map type states via enum It's not possible for a map type to have value type defined, when key type is not defined. A representation of map type before this commit, would technically allow us to create such state. Introduced a `MapDataType` enum which represents valid states of map data type. --- scylla-rust-wrapper/src/cass_types.rs | 122 ++++++++++++++---------- scylla-rust-wrapper/src/collection.rs | 53 +++++----- scylla-rust-wrapper/src/query_result.rs | 6 +- scylla-rust-wrapper/src/session.rs | 9 +- scylla-rust-wrapper/src/value.rs | 15 ++- 5 files changed, 112 insertions(+), 93 deletions(-) diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 12da2384..f75540ae 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -131,6 +131,13 @@ impl Default for UDTDataType { } } +#[derive(Clone, Debug, PartialEq)] +pub enum MapDataType { + Untyped, + Key(Arc), + KeyAndValue(Arc, Arc), +} + #[derive(Clone, Debug, PartialEq)] pub enum CassDataType { Value(CassValueType), @@ -146,10 +153,7 @@ pub enum CassDataType { frozen: bool, }, Map { - // None, None stands for untyped map. - // Some, None stands for a map with an untyped value type. - key_type: Option>, - val_type: Option>, + typ: MapDataType, frozen: bool, }, // Empty vector stands for untyped tuple. @@ -183,29 +187,22 @@ impl CassDataType { } _ => false, }, - CassDataType::Map { - key_type: k, - val_type: v, - .. - } => match other { - CassDataType::Map { - key_type: k_other, - val_type: v_other, - .. - } => match ((k, v), (k_other, v_other)) { + CassDataType::Map { typ: t, .. } => match other { + CassDataType::Map { typ: t_other, .. } => match (t, t_other) { // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218 // In cpp-driver the types are held in a vector. // The logic is following: // If either of vectors is empty, skip the typecheck. - ((None, None), _) => true, - (_, (None, None)) => true, + (MapDataType::Untyped, _) => true, + (_, MapDataType::Untyped) => true, // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes. - ((Some(k), None), (Some(k_other), None)) => k.typecheck_equals(k_other), - ((Some(k), Some(v)), (Some(k_other), Some(v_other))) => { - k.typecheck_equals(k_other) && v.typecheck_equals(v_other) - } + (MapDataType::Key(k), MapDataType::Key(k_other)) => k.typecheck_equals(k_other), + ( + MapDataType::KeyAndValue(k, v), + MapDataType::KeyAndValue(k_other, v_other), + ) => k.typecheck_equals(k_other) && v.typecheck_equals(v_other), _ => false, }, _ => false, @@ -278,16 +275,18 @@ pub fn get_column_type_from_cql_type( frozen: *frozen, }, CollectionType::Map(key, value) => CassDataType::Map { - key_type: Some(Arc::new(get_column_type_from_cql_type( - key, - user_defined_types, - keyspace_name, - ))), - val_type: Some(Arc::new(get_column_type_from_cql_type( - value, - user_defined_types, - keyspace_name, - ))), + typ: MapDataType::KeyAndValue( + Arc::new(get_column_type_from_cql_type( + key, + user_defined_types, + keyspace_name, + )), + Arc::new(get_column_type_from_cql_type( + value, + user_defined_types, + keyspace_name, + )), + ), frozen: *frozen, }, CollectionType::Set(set) => CassDataType::Set { @@ -340,10 +339,19 @@ impl CassDataType { } } CassDataType::Map { - key_type, val_type, .. + typ: MapDataType::Untyped, + .. + } => None, + CassDataType::Map { + typ: MapDataType::Key(k), + .. + } => (index == 0).then_some(k), + CassDataType::Map { + typ: MapDataType::KeyAndValue(k, v), + .. } => match index { - 0 => key_type.as_ref(), - 1 => val_type.as_ref(), + 0 => Some(k), + 1 => Some(v), _ => None, }, CassDataType::Tuple(v) => v.get(index), @@ -361,17 +369,28 @@ impl CassDataType { } }, CassDataType::Map { - key_type, val_type, .. + typ: MapDataType::KeyAndValue(_, _), + .. + } => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS), + CassDataType::Map { + typ: MapDataType::Key(k), + frozen, } => { - if key_type.is_some() && val_type.is_some() { - Err(CassError::CASS_ERROR_LIB_BAD_PARAMS) - } else if key_type.is_none() { - *key_type = Some(sub_type); - Ok(()) - } else { - *val_type = Some(sub_type); - Ok(()) - } + *self = CassDataType::Map { + typ: MapDataType::KeyAndValue(k.clone(), sub_type), + frozen: *frozen, + }; + Ok(()) + } + CassDataType::Map { + typ: MapDataType::Untyped, + frozen, + } => { + *self = CassDataType::Map { + typ: MapDataType::Key(sub_type), + frozen: *frozen, + }; + Ok(()) } CassDataType::Tuple(types) => { types.push(sub_type); @@ -423,8 +442,10 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { frozen: false, }, ColumnType::Map(key, value) => CassDataType::Map { - key_type: Some(Arc::new(get_column_type(key.as_ref()))), - val_type: Some(Arc::new(get_column_type(value.as_ref()))), + typ: MapDataType::KeyAndValue( + Arc::new(get_column_type(key.as_ref())), + Arc::new(get_column_type(value.as_ref())), + ), frozen: false, }, ColumnType::Set(boxed_type) => CassDataType::Set { @@ -475,8 +496,7 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const }, CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataType::Tuple(Vec::new()), CassValueType::CASS_VALUE_TYPE_MAP => CassDataType::Map { - key_type: None, - val_type: None, + typ: MapDataType::Untyped, frozen: false, }, CassValueType::CASS_VALUE_TYPE_UDT => CassDataType::UDT(UDTDataType::new()), @@ -673,9 +693,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat CassDataType::Value(..) => 0, CassDataType::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t, CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => typ.is_some() as size_t, - CassDataType::Map { - key_type, val_type, .. - } => key_type.is_some() as size_t + val_type.is_some() as size_t, + CassDataType::Map { typ, .. } => match typ { + MapDataType::Untyped => 0, + MapDataType::Key(_) => 1, + MapDataType::KeyAndValue(_, _) => 2, + }, CassDataType::Tuple(v) => v.len() as size_t, CassDataType::Custom(..) => 0, } diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index c22312c2..ea0c2a79 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -1,5 +1,5 @@ use crate::cass_error::CassError; -use crate::cass_types::CassDataType; +use crate::cass_types::{CassDataType, MapDataType}; use crate::types::*; use crate::value::CassCqlValue; use crate::{argconv::*, value}; @@ -18,8 +18,7 @@ static UNTYPED_SET_TYPE: CassDataType = CassDataType::Set { frozen: false, }; static UNTYPED_MAP_TYPE: CassDataType = CassDataType::Map { - key_type: None, - val_type: None, + typ: MapDataType::Untyped, frozen: false, }; @@ -47,23 +46,27 @@ impl CassCollection { } } } - CassDataType::Map { - key_type: k_typ, - val_type: v_typ, - .. - } => { + + CassDataType::Map { typ, .. } => { // Cpp-driver does the typecheck only if both map types are present... // However, we decided not to mimic this behaviour (which is probably a bug). // We will do the typecheck if just the key type is defined as well (half-typed maps). - if let Some(k_typ) = k_typ { - if index % 2 == 0 && !value::is_type_compatible(value, k_typ) { - return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + match typ { + MapDataType::Key(k_typ) => { + if index % 2 == 0 && !value::is_type_compatible(value, k_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } } - } - if let Some(v_typ) = v_typ { - if index % 2 != 0 && !value::is_type_compatible(value, v_typ) { - return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + MapDataType::KeyAndValue(k_typ, v_typ) => { + if index % 2 == 0 && !value::is_type_compatible(value, k_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } + if index % 2 != 0 && !value::is_type_compatible(value, v_typ) { + return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE; + } } + // Skip the typecheck for untyped map. + MapDataType::Untyped => (), } } _ => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -214,7 +217,7 @@ mod tests { use crate::{ cass_error::CassError, - cass_types::{CassDataType, CassValueType}, + cass_types::{CassDataType, CassValueType, MapDataType}, collection::{ cass_collection_append_double, cass_collection_append_float, cass_collection_free, }, @@ -255,8 +258,7 @@ mod tests { // untyped map (via cass_collection_new_from_data_type - collection's type is Some(untyped_map)). { let dt = Arc::new(CassDataType::Map { - key_type: None, - val_type: None, + typ: MapDataType::Untyped, frozen: false, }); @@ -285,10 +287,9 @@ mod tests { // half-typed map (key-only) { let dt = Arc::new(CassDataType::Map { - key_type: Some(Arc::new(CassDataType::Value( + typ: MapDataType::Key(Arc::new(CassDataType::Value( CassValueType::CASS_VALUE_TYPE_BOOLEAN, ))), - val_type: None, frozen: false, }); @@ -325,12 +326,12 @@ mod tests { // typed map { let dt = Arc::new(CassDataType::Map { - key_type: Some(Arc::new(CassDataType::Value( - CassValueType::CASS_VALUE_TYPE_BOOLEAN, - ))), - val_type: Some(Arc::new(CassDataType::Value( - CassValueType::CASS_VALUE_TYPE_SMALL_INT, - ))), + typ: MapDataType::KeyAndValue( + Arc::new(CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN)), + Arc::new(CassDataType::Value( + CassValueType::CASS_VALUE_TYPE_SMALL_INT, + )), + ), frozen: false, }); let dt_ptr = Arc::into_raw(dt); diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index dec575bb..4a98b991 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -1,6 +1,6 @@ use crate::argconv::*; use crate::cass_error::CassError; -use crate::cass_types::{cass_data_type_type, CassDataType, CassValueType}; +use crate::cass_types::{cass_data_type_type, CassDataType, CassValueType, MapDataType}; use crate::inet::CassInet; use crate::metadata::{ CassColumnMeta, CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta, CassTableMeta, @@ -1239,7 +1239,7 @@ pub unsafe extern "C" fn cass_value_primary_sub_type( } => list.get_value_type(), CassDataType::Set { typ: Some(set), .. } => set.get_value_type(), CassDataType::Map { - key_type: Some(key), + typ: MapDataType::Key(key) | MapDataType::KeyAndValue(key, _), .. } => key.get_value_type(), _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, @@ -1254,7 +1254,7 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type( match val.value_type.as_ref() { CassDataType::Map { - val_type: Some(value), + typ: MapDataType::KeyAndValue(_, value), .. } => value.get_value_type(), _ => CassValueType::CASS_VALUE_TYPE_UNKNOWN, diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 397814b3..2f5f17e8 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::batch::CassBatch; use crate::cass_error::*; -use crate::cass_types::{get_column_type, CassDataType, UDTDataType}; +use crate::cass_types::{get_column_type, CassDataType, MapDataType, UDTDataType}; use crate::cluster::build_session_builder; use crate::cluster::CassCluster; use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProfile}; @@ -388,8 +388,7 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value ( CqlValue::Map(map), CassDataType::Map { - key_type: Some(key_typ), - val_type: Some(value_type), + typ: MapDataType::KeyAndValue(key_type, value_type), .. }, ) => CollectionValue(Collection::Map( @@ -397,8 +396,8 @@ fn get_column_value(column: CqlValue, column_type: &Arc) -> Value .map(|(key, val)| { ( CassValue { - value_type: key_typ.clone(), - value: Some(get_column_value(key, key_typ)), + value_type: key_type.clone(), + value: Some(get_column_value(key, key_type)), }, CassValue { value_type: value_type.clone(), diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 28b25793..7636878f 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -402,7 +402,7 @@ mod tests { use scylla::frame::value::{CqlDate, CqlDecimal, CqlDuration}; use crate::{ - cass_types::{CassDataType, CassValueType, UDTDataType}, + cass_types::{CassDataType, CassValueType, MapDataType, UDTDataType}, value::{is_type_compatible, CassCqlValue}, }; @@ -630,8 +630,7 @@ mod tests { }); let data_type_bool_float_map = Arc::new(CassDataType::Map { - key_type: Some(data_type_bool.clone()), - val_type: Some(data_type_float.clone()), + typ: MapDataType::KeyAndValue(data_type_bool.clone(), data_type_float.clone()), frozen: false, }); @@ -820,18 +819,16 @@ mod tests { }); let data_type_untyped_map = Arc::new(CassDataType::Map { - key_type: None, - val_type: None, + typ: MapDataType::Untyped, frozen: false, }); let data_type_typed_key_float_map = Arc::new(CassDataType::Map { - key_type: Some(data_type_float.clone()), - val_type: None, + typ: MapDataType::Key(data_type_float.clone()), + frozen: false, }); let data_type_float_int_map = Arc::new(CassDataType::Map { - key_type: Some(data_type_float.clone()), - val_type: Some(data_type_int.clone()), + typ: MapDataType::KeyAndValue(data_type_float.clone(), data_type_int.clone()), frozen: false, });