Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎉 add optional arrow2 support #48

Merged
merged 1 commit into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ num-traits = { version = "0.2.15", default-features = false }
half = { version = "2.1.0", default-features = false, features=["num-traits"], optional = true }
ndarray = { version = "0.15.6", default-features = false, optional = true}
arrow = { version = ">0", default-features = false, optional = true}
# arrow2 = { version = ">0", default-features = false, optional = true}
arrow2 = { version = ">0.0", default-features = false, optional = true}
# once_cell = "1.16.0"

[features]
Expand All @@ -27,6 +27,7 @@ float = []
half = ["dep:half"]
ndarray = ["dep:ndarray"]
arrow = ["dep:arrow"]
arrow2 = ["dep:arrow2"]

[dev-dependencies]
rstest = { version = "0.16", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

🚀 The function is generic over the type of the array, so it can be used on `&[T]` or `Vec<T>` where `T` can be `f16`<sup>1</sup>, `f32`<sup>2</sup>, `f64`<sup>2</sup>, `i8`, `i16`, `i32`, `i64`, `u8`, `u16`, `u32`, `u64`.

🤝 The trait is implemented for [`slice`](https://doc.rust-lang.org/std/primitive.slice.html), [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html), 1D [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)<sup>3</sup>, and apache [`arrow::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html)<sup>4</sup>.
🤝 The trait is implemented for [`slice`](https://doc.rust-lang.org/std/primitive.slice.html), [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html), 1D [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)<sup>3</sup>, apache [`arrow::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html)<sup>4</sup> and [`arrow2::PrimitiveArray`](https://docs.rs/arrow2/latest/arrow2/array/struct.PrimitiveArray.html)<sup>5</sup>.

⚡ **Runtime CPU feature detection** is used to select the most efficient implementation for the current CPU. This means that the same binary can be used on different CPUs without recompilation.

Expand All @@ -18,6 +18,7 @@
> <i><sup>2</sup> for <code>f32</code> and <code>f64</code> you should enable the (default) `"float"` feature.</i>
> <i><sup>3</sup> for <code>ndarray::ArrayBase</code> you should enable the `"ndarray"` feature.</i>
> <i><sup>4</sup> for <code>arrow::PrimitiveArray</code> you should enable the `"arrow"` feature.</i>
> <i><sup>5</sup> for <code>arrow2::PrimitiveArray</code> you should enable the `"arrow2"` feature.</i>

## Installing

Expand Down
92 changes: 92 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
//! - **`half`** - enables the traits for `f16` (requires the [`half`](https://crates.io/crates/half) crate).
//! - **`ndarray`** - adds the traits to [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html) (requires the `ndarray` crate).
//! - **`arrow`** - adds the traits to [`arrow::array::PrimitiveArray`](https://docs.rs/arrow/latest/arrow/array/struct.PrimitiveArray.html) (requires the `arrow` crate).
//! - **`arrow2`** - adds the traits to [`arrow2::array::PrimitiveArray`](https://docs.rs/arrow2/latest/arrow2/array/struct.PrimitiveArray.html) (requires the `arrow2` crate).
//!
//!
//! # Examples
Expand Down Expand Up @@ -740,3 +741,94 @@ mod arrow_impl {
}
}
}

// ---------------------- (optional) arrow2 ----------------------

#[cfg(feature = "arrow2")]
mod arrow2_impl {
use super::*;
use arrow2::array::PrimitiveArray;

impl<T> ArgMinMax for PrimitiveArray<T>
where
T: arrow2::types::NativeType,
for<'a> &'a [T]: ArgMinMax,
{
fn argminmax(&self) -> (usize, usize) {
self.values().as_ref().argminmax()
}

fn argmin(&self) -> usize {
self.values().as_ref().argmin()
}

fn argmax(&self) -> usize {
self.values().as_ref().argmax()
}
}

#[cfg(feature = "float")]
impl<T> NaNArgMinMax for PrimitiveArray<T>
where
T: arrow2::types::NativeType,
for<'a> &'a [T]: NaNArgMinMax,
{
fn nanargminmax(&self) -> (usize, usize) {
self.values().as_ref().nanargminmax()
}

fn nanargmin(&self) -> usize {
self.values().as_ref().nanargmin()
}

fn nanargmax(&self) -> usize {
self.values().as_ref().nanargmax()
}
}

#[cfg(feature = "half")]
#[inline(always)]
/// Convert a PrimitiveArray<arrow2::types::f16> to a slice of half::f16
/// To do so, the pointer to the arrow2::types::f16 slice is casted to a pointer to
/// a slice of half::f16 (since both use u16 as their underlying type)
fn _to_half_f16_slice(
primitive_array_f16: &PrimitiveArray<arrow2::types::f16>,
) -> &[half::f16] {
unsafe {
std::slice::from_raw_parts(
primitive_array_f16.values().as_ptr() as *const half::f16,
primitive_array_f16.len(),
)
}
}

#[cfg(feature = "half")]
impl ArgMinMax for PrimitiveArray<arrow2::types::f16> {
fn argminmax(&self) -> (usize, usize) {
_to_half_f16_slice(self).argminmax()
}

fn argmin(&self) -> usize {
_to_half_f16_slice(self).argmin()
}

fn argmax(&self) -> usize {
_to_half_f16_slice(self).argmax()
}
}

#[cfg(feature = "half")]
impl NaNArgMinMax for PrimitiveArray<arrow2::types::f16> {
fn nanargminmax(&self) -> (usize, usize) {
_to_half_f16_slice(self).nanargminmax()
}

fn nanargmin(&self) -> usize {
_to_half_f16_slice(self).nanargmin()
}

fn nanargmax(&self) -> usize {
_to_half_f16_slice(self).nanargmax()
}
}
}
168 changes: 162 additions & 6 deletions tests/argminmax_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use dev_utils::utils;
use rand;

const ARRAY_LENGTH: usize = 100_000;
const NB_RANDOM_RUNS: usize = 500;
const RANDOM_ARR_LENGTH: usize = 5_000;

// ----- dtypes_with_nan template -----

Expand Down Expand Up @@ -207,8 +209,8 @@ mod default_test {
T: Copy + FromPrimitive + AsPrimitive<usize> + rand::distributions::uniform::SampleUniform,
for<'a> &'a [T]: ArgMinMax,
{
for _ in 0..500 {
let data: Vec<T> = utils::get_random_array::<T>(5_000, min, max);
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
Expand Down Expand Up @@ -337,8 +339,8 @@ mod ndarray_tests {
T: Copy + FromPrimitive + AsPrimitive<usize> + rand::distributions::uniform::SampleUniform,
for<'a> &'a [T]: ArgMinMax,
{
for _ in 0..500 {
let data: Vec<T> = utils::get_random_array::<T>(5_000, min, max);
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
Expand Down Expand Up @@ -478,8 +480,8 @@ mod arrow_tests {
ArrowDataType: ArrowPrimitiveType<Native = T> + ArrowNumericType,
PrimitiveArray<ArrowDataType>: From<Vec<T>>,
{
for _ in 0..500 {
let data: Vec<T> = utils::get_random_array::<T>(5_000, min, max);
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
Expand All @@ -497,3 +499,157 @@ mod arrow_tests {
}
}
}

#[cfg(feature = "arrow2")]
#[cfg(test)]
mod arrow2_tests {
use super::*;

use arrow2::array::PrimitiveArray;
use arrow2::types::NativeType;

// Float and not half
#[cfg(feature = "float")]
#[template]
#[rstest]
#[case::float32(f32::MIN, f32::MAX)]
#[case::float64(f64::MIN, f64::MAX)]
fn dtypes_with_nan_arrow2<T>(#[case] min: T, #[case] max: T) {}

#[apply(dtypes)]
fn test_argminmax_arrow2<T>(#[case] _min: T, #[case] max: T)
where
for<'a> &'a [T]: ArgMinMax,
T: Copy + FromPrimitive + AsPrimitive<usize> + NativeType,
{
// max_index is the max value that can be represented by T
let max_index: usize = std::cmp::min(ARRAY_LENGTH, max.as_());

let data: PrimitiveArray<T> =
PrimitiveArray::from_vec(get_monotonic_array(ARRAY_LENGTH, max_index));
// Test owned PrimitiveArray
let (min, max) = data.argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
}

#[cfg(feature = "float")]
#[apply(dtypes_with_nan_arrow2)]
fn test_argminmax_arrow2_nan<T>(#[case] _min: T, #[case] max: T)
where
for<'a> &'a [T]: NaNArgMinMax,
T: Copy + FromPrimitive + AsPrimitive<usize> + NativeType,
{
// max_index is the max value that can be represented by T
let max_index: usize = std::cmp::min(ARRAY_LENGTH, max.as_());

let data: PrimitiveArray<T> =
PrimitiveArray::from_vec(get_monotonic_array(ARRAY_LENGTH, max_index));
// Test owned PrimitiveArray
let (min, max) = data.nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
}

#[apply(dtypes)]
fn test_argminmax_many_random_runs_arrow2<T>(#[case] min: T, #[case] max: T)
where
for<'a> &'a [T]: ArgMinMax,
T: Copy
+ FromPrimitive
+ AsPrimitive<usize>
+ rand::distributions::uniform::SampleUniform
+ NativeType,
{
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<T> = utils::get_random_array::<T>(RANDOM_ARR_LENGTH, min, max);
// Slice
let slice: &[T] = &data;
let (min_slice, max_slice) = slice.argminmax();
// Vec
let (min_vec, max_vec) = data.argminmax();
// Arrow
let arrow: PrimitiveArray<T> = PrimitiveArray::from_vec(data);
let (min_arrow, max_arrow) = arrow.argminmax();

// Check
assert_eq!(min_slice, min_vec);
assert_eq!(max_slice, max_vec);
assert_eq!(min_slice, min_arrow);
assert_eq!(max_slice, max_arrow);
}
}

// Perform the same tests with half::f16 - convert to arrow2::types::f16
#[test]
#[cfg(feature = "half")]
fn test_argminmax_arrow2_f16() {
// Get monotonic array
let max_index: usize = 1 << f16::MANTISSA_DIGITS;
let data: Vec<f16> = get_monotonic_array(ARRAY_LENGTH, max_index);
// Convert the half::f16 vec to PrimitiveArray<arrow2::types::f16>
let data: Vec<arrow2::types::f16> = data
.into_iter()
.map(|x| arrow2::types::f16(x.to_bits()))
.collect();

let data: PrimitiveArray<arrow2::types::f16> = PrimitiveArray::from_vec(data);

// --- ArgMinMax
// Test owned PrimitiveArray
let (min, max) = data.argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).argminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);

// --- NaNArgMinMax
// Test owned PrimitiveArray
let (min, max) = data.nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);
// Test borrowed PrimitiveArray
let (min, max) = (&data).nanargminmax();
assert_eq!(min, 0);
assert_eq!(max, max_index - 1);

// --- many random runs
for _ in 0..NB_RANDOM_RUNS {
let data: Vec<i16> =
utils::get_random_array::<i16>(RANDOM_ARR_LENGTH, i16::MIN, i16::MAX);
// convert to half::f16
let data_half: Vec<f16> = data.into_iter().map(|x| f16::from_bits(x as u16)).collect();
// convert to arrow2::types::f16
let data: Vec<arrow2::types::f16> = data_half
.clone()
.into_iter()
.map(|x| arrow2::types::f16(x.to_bits()))
.collect();

// Slice
let slice: &[f16] = &data_half;
let (min_slice, max_slice) = slice.argminmax();
// Vec
let (min_vec, max_vec) = data_half.argminmax();
// Arrow2
let arrow: PrimitiveArray<arrow2::types::f16> = PrimitiveArray::from_vec(data);
let (min_arrow, max_arrow) = arrow.argminmax();

// Check
assert_eq!(min_slice, min_vec);
assert_eq!(max_slice, max_vec);
assert_eq!(min_slice, min_arrow);
assert_eq!(max_slice, max_arrow);
}
}
}