diff --git a/Cargo.lock b/Cargo.lock index 6f23f907ed2..bf7647a26f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8651,8 +8651,9 @@ dependencies = [ name = "vortex-compute" version = "0.1.0" dependencies = [ + "num-traits", "vortex-buffer", - "vortex-error", + "vortex-dtype", "vortex-mask", "vortex-vector", ] diff --git a/encodings/alp/src/alp/mod.rs b/encodings/alp/src/alp/mod.rs index 6a42da28ba3..6aecf49bf20 100644 --- a/encodings/alp/src/alp/mod.rs +++ b/encodings/alp/src/alp/mod.rs @@ -193,7 +193,7 @@ pub trait ALPFloat: private::Sealed + Float + Display + NativePType { } fn decode_buffer(encoded: BufferMut, exponents: Exponents) -> BufferMut { - encoded.map_each(move |encoded| Self::decode_single(encoded, exponents)) + encoded.map_each_in_place(move |encoded| Self::decode_single(encoded, exponents)) } #[inline(always)] diff --git a/encodings/alp/src/alp_rd/mod.rs b/encodings/alp/src/alp_rd/mod.rs index 180922f54b1..98c0db06f89 100644 --- a/encodings/alp/src/alp_rd/mod.rs +++ b/encodings/alp/src/alp_rd/mod.rs @@ -311,7 +311,7 @@ pub fn alp_rd_decode( // Shift the left-parts and add in the right-parts. let mut index = 0; right_parts - .map_each(|right| { + .map_each_in_place(|right| { let left = values[index]; index += 1; let left = ::from_u16(left); diff --git a/encodings/datetime-parts/src/canonical.rs b/encodings/datetime-parts/src/canonical.rs index cfc4a830d43..401ed6cb030 100644 --- a/encodings/datetime-parts/src/canonical.rs +++ b/encodings/datetime-parts/src/canonical.rs @@ -53,7 +53,7 @@ pub fn decode_to_temporal(array: &DateTimePartsArray) -> TemporalArray { // are constant. let mut values: BufferMut = days_buf .into_buffer_mut::() - .map_each(|d| d * 86_400 * divisor); + .map_each_in_place(|d| d * 86_400 * divisor); if let Some(seconds) = array.seconds().as_constant() { let seconds = seconds diff --git a/encodings/fastlanes/src/for/compress.rs b/encodings/fastlanes/src/for/compress.rs index 3235dcd19d7..fe568a6748a 100644 --- a/encodings/fastlanes/src/for/compress.rs +++ b/encodings/fastlanes/src/for/compress.rs @@ -151,7 +151,9 @@ fn decompress_primitive( values: BufferMut, min: T, ) -> Buffer { - values.map_each(move |v| v.wrapping_add(&min)).freeze() + values + .map_each_in_place(move |v| v.wrapping_add(&min)) + .freeze() } #[cfg(test)] diff --git a/encodings/zigzag/src/compress.rs b/encodings/zigzag/src/compress.rs index 7734e0cd67f..e68441f401e 100644 --- a/encodings/zigzag/src/compress.rs +++ b/encodings/zigzag/src/compress.rs @@ -33,7 +33,10 @@ fn zigzag_encode_primitive( where ::UInt: NativePType, { - PrimitiveArray::new(values.map_each(|v| T::encode(v)).freeze(), validity) + PrimitiveArray::new( + values.map_each_in_place(|v| T::encode(v)).freeze(), + validity, + ) } pub fn zigzag_decode(parray: PrimitiveArray) -> PrimitiveArray { @@ -57,7 +60,10 @@ fn zigzag_decode_primitive( where ::UInt: NativePType, { - PrimitiveArray::new(values.map_each(|v| T::decode(v)).freeze(), validity) + PrimitiveArray::new( + values.map_each_in_place(|v| T::decode(v)).freeze(), + validity, + ) } #[cfg(test)] diff --git a/vortex-array/src/array/operator.rs b/vortex-array/src/array/operator.rs index 594e6baa521..6d6bc133c5f 100644 --- a/vortex-array/src/array/operator.rs +++ b/vortex-array/src/array/operator.rs @@ -3,7 +3,8 @@ use std::sync::Arc; -use vortex_error::VortexResult; +use vortex_dtype::DType; +use vortex_error::{VortexResult, vortex_bail}; use vortex_vector::Vector; use crate::execution::{BatchKernelRef, BindCtx}; @@ -33,6 +34,21 @@ pub trait ArrayOperator: 'static + Send + Sync { impl ArrayOperator for Arc { fn execute_with_selection(&self, selection: Option<&ArrayRef>) -> VortexResult { + if let Some(selection) = selection.as_ref() { + if !matches!(selection.dtype(), DType::Bool(_)) { + vortex_bail!( + "Selection array must be of boolean type, got {}", + selection.dtype() + ); + } + if selection.len() != self.len() { + vortex_bail!( + "Selection array length {} does not match array length {}", + selection.len(), + self.len() + ); + } + } self.as_ref().execute_with_selection(selection) } diff --git a/vortex-array/src/arrays/primitive/array/mod.rs b/vortex-array/src/arrays/primitive/array/mod.rs index 9fdda9b0547..ea7d43c1905 100644 --- a/vortex-array/src/arrays/primitive/array/mod.rs +++ b/vortex-array/src/arrays/primitive/array/mod.rs @@ -192,7 +192,7 @@ impl PrimitiveArray { { let validity = self.validity().clone(); let buffer = match self.try_into_buffer_mut() { - Ok(buffer_mut) => buffer_mut.map_each(f), + Ok(buffer_mut) => buffer_mut.map_each_in_place(f), Err(parray) => BufferMut::::from_iter(parray.buffer::().iter().copied().map(f)), }; PrimitiveArray::new(buffer.freeze(), validity) diff --git a/vortex-array/src/compute/arrays/arithmetic.rs b/vortex-array/src/compute/arrays/arithmetic.rs new file mode 100644 index 00000000000..8ae102c97d4 --- /dev/null +++ b/vortex-array/src/compute/arrays/arithmetic.rs @@ -0,0 +1,430 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::hash::{Hash, Hasher}; +use std::sync::LazyLock; + +use enum_map::{Enum, EnumMap, enum_map}; +use vortex_buffer::ByteBuffer; +use vortex_compute::arithmetic::{ + Add, Arithmetic, CheckedArithmetic, CheckedOperator, Div, Mul, Operator, Sub, +}; +use vortex_dtype::{DType, NativePType, PTypeDowncastExt, match_each_native_ptype}; +use vortex_error::{VortexExpect, VortexResult, vortex_err}; +use vortex_scalar::{PValue, Scalar}; +use vortex_vector::PVector; + +use crate::arrays::ConstantArray; +use crate::execution::{BatchKernelRef, BindCtx, kernel}; +use crate::serde::ArrayChildren; +use crate::stats::{ArrayStats, StatsSetRef}; +use crate::vtable::{ + ArrayVTable, NotSupported, OperatorVTable, SerdeVTable, VTable, VisitorVTable, +}; +use crate::{ + Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef, + DeserializeMetadata, EmptyMetadata, EncodingId, EncodingRef, IntoArray, Precision, vtable, +}; + +/// The set of operators supported by an arithmetic array. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enum)] +pub enum ArithmeticOperator { + /// Addition - errors on overflow for integers. + Add, + /// Subtraction - errors on overflow for integers. + Sub, + /// Multiplication - errors on overflow for integers. + Mul, + /// Division - errors on division by zero for integers. + Div, +} + +vtable!(Arithmetic); + +#[derive(Debug, Clone)] +pub struct ArithmeticArray { + encoding: EncodingRef, + lhs: ArrayRef, + rhs: ArrayRef, + stats: ArrayStats, +} + +impl ArithmeticArray { + /// Create a new arithmetic array. + pub fn new(lhs: ArrayRef, rhs: ArrayRef, operator: ArithmeticOperator) -> Self { + assert_eq!( + lhs.len(), + rhs.len(), + "Arithmetic arrays require lhs and rhs to have the same length" + ); + + // TODO(ngates): should we automatically cast non-null to nullable if required? + assert!(matches!(lhs.dtype(), DType::Primitive(..))); + assert_eq!(lhs.dtype(), rhs.dtype()); + + Self { + encoding: ENCODINGS[operator].clone(), + lhs, + rhs, + stats: ArrayStats::default(), + } + } + + /// Returns the operator of this logical array. + pub fn operator(&self) -> ArithmeticOperator { + self.encoding.as_::().operator + } +} + +#[derive(Debug, Clone)] +pub struct ArithmeticEncoding { + // We include the operator in the encoding so each operator is a different encoding ID. + // This makes it easier for plugins to construct expressions and perform pushdown + // optimizations. + operator: ArithmeticOperator, +} + +#[allow(clippy::mem_forget)] +static ENCODINGS: LazyLock> = LazyLock::new(|| { + enum_map! { + operator => ArithmeticEncoding { operator }.to_encoding(), + } +}); + +impl VTable for ArithmeticVTable { + type Array = ArithmeticArray; + type Encoding = ArithmeticEncoding; + type ArrayVTable = Self; + type CanonicalVTable = NotSupported; + type OperationsVTable = NotSupported; + type ValidityVTable = NotSupported; + type VisitorVTable = Self; + type ComputeVTable = NotSupported; + type EncodeVTable = NotSupported; + type SerdeVTable = Self; + type OperatorVTable = Self; + + fn id(encoding: &Self::Encoding) -> EncodingId { + match encoding.operator { + ArithmeticOperator::Add => EncodingId::from("vortex.add"), + ArithmeticOperator::Sub => EncodingId::from("vortex.sub"), + ArithmeticOperator::Mul => EncodingId::from("vortex.mul"), + ArithmeticOperator::Div => EncodingId::from("vortex.div"), + } + } + + fn encoding(array: &Self::Array) -> EncodingRef { + array.encoding.clone() + } +} + +impl ArrayVTable for ArithmeticVTable { + fn len(array: &ArithmeticArray) -> usize { + array.lhs.len() + } + + fn dtype(array: &ArithmeticArray) -> &DType { + array.lhs.dtype() + } + + fn stats(array: &ArithmeticArray) -> StatsSetRef<'_> { + array.stats.to_ref(array.as_ref()) + } + + fn array_hash(array: &ArithmeticArray, state: &mut H, precision: Precision) { + array.lhs.array_hash(state, precision); + array.rhs.array_hash(state, precision); + } + + fn array_eq(array: &ArithmeticArray, other: &ArithmeticArray, precision: Precision) -> bool { + array.lhs.array_eq(&other.lhs, precision) && array.rhs.array_eq(&other.rhs, precision) + } +} + +impl VisitorVTable for ArithmeticVTable { + fn visit_buffers(_array: &ArithmeticArray, _visitor: &mut dyn ArrayBufferVisitor) { + // No buffers + } + + fn visit_children(array: &ArithmeticArray, visitor: &mut dyn ArrayChildVisitor) { + visitor.visit_child("lhs", array.lhs.as_ref()); + visitor.visit_child("rhs", array.rhs.as_ref()); + } +} + +impl SerdeVTable for ArithmeticVTable { + type Metadata = EmptyMetadata; + + fn metadata(_array: &ArithmeticArray) -> VortexResult> { + Ok(Some(EmptyMetadata)) + } + + fn build( + encoding: &ArithmeticEncoding, + dtype: &DType, + len: usize, + _metadata: &::Output, + buffers: &[ByteBuffer], + children: &dyn ArrayChildren, + ) -> VortexResult { + assert!(buffers.is_empty()); + + Ok(ArithmeticArray::new( + children.get(0, dtype, len)?, + children.get(1, dtype, len)?, + encoding.operator, + )) + } +} + +impl OperatorVTable for ArithmeticVTable { + fn reduce_children(array: &ArithmeticArray) -> VortexResult> { + match (array.lhs.as_constant(), array.rhs.as_constant()) { + // If both sides are constant, we compute the value now. + (Some(lhs), Some(rhs)) => { + let op: vortex_scalar::NumericOperator = match array.operator() { + ArithmeticOperator::Add => vortex_scalar::NumericOperator::Add, + ArithmeticOperator::Sub => vortex_scalar::NumericOperator::Sub, + ArithmeticOperator::Mul => vortex_scalar::NumericOperator::Mul, + ArithmeticOperator::Div => vortex_scalar::NumericOperator::Div, + }; + let result = lhs + .as_primitive() + .checked_binary_numeric(&rhs.as_primitive(), op) + .ok_or_else(|| { + vortex_err!("Constant arithmetic operation resulted in overflow") + })?; + return Ok(Some( + ConstantArray::new(Scalar::from(result), array.len()).into_array(), + )); + } + // If either side is constant null, the result is constant null. + (Some(lhs), _) if lhs.is_null() => { + return Ok(Some( + ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()) + .into_array(), + )); + } + (_, Some(rhs)) if rhs.is_null() => { + return Ok(Some( + ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()) + .into_array(), + )); + } + _ => {} + } + + Ok(None) + } + + fn bind( + array: &ArithmeticArray, + selection: Option<&ArrayRef>, + ctx: &mut dyn BindCtx, + ) -> VortexResult { + // Optimize for constant RHS + if let Some(rhs_scalar) = array.rhs.as_constant() { + if rhs_scalar.is_null() { + // If the RHS is null, the result is always null. + return ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()) + .into_array() + .bind(selection, ctx); + } + + let lhs = ctx.bind(&array.lhs, selection)?; + return match_each_native_ptype!( + array.dtype().as_ptype(), + integral: |T| { + let rhs: T = rhs_scalar + .as_primitive() + .typed_value::() + .vortex_expect("Already checked for null above"); + Ok(match array.operator() { + ArithmeticOperator::Add => checked_arithmetic_scalar_kernel::(lhs, rhs), + ArithmeticOperator::Sub => checked_arithmetic_scalar_kernel::(lhs, rhs), + ArithmeticOperator::Mul => checked_arithmetic_scalar_kernel::(lhs, rhs), + ArithmeticOperator::Div => checked_arithmetic_scalar_kernel::(lhs, rhs), + }) + }, + floating: |T| { + let rhs: T = rhs_scalar + .as_primitive() + .typed_value::() + .vortex_expect("Already checked for null above"); + Ok(match array.operator() { + ArithmeticOperator::Add => arithmetic_scalar_kernel::(lhs, rhs), + ArithmeticOperator::Sub => arithmetic_scalar_kernel::(lhs, rhs), + ArithmeticOperator::Mul => arithmetic_scalar_kernel::(lhs, rhs), + ArithmeticOperator::Div => arithmetic_scalar_kernel::(lhs, rhs), + }) + } + ); + } + + let lhs = ctx.bind(&array.lhs, selection)?; + let rhs = ctx.bind(&array.rhs, selection)?; + + match_each_native_ptype!( + array.dtype().as_ptype(), + integral: |T| { + Ok(match array.operator() { + ArithmeticOperator::Add => checked_arithmetic_kernel::(lhs, rhs), + ArithmeticOperator::Sub => checked_arithmetic_kernel::(lhs, rhs), + ArithmeticOperator::Mul => checked_arithmetic_kernel::(lhs, rhs), + ArithmeticOperator::Div => checked_arithmetic_kernel::(lhs, rhs), + }) + }, + floating: |T| { + Ok(match array.operator() { + ArithmeticOperator::Add => arithmetic_kernel::(lhs, rhs), + ArithmeticOperator::Sub => arithmetic_kernel::(lhs, rhs), + ArithmeticOperator::Mul => arithmetic_kernel::(lhs, rhs), + ArithmeticOperator::Div => arithmetic_kernel::(lhs, rhs), + }) + } + ) + } +} + +fn arithmetic_kernel(lhs: BatchKernelRef, rhs: BatchKernelRef) -> BatchKernelRef +where + T: NativePType, + Op: Operator, +{ + kernel(move || { + let lhs = lhs.execute()?.into_primitive().downcast::(); + let rhs = rhs.execute()?.into_primitive().downcast::(); + let result = Arithmetic::::eval(lhs, &rhs); + Ok(result.into()) + }) +} + +fn arithmetic_scalar_kernel(lhs: BatchKernelRef, rhs: T) -> BatchKernelRef +where + T: NativePType + TryFrom, + Op: Operator, +{ + kernel(move || { + let lhs = lhs.execute()?.into_primitive().downcast::(); + let result = Arithmetic::::eval(lhs, &rhs); + Ok(result.into()) + }) +} + +fn checked_arithmetic_kernel(lhs: BatchKernelRef, rhs: BatchKernelRef) -> BatchKernelRef +where + T: NativePType, + Op: CheckedOperator, + PVector: for<'a> CheckedArithmetic, Output = PVector>, +{ + kernel(move || { + let lhs = lhs.execute()?.into_primitive().downcast::(); + let rhs = rhs.execute()?.into_primitive().downcast::(); + let result = CheckedArithmetic::::checked_eval(lhs, &rhs) + .ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?; + Ok(result.into()) + }) +} + +fn checked_arithmetic_scalar_kernel(lhs: BatchKernelRef, rhs: T) -> BatchKernelRef +where + T: NativePType + TryFrom, + Op: CheckedOperator, + PVector: for<'a> CheckedArithmetic>, +{ + kernel(move || { + let lhs = lhs.execute()?.into_primitive().downcast::(); + let result = CheckedArithmetic::::checked_eval(lhs, &rhs) + .ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?; + Ok(result.into()) + }) +} + +#[cfg(test)] +mod tests { + use vortex_buffer::{bitbuffer, buffer}; + use vortex_dtype::PTypeDowncastExt; + + use crate::arrays::PrimitiveArray; + use crate::compute::arrays::arithmetic::{ArithmeticArray, ArithmeticOperator}; + use crate::{ArrayOperator, ArrayRef, IntoArray}; + + fn add(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef { + ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Add).into_array() + } + + fn sub(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef { + ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Sub).into_array() + } + + fn mul(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef { + ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Mul).into_array() + } + + fn div(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef { + ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Div).into_array() + } + + #[test] + fn test_add() { + let lhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array(); + let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array(); + let result = add(lhs, rhs) + .execute() + .unwrap() + .into_primitive() + .downcast::(); + assert_eq!(result.elements(), &buffer![11u32, 22, 33]); + } + + #[test] + fn test_sub() { + let lhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array(); + let rhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array(); + let result = sub(lhs, rhs) + .execute() + .unwrap() + .into_primitive() + .downcast::(); + assert_eq!(result.elements(), &buffer![9u32, 18, 27]); + } + + #[test] + fn test_mul() { + let lhs = PrimitiveArray::from_iter([2u32, 3, 4]).into_array(); + let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array(); + let result = mul(lhs, rhs) + .execute() + .unwrap() + .into_primitive() + .downcast::(); + assert_eq!(result.elements(), &buffer![20u32, 60, 120]); + } + + #[test] + fn test_div() { + let lhs = PrimitiveArray::from_iter([100u32, 200, 300]).into_array(); + let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array(); + let result = div(lhs, rhs) + .execute() + .unwrap() + .into_primitive() + .downcast::(); + assert_eq!(result.elements(), &buffer![10u32, 10, 10]); + } + + #[test] + fn test_add_with_selection() { + let lhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array(); + let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array(); + + let selection = bitbuffer![1 0 1].into_array(); + + let result = add(lhs, rhs) + .execute_with_selection(Some(&selection)) + .unwrap() + .into_primitive() + .downcast::(); + assert_eq!(result.elements(), &buffer![11u32, 33]); + } +} diff --git a/vortex-array/src/compute/arrays/mod.rs b/vortex-array/src/compute/arrays/mod.rs index 97dc8731c56..cb099647ec1 100644 --- a/vortex-array/src/compute/arrays/mod.rs +++ b/vortex-array/src/compute/arrays/mod.rs @@ -1,4 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod logical; +pub mod arithmetic; +pub mod logical; diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 094bf982313..27530d63b6f 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -43,7 +43,7 @@ use crate::{Array, ArrayRef}; #[cfg(feature = "arbitrary")] mod arbitrary; -mod arrays; +pub mod arrays; mod between; mod boolean; mod cast; diff --git a/vortex-array/src/execution/batch.rs b/vortex-array/src/execution/batch.rs index a88cfb8a5bf..928a9b7487b 100644 --- a/vortex-array/src/execution/batch.rs +++ b/vortex-array/src/execution/batch.rs @@ -23,6 +23,7 @@ impl VortexResult + Send + 'static> BatchKernel for Batch } /// Create a batch execution kernel from the given closure. +#[inline(always)] pub fn kernel VortexResult + Send + 'static>(f: F) -> BatchKernelRef { Box::new(BatchKernelAdapter(f)) } diff --git a/vortex-buffer/benches/vortex_buffer.rs b/vortex-buffer/benches/vortex_buffer.rs index ecfab3917cc..5374701ab2c 100644 --- a/vortex-buffer/benches/vortex_buffer.rs +++ b/vortex-buffer/benches/vortex_buffer.rs @@ -76,7 +76,7 @@ impl MapEach for BufferMut { where F: FnMut(T) -> R, { - BufferMut::::map_each(self, f) + BufferMut::::map_each_in_place(self, f) } } diff --git a/vortex-buffer/src/buffer_mut.rs b/vortex-buffer/src/buffer_mut.rs index 75df8ebbcfe..f0ea523b5d4 100644 --- a/vortex-buffer/src/buffer_mut.rs +++ b/vortex-buffer/src/buffer_mut.rs @@ -393,7 +393,7 @@ impl BufferMut { } /// Map each element of the buffer with a closure. - pub fn map_each(self, mut f: F) -> BufferMut + pub fn map_each_in_place(self, mut f: F) -> BufferMut where T: Copy, F: FnMut(T) -> R, @@ -785,7 +785,7 @@ mod test { fn map_each() { let buf = buffer_mut![0i32, 1, 2]; // Add one, and cast to an unsigned u32 in the same closure - let buf = buf.map_each(|i| (i + 1) as u32); + let buf = buf.map_each_in_place(|i| (i + 1) as u32); assert_eq!(buf.as_slice(), &[1u32, 2, 3]); } diff --git a/vortex-compute/Cargo.toml b/vortex-compute/Cargo.toml index a38f1e45fe8..406229cd66e 100644 --- a/vortex-compute/Cargo.toml +++ b/vortex-compute/Cargo.toml @@ -21,12 +21,15 @@ workspace = true [dependencies] vortex-buffer = { workspace = true } -vortex-error = { workspace = true } +vortex-dtype = { workspace = true } vortex-mask = { workspace = true } vortex-vector = { workspace = true } +num-traits = { workspace = true } + [features] -default = ["filter", "logical"] +default = ["arithmetic", "filter", "logical"] +arithmetic = [] filter = [] logical = [] diff --git a/vortex-compute/src/arithmetic/buffer.rs b/vortex-compute/src/arithmetic/buffer.rs new file mode 100644 index 00000000000..9f4f4c36ca8 --- /dev/null +++ b/vortex-compute/src/arithmetic/buffer.rs @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::{Buffer, BufferMut}; + +use crate::arithmetic::{Arithmetic, Operator}; + +/// Implementation that attempts to downcast to a mutable buffer and operates in-place. +impl Arithmetic> for Buffer +where + T: Copy, + BufferMut: for<'a> Arithmetic, Output = Buffer>, + for<'a> &'a Buffer: Arithmetic, Output = Buffer>, +{ + type Output = Buffer; + + fn eval(self, rhs: &Buffer) -> Self::Output { + match self.try_into_mut() { + Ok(lhs) => lhs.eval(rhs), + Err(lhs) => (&lhs).eval(rhs), // (&lhs) to delegate to borrowed impl + } + } +} + +/// Implementation that operates in-place over a mutable buffer. +impl Arithmetic> for BufferMut +where + T: Copy + num_traits::Zero, + Op: Operator, +{ + type Output = Buffer; + + fn eval(self, rhs: &Buffer) -> Self::Output { + assert_eq!(self.len(), rhs.len()); + + let mut i = 0; + self.map_each_in_place(|a| { + // SAFETY: lengths are equal, so index is in bounds + let b = unsafe { *rhs.get_unchecked(i) }; + i += 1; + + Op::apply(&a, &b) + }) + .freeze() + } +} + +/// Implementation that allocates a new output buffer. +impl Arithmetic for &Buffer +where + Op: Operator, +{ + type Output = Buffer; + + fn eval(self, rhs: &Buffer) -> Self::Output { + assert_eq!(self.len(), rhs.len()); + Buffer::::from_trusted_len_iter( + self.iter().zip(rhs.iter()).map(|(a, b)| Op::apply(a, b)), + ) + } +} + +/// Implementation that attempts to downcast to a mutable buffer and operates in-place against +/// a scalar RHS value. +impl Arithmetic for Buffer +where + BufferMut: for<'a> Arithmetic>, + for<'a> &'a Buffer: Arithmetic>, +{ + type Output = Buffer; + + fn eval(self, rhs: &T) -> Self::Output { + match self.try_into_mut() { + Ok(lhs) => lhs.eval(rhs), + Err(lhs) => (&lhs).eval(rhs), + } + } +} + +/// Implementation that operates in-place over a mutable buffer against a scalar RHS value. +impl Arithmetic for BufferMut +where + T: Copy, + Op: Operator, +{ + type Output = Buffer; + + fn eval(self, rhs: &T) -> Self::Output { + self.map_each_in_place(|a| Op::apply(&a, rhs)).freeze() + } +} + +/// Implementation that allocates a new output buffer operating against a scalar RHS value. +impl Arithmetic for &Buffer +where + Op: Operator, +{ + type Output = Buffer; + + fn eval(self, rhs: &T) -> Self::Output { + Buffer::::from_trusted_len_iter(self.iter().map(|a| Op::apply(a, rhs))) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::arithmetic::{Arithmetic, WrappingAdd, WrappingMul, WrappingSub}; + + #[test] + fn test_add_buffers() { + let left = buffer![1u32, 2, 3, 4]; + let right = buffer![10u32, 20, 30, 40]; + + let result = Arithmetic::::eval(left, &right); + assert_eq!(result, buffer![11u32, 22, 33, 44]); + } + + #[test] + fn test_add_scalar() { + let buf = buffer![1u32, 2, 3, 4]; + let result = Arithmetic::::eval(buf, &10); + assert_eq!(result, buffer![11u32, 12, 13, 14]); + } + + #[test] + fn test_sub_buffers() { + let left = buffer![10u32, 20, 30, 40]; + let right = buffer![1u32, 2, 3, 4]; + + let result = Arithmetic::::eval(left, &right); + assert_eq!(result, buffer![9u32, 18, 27, 36]); + } + + #[test] + fn test_sub_scalar() { + let buf = buffer![10u32, 20, 30, 40]; + let result = Arithmetic::::eval(buf, &5); + assert_eq!(result, buffer![5u32, 15, 25, 35]); + } + + #[test] + fn test_mul_buffers() { + let left = buffer![2u32, 3, 4, 5]; + let right = buffer![10u32, 20, 30, 40]; + + let result = Arithmetic::::eval(left, &right); + assert_eq!(result, buffer![20u32, 60, 120, 200]); + } + + #[test] + fn test_mul_scalar() { + let buf = buffer![1u32, 2, 3, 4]; + let result = Arithmetic::::eval(buf, &10); + assert_eq!(result, buffer![10u32, 20, 30, 40]); + } +} diff --git a/vortex-compute/src/arithmetic/buffer_checked.rs b/vortex-compute/src/arithmetic/buffer_checked.rs new file mode 100644 index 00000000000..2eb8a71f99e --- /dev/null +++ b/vortex-compute/src/arithmetic/buffer_checked.rs @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::{Buffer, BufferMut}; + +use crate::arithmetic::{CheckedArithmetic, CheckedOperator}; + +/// Implementation that attempts to downcast to a mutable buffer and operates in-place. +impl CheckedArithmetic> for Buffer +where + T: Copy + num_traits::Zero, + BufferMut: for<'a> CheckedArithmetic, Output = Buffer>, + for<'a> &'a Buffer: CheckedArithmetic, Output = Buffer>, +{ + type Output = Buffer; + + fn checked_eval(self, rhs: &Buffer) -> Option { + match self.try_into_mut() { + Ok(lhs) => lhs.checked_eval(rhs), + Err(lhs) => (&lhs).checked_eval(rhs), // (&lhs) to delegate to borrowed impl + } + } +} + +/// Implementation that operates in-place over a mutable buffer. +impl CheckedArithmetic> for BufferMut +where + T: Copy + num_traits::Zero, + Op: CheckedOperator, +{ + type Output = Buffer; + + fn checked_eval(self, rhs: &Buffer) -> Option { + assert_eq!(self.len(), rhs.len()); + + let mut i = 0; + let mut overflow = false; + let buffer = self + .map_each_in_place(|a| { + // SAFETY: lengths are equal, so index is in bounds + let b = unsafe { *rhs.get_unchecked(i) }; + i += 1; + + // On overflow, set flag and write zero + // We don't abort early because this code vectorizes better without the + // branch, and we expect overflow to be an exception rather than the norm. + Op::apply(&a, &b).unwrap_or_else(|| { + overflow = true; + T::zero() + }) + }) + .freeze(); + + (!overflow).then_some(buffer) + } +} + +/// Implementation that allocates a new output buffer. +impl CheckedArithmetic for &Buffer +where + T: Copy + num_traits::Zero, + Op: CheckedOperator, +{ + type Output = Buffer; + + fn checked_eval(self, rhs: &Buffer) -> Option { + assert_eq!(self.len(), rhs.len()); + + let mut overflow = false; + let buffer = + Buffer::::from_trusted_len_iter(self.iter().zip(rhs.iter()).map(|(a, b)| { + // On overflow, set flag and write zero + // We don't abort early because this code vectorizes better without the + // branch, and we expect overflow to be an exception rather than the norm. + Op::apply(a, b).unwrap_or_else(|| { + overflow = true; + T::zero() + }) + })); + (!overflow).then_some(buffer) + } +} + +/// Implementation that attempts to downcast to a mutable buffer and operates in-place against +/// a scalar RHS value. +impl CheckedArithmetic for Buffer +where + T: Copy + num_traits::Zero, + BufferMut: for<'a> CheckedArithmetic>, + for<'a> &'a Buffer: CheckedArithmetic>, +{ + type Output = Buffer; + + fn checked_eval(self, rhs: &T) -> Option { + match self.try_into_mut() { + Ok(lhs) => lhs.checked_eval(rhs), + Err(lhs) => (&lhs).checked_eval(rhs), + } + } +} + +/// Implementation that operates in-place over a mutable buffer against a scalar RHS value. +impl CheckedArithmetic for BufferMut +where + T: Copy + num_traits::Zero, + Op: CheckedOperator, +{ + type Output = Buffer; + + fn checked_eval(self, rhs: &T) -> Option { + let mut overflow = false; + let buffer = self + .map_each_in_place(|a| { + Op::apply(&a, rhs).unwrap_or_else(|| { + overflow = true; + T::zero() + }) + }) + .freeze(); + + (!overflow).then_some(buffer) + } +} + +/// Implementation that allocates a new output buffer operating against a scalar RHS value. +impl CheckedArithmetic for &Buffer +where + T: Copy + num_traits::Zero, + Op: CheckedOperator, +{ + type Output = Buffer; + + fn checked_eval(self, rhs: &T) -> Option { + let mut overflow = false; + let buffer = Buffer::::from_trusted_len_iter(self.iter().map(|a| { + Op::apply(a, rhs).unwrap_or_else(|| { + overflow = true; + T::zero() + }) + })); + + (!overflow).then_some(buffer) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::arithmetic::{Add, CheckedArithmetic, Div, Mul, Sub}; + + #[test] + fn test_add_buffers() { + let left = buffer![1u32, 2, 3, 4]; + let right = buffer![10u32, 20, 30, 40]; + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result, buffer![11u32, 22, 33, 44]); + } + + #[test] + fn test_add_scalar() { + let buf = buffer![1u32, 2, 3, 4]; + let result = CheckedArithmetic::::checked_eval(buf, &10).unwrap(); + assert_eq!(result, buffer![11u32, 12, 13, 14]); + } + + #[test] + fn test_add_overflow() { + let left = buffer![u8::MAX, 100]; + let right = buffer![1u8, 50]; + + let result = CheckedArithmetic::::checked_eval(left, &right); + assert!(result.is_none()); + } + + #[test] + fn test_sub_buffers() { + let left = buffer![10u32, 20, 30, 40]; + let right = buffer![1u32, 2, 3, 4]; + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result, buffer![9u32, 18, 27, 36]); + } + + #[test] + fn test_sub_scalar() { + let buf = buffer![10u32, 20, 30, 40]; + let result = CheckedArithmetic::::checked_eval(buf, &5).unwrap(); + assert_eq!(result, buffer![5u32, 15, 25, 35]); + } + + #[test] + fn test_sub_underflow() { + let left = buffer![5u32, 10]; + let right = buffer![10u32, 5]; + + let result = CheckedArithmetic::::checked_eval(left, &right); + assert!(result.is_none()); + } + + #[test] + fn test_mul_buffers() { + let left = buffer![2u32, 3, 4, 5]; + let right = buffer![10u32, 20, 30, 40]; + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result, buffer![20u32, 60, 120, 200]); + } + + #[test] + fn test_mul_scalar() { + let buf = buffer![1u32, 2, 3, 4]; + let result = CheckedArithmetic::::checked_eval(buf, &10).unwrap(); + assert_eq!(result, buffer![10u32, 20, 30, 40]); + } + + #[test] + fn test_mul_overflow() { + let left = buffer![u8::MAX, 100]; + let right = buffer![2u8, 3]; + + let result = CheckedArithmetic::::checked_eval(left, &right); + assert!(result.is_none()); + } + + #[test] + fn test_div_buffers() { + let left = buffer![100u32, 200, 300, 400]; + let right = buffer![10u32, 20, 30, 40]; + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result, buffer![10u32, 10, 10, 10]); + } + + #[test] + fn test_div_scalar() { + let buf = buffer![100u32, 200, 300, 400]; + let result = CheckedArithmetic::::checked_eval(buf, &10).unwrap(); + assert_eq!(result, buffer![10u32, 20, 30, 40]); + } + + #[test] + fn test_div_by_zero() { + let left = buffer![10u32, 20, 30]; + let right = buffer![2u32, 0, 3]; + + let result = CheckedArithmetic::::checked_eval(left, &right); + assert!(result.is_none()); + } + + #[test] + fn test_div_scalar_by_zero() { + let buf = buffer![10u32, 20, 30]; + let result = CheckedArithmetic::::checked_eval(buf, &0); + assert!(result.is_none()); + } +} diff --git a/vortex-compute/src/arithmetic/mod.rs b/vortex-compute/src/arithmetic/mod.rs new file mode 100644 index 00000000000..701997f1c07 --- /dev/null +++ b/vortex-compute/src/arithmetic/mod.rs @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Arithmetic operations on buffers and vectors. + +use vortex_dtype::half::f16; + +mod buffer; +mod buffer_checked; +mod pvector; +mod pvector_checked; + +/// Trait for arithmetic operations. +pub trait Arithmetic { + /// The result type after performing the operation. + type Output; + + /// Perform the operation. + fn eval(self, rhs: Rhs) -> Self::Output; +} + +/// Trait for checked arithmetic operators. +pub trait Operator { + /// Apply the operator to the two operands. + fn apply(a: &T, b: &T) -> T; +} + +/// Trait for checked arithmetic operations. +pub trait CheckedArithmetic { + /// The result type after performing the operation. + type Output; + + /// Perform the operation, returning None on overflow/underflow or division by zero. + /// See the `Op` marker detailed semantics on the checked behavior. + fn checked_eval(self, rhs: Rhs) -> Option; +} + +/// Trait for checked arithmetic operators. +pub trait CheckedOperator { + /// Apply the operator to the two operands, returning None on overflow/underflow. + fn apply(a: &T, b: &T) -> Option; +} + +/// Marker type for arithmetic addition. +pub struct Add; +/// Marker type for arithmetic subtraction. +pub struct Sub; +/// Marker type for arithmetic multiplication. +pub struct Mul; +/// Marker type for arithmetic division. +pub struct Div; + +/// Marker type for arithmetic addition that wraps on overflow. +pub struct WrappingAdd; +/// Marker type for arithmetic subtraction that wraps on overflow. +pub struct WrappingSub; +/// Marker type for arithmetic multiplication that wraps on overflow. +pub struct WrappingMul; + +/// Marker type for arithmetic addition that saturates on overflow. +pub struct SaturatingAdd; +/// Marker type for arithmetic subtraction that saturates on overflow. +pub struct SaturatingSub; +/// Marker type for arithmetic multiplication that saturates on overflow. +pub struct SaturatingMul; + +impl CheckedOperator for Add { + #[inline(always)] + fn apply(a: &T, b: &T) -> Option { + num_traits::CheckedAdd::checked_add(a, b) + } +} +impl CheckedOperator for Sub { + #[inline(always)] + fn apply(a: &T, b: &T) -> Option { + num_traits::CheckedSub::checked_sub(a, b) + } +} +impl CheckedOperator for Mul { + #[inline(always)] + fn apply(a: &T, b: &T) -> Option { + num_traits::CheckedMul::checked_mul(a, b) + } +} +impl CheckedOperator for Div { + #[inline(always)] + fn apply(a: &T, b: &T) -> Option { + num_traits::CheckedDiv::checked_div(a, b) + } +} + +impl Operator for WrappingAdd { + #[inline(always)] + fn apply(a: &T, b: &T) -> T { + num_traits::WrappingAdd::wrapping_add(a, b) + } +} +impl Operator for WrappingSub { + #[inline(always)] + fn apply(a: &T, b: &T) -> T { + num_traits::WrappingSub::wrapping_sub(a, b) + } +} +impl Operator for WrappingMul { + #[inline(always)] + fn apply(a: &T, b: &T) -> T { + num_traits::WrappingMul::wrapping_mul(a, b) + } +} + +impl Operator for SaturatingAdd { + #[inline(always)] + fn apply(a: &T, b: &T) -> T { + num_traits::SaturatingAdd::saturating_add(a, b) + } +} +impl Operator for SaturatingSub { + #[inline(always)] + fn apply(a: &T, b: &T) -> T { + num_traits::SaturatingSub::saturating_sub(a, b) + } +} +impl Operator for SaturatingMul { + #[inline(always)] + fn apply(a: &T, b: &T) -> T { + num_traits::SaturatingMul::saturating_mul(a, b) + } +} + +/// Macro to implement arithmetic operators for floating-point types. +/// +/// These are not deferred to the `std::ops::Add` since those implementations will panic on +/// overflow in some cases (e.g., debug builds). +macro_rules! impl_float { + ($T:ty) => { + impl Operator<$T> for Add { + #[inline(always)] + fn apply(a: &$T, b: &$T) -> $T { + a + b + } + } + impl Operator<$T> for Sub { + #[inline(always)] + fn apply(a: &$T, b: &$T) -> $T { + a - b + } + } + impl Operator<$T> for Mul { + #[inline(always)] + fn apply(a: &$T, b: &$T) -> $T { + a * b + } + } + impl Operator<$T> for Div { + #[inline(always)] + fn apply(a: &$T, b: &$T) -> $T { + a / b + } + } + }; +} + +impl_float!(f16); +impl_float!(f32); +impl_float!(f64); diff --git a/vortex-compute/src/arithmetic/pvector.rs b/vortex-compute/src/arithmetic/pvector.rs new file mode 100644 index 00000000000..e9196e67864 --- /dev/null +++ b/vortex-compute/src/arithmetic/pvector.rs @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::BitAnd; + +use vortex_buffer::{Buffer, BufferMut}; +use vortex_dtype::NativePType; +use vortex_vector::{PVector, PVectorMut, VectorMutOps, VectorOps}; + +use crate::arithmetic::{Arithmetic, Operator}; + +/// Implementation that attempts to downcast to a mutable vector and operates in-place. +impl Arithmetic> for PVector +where + T: NativePType, + Op: Operator, +{ + type Output = PVector; + + fn eval(self, rhs: &PVector) -> Self::Output { + match self.try_into_mut() { + Ok(lhs) => Arithmetic::::eval(lhs, rhs), + Err(lhs) => Arithmetic::::eval(&lhs, rhs), + } + } +} + +/// Implementation that operates in-place over a mutable vector. +impl Arithmetic> for PVectorMut +where + T: NativePType, + Op: Operator, + BufferMut: for<'a> Arithmetic, Output = Buffer>, +{ + type Output = PVector; + + fn eval(self, other: &PVector) -> Self::Output { + assert_eq!(self.len(), other.len()); + + let (lhs_buffer, lhs_validity) = self.into_parts(); + + // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here + // or choose a different implementation. + let validity = lhs_validity.freeze().bitand(other.validity()); + let elements = Arithmetic::::eval(lhs_buffer, other.elements()); + + PVector::new(elements, validity) + } +} + +/// Implementation that allocates a new output vector. +impl Arithmetic> for &PVector +where + T: NativePType, + Op: Operator, + for<'a> &'a Buffer: Arithmetic, Output = Buffer>, +{ + type Output = PVector; + + fn eval(self, rhs: &PVector) -> Self::Output { + assert_eq!(self.len(), rhs.len()); + + // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here + // or choose a different implementation. + let validity = self.validity().bitand(rhs.validity()); + + let elements = Arithmetic::::eval(self.elements(), rhs.elements()); + PVector::new(elements, validity) + } +} + +/// Implementation that attempts to downcast to a mutable vector and operates in-place against +/// a scalar RHS value. +impl Arithmetic for PVector +where + T: NativePType, + Op: Operator, + PVectorMut: for<'a> Arithmetic>, +{ + type Output = PVector; + + fn eval(self, rhs: &T) -> Self::Output { + match self.try_into_mut() { + Ok(lhs) => Arithmetic::::eval(lhs, rhs), + Err(lhs) => Arithmetic::::eval(&lhs, rhs), + } + } +} + +/// Implementation that operates in-place over a mutable vector against a scalar RHS value. +impl Arithmetic for PVectorMut +where + T: NativePType, + Op: Operator, + BufferMut: for<'a> Arithmetic>, +{ + type Output = PVector; + + fn eval(self, rhs: &T) -> Self::Output { + let (lhs_buffer, lhs_validity) = self.into_parts(); + let validity = lhs_validity.freeze(); + + let elements = Arithmetic::::eval(lhs_buffer, rhs); + + PVector::new(elements, validity) + } +} + +/// Implementation that allocates a new output vector against a scalar RHS value. +impl Arithmetic for &PVector +where + T: NativePType, + Op: Operator, + for<'a> &'a Buffer: Arithmetic>, +{ + type Output = PVector; + + fn eval(self, rhs: &T) -> Self::Output { + let buffer = Arithmetic::::eval(self.elements(), rhs); + PVector::new(buffer, self.validity().clone()) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_mask::Mask; + use vortex_vector::{PVector, VectorOps}; + + use crate::arithmetic::{Arithmetic, WrappingAdd, WrappingMul, WrappingSub}; + + #[test] + fn test_add_pvectors() { + let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + + let result = Arithmetic::::eval(left, &right); + assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]); + } + + #[test] + fn test_add_scalar() { + let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + let result = Arithmetic::::eval(vec, &10); + assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]); + } + + #[test] + fn test_add_with_nulls() { + let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true])); + let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3)); + + let result = Arithmetic::::eval(left, &right); + // Validity is AND'd, so if either side is null, result is null + assert_eq!(result.validity(), &Mask::from_iter([true, false, true])); + assert_eq!(result.elements(), &buffer![11u32, 22, 33]); + } + + #[test] + fn test_sub_pvectors() { + let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + + let result = Arithmetic::::eval(left, &right); + assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]); + } + + #[test] + fn test_sub_scalar() { + let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + let result = Arithmetic::::eval(vec, &5); + assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]); + } + + #[test] + fn test_mul_pvectors() { + let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4)); + let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + + let result = Arithmetic::::eval(left, &right); + assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]); + } + + #[test] + fn test_mul_scalar() { + let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + let result = Arithmetic::::eval(vec, &10); + assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]); + } + + #[test] + fn test_scalar_preserves_validity() { + let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true])); + let result = Arithmetic::::eval(vec, &10); + + assert_eq!(result.validity(), &Mask::from_iter([true, false, true])); + assert_eq!(result.elements(), &buffer![11u32, 12, 13]); + } +} diff --git a/vortex-compute/src/arithmetic/pvector_checked.rs b/vortex-compute/src/arithmetic/pvector_checked.rs new file mode 100644 index 00000000000..5aaca80b280 --- /dev/null +++ b/vortex-compute/src/arithmetic/pvector_checked.rs @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::BitAnd; + +use vortex_buffer::{Buffer, BufferMut}; +use vortex_dtype::NativePType; +use vortex_vector::{PVector, PVectorMut, VectorMutOps, VectorOps}; + +use crate::arithmetic::CheckedArithmetic; + +/// Implementation that attempts to downcast to a mutable vector and operates in-place. +impl CheckedArithmetic> for PVector +where + T: NativePType, + PVectorMut: for<'a> CheckedArithmetic, Output = PVector>, + for<'a> &'a PVector: CheckedArithmetic, Output = PVector>, +{ + type Output = PVector; + + fn checked_eval(self, rhs: &PVector) -> Option { + match self.try_into_mut() { + Ok(lhs) => CheckedArithmetic::::checked_eval(lhs, rhs), + Err(lhs) => CheckedArithmetic::::checked_eval(&lhs, rhs), + } + } +} + +/// Implementation that operates in-place over a mutable vector. +impl CheckedArithmetic> for PVectorMut +where + T: NativePType, + BufferMut: for<'a> CheckedArithmetic, Output = Buffer>, +{ + type Output = PVector; + + fn checked_eval(self, other: &PVector) -> Option { + assert_eq!(self.len(), other.len()); + + let (lhs_buffer, lhs_validity) = self.into_parts(); + + // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here + // or choose a different implementation. + let validity = lhs_validity.freeze().bitand(other.validity()); + let elements = CheckedArithmetic::::checked_eval(lhs_buffer, other.elements())?; + + Some(PVector::new(elements, validity)) + } +} + +/// Implementation that allocates a new output vector. +impl CheckedArithmetic> for &PVector +where + T: NativePType, + for<'a> &'a Buffer: CheckedArithmetic, Output = Buffer>, +{ + type Output = PVector; + + fn checked_eval(self, rhs: &PVector) -> Option { + assert_eq!(self.len(), rhs.len()); + + // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here + // or choose a different implementation. + let validity = self.validity().bitand(rhs.validity()); + + let elements = CheckedArithmetic::::checked_eval(self.elements(), rhs.elements())?; + Some(PVector::new(elements, validity)) + } +} + +/// Implementation that attempts to downcast to a mutable vector and operates in-place against +/// a scalar RHS value. +impl CheckedArithmetic for PVector +where + T: NativePType, + PVectorMut: for<'a> CheckedArithmetic>, + for<'a> &'a PVector: CheckedArithmetic>, +{ + type Output = PVector; + + fn checked_eval(self, rhs: &T) -> Option { + match self.try_into_mut() { + Ok(lhs) => CheckedArithmetic::::checked_eval(lhs, rhs), + Err(lhs) => CheckedArithmetic::::checked_eval(&lhs, rhs), + } + } +} + +/// Implementation that operates in-place over a mutable vector against a scalar RHS value. +impl CheckedArithmetic for PVectorMut +where + T: NativePType, + BufferMut: for<'a> CheckedArithmetic>, +{ + type Output = PVector; + + fn checked_eval(self, rhs: &T) -> Option { + let (lhs_buffer, lhs_validity) = self.into_parts(); + let validity = lhs_validity.freeze(); + + let elements = CheckedArithmetic::::checked_eval(lhs_buffer, rhs)?; + + Some(PVector::new(elements, validity)) + } +} + +/// Implementation that allocates a new output vector against a scalar RHS value. +impl CheckedArithmetic for &PVector +where + T: NativePType, + for<'a> &'a Buffer: CheckedArithmetic>, +{ + type Output = PVector; + + fn checked_eval(self, rhs: &T) -> Option { + let buffer = CheckedArithmetic::::checked_eval(self.elements(), rhs)?; + Some(PVector::new(buffer, self.validity().clone())) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_mask::Mask; + use vortex_vector::{PVector, VectorOps}; + + use crate::arithmetic::{Add, CheckedArithmetic, Div, Mul, Sub}; + + #[test] + fn test_add_pvectors() { + let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]); + } + + #[test] + fn test_add_scalar() { + let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + let result = CheckedArithmetic::::checked_eval(vec, &10).unwrap(); + assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]); + } + + #[test] + fn test_add_with_nulls() { + let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true])); + let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3)); + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + // Validity is AND'd, so if either side is null, result is null + assert_eq!(result.validity(), &Mask::from_iter([true, false, true])); + assert_eq!(result.elements(), &buffer![11u32, 22, 33]); + } + + #[test] + fn test_sub_pvectors() { + let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]); + } + + #[test] + fn test_sub_scalar() { + let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + let result = CheckedArithmetic::::checked_eval(vec, &5).unwrap(); + assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]); + } + + #[test] + fn test_mul_pvectors() { + let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4)); + let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]); + } + + #[test] + fn test_mul_scalar() { + let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4)); + let result = CheckedArithmetic::::checked_eval(vec, &10).unwrap(); + assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]); + } + + #[test] + fn test_div_pvectors() { + let left = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4)); + let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4)); + + let result = CheckedArithmetic::::checked_eval(left, &right).unwrap(); + assert_eq!(result.elements(), &buffer![10u32, 10, 10, 10]); + } + + #[test] + fn test_div_scalar() { + let vec = PVector::new(buffer![100u32, 200, 300, 400], Mask::new_true(4)); + let result = CheckedArithmetic::::checked_eval(vec, &10).unwrap(); + assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]); + } + + #[test] + fn test_overflow_returns_none() { + let left = PVector::new(buffer![u8::MAX, 100], Mask::new_true(2)); + let right = PVector::new(buffer![1u8, 50], Mask::new_true(2)); + + let result = CheckedArithmetic::::checked_eval(left, &right); + assert!(result.is_none()); + } + + #[test] + fn test_div_by_zero_returns_none() { + let left = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3)); + let right = PVector::new(buffer![2u32, 0, 3], Mask::new_true(3)); + + let result = CheckedArithmetic::::checked_eval(left, &right); + assert!(result.is_none()); + } + + #[test] + fn test_scalar_preserves_validity() { + let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true])); + let result = CheckedArithmetic::::checked_eval(vec, &10).unwrap(); + + assert_eq!(result.validity(), &Mask::from_iter([true, false, true])); + assert_eq!(result.elements(), &buffer![11u32, 12, 13]); + } +} diff --git a/vortex-compute/src/lib.rs b/vortex-compute/src/lib.rs index 5c2b009ef03..f83053f422b 100644 --- a/vortex-compute/src/lib.rs +++ b/vortex-compute/src/lib.rs @@ -7,6 +7,8 @@ #![deny(clippy::missing_panics_doc)] #![deny(clippy::missing_safety_doc)] +#[cfg(feature = "arithmetic")] +pub mod arithmetic; #[cfg(feature = "filter")] pub mod filter; #[cfg(feature = "logical")] diff --git a/vortex-compute/src/logical/and.rs b/vortex-compute/src/logical/and.rs index a5f1ec04e34..38c7206d4c3 100644 --- a/vortex-compute/src/logical/and.rs +++ b/vortex-compute/src/logical/and.rs @@ -3,6 +3,7 @@ use std::ops::BitAnd; +use vortex_mask::Mask; use vortex_vector::{BoolVector, VectorOps}; use crate::logical::LogicalAnd; @@ -27,6 +28,22 @@ impl LogicalAnd<&BoolVector> for BoolVector { } } +impl LogicalAnd for &Mask { + type Output = Mask; + + fn and(self, other: Self) -> Self::Output { + self.bitand(other) + } +} + +impl LogicalAnd<&Mask> for Mask { + type Output = Mask; + + fn and(self, other: &Mask) -> Self::Output { + self.bitand(other) + } +} + #[cfg(test)] mod tests { use vortex_buffer::bitbuffer; diff --git a/vortex-compute/src/logical/or.rs b/vortex-compute/src/logical/or.rs index b88cc840c26..57af49f90c7 100644 --- a/vortex-compute/src/logical/or.rs +++ b/vortex-compute/src/logical/or.rs @@ -3,6 +3,7 @@ use std::ops::{BitAnd, BitOr}; +use vortex_mask::Mask; use vortex_vector::{BoolVector, VectorOps}; use crate::logical::LogicalOr; @@ -27,6 +28,22 @@ impl LogicalOr<&BoolVector> for BoolVector { } } +impl LogicalOr for &Mask { + type Output = Mask; + + fn or(self, other: Self) -> Self::Output { + self.bitor(other) + } +} + +impl LogicalOr<&Mask> for Mask { + type Output = Mask; + + fn or(self, other: &Mask) -> Self::Output { + self.bitor(other) + } +} + #[cfg(test)] mod tests { use vortex_buffer::bitbuffer; diff --git a/vortex-dtype/src/ptype.rs b/vortex-dtype/src/ptype.rs index 14a236187c4..863a89a03af 100644 --- a/vortex-dtype/src/ptype.rs +++ b/vortex-dtype/src/ptype.rs @@ -99,6 +99,7 @@ pub trait NativePType: + FromPrimitive + ToBytes + TryFromBytes + + private::Sealed + 'static { /// The PType that corresponds to this native type @@ -149,6 +150,25 @@ pub trait NativePType: fn upcast(input: V::Input) -> V; } +mod private { + use half::f16; + + /// A private trait to prevent external implementations of `NativePType`. + pub trait Sealed {} + + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for f16 {} + impl Sealed for f32 {} + impl Sealed for f64 {} +} + /// A visitor trait for converting a `NativePType` to another parameterized type. #[allow(missing_docs)] // Kind of obvious. pub trait PTypeDowncast { @@ -170,7 +190,7 @@ pub trait PTypeDowncast { /// Extension trait to provide generic downcasting for [`PTypeDowncast`]. pub trait PTypeDowncastExt: PTypeDowncast { /// Downcast the object to a specific primitive type. - fn into_primitive(self) -> Self::Output + fn downcast(self) -> Self::Output where Self: Sized, { diff --git a/vortex-vector/src/primitive/vector.rs b/vortex-vector/src/primitive/vector.rs index dab6b3d945d..3193fe75e45 100644 --- a/vortex-vector/src/primitive/vector.rs +++ b/vortex-vector/src/primitive/vector.rs @@ -4,7 +4,7 @@ //! Definition and implementation of [`PrimitiveVector`]. use vortex_dtype::half::f16; -use vortex_dtype::{NativePType, PTypeDowncast, PTypeUpcast}; +use vortex_dtype::{NativePType, PType, PTypeDowncast, PTypeUpcast}; use vortex_error::vortex_panic; use super::macros::match_each_pvector; @@ -44,6 +44,25 @@ pub enum PrimitiveVector { F64(PVector), } +impl PrimitiveVector { + /// Returns the [`PType`] of this [`PrimitiveVector`]. + pub fn ptype(&self) -> PType { + match self { + PrimitiveVector::U8(_) => PType::U8, + PrimitiveVector::U16(_) => PType::U16, + PrimitiveVector::U32(_) => PType::U32, + PrimitiveVector::U64(_) => PType::U64, + PrimitiveVector::I8(_) => PType::I8, + PrimitiveVector::I16(_) => PType::I16, + PrimitiveVector::I32(_) => PType::I32, + PrimitiveVector::I64(_) => PType::I64, + PrimitiveVector::F16(_) => PType::F16, + PrimitiveVector::F32(_) => PType::F32, + PrimitiveVector::F64(_) => PType::F64, + } + } +} + impl VectorOps for PrimitiveVector { type Mutable = PrimitiveVectorMut; diff --git a/vortex-vector/src/primitive/vector_mut.rs b/vortex-vector/src/primitive/vector_mut.rs index d1d1436d728..ba6f70a6649 100644 --- a/vortex-vector/src/primitive/vector_mut.rs +++ b/vortex-vector/src/primitive/vector_mut.rs @@ -44,6 +44,25 @@ pub enum PrimitiveVectorMut { F64(PVectorMut), } +impl PrimitiveVectorMut { + /// Returns the [`PType`] of this [`PrimitiveVectorMut`]. + pub fn ptype(&self) -> PType { + match self { + PrimitiveVectorMut::U8(_) => PType::U8, + PrimitiveVectorMut::U16(_) => PType::U16, + PrimitiveVectorMut::U32(_) => PType::U32, + PrimitiveVectorMut::U64(_) => PType::U64, + PrimitiveVectorMut::I8(_) => PType::I8, + PrimitiveVectorMut::I16(_) => PType::I16, + PrimitiveVectorMut::I32(_) => PType::I32, + PrimitiveVectorMut::I64(_) => PType::I64, + PrimitiveVectorMut::F16(_) => PType::F16, + PrimitiveVectorMut::F32(_) => PType::F32, + PrimitiveVectorMut::F64(_) => PType::F64, + } + } +} + impl PrimitiveVectorMut { /// Create a new mutable primitive vector with the given primitive type and capacity. pub fn with_capacity(ptype: PType, capacity: usize) -> Self {