diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index 34fba6f79304..f0b0bc23fd73 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![doc = include_str!("../README.md")] pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index fb0a960f37b6..1322577b207a 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index 0394b05277da..1c5e38438724 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index ec8db0bdcd91..a9a7432c8cfc 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 68b753a6678a..791c2d16aeb5 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -22,9 +22,10 @@ use crate::{downcast_value, Result}; use arrow::array::{ - BinaryViewArray, DurationMicrosecondArray, DurationMillisecondArray, - DurationNanosecondArray, DurationSecondArray, Float16Array, Int16Array, Int8Array, - LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array, + BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, + Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray, + UInt16Array, }; use arrow::{ array::{ @@ -97,6 +98,16 @@ pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> { Ok(downcast_value!(array, UInt64Array)) } +// Downcast Array to Decimal32Array +pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> { + Ok(downcast_value!(array, Decimal32Array)) +} + +// Downcast Array to Decimal64Array +pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> { + Ok(downcast_value!(array, Decimal64Array)) +} + // Downcast Array to Decimal128Array pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> { Ok(downcast_value!(array, Decimal128Array)) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index f9e3b2cee40d..31b77727f6cf 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -798,6 +798,14 @@ impl DFSchema { .zip(iter2) .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2)) } + ( + DataType::Decimal32(_l_precision, _l_scale), + DataType::Decimal32(_r_precision, _r_scale), + ) => true, + ( + DataType::Decimal64(_l_precision, _l_scale), + DataType::Decimal64(_r_precision, _r_scale), + ) => true, ( DataType::Decimal128(_l_precision, _l_scale), DataType::Decimal128(_r_precision, _r_scale), @@ -863,6 +871,214 @@ impl DFSchema { .zip(self.inner.fields().iter()) .map(|(qualifier, field)| (qualifier.as_ref(), field)) } + /// Print schema in tree format + /// + /// This method formats the schema + /// with a tree-like structure showing field names, types, and nullability. + /// + /// # Example + /// + /// ``` + /// use datafusion_common::DFSchema; + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use std::collections::HashMap; + /// + /// let schema = DFSchema::from_unqualified_fields( + /// vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("name", DataType::Utf8, true), + /// ].into(), + /// HashMap::new() + /// ).unwrap(); + /// + /// assert_eq!(schema.print_schema_tree().to_string(), + /// r#"root + /// |-- id: int32 (nullable = false) + /// |-- name: utf8 (nullable = true)"#); + /// ``` + pub fn print_schema_tree(&self) -> impl Display + '_ { + let mut result = String::from("root\n"); + + for (qualifier, field) in self.iter() { + let field_name = match qualifier { + Some(q) => format!("{}.{}", q, field.name()), + None => field.name().to_string(), + }; + + format_field_with_indent( + &mut result, + &field_name, + field.data_type(), + field.is_nullable(), + " ", + ); + } + + // Remove the trailing newline + if result.ends_with('\n') { + result.pop(); + } + + result + } +} + +/// Format field with proper nested indentation for complex types +fn format_field_with_indent( + result: &mut String, + field_name: &str, + data_type: &DataType, + nullable: bool, + indent: &str, +) { + let nullable_str = nullable.to_string().to_lowercase(); + let child_indent = format!("{indent}| "); + + match data_type { + DataType::List(field) => { + result.push_str(&format!( + "{indent}|-- {field_name}: list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::LargeList(field) => { + result.push_str(&format!( + "{indent}|-- {field_name}: large list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::FixedSizeList(field, _size) => { + result.push_str(&format!( + "{indent}|-- {field_name}: fixed size list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::Map(field, _) => { + result.push_str(&format!( + "{indent}|-- {field_name}: map (nullable = {nullable_str})\n" + )); + if let DataType::Struct(inner_fields) = field.data_type() { + if inner_fields.len() == 2 { + format_field_with_indent( + result, + "key", + inner_fields[0].data_type(), + inner_fields[0].is_nullable(), + &child_indent, + ); + let value_contains_null = + field.is_nullable().to_string().to_lowercase(); + // Handle complex value types properly + match inner_fields[1].data_type() { + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Map(_, _) => { + format_field_with_indent( + result, + "value", + inner_fields[1].data_type(), + inner_fields[1].is_nullable(), + &child_indent, + ); + } + _ => { + result.push_str(&format!("{child_indent}|-- value: {} (nullable = {value_contains_null})\n", + format_simple_data_type(inner_fields[1].data_type()))); + } + } + } + } + } + DataType::Struct(fields) => { + result.push_str(&format!( + "{indent}|-- {field_name}: struct (nullable = {nullable_str})\n" + )); + for struct_field in fields { + format_field_with_indent( + result, + struct_field.name(), + struct_field.data_type(), + struct_field.is_nullable(), + &child_indent, + ); + } + } + _ => { + let type_str = format_simple_data_type(data_type); + result.push_str(&format!( + "{indent}|-- {field_name}: {type_str} (nullable = {nullable_str})\n" + )); + } + } +} + +/// Format simple DataType in lowercase format (for leaf nodes) +fn format_simple_data_type(data_type: &DataType) -> String { + match data_type { + DataType::Boolean => "boolean".to_string(), + DataType::Int8 => "int8".to_string(), + DataType::Int16 => "int16".to_string(), + DataType::Int32 => "int32".to_string(), + DataType::Int64 => "int64".to_string(), + DataType::UInt8 => "uint8".to_string(), + DataType::UInt16 => "uint16".to_string(), + DataType::UInt32 => "uint32".to_string(), + DataType::UInt64 => "uint64".to_string(), + DataType::Float16 => "float16".to_string(), + DataType::Float32 => "float32".to_string(), + DataType::Float64 => "float64".to_string(), + DataType::Utf8 => "utf8".to_string(), + DataType::LargeUtf8 => "large_utf8".to_string(), + DataType::Binary => "binary".to_string(), + DataType::LargeBinary => "large_binary".to_string(), + DataType::FixedSizeBinary(_) => "fixed_size_binary".to_string(), + DataType::Date32 => "date32".to_string(), + DataType::Date64 => "date64".to_string(), + DataType::Time32(_) => "time32".to_string(), + DataType::Time64(_) => "time64".to_string(), + DataType::Timestamp(_, tz) => match tz { + Some(tz_str) => format!("timestamp ({tz_str})"), + None => "timestamp".to_string(), + }, + DataType::Interval(_) => "interval".to_string(), + DataType::Dictionary(_, value_type) => { + format_simple_data_type(value_type.as_ref()) + } + DataType::Decimal32(precision, scale) => { + format!("decimal32({precision}, {scale})") + } + DataType::Decimal64(precision, scale) => { + format!("decimal64({precision}, {scale})") + } + DataType::Decimal128(precision, scale) => { + format!("decimal128({precision}, {scale})") + } + DataType::Decimal256(precision, scale) => { + format!("decimal256({precision}, {scale})") + } + DataType::Null => "null".to_string(), + _ => format!("{data_type}").to_lowercase(), + } } impl From for Schema { @@ -1596,6 +1812,27 @@ mod tests { &DataType::Int16 )); + // Succeeds if decimal precision and scale are different + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal32(1, 2), + &DataType::Decimal32(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal64(1, 2), + &DataType::Decimal64(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal128(1, 2), + &DataType::Decimal128(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal256(1, 2), + &DataType::Decimal256(2, 1), + )); + // Test lists // Succeeds if both have the same element type, disregards names and nullability @@ -1738,4 +1975,488 @@ mod tests { fn test_metadata_n(n: usize) -> HashMap { (0..n).map(|i| (format!("k{i}"), format!("v{i}"))).collect() } + + #[test] + fn test_print_schema_unqualified() { + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, true), + Field::new("active", DataType::Boolean, false), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + + insta::assert_snapshot!(output, @r" + root + |-- id: int32 (nullable = false) + |-- name: utf8 (nullable = true) + |-- age: int64 (nullable = true) + |-- active: boolean (nullable = false) + "); + } + + #[test] + fn test_print_schema_qualified() { + let schema = DFSchema::try_from_qualified_schema( + "table1", + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ]), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + + insta::assert_snapshot!(output, @r" + root + |-- table1.id: int32 (nullable = false) + |-- table1.name: utf8 (nullable = true) + "); + } + + #[test] + fn test_print_schema_complex_types() { + let struct_field = Field::new( + "address", + DataType::Struct(Fields::from(vec![ + Field::new("street", DataType::Utf8, true), + Field::new("city", DataType::Utf8, true), + ])), + true, + ); + + let list_field = Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("id", DataType::Int32, false), + struct_field, + list_field, + Field::new("score", DataType::Decimal128(10, 2), true), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + insta::assert_snapshot!(output, @r" + root + |-- id: int32 (nullable = false) + |-- address: struct (nullable = true) + | |-- street: utf8 (nullable = true) + | |-- city: utf8 (nullable = true) + |-- tags: list (nullable = true) + | |-- item: utf8 (nullable = true) + |-- score: decimal128(10, 2) (nullable = true) + "); + } + + #[test] + fn test_print_schema_empty() { + let schema = DFSchema::empty(); + let output = schema.print_schema_tree(); + insta::assert_snapshot!(output, @r###"root"###); + } + + #[test] + fn test_print_schema_deeply_nested_types() { + // Create a deeply nested structure to test indentation and complex type formatting + let inner_struct = Field::new( + "inner", + DataType::Struct(Fields::from(vec![ + Field::new("level1", DataType::Utf8, true), + Field::new("level2", DataType::Int32, false), + ])), + true, + ); + + let nested_list = Field::new( + "nested_list", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Float64, true), + ])), + true, + ))), + true, + ); + + let map_field = Field::new( + "map_data", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + ))), + true, + ), + ])), + false, + )), + false, + ), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("simple_field", DataType::Utf8, true), + inner_struct, + nested_list, + map_field, + Field::new( + "timestamp_field", + DataType::Timestamp( + arrow::datatypes::TimeUnit::Microsecond, + Some("UTC".into()), + ), + false, + ), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + + insta::assert_snapshot!(output, @r" + root + |-- simple_field: utf8 (nullable = true) + |-- inner: struct (nullable = true) + | |-- level1: utf8 (nullable = true) + | |-- level2: int32 (nullable = false) + |-- nested_list: list (nullable = true) + | |-- item: struct (nullable = true) + | | |-- id: int64 (nullable = false) + | | |-- value: float64 (nullable = true) + |-- map_data: map (nullable = true) + | |-- key: utf8 (nullable = false) + | |-- value: list (nullable = true) + | | |-- item: int32 (nullable = true) + |-- timestamp_field: timestamp (UTC) (nullable = false) + "); + } + + #[test] + fn test_print_schema_mixed_qualified_unqualified() { + // Test a schema with mixed qualified and unqualified fields + let schema = DFSchema::new_with_metadata( + vec![ + ( + Some("table1".into()), + Arc::new(Field::new("id", DataType::Int32, false)), + ), + (None, Arc::new(Field::new("name", DataType::Utf8, true))), + ( + Some("table2".into()), + Arc::new(Field::new("score", DataType::Float64, true)), + ), + ( + None, + Arc::new(Field::new("active", DataType::Boolean, false)), + ), + ], + HashMap::new(), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + + insta::assert_snapshot!(output, @r" + root + |-- table1.id: int32 (nullable = false) + |-- name: utf8 (nullable = true) + |-- table2.score: float64 (nullable = true) + |-- active: boolean (nullable = false) + "); + } + + #[test] + fn test_print_schema_array_of_map() { + // Test the specific example from user feedback: array of map + let map_field = Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ])), + false, + ); + + let array_of_map_field = Field::new( + "array_map_field", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map(Arc::new(map_field), false), + false, + ))), + false, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![array_of_map_field].into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + + insta::assert_snapshot!(output, @r" + root + |-- array_map_field: list (nullable = false) + | |-- item: map (nullable = false) + | | |-- key: utf8 (nullable = false) + | | |-- value: utf8 (nullable = false) + "); + } + + #[test] + fn test_print_schema_complex_type_combinations() { + // Test various combinations of list, struct, and map types + + // List of structs + let list_of_structs = Field::new( + "list_of_structs", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("score", DataType::Float64, true), + ])), + true, + ))), + true, + ); + + // Struct containing lists + let struct_with_lists = Field::new( + "struct_with_lists", + DataType::Struct(Fields::from(vec![ + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ), + Field::new( + "scores", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new("metadata", DataType::Utf8, true), + ])), + false, + ); + + // Map with struct values + let map_with_struct_values = Field::new( + "map_with_struct_values", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct(Fields::from(vec![ + Field::new("count", DataType::Int64, false), + Field::new("active", DataType::Boolean, true), + ])), + true, + ), + ])), + false, + )), + false, + ), + true, + ); + + // List of maps + let list_of_maps = Field::new( + "list_of_maps", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ))), + true, + ); + + // Deeply nested: struct containing list of structs containing maps + let deeply_nested = Field::new( + "deeply_nested", + DataType::Struct(Fields::from(vec![ + Field::new("level1", DataType::Utf8, true), + Field::new( + "level2", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "properties", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, true), + ])), + false, + )), + false, + ), + true, + ), + ])), + true, + ))), + false, + ), + ])), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + list_of_structs, + struct_with_lists, + map_with_struct_values, + list_of_maps, + deeply_nested, + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + + insta::assert_snapshot!(output, @r" + root + |-- list_of_structs: list (nullable = true) + | |-- item: struct (nullable = true) + | | |-- id: int32 (nullable = false) + | | |-- name: utf8 (nullable = true) + | | |-- score: float64 (nullable = true) + |-- struct_with_lists: struct (nullable = false) + | |-- tags: list (nullable = true) + | | |-- item: utf8 (nullable = true) + | |-- scores: list (nullable = false) + | | |-- item: int32 (nullable = true) + | |-- metadata: utf8 (nullable = true) + |-- map_with_struct_values: map (nullable = true) + | |-- key: utf8 (nullable = false) + | |-- value: struct (nullable = true) + | | |-- count: int64 (nullable = false) + | | |-- active: boolean (nullable = true) + |-- list_of_maps: list (nullable = true) + | |-- item: map (nullable = true) + | | |-- key: utf8 (nullable = false) + | | |-- value: int32 (nullable = false) + |-- deeply_nested: struct (nullable = true) + | |-- level1: utf8 (nullable = true) + | |-- level2: list (nullable = false) + | | |-- item: struct (nullable = true) + | | | |-- id: int32 (nullable = false) + | | | |-- properties: map (nullable = true) + | | | | |-- key: utf8 (nullable = false) + | | | | |-- value: float64 (nullable = false) + "); + } + + #[test] + fn test_print_schema_edge_case_types() { + // Test edge cases and special types + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("null_field", DataType::Null, true), + Field::new("binary_field", DataType::Binary, false), + Field::new("large_binary", DataType::LargeBinary, true), + Field::new("large_utf8", DataType::LargeUtf8, false), + Field::new("fixed_size_binary", DataType::FixedSizeBinary(16), true), + Field::new( + "fixed_size_list", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Int32, true)), + 5, + ), + false, + ), + Field::new("decimal32", DataType::Decimal32(9, 4), true), + Field::new("decimal64", DataType::Decimal64(9, 4), true), + Field::new("decimal128", DataType::Decimal128(18, 4), true), + Field::new("decimal256", DataType::Decimal256(38, 10), false), + Field::new("date32", DataType::Date32, true), + Field::new("date64", DataType::Date64, false), + Field::new( + "time32_seconds", + DataType::Time32(arrow::datatypes::TimeUnit::Second), + true, + ), + Field::new( + "time64_nanoseconds", + DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond), + false, + ), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.print_schema_tree(); + + insta::assert_snapshot!(output, @r" + root + |-- null_field: null (nullable = true) + |-- binary_field: binary (nullable = false) + |-- large_binary: large_binary (nullable = true) + |-- large_utf8: large_utf8 (nullable = false) + |-- fixed_size_binary: fixed_size_binary (nullable = true) + |-- fixed_size_list: fixed size list (nullable = false) + | |-- item: int32 (nullable = true) + |-- decimal32: decimal32(9, 4) (nullable = true) + |-- decimal64: decimal64(9, 4) (nullable = true) + |-- decimal128: decimal128(18, 4) (nullable = true) + |-- decimal256: decimal256(38, 10) (nullable = false) + |-- date32: date32 (nullable = true) + |-- date64: date64 (nullable = false) + |-- time32_seconds: time32 (nullable = true) + |-- time64_nanoseconds: time64 (nullable = false) + "); + } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 3a558fa86789..8a7b765d205d 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 91058575723e..862d896353f4 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -35,13 +35,14 @@ use std::sync::Arc; use crate::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_date32_array, - as_date64_array, as_decimal128_array, as_decimal256_array, as_dictionary_array, - as_duration_microsecond_array, as_duration_millisecond_array, - as_duration_nanosecond_array, as_duration_second_array, as_fixed_size_binary_array, - as_fixed_size_list_array, as_float16_array, as_float32_array, as_float64_array, - as_int16_array, as_int32_array, as_int64_array, as_int8_array, as_interval_dt_array, - as_interval_mdn_array, as_interval_ym_array, as_large_binary_array, - as_large_list_array, as_large_string_array, as_string_array, as_string_view_array, + as_date64_array, as_decimal128_array, as_decimal256_array, as_decimal32_array, + as_decimal64_array, as_dictionary_array, as_duration_microsecond_array, + as_duration_millisecond_array, as_duration_nanosecond_array, + as_duration_second_array, as_fixed_size_binary_array, as_fixed_size_list_array, + as_float16_array, as_float32_array, as_float64_array, as_int16_array, as_int32_array, + as_int64_array, as_int8_array, as_interval_dt_array, as_interval_mdn_array, + as_interval_ym_array, as_large_binary_array, as_large_list_array, + as_large_string_array, as_string_array, as_string_view_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, @@ -56,17 +57,17 @@ use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ new_empty_array, new_null_array, Array, ArrayData, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal128Array, Decimal256Array, DictionaryArray, - DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, - DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, Float16Array, - Float32Array, Float64Array, GenericListArray, Int16Array, Int32Array, Int64Array, - Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, - LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, MapArray, - MutableArrayData, PrimitiveArray, Scalar, StringArray, StringViewArray, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, UnionArray, + Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, + DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, + FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, + LargeStringArray, ListArray, MapArray, MutableArrayData, PrimitiveArray, Scalar, + StringArray, StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, UnionArray, }; use arrow::buffer::ScalarBuffer; use arrow::compute::kernels::cast::{cast_with_options, CastOptions}; @@ -75,12 +76,13 @@ use arrow::compute::kernels::numeric::{ }; use arrow::datatypes::{ i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType, - ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, Field, - Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime, - IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, + ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, + Decimal32Type, Decimal64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionFields, + UnionMode, DECIMAL128_MAX_PRECISION, }; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; @@ -231,6 +233,10 @@ pub enum ScalarValue { Float32(Option), /// 64bit float Float64(Option), + /// 32bit decimal, using the i32 to represent the decimal, precision scale + Decimal32(Option, u8, i8), + /// 64bit decimal, using the i64 to represent the decimal, precision scale + Decimal64(Option, u8, i8), /// 128bit decimal, using the i128 to represent the decimal, precision scale Decimal128(Option, u8, i8), /// 256bit decimal, using the i256 to represent the decimal, precision scale @@ -340,6 +346,14 @@ impl PartialEq for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal32(_, _, _), _) => false, + (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal64(_, _, _), _) => false, (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { v1.eq(v2) && p1.eq(p2) && s1.eq(s2) } @@ -459,6 +473,24 @@ impl PartialOrd for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal32(_, _, _), _) => None, + (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal64(_, _, _), _) => None, (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { if p1.eq(p2) && s1.eq(s2) { v1.partial_cmp(v2) @@ -760,6 +792,16 @@ impl Hash for ScalarValue { fn hash(&self, state: &mut H) { use ScalarValue::*; match self { + Decimal32(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } + Decimal64(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } Decimal128(v, p, s) => { v.hash(state); p.hash(state); @@ -1045,6 +1087,12 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(None), DataType::UInt32 => ScalarValue::UInt32(None), DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(None, *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(None, *precision, *scale) + } DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(None, *precision, *scale) } @@ -1527,6 +1575,34 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), + DataType::Decimal32(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match 10_i32.checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i64::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } DataType::Decimal128(precision, scale) => { validate_decimal_precision_and_scale::( *precision, *scale, @@ -1573,6 +1649,34 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), + DataType::Decimal32(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match 10_i32.checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i64::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } DataType::Decimal128(precision, scale) => { validate_decimal_precision_and_scale::( *precision, *scale, @@ -1622,6 +1726,38 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))), DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), + DataType::Decimal32(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale <= 0 { + return _internal_err!("Negative scale is not supported"); + } + match 10_i32.checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale <= 0 { + return _internal_err!("Negative scale is not supported"); + } + match i64::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } DataType::Decimal128(precision, scale) => { if let Err(err) = validate_decimal_precision_and_scale::( *precision, *scale, @@ -1674,6 +1810,12 @@ impl ScalarValue { ScalarValue::Int16(_) => DataType::Int16, ScalarValue::Int32(_) => DataType::Int32, ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Decimal32(_, precision, scale) => { + DataType::Decimal32(*precision, *scale) + } + ScalarValue::Decimal64(_, precision, scale) => { + DataType::Decimal64(*precision, *scale) + } ScalarValue::Decimal128(_, precision, scale) => { DataType::Decimal128(*precision, *scale) } @@ -1796,6 +1938,24 @@ impl ScalarValue { ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } + ScalarValue::Decimal32(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal32( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal32({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal64( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal64({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } ScalarValue::Decimal128(Some(v), precision, scale) => { Ok(ScalarValue::Decimal128( Some(neg_checked_with_ctx(*v, || { @@ -1947,6 +2107,8 @@ impl ScalarValue { ScalarValue::Float16(v) => v.is_none(), ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), + ScalarValue::Decimal32(v, _, _) => v.is_none(), + ScalarValue::Decimal64(v, _, _) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), ScalarValue::Decimal256(v, _, _) => v.is_none(), ScalarValue::Int8(v) => v.is_none(), @@ -2202,19 +2364,19 @@ impl ScalarValue { } let array: ArrayRef = match &data_type { - DataType::Decimal32(_precision, _scale) => { - return _not_impl_err!( - "Decimal32 not supported in ScalarValue::iter_to_array" - ); + DataType::Decimal32(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal32_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) } - DataType::Decimal64(_precision, _scale) => { - return _not_impl_err!( - "Decimal64 not supported in ScalarValue::iter_to_array" - ); + DataType::Decimal64(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal64_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) } DataType::Decimal128(precision, scale) => { let decimal_array = - ScalarValue::iter_to_decimal_array(scalars, *precision, *scale)?; + ScalarValue::iter_to_decimal128_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } DataType::Decimal256(precision, scale) => { @@ -2423,7 +2585,43 @@ impl ScalarValue { Ok(new_null_array(&DataType::Null, length)) } - fn iter_to_decimal_array( + fn iter_to_decimal32_array( + scalars: impl IntoIterator, + precision: u8, + scale: i8, + ) -> Result { + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal32(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) + } + + fn iter_to_decimal64_array( + scalars: impl IntoIterator, + precision: u8, + scale: i8, + ) -> Result { + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal64(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) + } + + fn iter_to_decimal128_array( scalars: impl IntoIterator, precision: u8, scale: i8, @@ -2461,7 +2659,43 @@ impl ScalarValue { Ok(array) } - fn build_decimal_array( + fn build_decimal32_array( + value: Option, + precision: u8, + scale: i8, + size: usize, + ) -> Result { + Ok(match value { + Some(val) => Decimal32Array::from(vec![val; size]) + .with_precision_and_scale(precision, scale)?, + None => { + let mut builder = Decimal32Array::builder(size) + .with_precision_and_scale(precision, scale)?; + builder.append_nulls(size); + builder.finish() + } + }) + } + + fn build_decimal64_array( + value: Option, + precision: u8, + scale: i8, + size: usize, + ) -> Result { + Ok(match value { + Some(val) => Decimal64Array::from(vec![val; size]) + .with_precision_and_scale(precision, scale)?, + None => { + let mut builder = Decimal64Array::builder(size) + .with_precision_and_scale(precision, scale)?; + builder.append_nulls(size); + builder.finish() + } + }) + } + + fn build_decimal128_array( value: Option, precision: u8, scale: i8, @@ -2640,8 +2874,14 @@ impl ScalarValue { /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { Ok(match self { + ScalarValue::Decimal32(e, precision, scale) => Arc::new( + ScalarValue::build_decimal32_array(*e, *precision, *scale, size)?, + ), + ScalarValue::Decimal64(e, precision, scale) => Arc::new( + ScalarValue::build_decimal64_array(*e, *precision, *scale, size)?, + ), ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal_array(*e, *precision, *scale, size)?, + ScalarValue::build_decimal128_array(*e, *precision, *scale, size)?, ), ScalarValue::Decimal256(e, precision, scale) => Arc::new( ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, @@ -2951,6 +3191,24 @@ impl ScalarValue { scale: i8, ) -> Result { match array.data_type() { + DataType::Decimal32(_, _) => { + let array = as_decimal32_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal32(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal32(Some(value), precision, scale)) + } + } + DataType::Decimal64(_, _) => { + let array = as_decimal64_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal64(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal64(Some(value), precision, scale)) + } + } DataType::Decimal128(_, _) => { let array = as_decimal128_array(array)?; if array.is_null(index) { @@ -2969,7 +3227,9 @@ impl ScalarValue { Ok(ScalarValue::Decimal256(Some(value), precision, scale)) } } - _ => _internal_err!("Unsupported decimal type"), + other => { + unreachable!("Invalid type isn't decimal: {other:?}") + } } } @@ -3083,6 +3343,16 @@ impl ScalarValue { Ok(match array.data_type() { DataType::Null => ScalarValue::Null, + DataType::Decimal32(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } + DataType::Decimal64(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } DataType::Decimal128(precision, scale) => { ScalarValue::get_decimal_value_from_array( array, index, *precision, *scale, @@ -3343,6 +3613,44 @@ impl ScalarValue { ScalarValue::try_from_array(&cast_arr, 0) } + fn eq_array_decimal32( + array: &ArrayRef, + index: usize, + value: Option<&i32>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal32_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + + fn eq_array_decimal64( + array: &ArrayRef, + index: usize, + value: Option<&i64>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal64_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + fn eq_array_decimal( array: &ArrayRef, index: usize, @@ -3410,6 +3718,24 @@ impl ScalarValue { #[inline] pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { Ok(match self { + ScalarValue::Decimal32(v, precision, scale) => { + ScalarValue::eq_array_decimal32( + array, + index, + v.as_ref(), + *precision, + *scale, + )? + } + ScalarValue::Decimal64(v, precision, scale) => { + ScalarValue::eq_array_decimal64( + array, + index, + v.as_ref(), + *precision, + *scale, + )? + } ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( array, @@ -3608,6 +3934,8 @@ impl ScalarValue { | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) | ScalarValue::Decimal128(_, _, _) | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) @@ -3717,6 +4045,8 @@ impl ScalarValue { | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) | ScalarValue::Decimal128(_, _, _) | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) @@ -4230,6 +4560,12 @@ macro_rules! format_option { impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + ScalarValue::Decimal32(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } + ScalarValue::Decimal64(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } ScalarValue::Decimal128(v, p, s) => { write!(f, "{v:?},{p:?},{s:?}")?; } @@ -4419,6 +4755,8 @@ fn fmt_binary(data: &[u8], f: &mut fmt::Formatter) -> fmt::Result { impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + ScalarValue::Decimal32(_, _, _) => write!(f, "Decimal32({self})"), + ScalarValue::Decimal64(_, _, _) => write!(f, "Decimal64({self})"), ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 76629e555b8c..237d04f7f70b 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -23,6 +23,7 @@ use crate::error::{Result, _internal_err}; use arrow::compute::can_cast_types; use arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, + DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use std::{fmt::Display, sync::Arc}; @@ -228,7 +229,15 @@ impl LogicalType for NativeType { (Self::Float16, _) => Float16, (Self::Float32, _) => Float32, (Self::Float64, _) => Float64, - (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), + (Self::Decimal(p, s), _) if *p <= DECIMAL32_MAX_PRECISION => { + Decimal32(*p, *s) + } + (Self::Decimal(p, s), _) if *p <= DECIMAL64_MAX_PRECISION => { + Decimal64(*p, *s) + } + (Self::Decimal(p, s), _) if *p <= DECIMAL128_MAX_PRECISION => { + Decimal128(*p, *s) + } (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), // If given type is Date, return the same type diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index c114ae7a29d4..ad7f23d9008e 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 // diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs index e7f63b535104..45dba5f7864b 100644 --- a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs +++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs @@ -20,18 +20,19 @@ use std::sync::Arc; use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch}; use arrow::datatypes::{ ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, - Decimal256Type, DurationMicrosecondType, DurationMillisecondType, - DurationNanosecondType, DurationSecondType, Field, Float32Type, Float64Type, - Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType, + DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, + Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow_schema::{ DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, + DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, + DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, }; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; use rand::{rng, rngs::StdRng, Rng, SeedableRng}; @@ -104,6 +105,20 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec { "duration_nanosecond", DataType::Duration(TimeUnit::Nanosecond), ), + ColumnDescr::new("decimal32", { + let precision: u8 = rng.random_range(1..=DECIMAL32_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL32_MAX_SCALE), + ); + DataType::Decimal32(precision, scale) + }), + ColumnDescr::new("decimal64", { + let precision: u8 = rng.random_range(1..=DECIMAL64_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL64_MAX_SCALE), + ); + DataType::Decimal64(precision, scale) + }), ColumnDescr::new("decimal128", { let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION); let scale: i8 = rng.random_range( @@ -682,6 +697,32 @@ impl RecordBatchGenerator { _ => unreachable!(), } } + DataType::Decimal32(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal32Type + ) + } + DataType::Decimal64(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal64Type + ) + } DataType::Decimal128(precision, scale) => { generate_decimal_array!( self, diff --git a/datafusion/datasource-avro/src/mod.rs b/datafusion/datasource-avro/src/mod.rs index de595011df83..ad8ebe11446f 100644 --- a/datafusion/datasource-avro/src/mod.rs +++ b/datafusion/datasource-avro/src/mod.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index 3cd4a1a6c1c9..8b2e49fb4b6e 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index 9a2c5656bae9..a57d299d026e 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #[allow(rustdoc::broken_intra_doc_links)] /// Documentation for use by [`ScalarUDFImpl`](ScalarUDFImpl), diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index e971e838a6e5..55243e301e0e 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/expr-common/src/casts.rs b/datafusion/expr-common/src/casts.rs index c31d4f77c6a7..ae2cf9b78666 100644 --- a/datafusion/expr-common/src/casts.rs +++ b/datafusion/expr-common/src/casts.rs @@ -25,7 +25,9 @@ use std::cmp::Ordering; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, + MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION, + MIN_DECIMAL64_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::ScalarValue; @@ -69,6 +71,8 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Timestamp(_, _) ) @@ -114,6 +118,8 @@ fn try_cast_numeric_literal( | DataType::Int32 | DataType::Int64 => 1_i128, DataType::Timestamp(_, _) => 1_i128, + DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32), + DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32), DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), _ => return None, }; @@ -127,6 +133,20 @@ fn try_cast_numeric_literal( DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal32(precision, _) => ( + // Different precision for decimal32 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128, + MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128, + ), + DataType::Decimal64(precision, _) => ( + // Different precision for decimal64 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128, + MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128, + ), DataType::Decimal128(precision, _) => ( // Different precision for decimal128 can store different range of value. // For example, the precision is 3, the max of value is `999` and the min @@ -149,6 +169,46 @@ fn try_cast_numeric_literal( ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal32(Some(v), _, scale) => { + let v = *v as i128; + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + v.checked_mul(mul / lit_scale_mul) + } else if v % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } + ScalarValue::Decimal64(Some(v), _, scale) => { + let v = *v as i128; + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + v.checked_mul(mul / lit_scale_mul) + } else if v % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } ScalarValue::Decimal128(Some(v), _, scale) => { let lit_scale_mul = 10_i128.pow(*scale as u32); if mul >= lit_scale_mul { @@ -218,6 +278,12 @@ fn try_cast_numeric_literal( ); ScalarValue::TimestampNanosecond(value, tz.clone()) } + DataType::Decimal32(p, s) => { + ScalarValue::Decimal32(Some(value as i32), *p, *s) + } + DataType::Decimal64(p, s) => { + ScalarValue::Decimal64(Some(value as i64), *p, *s) + } DataType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) } diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index f0bb6f99943b..a4f6414a8c51 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -27,7 +27,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index e9377ce7de5a..c462246f803e 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -18,7 +18,8 @@ use crate::signature::TypeSignature; use arrow::datatypes::{ DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, + DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, }; use datafusion_common::{internal_err, plan_err, Result}; @@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal32(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // In the spark, the result type is DECIMAL(min(38,precision+10), s) // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 @@ -196,6 +209,20 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { /// Function return type of an average pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { match arg_type { + DataType::Decimal32(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal32(new_precision, new_scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal64(new_precision, new_scale)) + } DataType::Decimal128(precision, scale) => { // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 @@ -222,6 +249,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result /// Internal sum type of an average pub fn avg_sum_type(arg_type: &DataType) -> Result { match arg_type { + DataType::Decimal32(precision, scale) => { + // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); @@ -249,7 +286,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { _ => matches!( arg_type, arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) + || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) ), } } @@ -262,7 +299,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { _ => matches!( arg_type, arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _)) + || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) ), } } @@ -297,6 +334,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Result { match &data_type { + DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), + DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), d if d.is_numeric() => Ok(DataType::Float64), diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 9264a2940dd1..a6f71fcbae4f 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -27,6 +27,8 @@ use arrow::compute::can_cast_types; use arrow::datatypes::{ DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, + DECIMAL64_MAX_SCALE, }; use datafusion_common::types::NativeType; use datafusion_common::{ @@ -334,22 +336,64 @@ fn math_decimal_coercion( let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?; Some((lhs_type, value_type)) } - (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { - Some((dec_type.clone(), dec_type.clone())) - } - (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { + ( + Null, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => Some((rhs_type.clone(), rhs_type.clone())), + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Null, + ) => Some((lhs_type.clone(), lhs_type.clone())), + (Decimal32(_, _), Decimal32(_, _)) + | (Decimal64(_, _), Decimal64(_, _)) + | (Decimal128(_, _), Decimal128(_, _)) + | (Decimal256(_, _), Decimal256(_, _)) => { Some((lhs_type.clone(), rhs_type.clone())) } // Unlike with comparison we don't coerce to a decimal in the case of floating point // numbers, instead falling back to floating point arithmetic instead + ( + Decimal32(_, _), + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + ) => Some(( + lhs_type.clone(), + coerce_numeric_type_to_decimal32(rhs_type)?, + )), + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Decimal32(_, _), + ) => Some(( + coerce_numeric_type_to_decimal32(lhs_type)?, + rhs_type.clone(), + )), + ( + Decimal64(_, _), + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + ) => Some(( + lhs_type.clone(), + coerce_numeric_type_to_decimal64(rhs_type)?, + )), + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Decimal64(_, _), + ) => Some(( + coerce_numeric_type_to_decimal64(lhs_type)?, + rhs_type.clone(), + )), ( Decimal128(_, _), Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, - ) => Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?)), + ) => Some(( + lhs_type.clone(), + coerce_numeric_type_to_decimal128(rhs_type)?, + )), ( Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _), - ) => Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone())), + ) => Some(( + coerce_numeric_type_to_decimal128(lhs_type)?, + rhs_type.clone(), + )), ( Decimal256(_, _), Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, @@ -925,8 +969,8 @@ fn get_common_decimal_type( ) -> Option { use arrow::datatypes::DataType::*; match decimal_type { - Decimal128(_, _) => { - let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?; + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?; get_wider_decimal_type(decimal_type, &other_decimal_type) } Decimal256(_, _) => { @@ -946,11 +990,23 @@ fn get_wider_decimal_type( rhs_type: &DataType, ) -> Option { match (lhs_decimal_type, rhs_type) { + (DataType::Decimal32(p1, s1), DataType::Decimal32(p2, s2)) => { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = *s1.max(s2); + let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); + Some(create_decimal32_type((range + s) as u8, s)) + } + (DataType::Decimal64(p1, s1), DataType::Decimal64(p2, s2)) => { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = *s1.max(s2); + let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); + Some(create_decimal64_type((range + s) as u8, s)) + } (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => { // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) let s = *s1.max(s2); let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); - Some(create_decimal_type((range + s) as u8, s)) + Some(create_decimal128_type((range + s) as u8, s)) } (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => { // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) @@ -964,7 +1020,39 @@ fn get_wider_decimal_type( /// Convert the numeric data type to the decimal data type. /// We support signed and unsigned integer types and floating-point type. -fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { +fn coerce_numeric_type_to_decimal32(numeric_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 + match numeric_type { + Int8 | UInt8 => Some(Decimal32(3, 0)), + Int16 | UInt16 => Some(Decimal32(5, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal32(6, 3)), + _ => None, + } +} + +/// Convert the numeric data type to the decimal data type. +/// We support signed and unsigned integer types and floating-point type. +fn coerce_numeric_type_to_decimal64(numeric_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 + match numeric_type { + Int8 | UInt8 => Some(Decimal64(3, 0)), + Int16 | UInt16 => Some(Decimal64(5, 0)), + Int32 | UInt32 => Some(Decimal64(10, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal64(6, 3)), + Float32 => Some(Decimal64(14, 7)), + _ => None, + } +} + +/// Convert the numeric data type to the decimal data type. +/// We support signed and unsigned integer types and floating-point type. +fn coerce_numeric_type_to_decimal128(numeric_type: &DataType) -> Option { use arrow::datatypes::DataType::*; // This conversion rule is from spark // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 @@ -1113,7 +1201,21 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option DataType { +fn create_decimal32_type(precision: u8, scale: i8) -> DataType { + DataType::Decimal128( + DECIMAL32_MAX_PRECISION.min(precision), + DECIMAL32_MAX_SCALE.min(scale), + ) +} + +fn create_decimal64_type(precision: u8, scale: i8) -> DataType { + DataType::Decimal128( + DECIMAL64_MAX_PRECISION.min(precision), + DECIMAL64_MAX_SCALE.min(scale), + ) +} + +fn create_decimal128_type(precision: u8, scale: i8) -> DataType { DataType::Decimal128( DECIMAL128_MAX_PRECISION.min(precision), DECIMAL128_MAX_SCALE.min(scale), diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index fdd41ae2bb47..e6238ba0078d 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -56,32 +56,75 @@ fn test_date_timestamp_arithmetic_error() -> Result<()> { #[test] fn test_decimal_mathematics_op_type() { + // Decimal32 assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(), + coerce_numeric_type_to_decimal32(&DataType::Int8).unwrap(), + DataType::Decimal32(3, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal32(&DataType::Int16).unwrap(), + DataType::Decimal32(5, 0) + ); + assert!(coerce_numeric_type_to_decimal32(&DataType::Int32).is_none()); + assert!(coerce_numeric_type_to_decimal32(&DataType::Int64).is_none(),); + assert_eq!( + coerce_numeric_type_to_decimal32(&DataType::Float16).unwrap(), + DataType::Decimal32(6, 3) + ); + assert!(coerce_numeric_type_to_decimal32(&DataType::Float32).is_none(),); + assert!(coerce_numeric_type_to_decimal32(&DataType::Float64).is_none()); + + // Decimal64 + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Int8).unwrap(), + DataType::Decimal64(3, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Int16).unwrap(), + DataType::Decimal64(5, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Int32).unwrap(), + DataType::Decimal64(10, 0) + ); + assert!(coerce_numeric_type_to_decimal64(&DataType::Int64).is_none(),); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Float16).unwrap(), + DataType::Decimal64(6, 3) + ); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Float32).unwrap(), + DataType::Decimal64(14, 7) + ); + assert!(coerce_numeric_type_to_decimal64(&DataType::Float64).is_none()); + + // Decimal128 + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Int8).unwrap(), DataType::Decimal128(3, 0) ); assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(), + coerce_numeric_type_to_decimal128(&DataType::Int16).unwrap(), DataType::Decimal128(5, 0) ); assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(), + coerce_numeric_type_to_decimal128(&DataType::Int32).unwrap(), DataType::Decimal128(10, 0) ); assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(), + coerce_numeric_type_to_decimal128(&DataType::Int64).unwrap(), DataType::Decimal128(20, 0) ); assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Float16).unwrap(), + coerce_numeric_type_to_decimal128(&DataType::Float16).unwrap(), DataType::Decimal128(6, 3) ); assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(), + coerce_numeric_type_to_decimal128(&DataType::Float32).unwrap(), DataType::Decimal128(14, 7) ); assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(), + coerce_numeric_type_to_decimal128(&DataType::Float64).unwrap(), DataType::Decimal128(30, 15) ); } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index b4ad8387215e..4ce02391253c 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 88d49722a587..919faa639a62 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -281,15 +281,14 @@ impl LogicalPlanBuilder { let value = &row[j]; let data_type = value.get_type(schema)?; - if !data_type.equals_datatype(field_type) { - if can_cast_types(&data_type, field_type) { - } else { - return exec_err!( - "type mismatch and can't cast to got {} and {}", - data_type, - field_type - ); - } + if !data_type.equals_datatype(field_type) + && !can_cast_types(&data_type, field_type) + { + return exec_err!( + "type mismatch and can't cast to got {} and {}", + data_type, + field_type + ); } } fields.push(field_type.to_owned(), field_nullable); diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 3feab09bbd19..0091303058f3 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -23,6 +23,7 @@ use std::any::Any; use arrow::datatypes::{ DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; @@ -135,9 +136,10 @@ impl AggregateUDFImpl for Sum { DataType::Dictionary(_, v) => coerced_type(v), // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { - Ok(data_type.clone()) - } + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(data_type.clone()), dt if dt.is_signed_integer() => Ok(DataType::Int64), dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), dt if dt.is_floating() => Ok(DataType::Float64), @@ -153,6 +155,18 @@ impl AggregateUDFImpl for Sum { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal32(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 763a4e6539fd..bf8ea51ccbc9 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -851,7 +851,10 @@ fn coerced_from<'a>( | UInt64 | Float32 | Float64 - | Decimal128(_, _), + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _), ) => Some(type_into.clone()), ( Timestamp(TimeUnit::Nanosecond, None), diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 4fc150ef2996..bd1acd3f3a2e 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -51,6 +51,8 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _), ) @@ -89,5 +91,11 @@ pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool { /// Determine whether the given data type `dt` is a `Decimal`. pub fn is_decimal(dt: &DataType) -> bool { - matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) + matches!( + dt, + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) } diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index ff641e8315c7..0c2340e8ce7b 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs index 3d6889431d61..56cdaf6618de 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod decimal; mod numeric; +pub use decimal::DecimalDistinctAvgAccumulator; pub use numeric::Float64DistinctAvgAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs new file mode 100644 index 000000000000..9920bf5bf448 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, ArrowNumericType}, + datatypes::{ + i256, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, + }, +}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr_common::accumulator::Accumulator; +use std::fmt::Debug; +use std::mem::size_of_val; + +use crate::aggregate::sum_distinct::DistinctSumAccumulator; +use crate::utils::DecimalAverager; + +/// Generic implementation of `AVG DISTINCT` for Decimal types. +/// Handles both all Arrow decimal types (32, 64, 128 and 256 bits). +#[derive(Debug)] +pub struct DecimalDistinctAvgAccumulator { + sum_accumulator: DistinctSumAccumulator, + sum_scale: i8, + target_precision: u8, + target_scale: i8, +} + +impl DecimalDistinctAvgAccumulator { + pub fn with_decimal_params( + sum_scale: i8, + target_precision: u8, + target_scale: i8, + ) -> Self { + let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale); + + Self { + sum_accumulator: DistinctSumAccumulator::new(&data_type), + sum_scale, + target_precision, + target_scale, + } + } +} + +impl Accumulator + for DecimalDistinctAvgAccumulator +{ + fn state(&mut self) -> Result> { + self.sum_accumulator.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + if self.sum_accumulator.distinct_count() == 0 { + return ScalarValue::new_primitive::( + None, + &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), + ); + } + + let sum_scalar = self.sum_accumulator.evaluate()?; + + match sum_scalar { + ScalarValue::Decimal32(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i32)?; + Ok(ScalarValue::Decimal32( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal64(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i64)?; + Ok(ScalarValue::Decimal64( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal128(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i128)?; + Ok(ScalarValue::Decimal128( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal256(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + // `distinct_count` returns `u64`, but `avg` expects `i256` + // first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow + let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128; + let count: i256 = i256::from_i128(distinct_cnt); + let avg = decimal_averager.avg(sum, count)?; + Ok(ScalarValue::Decimal256( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + + _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar), + } + } + + fn size(&self) -> usize { + let fixed_size = size_of_val(self); + + // Account for the size of the sum_accumulator with its contained values + fixed_size + self.sum_accumulator.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + }; + use std::sync::Arc; + + #[test] + fn test_decimal32_distinct_avg_accumulator() -> Result<()> { + let precision = 5_u8; + let scale = 2_i8; + let array = Decimal32Array::from(vec![ + Some(10_00), + Some(12_50), + Some(17_50), + Some(20_00), + Some(20_00), + Some(30_00), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 9, 6, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = ScalarValue::Decimal32(Some(18000000), 9, 6); + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal64_distinct_avg_accumulator() -> Result<()> { + let precision = 10_u8; + let scale = 4_i8; + let array = Decimal64Array::from(vec![ + Some(100_0000), + Some(125_0000), + Some(175_0000), + Some(200_0000), + Some(200_0000), + Some(300_0000), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 14, 8, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = ScalarValue::Decimal64(Some(180_00000000), 14, 8); + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal128_distinct_avg_accumulator() -> Result<()> { + let precision = 10_u8; + let scale = 4_i8; + let array = Decimal128Array::from(vec![ + Some(100_0000), + Some(125_0000), + Some(175_0000), + Some(200_0000), + Some(200_0000), + Some(300_0000), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 14, 8, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8); + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal256_distinct_avg_accumulator() -> Result<()> { + let precision = 50_u8; + let scale = 2_i8; + + let array = Decimal256Array::from(vec![ + Some(i256::from_i128(10_000)), + Some(i256::from_i128(12_500)), + Some(i256::from_i128(17_500)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(30_000)), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 54, 6, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = + ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6); + assert_eq!(result, expected_result); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate-common/src/lib.rs b/datafusion/functions-aggregate-common/src/lib.rs index 203ae98fe1ed..a07ef4d597cf 100644 --- a/datafusion/functions-aggregate-common/src/lib.rs +++ b/datafusion/functions-aggregate-common/src/lib.rs @@ -26,7 +26,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index b02001753215..dc3f44eecda9 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -19,20 +19,490 @@ use arrow::array::{ ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, - DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, - FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, - IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray, - StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, + DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::compute; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::{downcast_value, Result, ScalarValue}; -use std::cmp::Ordering; +use datafusion_common::{ + downcast_value, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr_common::accumulator::Accumulator; +use std::{cmp::Ordering, mem::size_of_val}; + +// min/max of two non-string scalar values. +macro_rules! typed_min_max { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $($EXTRA_ARGS.clone()),* + ) + }}; +} + +macro_rules! typed_min_max_float { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => match a.total_cmp(b) { + choose_min_max!($OP) => Some(*b), + _ => Some(*a), + }, + }) + }}; +} + +// min/max of two scalar string values. +macro_rules! typed_min_max_string { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }) + }}; +} + +// min/max of two scalar string values with a prefix argument. +macro_rules! typed_min_max_string_arg { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $ARG:expr) => {{ + ScalarValue::$SCALAR( + $ARG, + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }, + ) + }}; +} + +macro_rules! choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; +} + +macro_rules! interval_min_max { + ($OP:tt, $LHS:expr, $RHS:expr) => {{ + match $LHS.partial_cmp(&$RHS) { + Some(choose_min_max!($OP)) => $RHS.clone(), + Some(_) => $LHS.clone(), + None => { + return internal_err!("Comparison error while computing interval min/max") + } + } + }}; +} + +macro_rules! min_max_generic { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + if $VALUE.is_null() { + let mut delta_copy = $DELTA.clone(); + // When the new value won we want to compact it to + // avoid storing the entire input + delta_copy.compact(); + delta_copy + } else if $DELTA.is_null() { + $VALUE.clone() + } else { + match $VALUE.partial_cmp(&$DELTA) { + Some(choose_min_max!($OP)) => { + // When the new value won we want to compact it to + // avoid storing the entire input + let mut delta_copy = $DELTA.clone(); + delta_copy.compact(); + delta_copy + } + _ => $VALUE.clone(), + } + } + }}; +} + +// min/max of two scalar values of the same type +macro_rules! min_max { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + Ok(match ($VALUE, $DELTA) { + (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, + ( + lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal32(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal32, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal64(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal64(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal64, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { + typed_min_max!(lhs, rhs, Boolean, $OP) + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + typed_min_max_float!(lhs, rhs, Float64, $OP) + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + typed_min_max_float!(lhs, rhs, Float32, $OP) + } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + typed_min_max_float!(lhs, rhs, Float16, $OP) + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + typed_min_max!(lhs, rhs, UInt64, $OP) + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + typed_min_max!(lhs, rhs, UInt32, $OP) + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + typed_min_max!(lhs, rhs, UInt16, $OP) + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + typed_min_max!(lhs, rhs, UInt8, $OP) + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + typed_min_max!(lhs, rhs, Int64, $OP) + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + typed_min_max!(lhs, rhs, Int32, $OP) + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + typed_min_max!(lhs, rhs, Int16, $OP) + } + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + typed_min_max!(lhs, rhs, Int8, $OP) + } + (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8, $OP) + } + (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) + } + (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8View, $OP) + } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + typed_min_max_string!(lhs, rhs, Binary, $OP) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeBinary, $OP) + } + (ScalarValue::FixedSizeBinary(lsize, lhs), ScalarValue::FixedSizeBinary(rsize, rhs)) => { + if lsize == rsize { + typed_min_max_string_arg!(lhs, rhs, FixedSizeBinary, $OP, *lsize) + } + else { + return internal_err!( + "MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes {:?}", + (lsize, rsize)) + } + } + (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { + typed_min_max_string!(lhs, rhs, BinaryView, $OP) + } + (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { + typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) + } + ( + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) + } + ( + ScalarValue::Date32(lhs), + ScalarValue::Date32(rhs), + ) => { + typed_min_max!(lhs, rhs, Date32, $OP) + } + ( + ScalarValue::Date64(lhs), + ScalarValue::Date64(rhs), + ) => { + typed_min_max!(lhs, rhs, Date64, $OP) + } + ( + ScalarValue::Time32Second(lhs), + ScalarValue::Time32Second(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Second, $OP) + } + ( + ScalarValue::Time32Millisecond(lhs), + ScalarValue::Time32Millisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Millisecond, $OP) + } + ( + ScalarValue::Time64Microsecond(lhs), + ScalarValue::Time64Microsecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Microsecond, $OP) + } + ( + ScalarValue::Time64Nanosecond(lhs), + ScalarValue::Time64Nanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) + } + ( + ScalarValue::IntervalYearMonth(lhs), + ScalarValue::IntervalYearMonth(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) + } + ( + ScalarValue::IntervalDayTime(lhs), + ScalarValue::IntervalDayTime(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalDayTime, $OP) + } + ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalMonthDayNano(_), + ) | ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalMonthDayNano(_), + ) => { + interval_min_max!($OP, $VALUE, $DELTA) + } + ( + ScalarValue::DurationSecond(lhs), + ScalarValue::DurationSecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationSecond, $OP) + } + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMillisecond, $OP) + } + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) + } + ( + ScalarValue::DurationNanosecond(lhs), + ScalarValue::DurationNanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationNanosecond, $OP) + } + + ( + lhs @ ScalarValue::Struct(_), + rhs @ ScalarValue::Struct(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + ( + lhs @ ScalarValue::List(_), + rhs @ ScalarValue::List(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::LargeList(_), + rhs @ ScalarValue::LargeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::FixedSizeList(_), + rhs @ ScalarValue::FixedSizeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + e => { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + e + ) + } + }) + }}; +} + +/// An accumulator to compute the maximum value +#[derive(Debug, Clone)] +pub struct MaxAccumulator { + max: ScalarValue, +} + +impl MaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.max) + self.max.size() + } +} + +/// An accumulator to compute the minimum value +#[derive(Debug, Clone)] +pub struct MinAccumulator { + min: ScalarValue, +} + +impl MinAccumulator { + /// new min accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &min_batch(values)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.min) + self.min.size() + } +} // Statically-typed version of min/max(array) -> ScalarValue for string types macro_rules! typed_min_max_batch_string { @@ -69,6 +539,26 @@ macro_rules! min_max_batch { ($VALUES:expr, $OP:ident) => {{ match $VALUES.data_type() { DataType::Null => ScalarValue::Null, + DataType::Decimal32(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal32Array, + Decimal32, + $OP, + precision, + scale + ) + } + DataType::Decimal64(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal64Array, + Decimal64, + $OP, + precision, + scale + ) + } DataType::Decimal128(precision, scale) => { typed_min_max_batch!( $VALUES, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index f7cb74fd55a2..5aaa7fc224ec 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -24,9 +24,11 @@ use arrow::array::{ use arrow::compute::sum; use arrow::datatypes::{ - i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, - DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, + i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, + UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, @@ -40,7 +42,9 @@ use datafusion_expr::{ ReversedUDAF, Signature, }; -use datafusion_functions_aggregate_common::aggregate::avg_distinct::Float64DistinctAvgAccumulator; +use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ + DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, +}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ filtered_null_mask, set_nulls, @@ -120,14 +124,75 @@ impl AggregateUDFImpl for Avg { // instantiate specialized accumulator based for the type if acc_args.is_distinct { - match &data_type { + match (&data_type, acc_args.return_type()) { // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation - Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::default())), - _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type), + (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())), + + ( + Decimal32(_, scale), + Decimal32(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + ( + Decimal64(_, scale), + Decimal64(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + ( + Decimal128(_, scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + + ( + Decimal256(_, scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + + (dt, return_type) => exec_err!( + "AVG(DISTINCT) for ({} --> {}) not supported", + dt, + return_type + ), } } else { - match (&data_type, acc_args.return_field.data_type()) { + match (&data_type, acc_args.return_type()) { (Float64, Float64) => Ok(Box::::default()), + ( + Decimal32(sum_precision, sum_scale), + Decimal32(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + ( + Decimal64(sum_precision, sum_scale), + Decimal64(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), ( Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale), @@ -161,22 +226,37 @@ impl AggregateUDFImpl for Avg { })) } - _ => exec_err!( - "AvgAccumulator for ({} --> {})", - &data_type, - acc_args.return_field.data_type() - ), + (dt, return_type) => { + exec_err!("AvgAccumulator for ({} --> {})", dt, return_type) + } } } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { - // Copied from datafusion_functions_aggregate::sum::Sum::state_fields + // Decimal accumulator actually uses a different precision during accumulation, + // see DecimalDistinctAvgAccumulator::with_decimal_params + let dt = match args.input_fields[0].data_type() { + DataType::Decimal32(_, scale) => { + DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale) + } + DataType::Decimal64(_, scale) => { + DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale) + } + DataType::Decimal128(_, scale) => { + DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale) + } + DataType::Decimal256(_, scale) => { + DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale) + } + _ => args.return_type().clone(), + }; + // Similar to datafusion_functions_aggregate::sum::Sum::state_fields // since the accumulator uses DistinctSumAccumulator internally. Ok(vec![Field::new_list( format_state_name(args.name, "avg distinct"), - Field::new_list_field(args.return_type().clone(), true), + Field::new_list_field(dt, true), false, ) .into()]) @@ -202,7 +282,12 @@ impl AggregateUDFImpl for Avg { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( args.return_field.data_type(), - DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_) + DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Duration(_) ) && !args.is_distinct } @@ -222,6 +307,44 @@ impl AggregateUDFImpl for Avg { |sum: f64, count: u64| Ok(sum / count as f64), ))) } + ( + Decimal32(_sum_precision, sum_scale), + Decimal32(target_precision, target_scale), + ) => { + let decimal_averager = DecimalAverager::::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = + move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32); + + Ok(Box::new(AvgGroupsAccumulator::::new( + &data_type, + args.return_field.data_type(), + avg_fn, + ))) + } + ( + Decimal64(_sum_precision, sum_scale), + Decimal64(target_precision, target_scale), + ) => { + let decimal_averager = DecimalAverager::::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = + move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64); + + Ok(Box::new(AvgGroupsAccumulator::::new( + &data_type, + args.return_field.data_type(), + avg_fn, + ))) + } ( Decimal128(_sum_precision, sum_scale), Decimal128(target_precision, target_scale), diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 6ef1332ba003..28755427c732 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -30,12 +30,12 @@ use arrow::array::{ use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef, - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::cast::as_boolean_array; use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx}; @@ -185,6 +185,8 @@ impl AggregateUDFImpl for FirstValue { | Float16 | Float32 | Float64 + | Decimal32(_, _) + | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _) | Date32 @@ -234,6 +236,8 @@ impl AggregateUDFImpl for FirstValue { DataType::Float32 => create_accumulator::(args), DataType::Float64 => create_accumulator::(args), + DataType::Decimal32(_, _) => create_accumulator::(args), + DataType::Decimal64(_, _) => create_accumulator::(args), DataType::Decimal128(_, _) => create_accumulator::(args), DataType::Decimal256(_, _) => create_accumulator::(args), @@ -1124,6 +1128,8 @@ impl AggregateUDFImpl for LastValue { | Float16 | Float32 | Float64 + | Decimal32(_, _) + | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _) | Date32 @@ -1175,6 +1181,8 @@ impl AggregateUDFImpl for LastValue { DataType::Float32 => create_accumulator::(args), DataType::Float64 => create_accumulator::(args), + DataType::Decimal32(_, _) => create_accumulator::(args), + DataType::Decimal64(_, _) => create_accumulator::(args), DataType::Decimal128(_, _) => create_accumulator::(args), DataType::Decimal256(_, _) => create_accumulator::(args), diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b5bb69f6da9d..4ad0551d16a7 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index a73ccbd99bc1..a65759594eac 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -35,7 +35,9 @@ use arrow::{ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef}; +use arrow::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef, +}; use datafusion_common::{ internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, @@ -166,6 +168,8 @@ impl AggregateUDFImpl for Median { DataType::Float16 => helper!(Float16Type, dt), DataType::Float32 => helper!(Float32Type, dt), DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, dt), DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), _ => Err(DataFusionError::NotImplemented(format!( @@ -205,6 +209,8 @@ impl AggregateUDFImpl for Median { DataType::Float16 => helper!(Float16Type, dt), DataType::Float32 => helper!(Float32Type, dt), DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, dt), DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), _ => Err(DataFusionError::NotImplemented(format!( diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 1edf10dfee30..d839c6f023c4 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -23,10 +23,10 @@ mod min_max_struct; use arrow::array::ArrayRef; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType, - DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::stats::Precision; use datafusion_common::{ @@ -242,6 +242,8 @@ impl AggregateUDFImpl for Max { | Float16 | Float32 | Float64 + | Decimal32(_, _) + | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _) | Date32 @@ -323,6 +325,12 @@ impl AggregateUDFImpl for Max { Duration(Nanosecond) => { primitive_max_accumulator!(data_type, i64, DurationNanosecondType) } + Decimal32(_, _) => { + primitive_max_accumulator!(data_type, i32, Decimal32Type) + } + Decimal64(_, _) => { + primitive_max_accumulator!(data_type, i64, Decimal64Type) + } Decimal128(_, _) => { primitive_max_accumulator!(data_type, i128, Decimal128Type) } @@ -919,6 +927,8 @@ impl AggregateUDFImpl for Min { | Float16 | Float32 | Float64 + | Decimal32(_, _) + | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _) | Date32 @@ -1000,6 +1010,12 @@ impl AggregateUDFImpl for Min { Duration(Nanosecond) => { primitive_min_accumulator!(data_type, i64, DurationNanosecondType) } + Decimal32(_, _) => { + primitive_min_accumulator!(data_type, i32, Decimal32Type) + } + Decimal64(_, _) => { + primitive_min_accumulator!(data_type, i64, Decimal64Type) + } Decimal128(_, _) => { primitive_min_accumulator!(data_type, i128, Decimal128Type) } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 445c7dfe6b7a..82ce44c19401 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -18,6 +18,8 @@ //! Defines `SUM` and `SUM DISTINCT` aggregate accumulators use ahash::RandomState; +use arrow::datatypes::DECIMAL32_MAX_PRECISION; +use arrow::datatypes::DECIMAL64_MAX_PRECISION; use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; use std::mem::size_of_val; @@ -27,8 +29,8 @@ use arrow::array::ArrowNativeTypeOp; use arrow::array::{ArrowNumericType, AsArray}; use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, - DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float64Type, + Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{ @@ -71,6 +73,12 @@ macro_rules! downcast_sum { DataType::Float64 => { $helper!(Float64Type, $args.return_field.data_type().clone()) } + DataType::Decimal32(_, _) => { + $helper!(Decimal32Type, $args.return_field.data_type().clone()) + } + DataType::Decimal64(_, _) => { + $helper!(Decimal64Type, $args.return_field.data_type().clone()) + } DataType::Decimal128(_, _) => { $helper!(Decimal128Type, $args.return_field.data_type().clone()) } @@ -145,9 +153,10 @@ impl AggregateUDFImpl for Sum { DataType::Dictionary(_, v) => coerced_type(v), // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { - Ok(data_type.clone()) - } + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(data_type.clone()), dt if dt.is_signed_integer() => Ok(DataType::Int64), dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), dt if dt.is_floating() => Ok(DataType::Float64), @@ -163,6 +172,18 @@ impl AggregateUDFImpl for Sum { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal32(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 1d3f11b50c61..0a549fb294c6 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/functions-table/src/lib.rs b/datafusion/functions-table/src/lib.rs index 36fcdc7ede56..b339a8f4a52f 100644 --- a/datafusion/functions-table/src/lib.rs +++ b/datafusion/functions-table/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 7f668a20a76a..76341239f6a5 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 10e09542d7c5..139ace4bf709 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 51cd5df8060d..6a3ff624c0c2 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index 31cf9bb1b750..c1c08157aff2 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] extern crate proc_macro; use datafusion_expr::scalar_doc_sections::doc_sections_const; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 280010e3d92c..85fa9493f449 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/physical-expr-adapter/src/lib.rs b/datafusion/physical-expr-adapter/src/lib.rs index 025f1b4b6385..12ea0025e266 100644 --- a/datafusion/physical-expr-adapter/src/lib.rs +++ b/datafusion/physical-expr-adapter/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] //! Physical expression schema adaptation utilities for DataFusion diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 86d4487f4c12..e21206d90642 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 46f7b30d01aa..50a8e109ae1d 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 2e56e2cdb31d..4dc95bddb792 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index afe61541fc45..17628fd8ad1d 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index c5242d0176e6..902b148a7d64 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -37,7 +37,6 @@ use datafusion_common::{ TableParquetOptions, }, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, - not_impl_err, parsers::CompressionTypeVariant, plan_datafusion_err, stats::Precision, @@ -478,13 +477,13 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let null_type: DataType = v.try_into()?; null_type.try_into().map_err(Error::DataFusionError)? } - Value::Decimal32Value(_val) => { - return not_impl_err!("Decimal32 protobuf deserialization") - .map_err(Error::DataFusionError) + Value::Decimal32Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal32(Some(i32::from_be_bytes(array)), val.p as u8, val.s as i8) } - Value::Decimal64Value(_val) => { - return not_impl_err!("Decimal64 protobuf deserialization") - .map_err(Error::DataFusionError) + Value::Decimal64Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal64(Some(i64::from_be_bytes(array)), val.p as u8, val.s as i8) } Value::Decimal128Value(val) => { let array = vec_to_array(val.value.clone()); diff --git a/datafusion/proto-common/src/lib.rs b/datafusion/proto-common/src/lib.rs index 6400e4bdc66d..9efb234e3994 100644 --- a/datafusion/proto-common/src/lib.rs +++ b/datafusion/proto-common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index c06427065733..da88b130b11a 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -405,6 +405,42 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } + ScalarValue::Decimal32(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal32Value(protobuf::Decimal32 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, + ScalarValue::Decimal64(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal64Value(protobuf::Decimal64 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, ScalarValue::Decimal128(val, p, s) => match *val { Some(v) => { let array = v.to_be_bytes(); diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index b4d72aa1b6cb..1594564c650b 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index bec7d90062eb..1217b81e5a25 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 7e11f160a397..da15b90d22a8 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d4e911a62c09..96852a0de221 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -35,7 +35,9 @@ use arrow::array::{ }, ArrayRef, Date32Array, Date64Array, PrimitiveArray, }; -use arrow::datatypes::{DataType, Decimal128Type, Decimal256Type, DecimalType}; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, +}; use arrow::util::display::array_value_to_string; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, @@ -1182,6 +1184,20 @@ impl Unparser<'_> { Ok(ast::Expr::value(ast::Value::Number(f_val, false))) } ScalarValue::Float64(None) => Ok(ast::Expr::value(ast::Value::Null)), + ScalarValue::Decimal32(Some(value), precision, scale) => { + Ok(ast::Expr::value(ast::Value::Number( + Decimal32Type::format_decimal(*value, *precision, *scale), + false, + ))) + } + ScalarValue::Decimal32(None, ..) => Ok(ast::Expr::value(ast::Value::Null)), + ScalarValue::Decimal64(Some(value), precision, scale) => { + Ok(ast::Expr::value(ast::Value::Number( + Decimal64Type::format_decimal(*value, *precision, *scale), + false, + ))) + } + ScalarValue::Decimal64(None, ..) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Decimal128(Some(value), precision, scale) => { Ok(ast::Expr::value(ast::Value::Number( Decimal128Type::format_decimal(*value, *precision, *scale), @@ -1726,13 +1742,9 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), - DataType::Decimal32(_precision, _scale) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Decimal64(_precision, _scale) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Decimal128(precision, scale) + DataType::Decimal32(precision, scale) + | DataType::Decimal64(precision, scale) + | DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { let mut new_precision = *precision as u64; let mut new_scale = *scale as u64; @@ -2179,6 +2191,20 @@ mod tests { (col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#), (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), // See test_interval_scalar_to_expr for interval literals + ( + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal32(Some(1123), 4, 3), + None, + )), + r#"((a + b) > 1.123)"#, + ), + ( + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal64(Some(1123), 4, 3), + None, + )), + r#"((a + b) > 1.123)"#, + ), ( (col("a") + col("b")).gt(Expr::Literal( ScalarValue::Decimal128(Some(100123), 28, 3), diff --git a/datafusion/sqllogictest/src/lib.rs b/datafusion/sqllogictest/src/lib.rs index 3c786d6bdaac..f3a78607242c 100644 --- a/datafusion/sqllogictest/src/lib.rs +++ b/datafusion/sqllogictest/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index caf8d637ec45..eed3721078c7 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -7322,6 +7322,38 @@ SELECT a, median(b), arrow_typeof(median(b)) FROM group_median_all_nulls GROUP B group0 NULL Int32 group1 NULL Int32 +statement ok +create table t_decimal (c decimal(10, 4)) as values (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null); + +# Test avg_distinct for Decimal128 +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal; +---- +180 Decimal128(14, 8) + +statement ok +drop table t_decimal; + +# Test avg_distinct for Decimal256 +statement ok +create table t_decimal256 (c decimal(50, 2)) as values + (100.00), + (125.00), + (175.00), + (200.00), + (200.00), + (300.00), + (null), + (null); + +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal256; +---- +180 Decimal256(54, 6) + +statement ok +drop table t_decimal256; + query I with test AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) select count(*) from test WHERE 1 = 1; @@ -7444,55 +7476,65 @@ FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); # distinct average statement ok -create table distinct_avg (a int, b double) as values - (3, null), - (2, null), - (5, 100.5), - (5, 1.0), - (5, 44.112), - (null, 1.0), - (5, 100.5), - (1, 4.09), - (5, 100.5), - (5, 100.5), - (4, null), - (null, null) +create table distinct_avg (a int, b double, c decimal(10, 4), d decimal(50, 2)) as values + (3, null, 100.2562, 90251.21), + (2, null, 100.2562, null), + (5, 100.5, null, 10000000.11), + (5, 1.0, 100.2563, -1.0), + (5, 44.112, -132.12, null), + (null, 1.0, 100.2562, 90251.21), + (5, 100.5, -100.2562, -10000000.11), + (1, 4.09, 4222.124, 0.0), + (5, 100.5, null, 10000000.11), + (5, 100.5, 1.1, 1.0), + (4, null, 4222.124, null), + (null, null, null, null) ; # Need two columns to ensure single_distinct_to_group_by rule doesn't kick in, so we know our actual avg(distinct) code is being tested -query RTRTRR +query RTRTRTRTRRRR select avg(distinct a), arrow_typeof(avg(distinct a)), avg(distinct b), arrow_typeof(avg(distinct b)), + avg(distinct c), + arrow_typeof(avg(distinct c)), + avg(distinct d), + arrow_typeof(avg(distinct d)), avg(a), - avg(b) + avg(b), + avg(c), + avg(d) from distinct_avg; ---- -3 Float64 37.4255 Float64 4 56.52525 +3 Float64 37.4255 Float64 698.56005 Decimal128(14, 8) 15041.868333 Decimal256(54, 6) 4 56.52525 957.11074444 1272562.81625 -query RR rowsort +query RRRR rowsort select avg(distinct a), - avg(distinct b) + avg(distinct b), + avg(distinct c), + avg(distinct d) from distinct_avg group by b; ---- -1 4.09 -3 NULL -5 1 -5 100.5 -5 44.112 +1 4.09 4222.124 0 +3 NULL 2161.1901 90251.21 +5 1 100.25625 45125.105 +5 100.5 -49.5781 0.333333 +5 44.112 -132.12 NULL -query RR +query RRRR select avg(distinct a), - avg(distinct b) + avg(distinct b), + avg(distinct c), + avg(distinct d) from distinct_avg -where a is null and b is null; +where a is null and b is null and c is null and d is null; ---- -NULL NULL +NULL NULL NULL NULL statement ok drop table distinct_avg; diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 0f2fbf199be3..9a4f44e81df2 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index e30a1046ab27..d2efe995f100 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] extern crate wasm_bindgen; diff --git a/test-utils/src/array_gen/random_data.rs b/test-utils/src/array_gen/random_data.rs index 78518b7bf9dc..ea2b872f7d86 100644 --- a/test-utils/src/array_gen/random_data.rs +++ b/test-utils/src/array_gen/random_data.rs @@ -17,12 +17,12 @@ use arrow::array::ArrowPrimitiveType; use arrow::datatypes::{ - i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, - DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, - IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTime, IntervalDayTimeType, + IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalYearMonthType, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; @@ -67,6 +67,8 @@ basic_random_data!(Time32MillisecondType); basic_random_data!(Time64MicrosecondType); basic_random_data!(Time64NanosecondType); basic_random_data!(IntervalYearMonthType); +basic_random_data!(Decimal32Type); +basic_random_data!(Decimal64Type); basic_random_data!(Decimal128Type); basic_random_data!(TimestampSecondType); basic_random_data!(TimestampMillisecondType);