Skip to content

Commit

Permalink
feat: enum read support (#297)
Browse files Browse the repository at this point in the history
* add basic enum support

* add enum to test_all_types
  • Loading branch information
Mause authored Apr 23, 2024
1 parent d613139 commit 0018cd8
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 6 deletions.
21 changes: 20 additions & 1 deletion src/row.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::{convert, sync::Arc};

use super::{Error, Result, Statement};
use crate::types::{self, FromSql, FromSqlError, ValueRef};
use crate::types::{self, EnumType, FromSql, FromSqlError, ValueRef};

use arrow::array::DictionaryArray;
use arrow::{
array::{self, Array, ArrayRef, ListArray, StructArray},
datatypes::*,
Expand Down Expand Up @@ -601,6 +602,24 @@ impl<'stmt> Row<'stmt> {

ValueRef::List(arr, row)
}
DataType::Dictionary(key_type, ..) => {
let column = column.as_any();
ValueRef::Enum(
match key_type.as_ref() {
DataType::UInt8 => {
EnumType::UInt8(column.downcast_ref::<DictionaryArray<UInt8Type>>().unwrap())
}
DataType::UInt16 => {
EnumType::UInt16(column.downcast_ref::<DictionaryArray<UInt16Type>>().unwrap())
}
DataType::UInt32 => {
EnumType::UInt32(column.downcast_ref::<DictionaryArray<UInt32Type>>().unwrap())
}
typ => panic!("Unsupported key type: {typ:?}"),
},
row,
)
}
_ => unreachable!("invalid value: {} {}", col, column.data_type()),
}
}
Expand Down
18 changes: 15 additions & 3 deletions src/test_all_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ fn test_all_types() -> crate::Result<()> {
// union is currently blocked by https://github.com/duckdb/duckdb/pull/11326
"union",
// these remaining types are not yet supported by duckdb-rs
"small_enum",
"medium_enum",
"large_enum",
"struct",
"struct_of_arrays",
"array_of_structs",
Expand Down Expand Up @@ -349,6 +346,21 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
),
_ => assert_eq!(value, ValueRef::Null),
},
"small_enum" => match idx {
0 => assert_eq!(value.to_owned(), Value::Enum("DUCK_DUCK_ENUM".to_string())),
1 => assert_eq!(value.to_owned(), Value::Enum("GOOSE".to_string())),
_ => assert_eq!(value, ValueRef::Null),
},
"medium_enum" => match idx {
0 => assert_eq!(value.to_owned(), Value::Enum("enum_0".to_string())),
1 => assert_eq!(value.to_owned(), Value::Enum("enum_1".to_string())),
_ => assert_eq!(value, ValueRef::Null),
},
"large_enum" => match idx {
0 => assert_eq!(value.to_owned(), Value::Enum("enum_0".to_string())),
1 => assert_eq!(value.to_owned(), Value::Enum("enum_69999".to_string())),
_ => assert_eq!(value, ValueRef::Null),
},
_ => todo!("{column:?}"),
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub use self::{
from_sql::{FromSql, FromSqlError, FromSqlResult},
to_sql::{ToSql, ToSqlOutput},
value::Value,
value_ref::{TimeUnit, ValueRef},
value_ref::{EnumType, TimeUnit, ValueRef},
};

use arrow::datatypes::DataType;
Expand Down Expand Up @@ -149,6 +149,8 @@ pub enum Type {
Interval,
/// LIST
List(Box<Type>),
/// ENUM
Enum,
/// Any
Any,
}
Expand Down Expand Up @@ -219,6 +221,7 @@ impl fmt::Display for Type {
Type::Time64 => f.pad("Time64"),
Type::Interval => f.pad("Interval"),
Type::List(..) => f.pad("List"),
Type::Enum => f.pad("Enum"),
Type::Any => f.pad("Any"),
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub enum Value {
},
/// The value is a list
List(Vec<Value>),
/// The value is an enum
Enum(String),
}

impl From<Null> for Value {
Expand Down Expand Up @@ -225,6 +227,7 @@ impl Value {
Value::Time64(..) => Type::Time64,
Value::Interval { .. } => Type::Interval,
Value::List(_) => todo!(),
Value::Enum(..) => Type::Enum,
}
}
}
36 changes: 35 additions & 1 deletion src/types/value_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::types::{FromSqlError, FromSqlResult};
use crate::Row;
use rust_decimal::prelude::*;

use arrow::array::{Array, ListArray};
use arrow::array::{Array, DictionaryArray, ListArray};
use arrow::datatypes::{UInt16Type, UInt32Type, UInt8Type};

/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds.
/// Copy from arrow::datatypes::TimeUnit
Expand Down Expand Up @@ -75,6 +76,19 @@ pub enum ValueRef<'a> {
},
/// The value is a list
List(&'a ListArray, usize),
/// The value is an enum
Enum(EnumType<'a>, usize),
}

/// Wrapper type for different enum sizes
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum EnumType<'a> {
/// The underlying enum type is u8
UInt8(&'a DictionaryArray<UInt8Type>),
/// The underlying enum type is u16
UInt16(&'a DictionaryArray<UInt16Type>),
/// The underlying enum type is u32
UInt32(&'a DictionaryArray<UInt32Type>),
}

impl ValueRef<'_> {
Expand Down Expand Up @@ -103,6 +117,7 @@ impl ValueRef<'_> {
ValueRef::Time64(..) => Type::Time64,
ValueRef::Interval { .. } => Type::Interval,
ValueRef::List(arr, _) => arr.data_type().into(),
ValueRef::Enum(..) => Type::Enum,
}
}

Expand Down Expand Up @@ -170,6 +185,24 @@ impl From<ValueRef<'_>> for Value {
.collect();
Value::List(map)
}
ValueRef::Enum(items, idx) => {
let value = Row::value_ref_internal(
idx,
0,
match items {
EnumType::UInt8(res) => res.values(),
EnumType::UInt16(res) => res.values(),
EnumType::UInt32(res) => res.values(),
},
)
.to_owned();

if let Value::Text(s) = value {
Value::Enum(s)
} else {
panic!("Enum value is not a string")
}
}
}
}
}
Expand Down Expand Up @@ -213,6 +246,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
Value::Time64(t, d) => ValueRef::Time64(t, d),
Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos },
Value::List(..) => unimplemented!(),
Value::Enum(..) => todo!(),
}
}
}
Expand Down

0 comments on commit 0018cd8

Please sign in to comment.