Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
705 changes: 380 additions & 325 deletions native/Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ edition = "2021"
rust-version = "1.86"

[workspace.dependencies]
arrow = { version = "56.0.0", features = ["prettyprint", "ffi", "chrono-tz"] }
arrow = { version = "57.0.0", features = ["prettyprint", "ffi", "chrono-tz"] }
async-trait = { version = "0.1" }
bytes = { version = "1.10.0" }
parquet = { version = "56.2.0", default-features = false, features = ["experimental"] }
datafusion = { version = "50.3.0", default-features = false, features = ["unicode_expressions", "crypto_expressions", "nested_expressions", "parquet"] }
datafusion-spark = { version = "50.3.0" }
parquet = { version = "57.0.0", default-features = false, features = ["experimental"] }
datafusion = { version = "51.0.0", default-features = false, features = ["unicode_expressions", "crypto_expressions", "nested_expressions", "parquet"] }
datafusion-spark = { version = "51.0.0" }
datafusion-comet-spark-expr = { path = "spark-expr" }
datafusion-comet-proto = { path = "proto" }
chrono = { version = "0.4", default-features = false, features = ["clock"] }
Expand Down
2 changes: 1 addition & 1 deletion native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ jni = { version = "0.21", features = ["invocation"] }
lazy_static = "1.4"
assertables = "9"
hex = "0.4.3"
datafusion-functions-nested = { version = "50.3.0" }
datafusion-functions-nested = { version = "51.0.0" }

