Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Simplified Primitive and Boolean scalar #648

Merged
merged 1 commit into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/security.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ jobs:
toolchain: nightly-2021-10-24
override: true
- uses: Swatinem/rust-cache@v1
with:
key: key1
- name: Install Miri
run: |
rustup component add miri
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ jobs:
toolchain: nightly-2021-10-24
override: true
- uses: Swatinem/rust-cache@v1
with:
key: key1
- name: Install Miri
run: |
rustup component add miri
Expand All @@ -96,6 +98,8 @@ jobs:
toolchain: nightly-2021-10-24
override: true
- uses: Swatinem/rust-cache@v1
with:
key: key1
- name: Install Miri
run: |
rustup component add miri
Expand Down
25 changes: 13 additions & 12 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,58 +240,59 @@ macro_rules! compare_scalar {
Boolean => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<BooleanScalar>().unwrap();
boolean::$op(lhs, rhs.value())
// validity checked above
boolean::$op(lhs, rhs.value().unwrap())
}
Int8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i8>>().unwrap();
primitive::$op::<i8>(lhs, rhs.value())
primitive::$op::<i8>(lhs, rhs.value().unwrap())
}
Int16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i16>>().unwrap();
primitive::$op::<i16>(lhs, rhs.value())
primitive::$op::<i16>(lhs, rhs.value().unwrap())
}
Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i32>>().unwrap();
primitive::$op::<i32>(lhs, rhs.value())
primitive::$op::<i32>(lhs, rhs.value().unwrap())
}
Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
primitive::$op::<i64>(lhs, rhs.value())
primitive::$op::<i64>(lhs, rhs.value().unwrap())
}
UInt8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u8>>().unwrap();
primitive::$op::<u8>(lhs, rhs.value())
primitive::$op::<u8>(lhs, rhs.value().unwrap())
}
UInt16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u16>>().unwrap();
primitive::$op::<u16>(lhs, rhs.value())
primitive::$op::<u16>(lhs, rhs.value().unwrap())
}
UInt32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u32>>().unwrap();
primitive::$op::<u32>(lhs, rhs.value())
primitive::$op::<u32>(lhs, rhs.value().unwrap())
}
UInt64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u64>>().unwrap();
primitive::$op::<u64>(lhs, rhs.value())
primitive::$op::<u64>(lhs, rhs.value().unwrap())
}
Float16 => unreachable!(),
Float32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<f32>>().unwrap();
primitive::$op::<f32>(lhs, rhs.value())
primitive::$op::<f32>(lhs, rhs.value().unwrap())
}
Float64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<f64>>().unwrap();
primitive::$op::<f64>(lhs, rhs.value())
primitive::$op::<f64>(lhs, rhs.value().unwrap())
}
Utf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
Expand All @@ -309,7 +310,7 @@ macro_rules! compare_scalar {
.as_any()
.downcast_ref::<PrimitiveScalar<i128>>()
.unwrap();
primitive::$op::<i128>(lhs, rhs.value())
primitive::$op::<i128>(lhs, rhs.value().unwrap())
}
Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
Expand Down
25 changes: 7 additions & 18 deletions src/scalar/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,21 @@ use crate::datatypes::DataType;
use super::Scalar;

/// The [`Scalar`] implementation of a boolean.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct BooleanScalar {
value: bool,
is_valid: bool,
}

impl PartialEq for BooleanScalar {
fn eq(&self, other: &Self) -> bool {
self.is_valid == other.is_valid && ((!self.is_valid) | (self.value == other.value))
}
value: Option<bool>,
}

impl BooleanScalar {
/// Returns a new [`BooleanScalar`]
#[inline]
pub fn new(v: Option<bool>) -> Self {
let is_valid = v.is_some();
Self {
value: v.unwrap_or_default(),
is_valid,
}
pub fn new(value: Option<bool>) -> Self {
Self { value }
}

/// The value irrespectively of the validity
/// The value
#[inline]
pub fn value(&self) -> bool {
pub fn value(&self) -> Option<bool> {
self.value
}
}
Expand All @@ -41,7 +30,7 @@ impl Scalar for BooleanScalar {

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
self.value.is_some()
}

#[inline]
Expand Down
36 changes: 8 additions & 28 deletions src/scalar/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,16 @@ use super::Scalar;

/// The implementation of [`Scalar`] for primitive, semantically equivalent to [`Option<T>`]
/// with [`DataType`].
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct PrimitiveScalar<T: NativeType> {
// Not Option<T> because this offers a stabler pointer offset on the struct
value: T,
is_valid: bool,
value: Option<T>,
data_type: DataType,
}

impl<T: NativeType> PartialEq for PrimitiveScalar<T> {
fn eq(&self, other: &Self) -> bool {
self.data_type == other.data_type
&& self.is_valid == other.is_valid
&& ((!self.is_valid) | (self.value == other.value))
}
}

impl<T: NativeType> PrimitiveScalar<T> {
/// Returns a new [`PrimitiveScalar`].
#[inline]
pub fn new(data_type: DataType, v: Option<T>) -> Self {
pub fn new(data_type: DataType, value: Option<T>) -> Self {
if !T::is_valid(&data_type) {
Err(ArrowError::InvalidArgumentError(format!(
"Type {} does not support logical type {}",
Expand All @@ -36,30 +26,20 @@ impl<T: NativeType> PrimitiveScalar<T> {
)))
.unwrap()
}
let is_valid = v.is_some();
Self {
value: v.unwrap_or_default(),
is_valid,
data_type,
}
Self { value, data_type }
}

/// Returns the value irrespectively of its validity.
/// Returns the optional value.
#[inline]
pub fn value(&self) -> T {
pub fn value(&self) -> Option<T> {
self.value
}

/// Returns a new `PrimitiveScalar` with the same value but different [`DataType`]
/// # Panic
/// This function panics if the `data_type` is not valid for self's physical type `T`.
pub fn to(self, data_type: DataType) -> Self {
let v = if self.is_valid {
Some(self.value)
} else {
None
};
Self::new(data_type, v)
Self::new(data_type, self.value)
}
}

Expand All @@ -78,7 +58,7 @@ impl<T: NativeType> Scalar for PrimitiveScalar<T> {

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
self.value.is_some()
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion tests/it/scalar/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn equal() {
fn basics() {
let a = BooleanScalar::new(Some(true));

assert!(a.value());
assert_eq!(a.value(), Some(true));
assert_eq!(a.data_type(), &DataType::Boolean);
assert!(a.is_valid());

Expand Down
3 changes: 1 addition & 2 deletions tests/it/scalar/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ fn equal() {
fn basics() {
let a = PrimitiveScalar::from(Some(2i32));

assert_eq!(a.value(), 2i32);
assert_eq!(a.value(), Some(2i32));
assert_eq!(a.data_type(), &DataType::Int32);
assert!(a.is_valid());

let a = a.to(DataType::Date32);
assert_eq!(a.data_type(), &DataType::Date32);
Expand Down