1616// under the License.
1717
1818use crate :: signature:: TypeSignature ;
19- use arrow:: datatypes:: {
20- DataType , FieldRef , TimeUnit , DECIMAL128_MAX_PRECISION , DECIMAL128_MAX_SCALE ,
21- DECIMAL256_MAX_PRECISION , DECIMAL256_MAX_SCALE , DECIMAL32_MAX_PRECISION ,
22- DECIMAL32_MAX_SCALE , DECIMAL64_MAX_PRECISION , DECIMAL64_MAX_SCALE ,
23- } ;
19+ use arrow:: datatypes:: { DataType , FieldRef } ;
2420
2521use datafusion_common:: { internal_err, plan_err, Result } ;
2622
27- pub static STRINGS : & [ DataType ] =
28- & [ DataType :: Utf8 , DataType :: LargeUtf8 , DataType :: Utf8View ] ;
29-
30- pub static SIGNED_INTEGERS : & [ DataType ] = & [
31- DataType :: Int8 ,
32- DataType :: Int16 ,
33- DataType :: Int32 ,
34- DataType :: Int64 ,
35- ] ;
36-
37- pub static UNSIGNED_INTEGERS : & [ DataType ] = & [
38- DataType :: UInt8 ,
39- DataType :: UInt16 ,
40- DataType :: UInt32 ,
41- DataType :: UInt64 ,
42- ] ;
43-
23+ // TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures
24+ // see https://github.com/apache/datafusion/issues/18092
4425pub static INTEGERS : & [ DataType ] = & [
4526 DataType :: Int8 ,
4627 DataType :: Int16 ,
@@ -65,24 +46,6 @@ pub static NUMERICS: &[DataType] = &[
6546 DataType :: Float64 ,
6647] ;
6748
68- pub static TIMESTAMPS : & [ DataType ] = & [
69- DataType :: Timestamp ( TimeUnit :: Second , None ) ,
70- DataType :: Timestamp ( TimeUnit :: Millisecond , None ) ,
71- DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
72- DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
73- ] ;
74-
75- pub static DATES : & [ DataType ] = & [ DataType :: Date32 , DataType :: Date64 ] ;
76-
77- pub static BINARYS : & [ DataType ] = & [ DataType :: Binary , DataType :: LargeBinary ] ;
78-
79- pub static TIMES : & [ DataType ] = & [
80- DataType :: Time32 ( TimeUnit :: Second ) ,
81- DataType :: Time32 ( TimeUnit :: Millisecond ) ,
82- DataType :: Time64 ( TimeUnit :: Microsecond ) ,
83- DataType :: Time64 ( TimeUnit :: Nanosecond ) ,
84- ] ;
85-
8649/// Validate the length of `input_fields` matches the `signature` for `agg_fun`.
8750///
8851/// This method DOES NOT validate the argument fields - only that (at least one,
@@ -144,260 +107,3 @@ pub fn check_arg_count(
144107 }
145108 Ok ( ( ) )
146109}
147-
148- /// Function return type of a sum
149- pub fn sum_return_type ( arg_type : & DataType ) -> Result < DataType > {
150- match arg_type {
151- DataType :: Int64 => Ok ( DataType :: Int64 ) ,
152- DataType :: UInt64 => Ok ( DataType :: UInt64 ) ,
153- DataType :: Float64 => Ok ( DataType :: Float64 ) ,
154- DataType :: Decimal32 ( precision, scale) => {
155- // in the spark, the result type is DECIMAL(min(38,precision+10), s)
156- // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
157- let new_precision = DECIMAL32_MAX_PRECISION . min ( * precision + 10 ) ;
158- Ok ( DataType :: Decimal32 ( new_precision, * scale) )
159- }
160- DataType :: Decimal64 ( precision, scale) => {
161- // in the spark, the result type is DECIMAL(min(38,precision+10), s)
162- // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
163- let new_precision = DECIMAL64_MAX_PRECISION . min ( * precision + 10 ) ;
164- Ok ( DataType :: Decimal64 ( new_precision, * scale) )
165- }
166- DataType :: Decimal128 ( precision, scale) => {
167- // In the spark, the result type is DECIMAL(min(38,precision+10), s)
168- // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
169- let new_precision = DECIMAL128_MAX_PRECISION . min ( * precision + 10 ) ;
170- Ok ( DataType :: Decimal128 ( new_precision, * scale) )
171- }
172- DataType :: Decimal256 ( precision, scale) => {
173- // In the spark, the result type is DECIMAL(min(38,precision+10), s)
174- // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
175- let new_precision = DECIMAL256_MAX_PRECISION . min ( * precision + 10 ) ;
176- Ok ( DataType :: Decimal256 ( new_precision, * scale) )
177- }
178- other => plan_err ! ( "SUM does not support type \" {other:?}\" " ) ,
179- }
180- }
181-
182- /// Function return type of variance
183- pub fn variance_return_type ( arg_type : & DataType ) -> Result < DataType > {
184- if NUMERICS . contains ( arg_type) {
185- Ok ( DataType :: Float64 )
186- } else {
187- plan_err ! ( "VAR does not support {arg_type}" )
188- }
189- }
190-
191- /// Function return type of covariance
192- pub fn covariance_return_type ( arg_type : & DataType ) -> Result < DataType > {
193- if NUMERICS . contains ( arg_type) {
194- Ok ( DataType :: Float64 )
195- } else {
196- plan_err ! ( "COVAR does not support {arg_type}" )
197- }
198- }
199-
200- /// Function return type of correlation
201- pub fn correlation_return_type ( arg_type : & DataType ) -> Result < DataType > {
202- if NUMERICS . contains ( arg_type) {
203- Ok ( DataType :: Float64 )
204- } else {
205- plan_err ! ( "CORR does not support {arg_type}" )
206- }
207- }
208-
209- /// Function return type of an average
210- pub fn avg_return_type ( func_name : & str , arg_type : & DataType ) -> Result < DataType > {
211- match arg_type {
212- DataType :: Decimal32 ( precision, scale) => {
213- // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
214- // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
215- let new_precision = DECIMAL32_MAX_PRECISION . min ( * precision + 4 ) ;
216- let new_scale = DECIMAL32_MAX_SCALE . min ( * scale + 4 ) ;
217- Ok ( DataType :: Decimal32 ( new_precision, new_scale) )
218- }
219- DataType :: Decimal64 ( precision, scale) => {
220- // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
221- // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
222- let new_precision = DECIMAL64_MAX_PRECISION . min ( * precision + 4 ) ;
223- let new_scale = DECIMAL64_MAX_SCALE . min ( * scale + 4 ) ;
224- Ok ( DataType :: Decimal64 ( new_precision, new_scale) )
225- }
226- DataType :: Decimal128 ( precision, scale) => {
227- // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
228- // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
229- let new_precision = DECIMAL128_MAX_PRECISION . min ( * precision + 4 ) ;
230- let new_scale = DECIMAL128_MAX_SCALE . min ( * scale + 4 ) ;
231- Ok ( DataType :: Decimal128 ( new_precision, new_scale) )
232- }
233- DataType :: Decimal256 ( precision, scale) => {
234- // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
235- // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
236- let new_precision = DECIMAL256_MAX_PRECISION . min ( * precision + 4 ) ;
237- let new_scale = DECIMAL256_MAX_SCALE . min ( * scale + 4 ) ;
238- Ok ( DataType :: Decimal256 ( new_precision, new_scale) )
239- }
240- DataType :: Duration ( time_unit) => Ok ( DataType :: Duration ( * time_unit) ) ,
241- arg_type if NUMERICS . contains ( arg_type) => Ok ( DataType :: Float64 ) ,
242- DataType :: Dictionary ( _, dict_value_type) => {
243- avg_return_type ( func_name, dict_value_type. as_ref ( ) )
244- }
245- other => plan_err ! ( "{func_name} does not support {other:?}" ) ,
246- }
247- }
248-
249- /// Internal sum type of an average
250- pub fn avg_sum_type ( arg_type : & DataType ) -> Result < DataType > {
251- match arg_type {
252- DataType :: Decimal32 ( precision, scale) => {
253- // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
254- let new_precision = DECIMAL32_MAX_PRECISION . min ( * precision + 10 ) ;
255- Ok ( DataType :: Decimal32 ( new_precision, * scale) )
256- }
257- DataType :: Decimal64 ( precision, scale) => {
258- // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
259- let new_precision = DECIMAL64_MAX_PRECISION . min ( * precision + 10 ) ;
260- Ok ( DataType :: Decimal64 ( new_precision, * scale) )
261- }
262- DataType :: Decimal128 ( precision, scale) => {
263- // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
264- let new_precision = DECIMAL128_MAX_PRECISION . min ( * precision + 10 ) ;
265- Ok ( DataType :: Decimal128 ( new_precision, * scale) )
266- }
267- DataType :: Decimal256 ( precision, scale) => {
268- // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s)
269- let new_precision = DECIMAL256_MAX_PRECISION . min ( * precision + 10 ) ;
270- Ok ( DataType :: Decimal256 ( new_precision, * scale) )
271- }
272- DataType :: Duration ( time_unit) => Ok ( DataType :: Duration ( * time_unit) ) ,
273- arg_type if NUMERICS . contains ( arg_type) => Ok ( DataType :: Float64 ) ,
274- DataType :: Dictionary ( _, dict_value_type) => {
275- avg_sum_type ( dict_value_type. as_ref ( ) )
276- }
277- other => plan_err ! ( "AVG does not support {other:?}" ) ,
278- }
279- }
280-
281- pub fn is_sum_support_arg_type ( arg_type : & DataType ) -> bool {
282- match arg_type {
283- DataType :: Dictionary ( _, dict_value_type) => {
284- is_sum_support_arg_type ( dict_value_type. as_ref ( ) )
285- }
286- _ => matches ! (
287- arg_type,
288- arg_type if NUMERICS . contains( arg_type)
289- || matches!( arg_type, DataType :: Decimal32 ( _, _) | DataType :: Decimal64 ( _, _) |DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
290- ) ,
291- }
292- }
293-
294- pub fn is_avg_support_arg_type ( arg_type : & DataType ) -> bool {
295- match arg_type {
296- DataType :: Dictionary ( _, dict_value_type) => {
297- is_avg_support_arg_type ( dict_value_type. as_ref ( ) )
298- }
299- _ => matches ! (
300- arg_type,
301- arg_type if NUMERICS . contains( arg_type)
302- || matches!( arg_type, DataType :: Decimal32 ( _, _) | DataType :: Decimal64 ( _, _) |DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
303- ) ,
304- }
305- }
306-
307- pub fn is_variance_support_arg_type ( arg_type : & DataType ) -> bool {
308- matches ! (
309- arg_type,
310- arg_type if NUMERICS . contains( arg_type)
311- )
312- }
313-
314- pub fn is_covariance_support_arg_type ( arg_type : & DataType ) -> bool {
315- matches ! (
316- arg_type,
317- arg_type if NUMERICS . contains( arg_type)
318- )
319- }
320-
321- pub fn is_correlation_support_arg_type ( arg_type : & DataType ) -> bool {
322- matches ! (
323- arg_type,
324- arg_type if NUMERICS . contains( arg_type)
325- )
326- }
327-
328- pub fn is_integer_arg_type ( arg_type : & DataType ) -> bool {
329- arg_type. is_integer ( )
330- }
331-
332- pub fn coerce_avg_type ( func_name : & str , arg_types : & [ DataType ] ) -> Result < Vec < DataType > > {
333- // Supported types smallint, int, bigint, real, double precision, decimal, or interval
334- // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
335- fn coerced_type ( func_name : & str , data_type : & DataType ) -> Result < DataType > {
336- match & data_type {
337- DataType :: Decimal32 ( p, s) => Ok ( DataType :: Decimal32 ( * p, * s) ) ,
338- DataType :: Decimal64 ( p, s) => Ok ( DataType :: Decimal64 ( * p, * s) ) ,
339- DataType :: Decimal128 ( p, s) => Ok ( DataType :: Decimal128 ( * p, * s) ) ,
340- DataType :: Decimal256 ( p, s) => Ok ( DataType :: Decimal256 ( * p, * s) ) ,
341- d if d. is_numeric ( ) => Ok ( DataType :: Float64 ) ,
342- DataType :: Duration ( time_unit) => Ok ( DataType :: Duration ( * time_unit) ) ,
343- DataType :: Dictionary ( _, v) => coerced_type ( func_name, v. as_ref ( ) ) ,
344- _ => {
345- plan_err ! (
346- "The function {:?} does not support inputs of type {}." ,
347- func_name,
348- data_type
349- )
350- }
351- }
352- }
353- Ok ( vec ! [ coerced_type( func_name, & arg_types[ 0 ] ) ?] )
354- }
355- #[ cfg( test) ]
356- mod tests {
357- use super :: * ;
358-
359- #[ test]
360- fn test_variance_return_data_type ( ) -> Result < ( ) > {
361- let data_type = DataType :: Float64 ;
362- let result_type = variance_return_type ( & data_type) ?;
363- assert_eq ! ( DataType :: Float64 , result_type) ;
364-
365- let data_type = DataType :: Decimal128 ( 36 , 10 ) ;
366- assert ! ( variance_return_type( & data_type) . is_err( ) ) ;
367- Ok ( ( ) )
368- }
369-
370- #[ test]
371- fn test_sum_return_data_type ( ) -> Result < ( ) > {
372- let data_type = DataType :: Decimal128 ( 10 , 5 ) ;
373- let result_type = sum_return_type ( & data_type) ?;
374- assert_eq ! ( DataType :: Decimal128 ( 20 , 5 ) , result_type) ;
375-
376- let data_type = DataType :: Decimal128 ( 36 , 10 ) ;
377- let result_type = sum_return_type ( & data_type) ?;
378- assert_eq ! ( DataType :: Decimal128 ( 38 , 10 ) , result_type) ;
379- Ok ( ( ) )
380- }
381-
382- #[ test]
383- fn test_covariance_return_data_type ( ) -> Result < ( ) > {
384- let data_type = DataType :: Float64 ;
385- let result_type = covariance_return_type ( & data_type) ?;
386- assert_eq ! ( DataType :: Float64 , result_type) ;
387-
388- let data_type = DataType :: Decimal128 ( 36 , 10 ) ;
389- assert ! ( covariance_return_type( & data_type) . is_err( ) ) ;
390- Ok ( ( ) )
391- }
392-
393- #[ test]
394- fn test_correlation_return_data_type ( ) -> Result < ( ) > {
395- let data_type = DataType :: Float64 ;
396- let result_type = correlation_return_type ( & data_type) ?;
397- assert_eq ! ( DataType :: Float64 , result_type) ;
398-
399- let data_type = DataType :: Decimal128 ( 36 , 10 ) ;
400- assert ! ( correlation_return_type( & data_type) . is_err( ) ) ;
401- Ok ( ( ) )
402- }
403- }
0 commit comments