Skip to content

Commit

Permalink
fix: cast literal to timestamp (#5517)
Browse files Browse the repository at this point in the history
* fix: cast literal to timestamp

* update tests for all transformation

* handle cast between same type

* refactor cast_between_timestamp to avoid overflow

* handle overflow to None
  • Loading branch information
Weijun-H authored Mar 13, 2023
1 parent 9464bf2 commit 1f8ede5
Showing 1 changed file with 210 additions and 4 deletions.
214 changes: 210 additions & 4 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
};
use std::cmp::Ordering;
use std::sync::Arc;

/// [`UnwrapCastInComparison`] attempts to remove casts from
Expand Down Expand Up @@ -400,16 +402,36 @@ fn try_cast_literal_to_type(
DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
DataType::Timestamp(TimeUnit::Second, tz) => {
ScalarValue::TimestampSecond(Some(value as i64), tz.clone())
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Second, tz.clone()),
value,
);
ScalarValue::TimestampSecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
ScalarValue::TimestampMillisecond(Some(value as i64), tz.clone())
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
value,
);
ScalarValue::TimestampMillisecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
ScalarValue::TimestampMicrosecond(Some(value as i64), tz.clone())
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
value,
);
ScalarValue::TimestampMicrosecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
ScalarValue::TimestampNanosecond(Some(value as i64), tz.clone())
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
value,
);
ScalarValue::TimestampNanosecond(value, tz.clone())
}
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
Expand All @@ -428,6 +450,32 @@ fn try_cast_literal_to_type(
}
}

/// Cast a timestamp value from one unit to another
fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option<i64> {
let value = value as i64;
let from_scale = match from {
DataType::Timestamp(TimeUnit::Second, _) => 1,
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
_ => return Some(value),
};

let to_scale = match to {
DataType::Timestamp(TimeUnit::Second, _) => 1,
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
_ => return Some(value),
};

match from_scale.cmp(&to_scale) {
Ordering::Less => value.checked_mul(to_scale / from_scale),
Ordering::Greater => Some(value / (from_scale / to_scale)),
Ordering::Equal => Some(value),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1070,4 +1118,162 @@ mod tests {
}
}
}

#[test]
fn test_try_cast_literal_to_timestamp() {
// same timestamp
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();

assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123456), None)
);

// TimestampNanosecond to TimestampMicrosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap()
.unwrap();

assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123), None)
);

// TimestampNanosecond to TimestampMillisecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();

assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));

// TimestampNanosecond to TimestampSecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap()
.unwrap();

assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));

// TimestampMicrosecond to TimestampNanosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();

assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000), None)
);

// TimestampMicrosecond to TimestampMillisecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();

assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));

// TimestampMicrosecond to TimestampSecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));

// TimestampMillisecond to TimestampNanosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000000), None)
);

// TimestampMillisecond to TimestampMicrosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123000), None)
);
// TimestampMillisecond to TimestampSecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));

// TimestampSecond to TimestampNanosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000000000), None)
);

// TimestampSecond to TimestampMicrosecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123000000), None)
);

// TimestampSecond to TimestampMillisecond
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMillisecond(Some(123000), None)
);

// overflow
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(i64::MAX), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
}
}

0 comments on commit 1f8ede5

Please sign in to comment.