[features]
backtrace = ["datafusion/backtrace"]
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,7 @@ impl PhysicalPlanner {
partition_by,
sort_phy_exprs,
window_frame.into(),
input_schema.as_ref(),
input_schema,
false, // TODO: Ignore nulls
false, // TODO: Spark does not support DISTINCT ... OVER
None,
Expand Down
4 changes: 2 additions & 2 deletions native/core/src/parquet/encryption_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl EncryptionFactory for CometEncryptionFactory {
_options: &EncryptionFactoryOptions,
_schema: &SchemaRef,
_file_path: &Path,
) -> Result<Option<FileEncryptionProperties>, DataFusionError> {
) -> Result<Option<Arc<FileEncryptionProperties>>, DataFusionError> {
Err(DataFusionError::NotImplemented(
"Comet does not support Parquet encryption yet."
.parse()
Expand All @@ -69,7 +69,7 @@ impl EncryptionFactory for CometEncryptionFactory {
&self,
options: &EncryptionFactoryOptions,
file_path: &Path,
) -> Result<Option<FileDecryptionProperties>, DataFusionError> {
) -> Result<Option<Arc<FileDecryptionProperties>>, DataFusionError> {
let config: CometEncryptionConfig = options.to_extension_options()?;

let full_path: String = config.uri_base + file_path.as_ref();
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/parquet/parquet_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub(crate) fn init_datasource_exec(
object_store_url,
file_source,
)
.with_projection(Some(projection_vector))
.with_projection_indices(Some(projection_vector))
.with_table_partition_cols(partition_fields)
.build()
}
Expand Down
6 changes: 3 additions & 3 deletions native/core/src/parquet/read/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,19 +331,19 @@ impl ColumnReader {
None
};
match unit {
ParquetTimeUnit::MILLIS(_) => {
ParquetTimeUnit::MILLIS => {
typed_reader!(
Int64TimestampMillisColumnReader,
ArrowDataType::Timestamp(time_unit, time_zone)
)
}
ParquetTimeUnit::MICROS(_) => {
ParquetTimeUnit::MICROS => {
typed_reader!(
Int64TimestampMicrosColumnReader,
ArrowDataType::Timestamp(time_unit, time_zone)
)
}
ParquetTimeUnit::NANOS(_) => {
ParquetTimeUnit::NANOS => {
typed_reader!(
Int64TimestampNanosColumnReader,
ArrowDataType::Int64
Expand Down
7 changes: 3 additions & 4 deletions native/core/src/parquet/util/jni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use datafusion::execution::object_store::ObjectStoreUrl;
use object_store::path::Path;
use parquet::{
basic::{Encoding, LogicalType, TimeUnit, Type as PhysicalType},
format::{MicroSeconds, MilliSeconds, NanoSeconds},
schema::types::{ColumnDescriptor, ColumnPath, PrimitiveTypeBuilder},
};
use url::{ParseError, Url};
Expand Down Expand Up @@ -185,9 +184,9 @@ fn convert_logical_type(

fn convert_time_unit(time_unit: jint) -> TimeUnit {
match time_unit {
0 => TimeUnit::MILLIS(MilliSeconds::new()),
1 => TimeUnit::MICROS(MicroSeconds::new()),
2 => TimeUnit::NANOS(NanoSeconds::new()),
0 => TimeUnit::MILLIS,
1 => TimeUnit::MICROS,
2 => TimeUnit::NANOS,
_ => panic!("Invalid time unit id for Parquet: {time_unit}"),
}
}
Expand Down
1 change: 1 addition & 0 deletions native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ arrow = {workspace = true}
criterion = { version = "0.7", features = ["async", "async_tokio", "async_std"] }
rand = { workspace = true}
tokio = { version = "1", features = ["rt-multi-thread"] }
datafusion = { workspace = true, features = ["sql"] }

[lib]
name = "datafusion_comet_spark_expr"
Expand Down
10 changes: 8 additions & 2 deletions native/spark-expr/src/agg_funcs/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ use arrow::compute::sum;
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion::common::{not_impl_err, Result, ScalarValue};
use datafusion::logical_expr::{
type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, EmitTo,
GroupsAccumulator, ReversedUDAF, Signature,
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
};
use datafusion::physical_expr::expressions::format_state_name;
use std::{any::Any, sync::Arc};
Expand All @@ -36,6 +35,13 @@ use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::Volatility::Immutable;
use DataType::*;

fn avg_return_type(_name: &str, data_type: &DataType) -> Result<DataType> {
match data_type {
Float64 => Ok(Float64),
_ => not_impl_err!("Avg return type for {data_type}"),
}
}

/// AVG aggregate expression
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Avg {
Expand Down
19 changes: 17 additions & 2 deletions native/spark-expr/src/agg_funcs/avg_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,28 @@ use std::{any::Any, sync::Arc};

use crate::utils::{build_bool_state, is_valid_decimal_precision, unlikely};
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
use arrow::datatypes::{
DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, MAX_DECIMAL128_FOR_EACH_PRECISION,
MIN_DECIMAL128_FOR_EACH_PRECISION,
};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::type_coercion::aggregates::avg_return_type;
use datafusion::logical_expr::Volatility::Immutable;
use num::{integer::div_ceil, Integer};
use DataType::*;

fn avg_return_type(_name: &str, data_type: &DataType) -> Result<DataType> {
match data_type {
Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
Ok(Decimal128(new_precision, new_scale))
}
_ => not_impl_err!("Avg return type for {data_type}"),
}
}

/// AVG aggregate expression
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AvgDecimal {
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/math_funcs/internal/checkoverflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl PhysicalExpr for CheckOverflow {
);

let new_v: Option<i128> = v.and_then(|v| {
Decimal128Type::validate_decimal_precision(v, precision)
Decimal128Type::validate_decimal_precision(v, precision, scale)
.map(|_| v)
.ok()
});
Expand Down
10 changes: 6 additions & 4 deletions native/spark-expr/src/math_funcs/internal/make_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub fn spark_make_decimal(
match &args[0] {
ColumnarValue::Scalar(v) => match v {
ScalarValue::Int64(n) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
long_to_decimal(n, precision),
long_to_decimal(n, precision, scale),
precision,
scale,
))),
Expand All @@ -44,7 +44,7 @@ pub fn spark_make_decimal(
let arr = a.as_primitive::<Int64Type>();
let mut result = Decimal128Builder::new();
for v in arr.into_iter() {
result.append_option(long_to_decimal(&v, precision))
result.append_option(long_to_decimal(&v, precision, scale))
}
let result_type = DataType::Decimal128(precision, scale);

Expand All @@ -58,9 +58,11 @@ pub fn spark_make_decimal(
/// Convert the input long to decimal with the given maximum precision. If overflows, returns null
/// instead.
#[inline]
fn long_to_decimal(v: &Option<i64>, precision: u8) -> Option<i128> {
fn long_to_decimal(v: &Option<i64>, precision: u8, scale: i8) -> Option<i128> {
match v {
Some(v) if validate_decimal_precision(*v as i128, precision).is_ok() => Some(*v as i128),
Some(v) if validate_decimal_precision(*v as i128, precision, scale).is_ok() => {
Some(*v as i128)
}
_ => None,
}
}
Loading