Skip to content

Commit

Permalink
Datum based arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jul 4, 2023
1 parent 07a721f commit 68a89b3
Show file tree
Hide file tree
Showing 25 changed files with 270 additions and 3,761 deletions.
12 changes: 10 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ rust-version = "1.64"

[workspace.dependencies]
arrow = { version = "43.0.0", features = ["prettyprint", "dyn_cmp_dict"] }
arrow-flight = { version = "43.0.0", features = ["flight-sql-experimental"] }
arrow-array = { version = "43.0.0", default-features = false, features = ["chrono-tz"] }
arrow-buffer = { version = "43.0.0", default-features = false }
arrow-flight = { version = "43.0.0", features = ["flight-sql-experimental"] }
arrow-schema = { version = "43.0.0", default-features = false }
arrow-array = { version = "43.0.0", default-features = false, features = ["chrono-tz"] }
parquet = { version = "43.0.0", features = ["arrow", "async", "object_store"] }
sqlparser = { version = "0.35", features = ["visitor"] }

Expand All @@ -71,3 +71,11 @@ opt-level = 3
overflow-checks = false
panic = 'unwind'
rpath = false

[patch.crates-io]
arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
arrow-flight = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
45 changes: 15 additions & 30 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions datafusion-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ assert_cmd = "2.0"
ctor = "0.2.0"
predicates = "3.0"
rstest = "0.17"

[patch.crates-io]
arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
151 changes: 59 additions & 92 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use arrow::{
DECIMAL128_MAX_PRECISION,
},
};
use arrow_array::timezone::Tz;
use arrow_array::{timezone::Tz, ArrowNativeTypeOp};
use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};

// Constants we use throughout this file:
Expand Down Expand Up @@ -743,55 +743,21 @@ macro_rules! impl_op {
($LHS:expr, $RHS:expr, -) => {
match ($LHS, $RHS) {
(
ScalarValue::TimestampSecond(Some(ts_lhs), tz_lhs),
ScalarValue::TimestampSecond(Some(ts_rhs), tz_rhs),
) => {
let err = || {
DataFusionError::Execution(
"Overflow while converting seconds to milliseconds".to_string(),
)
};
ts_sub_to_interval::<MILLISECOND_MODE>(
ts_lhs.checked_mul(1_000).ok_or_else(err)?,
ts_rhs.checked_mul(1_000).ok_or_else(err)?,
tz_lhs.as_deref(),
tz_rhs.as_deref(),
)
},
ScalarValue::TimestampSecond(Some(ts_lhs), _),
ScalarValue::TimestampSecond(Some(ts_rhs), _),
) => Ok(ScalarValue::DurationSecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
(
ScalarValue::TimestampMillisecond(Some(ts_lhs), tz_lhs),
ScalarValue::TimestampMillisecond(Some(ts_rhs), tz_rhs),
) => ts_sub_to_interval::<MILLISECOND_MODE>(
*ts_lhs,
*ts_rhs,
tz_lhs.as_deref(),
tz_rhs.as_deref(),
),
ScalarValue::TimestampMillisecond(Some(ts_lhs), _),
ScalarValue::TimestampMillisecond(Some(ts_rhs), _),
) => Ok(ScalarValue::DurationMillisecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
(
ScalarValue::TimestampMicrosecond(Some(ts_lhs), tz_lhs),
ScalarValue::TimestampMicrosecond(Some(ts_rhs), tz_rhs),
) => {
let err = || {
DataFusionError::Execution(
"Overflow while converting microseconds to nanoseconds".to_string(),
)
};
ts_sub_to_interval::<NANOSECOND_MODE>(
ts_lhs.checked_mul(1_000).ok_or_else(err)?,
ts_rhs.checked_mul(1_000).ok_or_else(err)?,
tz_lhs.as_deref(),
tz_rhs.as_deref(),
)
},
ScalarValue::TimestampMicrosecond(Some(ts_lhs), _),
ScalarValue::TimestampMicrosecond(Some(ts_rhs), _),
) => Ok(ScalarValue::DurationMicrosecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
(
ScalarValue::TimestampNanosecond(Some(ts_lhs), tz_lhs),
ScalarValue::TimestampNanosecond(Some(ts_rhs), tz_rhs),
) => ts_sub_to_interval::<NANOSECOND_MODE>(
*ts_lhs,
*ts_rhs,
tz_lhs.as_deref(),
tz_rhs.as_deref(),
),
ScalarValue::TimestampNanosecond(Some(ts_lhs), _),
ScalarValue::TimestampNanosecond(Some(ts_rhs), _),
) => Ok(ScalarValue::DurationNanosecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
_ => impl_op_arithmetic!($LHS, $RHS, -)
}
};
Expand Down Expand Up @@ -1147,49 +1113,6 @@ pub const MDN_MODE: i8 = 2;

pub const MILLISECOND_MODE: bool = false;
pub const NANOSECOND_MODE: bool = true;
/// This function computes subtracts `rhs_ts` from `lhs_ts`, taking timezones
/// into account when given. Units of the resulting interval is specified by
/// the constant `TIME_MODE`.
/// The default behavior of Datafusion is the following:
/// - When subtracting timestamps at seconds/milliseconds precision, the output
/// interval will have the type [`IntervalDayTimeType`].
/// - When subtracting timestamps at microseconds/nanoseconds precision, the
/// output interval will have the type [`IntervalMonthDayNanoType`].
fn ts_sub_to_interval<const TIME_MODE: bool>(
lhs_ts: i64,
rhs_ts: i64,
lhs_tz: Option<&str>,
rhs_tz: Option<&str>,
) -> Result<ScalarValue> {
let parsed_lhs_tz = parse_timezones(lhs_tz)?;
let parsed_rhs_tz = parse_timezones(rhs_tz)?;

let (naive_lhs, naive_rhs) =
calculate_naives::<TIME_MODE>(lhs_ts, parsed_lhs_tz, rhs_ts, parsed_rhs_tz)?;
let delta_secs = naive_lhs.signed_duration_since(naive_rhs);

match TIME_MODE {
MILLISECOND_MODE => {
let as_millisecs = delta_secs.num_milliseconds();
Ok(ScalarValue::new_interval_dt(
(as_millisecs / MILLISECS_IN_ONE_DAY) as i32,
(as_millisecs % MILLISECS_IN_ONE_DAY) as i32,
))
}
NANOSECOND_MODE => {
let as_nanosecs = delta_secs.num_nanoseconds().ok_or_else(|| {
DataFusionError::Execution(String::from(
"Can not compute timestamp differences with nanosecond precision",
))
})?;
Ok(ScalarValue::new_interval_mdn(
0,
(as_nanosecs / NANOSECS_IN_ONE_DAY) as i32,
as_nanosecs % NANOSECS_IN_ONE_DAY,
))
}
}
}

/// This function parses the timezone from string to Tz.
/// If it cannot parse or timezone field is [`None`], it returns [`None`].
Expand Down Expand Up @@ -1424,6 +1347,14 @@ where
ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign),
ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i, sign),
ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign),
ScalarValue::DurationSecond(Some(v)) => prior.add(Duration::seconds(*v)),
ScalarValue::DurationMillisecond(Some(v)) => {
prior.add(Duration::milliseconds(*v))
}
ScalarValue::DurationMicrosecond(Some(v)) => {
prior.add(Duration::microseconds(*v))
}
ScalarValue::DurationNanosecond(Some(v)) => prior.add(Duration::nanoseconds(*v)),
other => Err(DataFusionError::Execution(format!(
"DateIntervalExpr does not support non-interval type {other:?}"
)))?,
Expand Down Expand Up @@ -1891,6 +1822,16 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano) => {
ScalarValue::IntervalMonthDayNano(Some(0))
}
DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None),
DataType::Duration(TimeUnit::Millisecond) => {
ScalarValue::DurationMillisecond(None)
}
DataType::Duration(TimeUnit::Microsecond) => {
ScalarValue::DurationMicrosecond(None)
}
DataType::Duration(TimeUnit::Nanosecond) => {
ScalarValue::DurationNanosecond(None)
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a zero scalar from data_type \"{datatype:?}\""
Expand Down Expand Up @@ -3191,6 +3132,20 @@ impl ScalarValue {
IntervalMonthDayNano
)
}

