-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Perform type coercion for corr aggregate function #15776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ use arrow::array::{ | |
downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder, | ||
UInt64Array, | ||
}; | ||
use arrow::compute::{and, filter, is_not_null, kernels::cast}; | ||
use arrow::compute::{and, filter, is_not_null}; | ||
use arrow::datatypes::{FieldRef, Float64Type, UInt64Type}; | ||
use arrow::{ | ||
array::ArrayRef, | ||
|
@@ -38,10 +38,9 @@ use log::debug; | |
|
||
use crate::covariance::CovarianceAccumulator; | ||
use crate::stddev::StddevAccumulator; | ||
use datafusion_common::{plan_err, Result, ScalarValue}; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::{ | ||
function::{AccumulatorArgs, StateFieldsArgs}, | ||
type_coercion::aggregates::NUMERICS, | ||
utils::format_state_name, | ||
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, | ||
}; | ||
|
@@ -83,10 +82,13 @@ impl Default for Correlation { | |
} | ||
|
||
impl Correlation { | ||
/// Create a new COVAR_POP aggregate function | ||
/// Create a new CORR aggregate function | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), | ||
signature: Signature::exact( | ||
vec![DataType::Float64, DataType::Float64], | ||
Volatility::Immutable, | ||
), | ||
} | ||
} | ||
} | ||
|
@@ -105,11 +107,7 @@ impl AggregateUDFImpl for Correlation { | |
&self.signature | ||
} | ||
|
||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
if !arg_types[0].is_numeric() { | ||
return plan_err!("Correlation requires numeric input types"); | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(DataType::Float64) | ||
} | ||
|
||
|
@@ -375,10 +373,8 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { | |
self.sum_xx.resize(total_num_groups, 0.0); | ||
self.sum_yy.resize(total_num_groups, 0.0); | ||
|
||
let array_x = &cast(&values[0], &DataType::Float64)?; | ||
let array_x = downcast_array::<Float64Array>(array_x); | ||
let array_y = &cast(&values[1], &DataType::Float64)?; | ||
let array_y = downcast_array::<Float64Array>(array_y); | ||
let array_x = downcast_array::<Float64Array>(&values[0]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks to me like this code only handles impl Correlation {
/// Create a new COVAR_POP aggregate function
pub fn new() -> Self {
Self {
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
}
}
} I wonder if you just changed the signature to say the function needs Float64 argument types, woudl that be enough? DataFusion already has a bunch of coercion rules, see https://docs.rs/datafusion/latest/datafusion/logical_expr/type_coercion/index.html for example |
||
let array_y = downcast_array::<Float64Array>(&values[1]); | ||
|
||
accumulate_multiple( | ||
group_indices, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