Skip to content

Perform type coercion for corr aggregate function #15776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions datafusion/functions-aggregate/src/correlation.rs
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ use arrow::array::{
downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder,
UInt64Array,
};
use arrow::compute::{and, filter, is_not_null, kernels::cast};
use arrow::compute::{and, filter, is_not_null};
use arrow::datatypes::{FieldRef, Float64Type, UInt64Type};
use arrow::{
array::ArrayRef,
@@ -38,10 +38,9 @@ use log::debug;

use crate::covariance::CovarianceAccumulator;
use crate::stddev::StddevAccumulator;
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
};
@@ -83,10 +82,13 @@ impl Default for Correlation {
}

impl Correlation {
/// Create a new COVAR_POP aggregate function
/// Create a new CORR aggregate function
pub fn new() -> Self {
Self {
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
signature: Signature::exact(
vec![DataType::Float64, DataType::Float64],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Volatility::Immutable,
),
}
}
}
@@ -105,11 +107,7 @@ impl AggregateUDFImpl for Correlation {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Correlation requires numeric input types");
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

@@ -375,10 +373,8 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
self.sum_xx.resize(total_num_groups, 0.0);
self.sum_yy.resize(total_num_groups, 0.0);

let array_x = &cast(&values[0], &DataType::Float64)?;
let array_x = downcast_array::<Float64Array>(array_x);
let array_y = &cast(&values[1], &DataType::Float64)?;
let array_y = downcast_array::<Float64Array>(array_y);
let array_x = downcast_array::<Float64Array>(&values[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me like this code only handles Float64 but the signature of the function reports that it accepts any numeric type:

impl Correlation {
    /// Create a new COVAR_POP aggregate function
    pub fn new() -> Self {
        Self {
            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
        }
    }
}

I wonder if you just changed the signature to say the function needs Float64 argument types, woudl that be enough?

DataFusion already has a bunch of coercion rules, see https://docs.rs/datafusion/latest/datafusion/logical_expr/type_coercion/index.html for example

let array_y = downcast_array::<Float64Array>(&values[1]);

accumulate_multiple(
group_indices,
110 changes: 110 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
@@ -2548,7 +2548,117 @@ select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t;
statement ok
drop table t;

# correlation_f64_1
statement ok
create table t (c1 double, c2 double) as values (1, 4), (2, 5), (3, 6);

query RT rowsort
select corr(c1, c2), arrow_typeof(corr(c1, c2)) from t;
----
1 Float64

# correlation with different numeric types (create test data)
statement ok
CREATE OR REPLACE TABLE corr_test(
int8_col TINYINT,
int16_col SMALLINT,
int32_col INT,
int64_col BIGINT,
uint32_col INT UNSIGNED,
float32_col FLOAT,
float64_col DOUBLE
) as VALUES
(1, 10, 100, 1000, 10000, 1.1, 10.1),
(2, 20, 200, 2000, 20000, 2.2, 20.2),
(3, 30, 300, 3000, 30000, 3.3, 30.3),
(4, 40, 400, 4000, 40000, 4.4, 40.4),
(5, 50, 500, 5000, 50000, 5.5, 50.5);

# correlation using int32 and float64
query R
SELECT corr(int32_col, float64_col) FROM corr_test;
----
1

# correlation using int64 and int32
query R
SELECT corr(int64_col, int32_col) FROM corr_test;
----
1

# correlation using float32 and int8
query R
SELECT corr(float32_col, int8_col) FROM corr_test;
----
1

# correlation using uint32 and int16
query R
SELECT corr(uint32_col, int16_col) FROM corr_test;
----
1

# correlation with nulls
statement ok
CREATE OR REPLACE TABLE corr_nulls(
x INT,
y DOUBLE
) as VALUES
(1, 10.0),
(2, 20.0),
(NULL, 30.0),
(4, NULL),
(5, 50.0);

# correlation with some nulls (should skip null pairs)
query R
SELECT corr(x, y) FROM corr_nulls;
----
1

# correlation with single row (should return NULL)
statement ok
CREATE OR REPLACE TABLE corr_single_row(
x INT,
y DOUBLE
) as VALUES
(1, 10.0);

query R
SELECT corr(x, y) FROM corr_single_row;
----
0

# correlation with all nulls
statement ok
CREATE OR REPLACE TABLE corr_all_nulls(
x INT,
y DOUBLE
) as VALUES
(NULL, NULL),
(NULL, NULL);

query R
SELECT corr(x, y) FROM corr_all_nulls;
----
NULL

statement ok
drop table corr_test;

statement ok
drop table corr_nulls;

statement ok
drop table corr_single_row;

statement ok
drop table corr_all_nulls;

# covariance_f64_4
statement ok
drop table if exists t;

statement ok
create table t (c1 double, c2 double) as values (1.1, 4.1), (2.0, 5.0), (3.0, 6.0);