Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improve performance of rem_scalar/div_scalar for integer types (4x-10x) #275

Merged
merged 1 commit into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Cargo.lock
fixtures
settings.json
dev/
.idea/
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ ahash = { version = "0.7", optional = true }

parquet2 = { version = "0.3", optional = true, default_features = false, features = ["stream"] }

# for division/remainder optimization at runtime
strength_reduce = { version = "0.2", optional = true }

[dev-dependencies]
rand = "0.8"
criterion = "0.3"
Expand Down Expand Up @@ -98,7 +101,7 @@ io_parquet_compression = [
io_json_integration = ["io_json", "hex"]
io_print = ["comfy-table"]
# the compute kernels. Disabling this significantly reduces compile time.
compute = []
compute = ["strength_reduce"]
# base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format.
io_parquet = ["parquet2", "io_ipc", "base64", "futures"]
benchmarks = ["rand"]
Expand Down Expand Up @@ -167,3 +170,7 @@ harness = false
[[bench]]
name = "write_ipc"
harness = false

[[bench]]
name = "arithmetic_kernels"
harness = false
52 changes: 52 additions & 0 deletions benches/arithmetic_kernels.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#[macro_use]
extern crate criterion;
use criterion::Criterion;

use arrow2::array::*;
use arrow2::util::bench_util::*;
use arrow2::{
compute::arithmetics::basic::div::div_scalar, datatypes::DataType, types::NativeType,
};
use num::NumCast;
use std::ops::Div;

fn bench_div_scalar<T>(lhs: &PrimitiveArray<T>, rhs: &T)
where
T: NativeType + Div<Output = T> + NumCast,
{
criterion::black_box(div_scalar(lhs, rhs));
}

fn add_benchmark(c: &mut Criterion) {
let size = 65536;
let arr = create_primitive_array_with_seed::<u64>(size, DataType::UInt64, 0.0, 43);

c.bench_function("divide_scalar 4", |b| {
// 4 is a very fast optimizable divisor
b.iter(|| bench_div_scalar(&arr, &4))
});
c.bench_function("divide_scalar prime", |b| {
// large prime number that is probably harder to simplify
b.iter(|| bench_div_scalar(&arr, &524287))
});
}

criterion_group!(benches, add_benchmark);
criterion_main!(benches);
95 changes: 91 additions & 4 deletions src/compute/arithmetics/basic/div.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Definition of basic div operations with primitive arrays
use std::ops::Div;

use num::{CheckedDiv, Zero};
use num::{CheckedDiv, NumCast, Zero};

use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
Expand All @@ -12,6 +13,9 @@ use crate::{
error::{ArrowError, Result},
types::NativeType,
};
use strength_reduce::{
StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8,
};

/// Divides two primitive arrays with the same type.
/// Panics if the divisor is zero of one pair of values overflows.
Expand Down Expand Up @@ -109,10 +113,72 @@ where
/// ```
pub fn div_scalar<T>(lhs: &PrimitiveArray<T>, rhs: &T) -> PrimitiveArray<T>
where
T: NativeType + Div<Output = T>,
T: NativeType + Div<Output = T> + NumCast,
{
let rhs = *rhs;
unary(lhs, |a| a / rhs, lhs.data_type().clone())
match T::DATA_TYPE {
DataType::UInt64 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u64>>().unwrap();
let rhs = rhs.to_u64().unwrap();

let reduced_div = StrengthReducedU64::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u64>` which means that
// T = u64
unsafe {
std::mem::transmute::<PrimitiveArray<u64>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
DataType::UInt32 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u32>>().unwrap();
let rhs = rhs.to_u32().unwrap();

let reduced_div = StrengthReducedU32::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u32>` which means that
// T = u32
unsafe {
std::mem::transmute::<PrimitiveArray<u32>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
DataType::UInt16 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u16>>().unwrap();
let rhs = rhs.to_u16().unwrap();

let reduced_div = StrengthReducedU16::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u16>` which means that
// T = u16
unsafe {
std::mem::transmute::<PrimitiveArray<u16>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
DataType::UInt8 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u8>>().unwrap();
let rhs = rhs.to_u8().unwrap();

let reduced_div = StrengthReducedU8::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u8>` which means that
// T = u8
unsafe {
std::mem::transmute::<PrimitiveArray<u8>, PrimitiveArray<T>>(unary(
lhs,
|a| a / reduced_div,
lhs.data_type().clone(),
))
}
}
_ => unary(lhs, |a| a / rhs, lhs.data_type().clone()),
}
}

/// Checked division of a primitive array of type T by a scalar T. If the
Expand Down Expand Up @@ -141,7 +207,7 @@ where
// Implementation of ArrayDiv trait for PrimitiveArrays with a scalar
impl<T> ArrayDiv<T> for PrimitiveArray<T>
where
T: NativeType + Div<Output = T> + NotI128,
T: NativeType + Div<Output = T> + NotI128 + NumCast,
{
type Output = Self;

Expand Down Expand Up @@ -226,6 +292,27 @@ mod tests {
// Trait testing
let result = a.div(&1i32).unwrap();
assert_eq!(result, expected);

// check the strength reduced branches
let a = UInt64Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u64);
let expected = UInt64Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);

let a = UInt32Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u32);
let expected = UInt32Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);

