Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 105 additions & 47 deletions scylla-rust-wrapper/src/cass_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub struct UDTDataType {

pub keyspace: String,
pub name: String,
pub frozen: bool,
}

impl UDTDataType {
Expand All @@ -30,13 +31,15 @@ impl UDTDataType {
field_types: Vec::new(),
keyspace: "".to_string(),
name: "".to_string(),
frozen: false,
}
}

pub fn create_with_params(
user_defined_types: &HashMap<String, Arc<UserDefinedType>>,
keyspace_name: &str,
name: &str,
frozen: bool,
) -> UDTDataType {
UDTDataType {
field_types: user_defined_types
Expand All @@ -57,6 +60,7 @@ impl UDTDataType {
.collect(),
keyspace: keyspace_name.to_string(),
name: name.to_owned(),
frozen,
}
}

Expand All @@ -65,6 +69,7 @@ impl UDTDataType {
field_types: Vec::with_capacity(capacity),
keyspace: "".to_string(),
name: "".to_string(),
frozen: false,
}
}

Expand Down Expand Up @@ -94,9 +99,19 @@ impl Default for UDTDataType {
pub enum CassDataType {
Value(CassValueType),
UDT(UDTDataType),
List(Option<Arc<CassDataType>>),
Set(Option<Arc<CassDataType>>),
Map(Option<Arc<CassDataType>>, Option<Arc<CassDataType>>),
List {
typ: Option<Arc<CassDataType>>,
frozen: bool,
},
Set {
typ: Option<Arc<CassDataType>>,
frozen: bool,
},
Map {
key_type: Option<Arc<CassDataType>>,
val_type: Option<Arc<CassDataType>>,
frozen: bool,
},
Tuple(Vec<Arc<CassDataType>>),
Custom(String),
}
Expand Down Expand Up @@ -135,25 +150,36 @@ pub fn get_column_type_from_cql_type(
) -> CassDataType {
match cql_type {
CqlType::Native(native) => CassDataType::Value(native.clone().into()),
CqlType::Collection { type_, .. } => match type_ {
CollectionType::List(list) => CassDataType::List(Some(Arc::new(
get_column_type_from_cql_type(list, user_defined_types, keyspace_name),
))),
CollectionType::Map(key, value) => CassDataType::Map(
Some(Arc::new(get_column_type_from_cql_type(
CqlType::Collection { type_, frozen } => match type_ {
CollectionType::List(list) => CassDataType::List {
typ: Some(Arc::new(get_column_type_from_cql_type(
list,
user_defined_types,
keyspace_name,
))),
frozen: *frozen,
},
CollectionType::Map(key, value) => CassDataType::Map {
key_type: Some(Arc::new(get_column_type_from_cql_type(
key,
user_defined_types,
keyspace_name,
))),
Some(Arc::new(get_column_type_from_cql_type(
val_type: Some(Arc::new(get_column_type_from_cql_type(
value,
user_defined_types,
keyspace_name,
))),
),
CollectionType::Set(set) => CassDataType::Set(Some(Arc::new(
get_column_type_from_cql_type(set, user_defined_types, keyspace_name),
))),
frozen: *frozen,
},
CollectionType::Set(set) => CassDataType::Set {
typ: Some(Arc::new(get_column_type_from_cql_type(
set,
user_defined_types,
keyspace_name,
))),
frozen: *frozen,
},
},
CqlType::Tuple(tuple) => CassDataType::Tuple(
tuple
Expand All @@ -167,7 +193,7 @@ pub fn get_column_type_from_cql_type(
})
.collect(),
),
CqlType::UserDefinedType { definition, .. } => {
CqlType::UserDefinedType { definition, frozen } => {
let name = match definition {
Ok(resolved) => &resolved.name,
Err(not_resolved) => &not_resolved.name,
Expand All @@ -176,6 +202,7 @@ pub fn get_column_type_from_cql_type(
user_defined_types,
keyspace_name,
name,
*frozen,
))
}
}
Expand All @@ -187,16 +214,18 @@ impl CassDataType {
CassDataType::UDT(udt_data_type) => {
udt_data_type.field_types.get(index).map(|(_, b)| b)
}
CassDataType::List(t) | CassDataType::Set(t) => {
CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => {
if index > 0 {
None
} else {
t.as_ref()
typ.as_ref()
}
}
CassDataType::Map(t1, t2) => match index {
0 => t1.as_ref(),
1 => t2.as_ref(),
CassDataType::Map {
key_type, val_type, ..
} => match index {
0 => key_type.as_ref(),
1 => val_type.as_ref(),
_ => None,
},
CassDataType::Tuple(v) => v.get(index),
Expand All @@ -206,21 +235,23 @@ impl CassDataType {

fn add_sub_data_type(&mut self, sub_type: Arc<CassDataType>) -> Result<(), CassError> {
match self {
CassDataType::List(t) | CassDataType::Set(t) => match t {
CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => match typ {
Some(_) => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS),
None => {
*t = Some(sub_type);
*typ = Some(sub_type);
Ok(())
}
},
CassDataType::Map(t1, t2) => {
if t1.is_some() && t2.is_some() {
CassDataType::Map {
key_type, val_type, ..
} => {
if key_type.is_some() && val_type.is_some() {
Err(CassError::CASS_ERROR_LIB_BAD_PARAMS)
} else if t1.is_none() {
*t1 = Some(sub_type);
} else if key_type.is_none() {
*key_type = Some(sub_type);
Ok(())
} else {
*t2 = Some(sub_type);
*val_type = Some(sub_type);
Ok(())
}
}
Expand All @@ -243,9 +274,9 @@ impl CassDataType {
match &self {
CassDataType::Value(value_data_type) => *value_data_type,
CassDataType::UDT { .. } => CassValueType::CASS_VALUE_TYPE_UDT,
CassDataType::List(..) => CassValueType::CASS_VALUE_TYPE_LIST,
CassDataType::Set(..) => CassValueType::CASS_VALUE_TYPE_SET,
CassDataType::Map(..) => CassValueType::CASS_VALUE_TYPE_MAP,
CassDataType::List { .. } => CassValueType::CASS_VALUE_TYPE_LIST,
CassDataType::Set { .. } => CassValueType::CASS_VALUE_TYPE_SET,
CassDataType::Map { .. } => CassValueType::CASS_VALUE_TYPE_MAP,
CassDataType::Tuple(..) => CassValueType::CASS_VALUE_TYPE_TUPLE,
CassDataType::Custom(..) => CassValueType::CASS_VALUE_TYPE_CUSTOM,
}
Expand All @@ -268,16 +299,19 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
ColumnType::Text => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TEXT),
ColumnType::Timestamp => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TIMESTAMP),
ColumnType::Inet => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_INET),
ColumnType::List(boxed_type) => {
CassDataType::List(Some(Arc::new(get_column_type(boxed_type.as_ref()))))
}
ColumnType::Map(key, value) => CassDataType::Map(
Some(Arc::new(get_column_type(key.as_ref()))),
Some(Arc::new(get_column_type(value.as_ref()))),
),
ColumnType::Set(boxed_type) => {
CassDataType::Set(Some(Arc::new(get_column_type(boxed_type.as_ref()))))
}
ColumnType::List(boxed_type) => CassDataType::List {
typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))),
frozen: false,
},
ColumnType::Map(key, value) => CassDataType::Map {
key_type: Some(Arc::new(get_column_type(key.as_ref()))),
val_type: Some(Arc::new(get_column_type(value.as_ref()))),
frozen: false,
},
ColumnType::Set(boxed_type) => CassDataType::Set {
typ: Some(Arc::new(get_column_type(boxed_type.as_ref()))),
frozen: false,
},
ColumnType::UserDefinedType {
type_name,
keyspace,
Expand All @@ -289,6 +323,7 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
.collect(),
keyspace: (*keyspace).clone(),
name: (*type_name).clone(),
frozen: false,
}),
ColumnType::SmallInt => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_SMALL_INT),
ColumnType::TinyInt => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_TINY_INT),
Expand All @@ -312,10 +347,20 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
#[no_mangle]
pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const CassDataType {
let data_type = match value_type {
CassValueType::CASS_VALUE_TYPE_LIST => CassDataType::List(None),
CassValueType::CASS_VALUE_TYPE_SET => CassDataType::Set(None),
CassValueType::CASS_VALUE_TYPE_LIST => CassDataType::List {
typ: None,
frozen: false,
},
CassValueType::CASS_VALUE_TYPE_SET => CassDataType::Set {
typ: None,
frozen: false,
},
CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataType::Tuple(Vec::new()),
CassValueType::CASS_VALUE_TYPE_MAP => CassDataType::Map(None, None),
CassValueType::CASS_VALUE_TYPE_MAP => CassDataType::Map {
key_type: None,
val_type: None,
frozen: false,
},
CassValueType::CASS_VALUE_TYPE_UDT => CassDataType::UDT(UDTDataType::new()),
CassValueType::CASS_VALUE_TYPE_CUSTOM => CassDataType::Custom("".to_string()),
CassValueType::CASS_VALUE_TYPE_UNKNOWN => return ptr::null_mut(),
Expand Down Expand Up @@ -358,8 +403,19 @@ pub unsafe extern "C" fn cass_data_type_type(data_type: *const CassDataType) ->
data_type.get_value_type()
}

// #[no_mangle]
// pub unsafe extern "C" fn cass_data_type_is_frozen(data_type: *const CassDataType) -> cass_bool_t {}
#[no_mangle]
pub unsafe extern "C" fn cass_data_type_is_frozen(data_type: *const CassDataType) -> cass_bool_t {
let data_type = ptr_to_ref(data_type);
let is_frozen = match data_type {
CassDataType::UDT(udt) => udt.frozen,
CassDataType::List { frozen, .. } => *frozen,
CassDataType::Set { frozen, .. } => *frozen,
CassDataType::Map { frozen, .. } => *frozen,
_ => false,
};

is_frozen as cass_bool_t
}

#[no_mangle]
pub unsafe extern "C" fn cass_data_type_type_name(
Expand Down Expand Up @@ -498,8 +554,10 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat
match data_type {
CassDataType::Value(..) => 0,
CassDataType::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t,
CassDataType::List(t) | CassDataType::Set(t) => t.is_some() as size_t,
CassDataType::Map(t1, t2) => t1.is_some() as size_t + t2.is_some() as size_t,
CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => typ.is_some() as size_t,
CassDataType::Map {
key_type, val_type, ..
} => key_type.is_some() as size_t + val_type.is_some() as size_t,
CassDataType::Tuple(v) => v.len() as size_t,
CassDataType::Custom(..) => 0,
}
Expand Down
16 changes: 12 additions & 4 deletions scylla-rust-wrapper/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1170,9 +1170,14 @@ pub unsafe extern "C" fn cass_value_primary_sub_type(
let val = ptr_to_ref(collection);

match val.value_type.as_ref() {
CassDataType::List(Some(list)) => list.get_value_type(),
CassDataType::Set(Some(set)) => set.get_value_type(),
CassDataType::Map(Some(key), _) => key.get_value_type(),
CassDataType::List {
typ: Some(list), ..
} => list.get_value_type(),
CassDataType::Set { typ: Some(set), .. } => set.get_value_type(),
CassDataType::Map {
key_type: Some(key),
..
} => key.get_value_type(),
_ => CassValueType::CASS_VALUE_TYPE_UNKNOWN,
}
}
Expand All @@ -1184,7 +1189,10 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type(
let val = ptr_to_ref(collection);

match val.value_type.as_ref() {
CassDataType::Map(_, Some(value)) => value.get_value_type(),
CassDataType::Map {
val_type: Some(value),
..
} => value.get_value_type(),
_ => CassValueType::CASS_VALUE_TYPE_UNKNOWN,
}
}
Expand Down
Loading