Skip to content

Commit

Permalink
Add default_cast_for
Browse files Browse the repository at this point in the history
  • Loading branch information
notfilippo committed Oct 25, 2024
1 parent 7ed7891 commit 1530c5b
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 9 deletions.
16 changes: 13 additions & 3 deletions datafusion/common/src/types/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use super::NativeType;
use crate::Result;
use arrow_schema::DataType;
use core::fmt;
use std::{cmp::Ordering, hash::Hash, sync::Arc};

use super::NativeType;

/// Signature that uniquely identifies a type among other types.
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum TypeSignature<'a> {
Expand Down Expand Up @@ -75,8 +76,17 @@ pub type LogicalTypeRef = Arc<dyn LogicalType>;
/// }
/// ```
pub trait LogicalType: Sync + Send {
/// Get the native backing type of this logical type.
fn native(&self) -> &NativeType;
/// Get the unique type signature for this logical type. Logical types with identical
/// signatures are considered equal.
fn signature(&self) -> TypeSignature<'_>;

/// Get the default physical type to cast `origin` to in order to obtain a physical type
/// that is logically compatible with this logical type.
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
self.native().default_cast_for(origin)
}
}

impl fmt::Debug for dyn LogicalType {
Expand All @@ -90,7 +100,7 @@ impl fmt::Debug for dyn LogicalType {

impl PartialEq for dyn LogicalType {
fn eq(&self, other: &Self) -> bool {
self.native().eq(other.native()) && self.signature().eq(&other.signature())
self.signature().eq(&other.signature())
}
}

Expand Down
153 changes: 147 additions & 6 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow_schema::{DataType, IntervalUnit, TimeUnit};

use super::{
LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature,
LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields,
TypeSignature,
};
use crate::{internal_err, Result};
use arrow_schema::{DataType, Field, FieldRef, IntervalUnit, TimeUnit};
use std::sync::Arc;

/// Representation of a type that DataFusion can handle natively. It is a subset
/// of the physical variants in Arrow's native [`DataType`].
Expand Down Expand Up @@ -188,6 +188,147 @@ impl LogicalType for NativeType {
fn signature(&self) -> TypeSignature<'_> {
TypeSignature::Native(self)
}

fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
use DataType::*;

fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> {
Ok(Arc::new(Field::new(
to.name.clone(),
to.logical_type.default_cast_for(from.data_type())?,
to.nullable,
)))
}

Ok(match (self, origin) {
(Self::Null, _) => Null,
(Self::Boolean, _) => Boolean,
(Self::Int8, _) => Int8,
(Self::Int16, _) => Int16,
(Self::Int32, _) => Int32,
(Self::Int64, _) => Int64,
(Self::UInt8, _) => UInt8,
(Self::UInt16, _) => UInt16,
(Self::UInt32, _) => UInt32,
(Self::UInt64, _) => UInt64,
(Self::Float16, _) => Float16,
(Self::Float32, _) => Float32,
(Self::Float64, _) => Float64,
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(p.clone(), s.clone()),
(Self::Decimal(p, s), _) => Decimal256(p.clone(), s.clone()),
(Self::Timestamp(tu, tz), _) => Timestamp(tu.clone(), tz.clone()),
(Self::Date, _) => Date32,
(Self::Time(tu), _) => match tu {
TimeUnit::Second | TimeUnit::Millisecond => Time32(tu.clone()),
TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(tu.clone()),
},
(Self::Duration(tu), _) => Duration(tu.clone()),
(Self::Interval(iu), _) => Interval(iu.clone()),
(Self::Binary, LargeUtf8) => LargeBinary,
(Self::Binary, Utf8View) => BinaryView,
(Self::Binary, _) => Binary,
(Self::FixedSizeBinary(size), _) => FixedSizeBinary(size.clone()),
(Self::Utf8, LargeBinary) => LargeUtf8,
(Self::Utf8, BinaryView) => Utf8View,
(Self::Utf8, _) => Utf8,
(Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => {
List(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), LargeList(from_field)) => {
LargeList(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), ListView(from_field)) => {
ListView(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), LargeListView(from_field)) => {
LargeListView(default_field_cast(to_field, from_field)?)
}
// List array where each element is a len 1 list of the origin type
(Self::List(field), _) => List(Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(origin)?,
field.nullable,
))),
(
Self::FixedSizeList(to_field, to_size),
FixedSizeList(from_field, from_size),
) if from_size == to_size => {
FixedSizeList(default_field_cast(to_field, from_field)?, to_size.clone())
}
(
Self::FixedSizeList(to_field, size),
List(from_field)
| LargeList(from_field)
| ListView(from_field)
| LargeListView(from_field),
) => FixedSizeList(default_field_cast(to_field, from_field)?, size.clone()),
// FixedSizeList array where each element is a len 1 list of the origin type
(Self::FixedSizeList(field, size), _) => FixedSizeList(
Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(origin)?,
field.nullable,
)),
size.clone(),
),
// From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196
(Self::Struct(to_fields), Struct(from_fields))
if from_fields.len() == to_fields.len() =>
{
Struct(
from_fields
.iter()
.zip(to_fields.iter())
.map(|(from, to)| default_field_cast(to, from))
.collect()?,
)
}
(Self::Struct(to_fields), Null) => Struct(
to_fields
.iter()
.map(|field| {
Ok(Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(&Null)?,
field.nullable,
)))
})
.collect()?,
),
(Self::Map(to_field), Map(from_field, sorted)) => {
Map(default_field_cast(to_field, from_field)?, sorted.clone())
}
(Self::Map(field), Null) => Map(
Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(&Null)?,
field.nullable,
)),
false,
),
(Self::Union(to_fields), Union(from_fields, mode))
if from_fields.len() == to_fields.len() =>
{
Union(
from_fields
.iter()
.zip(to_fields.iter())
.map(|((_, from), (i, to))| {
(i.clone(), default_field_cast(to, from))
})
.collect()?,
mode.clone(),
)
}
_ => {
return internal_err!(
"Unavailable default cast for native type {:?} from physical type {:?}",
self,
origin
)
}
})
}
}

// The following From<DataType>, From<Field>, ... implementations are temporary
Expand Down Expand Up @@ -230,9 +371,9 @@ impl From<DataType> for NativeType {
DataType::Union(union_fields, _) => {
Union(LogicalUnionFields::from(&union_fields))
}
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s),
DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())),
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
DataType::RunEndEncoded(_, field) => field.data_type().clone().into(),
}
}
Expand Down

0 comments on commit 1530c5b

Please sign in to comment.