let a = UInt16Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u16);
let expected = UInt16Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);

let a = UInt8Array::from(&[None, Some(6), None, Some(6)]);
let result = div_scalar(&a, &1u8);
let expected = UInt8Array::from(&[None, Some(6), None, Some(6)]);
assert_eq!(result, expected);
}

#[test]
Expand Down
96 changes: 92 additions & 4 deletions src/compute/arithmetics/basic/rem.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::ops::Rem;

use num::{traits::CheckedRem, Zero};
use num::{traits::CheckedRem, NumCast, Zero};

use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
Expand All @@ -11,6 +12,9 @@ use crate::{
error::{ArrowError, Result},
types::NativeType,
};
use strength_reduce::{
StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8,
};

/// Remainder of two primitive arrays with the same type.
/// Panics if the divisor is zero of one pair of values overflows.
Expand Down Expand Up @@ -106,10 +110,73 @@ where
/// ```
pub fn rem_scalar<T>(lhs: &PrimitiveArray<T>, rhs: &T) -> PrimitiveArray<T>
where
T: NativeType + Rem<Output = T>,
T: NativeType + Rem<Output = T> + NumCast,
{
let rhs = *rhs;
unary(lhs, |a| a % rhs, lhs.data_type().clone())

match T::DATA_TYPE {
DataType::UInt64 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u64>>().unwrap();
let rhs = rhs.to_u64().unwrap();

let reduced_rem = StrengthReducedU64::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u64>` which means that
// T = u64
unsafe {
std::mem::transmute::<PrimitiveArray<u64>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
DataType::UInt32 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u32>>().unwrap();
let rhs = rhs.to_u32().unwrap();

let reduced_rem = StrengthReducedU32::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u32>` which means that
// T = u32
unsafe {
std::mem::transmute::<PrimitiveArray<u32>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
DataType::UInt16 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u16>>().unwrap();
let rhs = rhs.to_u16().unwrap();

let reduced_rem = StrengthReducedU16::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u16>` which means that
// T = u16
unsafe {
std::mem::transmute::<PrimitiveArray<u16>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
DataType::UInt8 => {
let lhs = lhs.as_any().downcast_ref::<PrimitiveArray<u8>>().unwrap();
let rhs = rhs.to_u8().unwrap();

let reduced_rem = StrengthReducedU8::new(rhs);
// Safety: we just proved that `lhs` is `PrimitiveArray<u8>` which means that
// T = u8
unsafe {
std::mem::transmute::<PrimitiveArray<u8>, PrimitiveArray<T>>(unary(
lhs,
|a| a % reduced_rem,
lhs.data_type().clone(),
))
}
}
_ => unary(lhs, |a| a % rhs, lhs.data_type().clone()),
}
}

/// Checked remainder of a primitive array of type T by a scalar T. If the
Expand Down Expand Up @@ -137,7 +204,7 @@ where

impl<T> ArrayRem<T> for PrimitiveArray<T>
where
T: NativeType + Rem<Output = T> + NotI128,
T: NativeType + Rem<Output = T> + NotI128 + NumCast,
{
type Output = Self;

Expand Down Expand Up @@ -221,6 +288,27 @@ mod tests {
// Trait testing
let result = a.rem(&2i32).unwrap();
assert_eq!(result, expected);

// check the strength reduced branches
let a = UInt64Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u64);
let expected = UInt64Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);

let a = UInt32Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u32);
let expected = UInt32Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);

let a = UInt16Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u16);
let expected = UInt16Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);

let a = UInt8Array::from(&[None, Some(6), None, Some(5)]);
let result = rem_scalar(&a, &2u8);
let expected = UInt8Array::from(&[None, Some(0), None, Some(1)]);
assert_eq!(result, expected);
}

#[test]
Expand Down
5 changes: 3 additions & 2 deletions src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub mod time;

use std::ops::{Add, Div, Mul, Neg, Rem, Sub};

use num::Zero;
use num::{NumCast, Zero};

use crate::datatypes::{DataType, TimeUnit};
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -265,7 +265,8 @@ where
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Rem<Output = T>,
+ Rem<Output = T>
+ NumCast,
{
match op {
Operator::Add => Ok(basic::add::add_scalar(lhs, rhs)),
Expand Down