@@ -131,6 +131,13 @@ impl Default for UDTDataType {
131131 }
132132}
133133
134+ #[ derive( Clone , Debug , PartialEq ) ]
135+ pub enum MapDataType {
136+ Untyped ,
137+ Key ( Arc < CassDataType > ) ,
138+ KeyAndValue ( Arc < CassDataType > , Arc < CassDataType > ) ,
139+ }
140+
134141#[ derive( Clone , Debug , PartialEq ) ]
135142pub enum CassDataType {
136143 Value ( CassValueType ) ,
@@ -146,10 +153,7 @@ pub enum CassDataType {
146153 frozen : bool ,
147154 } ,
148155 Map {
149- // None, None stands for untyped map.
150- // Some, None stands for a map with an untyped value type.
151- key_type : Option < Arc < CassDataType > > ,
152- val_type : Option < Arc < CassDataType > > ,
156+ typ : MapDataType ,
153157 frozen : bool ,
154158 } ,
155159 // Empty vector stands for untyped tuple.
@@ -183,29 +187,22 @@ impl CassDataType {
183187 }
184188 _ => false ,
185189 } ,
186- CassDataType :: Map {
187- key_type : k,
188- val_type : v,
189- ..
190- } => match other {
191- CassDataType :: Map {
192- key_type : k_other,
193- val_type : v_other,
194- ..
195- } => match ( ( k, v) , ( k_other, v_other) ) {
190+ CassDataType :: Map { typ : t, .. } => match other {
191+ CassDataType :: Map { typ : t_other, .. } => match ( t, t_other) {
196192 // See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218
197193 // In cpp-driver the types are held in a vector.
198194 // The logic is following:
199195
200196 // If either of vectors is empty, skip the typecheck.
201- ( ( None , None ) , _) => true ,
202- ( _, ( None , None ) ) => true ,
197+ ( MapDataType :: Untyped , _) => true ,
198+ ( _, MapDataType :: Untyped ) => true ,
203199
204200 // Otherwise, the vectors should have equal length and we perform the typecheck for subtypes.
205- ( ( Some ( k) , None ) , ( Some ( k_other) , None ) ) => k. typecheck_equals ( k_other) ,
206- ( ( Some ( k) , Some ( v) ) , ( Some ( k_other) , Some ( v_other) ) ) => {
207- k. typecheck_equals ( k_other) && v. typecheck_equals ( v_other)
208- }
201+ ( MapDataType :: Key ( k) , MapDataType :: Key ( k_other) ) => k. typecheck_equals ( k_other) ,
202+ (
203+ MapDataType :: KeyAndValue ( k, v) ,
204+ MapDataType :: KeyAndValue ( k_other, v_other) ,
205+ ) => k. typecheck_equals ( k_other) && v. typecheck_equals ( v_other) ,
209206 _ => false ,
210207 } ,
211208 _ => false ,
@@ -278,16 +275,18 @@ pub fn get_column_type_from_cql_type(
278275 frozen : * frozen,
279276 } ,
280277 CollectionType :: Map ( key, value) => CassDataType :: Map {
281- key_type : Some ( Arc :: new ( get_column_type_from_cql_type (
282- key,
283- user_defined_types,
284- keyspace_name,
285- ) ) ) ,
286- val_type : Some ( Arc :: new ( get_column_type_from_cql_type (
287- value,
288- user_defined_types,
289- keyspace_name,
290- ) ) ) ,
278+ typ : MapDataType :: KeyAndValue (
279+ Arc :: new ( get_column_type_from_cql_type (
280+ key,
281+ user_defined_types,
282+ keyspace_name,
283+ ) ) ,
284+ Arc :: new ( get_column_type_from_cql_type (
285+ value,
286+ user_defined_types,
287+ keyspace_name,
288+ ) ) ,
289+ ) ,
291290 frozen : * frozen,
292291 } ,
293292 CollectionType :: Set ( set) => CassDataType :: Set {
@@ -340,10 +339,19 @@ impl CassDataType {
340339 }
341340 }
342341 CassDataType :: Map {
343- key_type, val_type, ..
342+ typ : MapDataType :: Untyped ,
343+ ..
344+ } => None ,
345+ CassDataType :: Map {
346+ typ : MapDataType :: Key ( k) ,
347+ ..
348+ } => ( index == 0 ) . then_some ( k) ,
349+ CassDataType :: Map {
350+ typ : MapDataType :: KeyAndValue ( k, v) ,
351+ ..
344352 } => match index {
345- 0 => key_type . as_ref ( ) ,
346- 1 => val_type . as_ref ( ) ,
353+ 0 => Some ( k ) ,
354+ 1 => Some ( v ) ,
347355 _ => None ,
348356 } ,
349357 CassDataType :: Tuple ( v) => v. get ( index) ,
@@ -361,17 +369,28 @@ impl CassDataType {
361369 }
362370 } ,
363371 CassDataType :: Map {
364- key_type, val_type, ..
372+ typ : MapDataType :: KeyAndValue ( _, _) ,
373+ ..
374+ } => Err ( CassError :: CASS_ERROR_LIB_BAD_PARAMS ) ,
375+ CassDataType :: Map {
376+ typ : MapDataType :: Key ( k) ,
377+ frozen,
365378 } => {
366- if key_type. is_some ( ) && val_type. is_some ( ) {
367- Err ( CassError :: CASS_ERROR_LIB_BAD_PARAMS )
368- } else if key_type. is_none ( ) {
369- * key_type = Some ( sub_type) ;
370- Ok ( ( ) )
371- } else {
372- * val_type = Some ( sub_type) ;
373- Ok ( ( ) )
374- }
379+ * self = CassDataType :: Map {
380+ typ : MapDataType :: KeyAndValue ( k. clone ( ) , sub_type) ,
381+ frozen : * frozen,
382+ } ;
383+ Ok ( ( ) )
384+ }
385+ CassDataType :: Map {
386+ typ : MapDataType :: Untyped ,
387+ frozen,
388+ } => {
389+ * self = CassDataType :: Map {
390+ typ : MapDataType :: Key ( sub_type) ,
391+ frozen : * frozen,
392+ } ;
393+ Ok ( ( ) )
375394 }
376395 CassDataType :: Tuple ( types) => {
377396 types. push ( sub_type) ;
@@ -423,8 +442,10 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
423442 frozen : false ,
424443 } ,
425444 ColumnType :: Map ( key, value) => CassDataType :: Map {
426- key_type : Some ( Arc :: new ( get_column_type ( key. as_ref ( ) ) ) ) ,
427- val_type : Some ( Arc :: new ( get_column_type ( value. as_ref ( ) ) ) ) ,
445+ typ : MapDataType :: KeyAndValue (
446+ Arc :: new ( get_column_type ( key. as_ref ( ) ) ) ,
447+ Arc :: new ( get_column_type ( value. as_ref ( ) ) ) ,
448+ ) ,
428449 frozen : false ,
429450 } ,
430451 ColumnType :: Set ( boxed_type) => CassDataType :: Set {
@@ -475,8 +496,7 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const
475496 } ,
476497 CassValueType :: CASS_VALUE_TYPE_TUPLE => CassDataType :: Tuple ( Vec :: new ( ) ) ,
477498 CassValueType :: CASS_VALUE_TYPE_MAP => CassDataType :: Map {
478- key_type : None ,
479- val_type : None ,
499+ typ : MapDataType :: Untyped ,
480500 frozen : false ,
481501 } ,
482502 CassValueType :: CASS_VALUE_TYPE_UDT => CassDataType :: UDT ( UDTDataType :: new ( ) ) ,
@@ -673,9 +693,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat
673693 CassDataType :: Value ( ..) => 0 ,
674694 CassDataType :: UDT ( udt_data_type) => udt_data_type. field_types . len ( ) as size_t ,
675695 CassDataType :: List { typ, .. } | CassDataType :: Set { typ, .. } => typ. is_some ( ) as size_t ,
676- CassDataType :: Map {
677- key_type, val_type, ..
678- } => key_type. is_some ( ) as size_t + val_type. is_some ( ) as size_t ,
696+ CassDataType :: Map { typ, .. } => match typ {
697+ MapDataType :: Untyped => 0 ,
698+ MapDataType :: Key ( _) => 1 ,
699+ MapDataType :: KeyAndValue ( _, _) => 2 ,
700+ } ,
679701 CassDataType :: Tuple ( v) => v. len ( ) as size_t ,
680702 CassDataType :: Custom ( ..) => 0 ,
681703 }
0 commit comments