Skip to content

Commit

Permalink
feat: add get function for maps (for string -> primitive) (#500)
Browse files Browse the repository at this point in the history
Adds the `get` function for maps for string and large_string to
primitive types.

Further additions can support more types. See
#494 for task list

Testing:
* verified that the sample map data (with a largeutf8) produces results
as expected.
  • Loading branch information
jordanrfrazier authored Jul 18, 2023
1 parent 0ca8805 commit c700afe
Show file tree
Hide file tree
Showing 51 changed files with 905 additions and 184 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"rust-analyzer.cargo.features": "all",
"editor.formatOnSave": true,
}
5 changes: 4 additions & 1 deletion crates/sparrow-api/src/kaskada/v1alpha/plan_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ impl Literal {
}
Some(literal::Literal::IntervalMonths(v)) => ScalarValue::IntervalMonths(Some(*v)),
Some(literal::Literal::Utf8(v)) => ScalarValue::Utf8(Some(v.clone())),
Some(literal::Literal::LargeUtf8(v)) => ScalarValue::LargeUtf8(Some(v.clone())),
Some(literal::Literal::Record(v)) => {
let values = v
.values
Expand Down Expand Up @@ -132,6 +133,7 @@ impl From<&ScalarValue> for Literal {
)),
ScalarValue::IntervalMonths(Some(v)) => Some(literal::Literal::IntervalMonths(*v)),
ScalarValue::Utf8(Some(v)) => Some(literal::Literal::Utf8(v.clone())),
ScalarValue::LargeUtf8(Some(v)) => Some(literal::Literal::LargeUtf8(v.clone())),
ScalarValue::Record(v) => {
let values: Vec<_> = v
.values()
Expand Down Expand Up @@ -422,7 +424,8 @@ impl std::fmt::Display for super::Literal {
Some(literal::Literal::IntervalMonths(i)) => {
write!(f, "interval_months:{i}")
}
Some(literal::Literal::Utf8(ref str)) => write!(f, "\\\"{str}\\\""),
Some(literal::Literal::Utf8(str)) => write!(f, "\\\"{str}\\\""),
Some(literal::Literal::LargeUtf8(str)) => write!(f, "\\\"{str}\\\""),
unreachable => unreachable!("Unable to format {:?}", unreachable),
}
}
Expand Down
117 changes: 75 additions & 42 deletions crates/sparrow-api/src/kaskada/v1alpha/schema_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,32 @@ impl DataType {
kind: Some(data_type::Kind::Struct(Schema { fields })),
}
}

