Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Introduced UnionMode enum #557

Merged
merged 1 commit into from
Oct 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions src/array/union/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
array::{display::get_value_display, display_fmt, new_empty_array, new_null_array, Array},
bitmap::Bitmap,
buffer::Buffer,
datatypes::{DataType, Field},
datatypes::{DataType, Field, UnionMode},
scalar::{new_scalar, Scalar},
};

Expand Down Expand Up @@ -37,13 +37,13 @@ pub struct UnionArray {
impl UnionArray {
/// Creates a new null [`UnionArray`].
pub fn new_null(data_type: DataType, length: usize) -> Self {
if let DataType::Union(f, _, is_sparse) = &data_type {
if let DataType::Union(f, _, mode) = &data_type {
let fields = f
.iter()
.map(|x| new_null_array(x.data_type().clone(), length).into())
.collect();

let offsets = if *is_sparse {
let offsets = if mode.is_sparse() {
None
} else {
Some((0..length as i32).collect::<Buffer<i32>>())
Expand All @@ -60,13 +60,13 @@ impl UnionArray {

/// Creates a new empty [`UnionArray`].
pub fn new_empty(data_type: DataType) -> Self {
if let DataType::Union(f, _, is_sparse) = &data_type {
if let DataType::Union(f, _, mode) = &data_type {
let fields = f
.iter()
.map(|x| new_empty_array(x.data_type().clone()).into())
.collect();

let offsets = if *is_sparse {
let offsets = if mode.is_sparse() {
None
} else {
Some(Buffer::new())
Expand All @@ -92,7 +92,7 @@ impl UnionArray {
fields: Vec<Arc<dyn Array>>,
offsets: Option<Buffer<i32>>,
) -> Self {
let (f, ids, is_sparse) = Self::get_all(&data_type);
let (f, ids, mode) = Self::get_all(&data_type);

if f.len() != fields.len() {
panic!("The number of `fields` must equal the number of fields in the Union DataType")
Expand All @@ -104,7 +104,7 @@ impl UnionArray {
if !same_data_types {
panic!("All fields' datatype in the union must equal the datatypes on the fields.")
}
if offsets.is_none() != is_sparse {
if offsets.is_none() != mode.is_sparse() {
panic!("Sparsness flag must equal to noness of offsets in UnionArray")
}
let fields_hash = ids.as_ref().map(|ids| {
Expand Down Expand Up @@ -244,11 +244,9 @@ impl Array for UnionArray {
}

impl UnionArray {
fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, bool) {
fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, UnionMode) {
match data_type.to_logical_type() {
DataType::Union(fields, ids, is_sparse) => {
(fields, ids.as_ref().map(|x| x.as_ref()), *is_sparse)
}
DataType::Union(fields, ids, mode) => (fields, ids.as_ref().map(|x| x.as_ref()), *mode),
_ => panic!("Wrong datatype passed to UnionArray."),
}
}
Expand All @@ -264,7 +262,7 @@ impl UnionArray {
/// # Panic
/// Panics iff `data_type`'s logical type is not [`DataType::Union`].
pub fn is_sparse(data_type: &DataType) -> bool {
Self::get_all(data_type).2
Self::get_all(data_type).2.is_sparse()
}
}

Expand Down
35 changes: 33 additions & 2 deletions src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ pub enum DataType {
/// A nested datatype that contains a number of sub-fields.
Struct(Vec<Field>),
/// A nested datatype that can represent slots of differing types.
/// Third argument represents sparsness
Union(Vec<Field>, Option<Vec<i32>>, bool),
/// Third argument represents mode
Union(Vec<Field>, Option<Vec<i32>>, UnionMode),
/// A nested type that is represented as
///
/// List<entries: Struct<key: K, value: V>>
Expand Down Expand Up @@ -144,6 +144,37 @@ impl std::fmt::Display for DataType {
}
}

/// Mode of [`DataType::Union`]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnionMode {
/// Dense union
Dense,
/// Sparse union
Sparse,
}

impl UnionMode {
/// Constructs a [`UnionMode::Sparse`] if the input bool is true,
/// or otherwise constructs a [`UnionMode::Dense`]
pub fn sparse(is_sparse: bool) -> Self {
if is_sparse {
Self::Sparse
} else {
Self::Dense
}
}

/// Returns whether the mode is sparse
pub fn is_sparse(&self) -> bool {
matches!(self, Self::Sparse)
}

/// Returns whether the mode is dense
pub fn is_dense(&self) -> bool {
matches!(self, Self::Dense)
}
}

/// The time units defined in Arrow.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TimeUnit {
Expand Down
10 changes: 5 additions & 5 deletions src/ffi/schema.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::BTreeMap, convert::TryInto, ffi::CStr, ffi::CString, ptr};

use crate::{
datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit},
datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit, UnionMode},
error::{ArrowError, Result},
};

Expand Down Expand Up @@ -314,7 +314,7 @@ unsafe fn to_data_type(schema: &Ffi_ArrowSchema) -> Result<DataType> {
DataType::Decimal(precision, scale)
} else if !parts.is_empty() && ((parts[0] == "+us") || (parts[0] == "+ud")) {
// union
let is_sparse = parts[0] == "+us";
let mode = UnionMode::sparse(parts[0] == "+us");
let type_ids = parts[1]
.split(',')
.map(|x| {
Expand All @@ -326,7 +326,7 @@ unsafe fn to_data_type(schema: &Ffi_ArrowSchema) -> Result<DataType> {
let fields = (0..schema.n_children as usize)
.map(|x| to_field(schema.child(x)))
.collect::<Result<Vec<_>>>()?;
DataType::Union(fields, Some(type_ids), is_sparse)
DataType::Union(fields, Some(type_ids), mode)
} else {
return Err(ArrowError::Ffi(format!(
"The datatype \"{}\" is still not supported in Rust implementation",
Expand Down Expand Up @@ -397,8 +397,8 @@ fn to_format(data_type: &DataType) -> String {
DataType::Struct(_) => "+s".to_string(),
DataType::FixedSizeBinary(size) => format!("w{}", size),
DataType::FixedSizeList(_, size) => format!("+w:{}", size),
DataType::Union(f, ids, is_sparse) => {
let sparsness = if *is_sparse { 's' } else { 'd' };
DataType::Union(f, ids, mode) => {
let sparsness = if mode.is_sparse() { 's' } else { 'd' };
let mut r = format!("+u{}:", sparsness);
let ids = if let Some(ids) = ids {
ids.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/io/avro/read/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ fn schema_to_field(
.iter()
.map(|s| schema_to_field(s, None, has_nullable, None))
.collect::<Result<Vec<Field>>>()?;
DataType::Union(fields, None, false)
DataType::Union(fields, None, UnionMode::Dense)
}
}
AvroSchema::Record { name, fields, .. } => {
Expand Down
10 changes: 5 additions & 5 deletions src/io/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod ipc {
}

use crate::datatypes::{
get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit,
get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode,
};
use crate::io::ipc::endianess::is_native_little_endian;

Expand Down Expand Up @@ -292,7 +292,7 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo
ipc::Type::Union => {
let type_ = field.type_as_union().unwrap();

let is_sparse = type_.mode() == ipc::UnionMode::Sparse;
let mode = UnionMode::sparse(type_.mode() == ipc::UnionMode::Sparse);

let ids = type_.typeIds().map(|x| x.iter().collect());

Expand All @@ -303,7 +303,7 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo
} else {
vec![]
};
DataType::Union(fields, ids, is_sparse)
DataType::Union(fields, ids, mode)
}
ipc::Type::Map => {
let map = field.type_as_map().unwrap();
Expand Down Expand Up @@ -704,13 +704,13 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(fbb.create_vector(&empty_fields[..])),
}
}
Union(fields, ids, is_sparse) => {
Union(fields, ids, mode) => {
let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect();

let ids = ids.as_ref().map(|ids| fbb.create_vector(ids));

let mut builder = ipc::UnionBuilder::new(fbb);
builder.add_mode(if *is_sparse {
builder.add_mode(if mode.is_sparse() {
ipc::UnionMode::Sparse
} else {
ipc::UnionMode::Dense
Expand Down
11 changes: 5 additions & 6 deletions src/io/ipc/read/array/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow_format::ipc;

use crate::array::UnionArray;
use crate::datatypes::DataType;
use crate::datatypes::UnionMode::Dense;
use crate::error::Result;

use super::super::deserialize::{read, skip, Node};
Expand Down Expand Up @@ -36,8 +37,8 @@ pub fn read_union<R: Read + Seek>(
compression,
)?;

let offsets = if let DataType::Union(_, _, is_sparse) = data_type {
if !is_sparse {
let offsets = if let DataType::Union(_, _, mode) = data_type {
if !mode.is_sparse() {
Some(read_buffer(
buffers,
field_node.length() as usize,
Expand Down Expand Up @@ -82,10 +83,8 @@ pub fn skip_union(
let _ = field_nodes.pop_front().unwrap();

let _ = buffers.pop_front().unwrap();
if let DataType::Union(_, _, is_sparse) = data_type {
if !*is_sparse {
let _ = buffers.pop_front().unwrap();
}
if let DataType::Union(_, _, Dense) = data_type {
let _ = buffers.pop_front().unwrap();
} else {
panic!()
};
Expand Down
11 changes: 7 additions & 4 deletions src/io/json_integration/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use std::{
use serde_derive::Deserialize;
use serde_json::{json, Value};

use crate::error::{ArrowError, Result};
use crate::{
datatypes::UnionMode,
error::{ArrowError, Result},
};

use crate::datatypes::{get_extension, DataType, Field, IntervalUnit, Schema, TimeUnit};

Expand Down Expand Up @@ -395,8 +398,8 @@ fn to_data_type(item: &Value, mut children: Vec<Field>) -> Result<DataType> {
}
"struct" => DataType::Struct(children),
"union" => {
let is_sparse = if let Some(Value::String(mode)) = item.get("mode") {
mode == "SPARSE"
let mode = if let Some(Value::String(mode)) = item.get("mode") {
UnionMode::sparse(mode == "SPARSE")
} else {
return Err(ArrowError::Schema("union requires mode".to_string()));
};
Expand All @@ -405,7 +408,7 @@ fn to_data_type(item: &Value, mut children: Vec<Field>) -> Result<DataType> {
} else {
return Err(ArrowError::Schema("union requires ids".to_string()));
};
DataType::Union(children, ids, is_sparse)
DataType::Union(children, ids, mode)
}
"map" => {
let sorted_keys = if let Some(Value::Bool(sorted_keys)) = item.get("keysSorted") {
Expand Down
26 changes: 21 additions & 5 deletions tests/it/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod utf8;

use arrow2::array::{clone, new_empty_array, new_null_array, Array, PrimitiveArray};
use arrow2::bitmap::Bitmap;
use arrow2::datatypes::{DataType, Field};
use arrow2::datatypes::{DataType, Field, UnionMode};

#[test]
fn nulls() {
Expand All @@ -31,8 +31,16 @@ fn nulls() {

// unions' null count is always 0
let datatypes = vec![
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, false),
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, true),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Dense,
),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Sparse,
),
];
let a = datatypes
.into_iter()
Expand All @@ -48,8 +56,16 @@ fn empty() {
DataType::Utf8,
DataType::Binary,
DataType::List(Box::new(Field::new("a", DataType::Binary, true))),
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, true),
DataType::Union(vec![Field::new("a", DataType::Binary, true)], None, false),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Sparse,
),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Dense,
),
];
let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0);
assert!(a);
Expand Down
4 changes: 2 additions & 2 deletions tests/it/array/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn display() -> Result<()> {
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let data_type = DataType::Union(fields, None, UnionMode::Sparse);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Expand All @@ -28,7 +28,7 @@ fn slice() -> Result<()> {
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let data_type = DataType::Union(fields, None, UnionMode::Sparse);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Expand Down
2 changes: 1 addition & 1 deletion tests/it/io/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ fn write_union() -> Result<()> {
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, true);
let data_type = DataType::Union(fields, None, UnionMode::Sparse);
let types = Buffer::from(&[0, 0, 1]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None, Some(2)])) as Arc<dyn Array>,
Expand Down