DataType::Duration(TimeUnit::Second) => {
typed_cast!(array, index, DurationSecondArray, DurationSecond)
}
DataType::Duration(TimeUnit::Millisecond) => {
typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)
}
DataType::Duration(TimeUnit::Microsecond) => {
typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)
}
DataType::Duration(TimeUnit::Nanosecond) => {
typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)
}

other => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from array of type \"{other:?}\""
Expand Down Expand Up @@ -3682,6 +3637,18 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano) => {
ScalarValue::IntervalMonthDayNano(None)
}

DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None),
DataType::Duration(TimeUnit::Millisecond) => {
ScalarValue::DurationMillisecond(None)
}
DataType::Duration(TimeUnit::Microsecond) => {
ScalarValue::DurationMicrosecond(None)
}
DataType::Duration(TimeUnit::Nanosecond) => {
ScalarValue::DurationNanosecond(None)
}

DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary(
index_type.clone(),
Box::new(value_type.as_ref().try_into()?),
Expand Down Expand Up @@ -3944,7 +3911,7 @@ mod tests {
use std::sync::Arc;

use arrow::compute::kernels;
use arrow::compute::{self, concat, is_null};
use arrow::compute::{concat, is_null};
use arrow::datatypes::ArrowPrimitiveType;
use arrow::util::pretty::pretty_format_columns;
use arrow_array::ArrowNumericType;
Expand Down Expand Up @@ -4073,7 +4040,7 @@ mod tests {
let right_array = right.to_array();
let arrow_left_array = left_array.as_primitive::<T>();
let arrow_right_array = right_array.as_primitive::<T>();
let arrow_result = compute::add_checked(arrow_left_array, arrow_right_array);
let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array);

assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
}
Expand Down
Loading

0 comments on commit 68a89b3

Please sign in to comment.