/// Creates a new map from the given fields.
///
/// `fields` should have two elements, the first being the key type
/// and the second being the value type.
pub fn new_map(fields: Vec<schema::Field>) -> Self {
pub fn new_map(name: &str, ordered: bool, fields: Vec<schema::Field>) -> Self {
debug_assert!(fields.len() == 2);
let key = &fields[0];
let value = &fields[1];
Self {
kind: Some(data_type::Kind::Map(Box::new(data_type::Map {
key: Some(Box::new(
fields[0]
.data_type
.as_ref()
.expect("data type to exist")
.clone(),
name: name.to_string(),
ordered,
key_name: key.name.clone(),
key_type: Some(Box::new(
key.data_type.as_ref().expect("data type to exist").clone(),
)),
value: Some(Box::new(
fields[1]
key_is_nullable: key.nullable,
value_name: value.name.clone(),
value_type: Some(Box::new(
value
.data_type
.as_ref()
.expect("data type to exist")
.clone(),
)),
value_is_nullable: value.nullable,
}))),
}
}
Expand Down Expand Up @@ -165,6 +168,9 @@ impl TryFrom<&arrow::datatypes::DataType> for DataType {
Ok(DataType::new_primitive(PrimitiveType::IntervalYearMonth))
}
arrow::datatypes::DataType::Utf8 => Ok(DataType::new_primitive(PrimitiveType::String)),
arrow::datatypes::DataType::LargeUtf8 => {
Ok(DataType::new_primitive(PrimitiveType::LargeString))
}
arrow::datatypes::DataType::Struct(fields) => {
let fields = fields
.iter()
Expand All @@ -174,6 +180,7 @@ impl TryFrom<&arrow::datatypes::DataType> for DataType {
Ok(data_type) => Ok(schema::Field {
name,
data_type: Some(data_type),
nullable: field.is_nullable(),
}),
Err(err) => Err(err.with_prepend_field(name)),
}
Expand All @@ -183,7 +190,7 @@ impl TryFrom<&arrow::datatypes::DataType> for DataType {
}
// Note: the `ordered` field may let us specialize the implementation
// to use binary search in the future.
arrow::datatypes::DataType::Map(s, _) => {
arrow::datatypes::DataType::Map(s, is_ordered) => {
// [DataType::Map] is represented as a list of structs with two fields: `key` and `value`
let arrow::datatypes::DataType::Struct(fields) = s.data_type() else {
// unexpected - maps should always contain a struct
Expand All @@ -194,23 +201,25 @@ impl TryFrom<&arrow::datatypes::DataType> for DataType {
let key = &fields[0];
let value = &fields[1];
let key = schema::Field {
name: "key".to_owned(),
name: key.name().to_owned(),
data_type: Some(key.data_type().try_into().map_err(
|err: ConversionError<arrow::datatypes::DataType>| {
err.with_prepend_field("key".to_owned())
},
)?),
nullable: key.is_nullable(),
};
let value = schema::Field {
name: "value".to_owned(),
name: value.name().to_owned(),
data_type: Some(value.data_type().try_into().map_err(
|err: ConversionError<arrow::datatypes::DataType>| {
err.with_prepend_field("value".to_owned())
},
)?),
nullable: value.is_nullable(),
};

Ok(DataType::new_map(vec![key, value]))
Ok(DataType::new_map(s.name(), *is_ordered, vec![key, value]))
}
unsupported => Err(ConversionError::new_unsupported(unsupported.clone())),
}
Expand Down Expand Up @@ -282,6 +291,7 @@ impl TryFrom<&DataType> for arrow::datatypes::DataType {
Some(PrimitiveType::F32) => Ok(arrow::datatypes::DataType::Float32),
Some(PrimitiveType::F64) => Ok(arrow::datatypes::DataType::Float64),
Some(PrimitiveType::String) => Ok(arrow::datatypes::DataType::Utf8),
Some(PrimitiveType::LargeString) => Ok(arrow::datatypes::DataType::LargeUtf8),
Some(PrimitiveType::IntervalDayTime) => {
Ok(arrow::datatypes::DataType::Interval(
arrow::datatypes::IntervalUnit::DayTime,
Expand Down Expand Up @@ -333,26 +343,36 @@ impl TryFrom<&DataType> for arrow::datatypes::DataType {
let item_type = arrow::datatypes::Field::new("item", item_type, true);
Ok(arrow::datatypes::DataType::List(Arc::new(item_type)))
}
Some(data_type::Kind::Map(map)) => match (map.key.as_ref(), map.value.as_ref()) {
(Some(key), Some(value)) => {
let key = arrow::datatypes::DataType::try_from(key.as_ref())
.map_err(|e| e.with_prepend_field("map key".to_owned()))?;
let value = arrow::datatypes::DataType::try_from(value.as_ref())
.map_err(|e| e.with_prepend_field("map value".to_owned()))?;

let fields = arrow::datatypes::Fields::from(vec![
arrow::datatypes::Field::new("key", key, false),
arrow::datatypes::Field::new("value", value, false),
]);
let s = arrow::datatypes::Field::new(
"entries",
arrow::datatypes::DataType::Struct(fields),
false,
);
Ok(arrow::datatypes::DataType::Map(Arc::new(s), false))
Some(data_type::Kind::Map(map)) => {
match (map.key_type.as_ref(), map.value_type.as_ref()) {
(Some(key), Some(value)) => {
let key = arrow::datatypes::DataType::try_from(key.as_ref())
.map_err(|e| e.with_prepend_field("map key".to_owned()))?;
let value = arrow::datatypes::DataType::try_from(value.as_ref())
.map_err(|e| e.with_prepend_field("map value".to_owned()))?;

let fields = arrow::datatypes::Fields::from(vec![
arrow::datatypes::Field::new(
map.key_name.clone(),
key,
map.key_is_nullable,
),
arrow::datatypes::Field::new(
map.value_name.clone(),
value,
map.value_is_nullable,
),
]);
let s = arrow::datatypes::Field::new(
map.name.clone(),
arrow::datatypes::DataType::Struct(fields),
false,
);
Ok(arrow::datatypes::DataType::Map(Arc::new(s), map.ordered))
}
_ => Err(ConversionError::new_unsupported(value.clone())),
}
_ => Err(ConversionError::new_unsupported(value.clone())),
},
}
None | Some(data_type::Kind::Window(_)) => {
Err(ConversionError::new_unsupported(value.clone()))
}
Expand Down Expand Up @@ -384,6 +404,7 @@ impl TryFrom<&arrow::datatypes::Schema> for Schema {
Ok(data_type) => Ok(schema::Field {
name,
data_type: Some(data_type),
nullable: field.is_nullable(),
}),
Err(err) => Err(err.with_prepend_field(name)),
}
Expand Down Expand Up @@ -423,6 +444,7 @@ mod tests {
arrow::datatypes::DataType::Float32,
arrow::datatypes::DataType::Float64,
arrow::datatypes::DataType::Utf8,
arrow::datatypes::DataType::LargeUtf8,
arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Second, None),
arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, None),
arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None),
Expand Down Expand Up @@ -489,23 +511,30 @@ mod tests {

#[test]
fn test_unsupported_datatype() {
let err = DataType::try_from(&arrow::datatypes::DataType::LargeUtf8).unwrap_err();
let err = DataType::try_from(&arrow::datatypes::DataType::FixedSizeBinary(1)).unwrap_err();
assert_eq!(
err,
ConversionError {
fields: vec![],
data_type: arrow::datatypes::DataType::LargeUtf8
data_type: arrow::datatypes::DataType::FixedSizeBinary(1)
}
);
assert_eq!(&err.to_string(), "Unsupported conversion from 'LargeUtf8'");
assert_eq!(
&err.to_string(),
"Unsupported conversion from 'FixedSizeBinary(1)'"
);
}

#[test]
fn test_unsupported_nested_struct() {
let inner_struct_type = arrow::datatypes::DataType::Struct(
vec![
arrow::datatypes::Field::new("a", arrow::datatypes::DataType::Int64, true),
arrow::datatypes::Field::new("b", arrow::datatypes::DataType::LargeUtf8, true),
arrow::datatypes::Field::new(
"b",
arrow::datatypes::DataType::FixedSizeBinary(1),
true,
),
]
.into(),
);
Expand All @@ -521,12 +550,12 @@ mod tests {
err,
ConversionError {
fields: vec!["b".to_owned(), "x".to_owned()],
data_type: arrow::datatypes::DataType::LargeUtf8,
data_type: arrow::datatypes::DataType::FixedSizeBinary(1),
}
);
assert_eq!(
&err.to_string(),
"Unsupported conversion from 'LargeUtf8' for field 'x.b'"
"Unsupported conversion from 'FixedSizeBinary(1)' for field 'x.b'"
);
}

Expand All @@ -535,7 +564,11 @@ mod tests {
let inner_struct_type = arrow::datatypes::DataType::Struct(
vec![
arrow::datatypes::Field::new("a", arrow::datatypes::DataType::Int64, true),
arrow::datatypes::Field::new("b", arrow::datatypes::DataType::LargeUtf8, true),
arrow::datatypes::Field::new(
"b",
arrow::datatypes::DataType::FixedSizeBinary(1),
true,
),
]
.into(),
);
Expand All @@ -548,12 +581,12 @@ mod tests {
err,
ConversionError {
fields: vec!["b".to_owned(), "x".to_owned()],
data_type: arrow::datatypes::DataType::LargeUtf8,
data_type: arrow::datatypes::DataType::FixedSizeBinary(1),
}
);
assert_eq!(
&err.to_string(),
"Unsupported conversion from 'LargeUtf8' for field 'x.b'"
"Unsupported conversion from 'FixedSizeBinary(1)' for field 'x.b'"
);
}
}
9 changes: 9 additions & 0 deletions crates/sparrow-arrow/src/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use arrow::array::{
StructArray,
};
use arrow::datatypes::ArrowPrimitiveType;
use arrow_array::MapArray;

/// Downcast an `ArrayRef` to a `PrimitiveArray<T>`.
pub fn downcast_primitive_array<T: ArrowPrimitiveType>(
Expand Down Expand Up @@ -55,6 +56,14 @@ pub fn downcast_struct_array(array: &dyn Array) -> anyhow::Result<&StructArray>
.with_context(|| format!("Unable to downcast {:?} to struct array", array.data_type()))
}

/// Downcast an `ArrayRef` to a `MapArray`.
pub fn downcast_map_array(array: &dyn Array) -> anyhow::Result<&MapArray> {
array
.as_any()
.downcast_ref::<MapArray>()
.with_context(|| format!("Unable to downcast {:?} to map array", array.data_type()))
}

/// Downcast an `ArrayRef` to a `BooleanArray`.
pub fn downcast_boolean_array(array: &dyn Array) -> anyhow::Result<&BooleanArray> {
array
Expand Down
21 changes: 21 additions & 0 deletions crates/sparrow-arrow/src/scalar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use anyhow::{anyhow, Context};
use arrow::array::{Array, ArrayRef, BooleanArray, NullArray, PrimitiveArray, StringArray};
use arrow::datatypes::*;
use arrow_array::LargeStringArray;
use decorum::Total;
use itertools::izip;
use num::{One, Signed, Zero};
Expand Down Expand Up @@ -53,6 +54,8 @@ pub enum ScalarValue {
IntervalMonths(Option<i32>),
/// UTF-8 encoded strings with 32 bit offsets.
Utf8(Option<String>),
/// Large UTF-8 encoded strings with 64 bit offsets.
LargeUtf8(Option<String>),
/// Records.
Record(Box<ScalarRecord>),
}
Expand Down Expand Up @@ -144,6 +147,7 @@ impl std::fmt::Display for ScalarValue {
}
ScalarValue::IntervalMonths(Some(months)) => write!(f, "interval_months:{months}"),
ScalarValue::Utf8(Some(str)) => write!(f, "\\\"{str}\\\""),
ScalarValue::LargeUtf8(Some(str)) => write!(f, "\\\"{str}\\\""),
unreachable => unreachable!("Unable to format {unreachable:?}"),
}
}
Expand Down Expand Up @@ -289,6 +293,7 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::DayTime) => Ok(Self::IntervalDayTime(None)),
DataType::Interval(IntervalUnit::YearMonth) => Ok(Self::IntervalMonths(None)),
DataType::Utf8 => Ok(Self::Utf8(None)),
DataType::LargeUtf8 => Ok(Self::LargeUtf8(None)),
DataType::Struct(fields) => Ok(Self::Record(Box::new(ScalarRecord {
value: None,
fields: fields.clone(),
Expand Down Expand Up @@ -326,6 +331,7 @@ impl ScalarValue {
ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime),
ScalarValue::IntervalMonths(_) => DataType::Interval(IntervalUnit::YearMonth),
ScalarValue::Utf8(_) => DataType::Utf8,
ScalarValue::LargeUtf8(_) => DataType::LargeUtf8,
ScalarValue::Record(record) => DataType::Struct(record.fields.clone()),
}
}
Expand Down Expand Up @@ -389,6 +395,10 @@ impl ScalarValue {
let iter = std::iter::repeat(s).take(len);
Arc::new(iter.cloned().collect::<StringArray>())
}
ScalarValue::LargeUtf8(s) => {
let iter = std::iter::repeat(s).take(len);
Arc::new(iter.cloned().collect::<LargeStringArray>())
}
ScalarValue::Record(record) => {
let fields: Vec<(FieldRef, ArrayRef)> = if let Some(values) = &record.value {
izip!(&record.fields, values)
Expand Down Expand Up @@ -544,6 +554,15 @@ impl ScalarValue {
Ok(Self::Utf8(None))
}
}
DataType::LargeUtf8 => {
if array.is_valid(row) {
let array: &LargeStringArray = downcast_string_array(array)?;
let string = array.value(row).to_owned();
Ok(Self::LargeUtf8(Some(string)))
} else {
Ok(Self::LargeUtf8(None))
}
}
DataType::Struct(fields) => {
let value = if array.is_valid(row) {
let array = downcast_struct_array(array)?;
Expand Down Expand Up @@ -589,6 +608,7 @@ impl ScalarValue {
ScalarValue::IntervalDayTime(n) => n.is_none(),
ScalarValue::IntervalMonths(n) => n.is_none(),
ScalarValue::Utf8(n) => n.is_none(),
ScalarValue::LargeUtf8(n) => n.is_none(),
ScalarValue::Record(record) => record.value.is_none(),
}
}
Expand Down Expand Up @@ -623,6 +643,7 @@ impl ScalarValue {
ScalarValue::IntervalDayTime(_) => ScalarValue::IntervalDayTime(None),
ScalarValue::IntervalMonths(_) => ScalarValue::IntervalMonths(None),
ScalarValue::Utf8(_) => ScalarValue::Utf8(None),
ScalarValue::LargeUtf8(_) => ScalarValue::LargeUtf8(None),
ScalarValue::Record(record) => ScalarValue::Record(Box::new(ScalarRecord {
value: None,
fields: record.fields.clone(),
Expand Down
Loading

0 comments on commit c700afe

Please sign in to comment.