-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix: support float16 for abs()
#18304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,8 @@ use std::sync::Arc; | |
|
|
||
| use arrow::array::{ | ||
| ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, | ||
| Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, | ||
| Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, | ||
| Int8Array, | ||
| }; | ||
| use arrow::datatypes::DataType; | ||
| use arrow::error::ArrowError; | ||
|
|
@@ -34,6 +35,7 @@ use datafusion_expr::{ | |
| Volatility, | ||
| }; | ||
| use datafusion_macros::user_doc; | ||
| use num_traits::sign::Signed; | ||
|
|
||
| type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>; | ||
|
|
||
|
|
@@ -81,6 +83,7 @@ macro_rules! make_decimal_abs_function { | |
| /// Return different implementations based on input datatype to reduce branches during execution | ||
| fn create_abs_function(input_data_type: &DataType) -> Result<MathArrayFunction> { | ||
| match input_data_type { | ||
| DataType::Float16 => Ok(make_abs_function!(Float16Array)), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix here |
||
| DataType::Float32 => Ok(make_abs_function!(Float32Array)), | ||
| DataType::Float64 => Ok(make_abs_function!(Float64Array)), | ||
|
|
||
|
|
@@ -143,6 +146,7 @@ impl ScalarUDFImpl for AbsFunc { | |
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
|
|
||
| fn name(&self) -> &str { | ||
| "abs" | ||
| } | ||
|
|
@@ -152,35 +156,7 @@ impl ScalarUDFImpl for AbsFunc { | |
| } | ||
|
|
||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
| match arg_types[0] { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a little cleanup |
||
| DataType::Float32 => Ok(DataType::Float32), | ||
| DataType::Float64 => Ok(DataType::Float64), | ||
| DataType::Int8 => Ok(DataType::Int8), | ||
| DataType::Int16 => Ok(DataType::Int16), | ||
| DataType::Int32 => Ok(DataType::Int32), | ||
| DataType::Int64 => Ok(DataType::Int64), | ||
| DataType::Null => Ok(DataType::Null), | ||
| DataType::UInt8 => Ok(DataType::UInt8), | ||
| DataType::UInt16 => Ok(DataType::UInt16), | ||
| DataType::UInt32 => Ok(DataType::UInt32), | ||
| DataType::UInt64 => Ok(DataType::UInt64), | ||
| DataType::Decimal32(precision, scale) => { | ||
| Ok(DataType::Decimal32(precision, scale)) | ||
| } | ||
| DataType::Decimal64(precision, scale) => { | ||
| Ok(DataType::Decimal64(precision, scale)) | ||
| } | ||
| DataType::Decimal128(precision, scale) => { | ||
| Ok(DataType::Decimal128(precision, scale)) | ||
| } | ||
| DataType::Decimal256(precision, scale) => { | ||
| Ok(DataType::Decimal256(precision, scale)) | ||
| } | ||
| _ => not_impl_err!( | ||
| "Unsupported data type {} for function abs", | ||
| arg_types[0].to_string() | ||
| ), | ||
| } | ||
| Ok(arg_types[0].clone()) | ||
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -139,16 +139,16 @@ select abs(arrow_cast('-1.2', 'Utf8')); | |
|
|
||
| statement ok | ||
| CREATE TABLE test_nullable_integer( | ||
| c1 TINYINT, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed my editor to remove trailing whitespaces on write hence these changes; I wonder if we should have a lint for SLT files? 🤔 |
||
| c2 SMALLINT, | ||
| c3 INT, | ||
| c4 BIGINT, | ||
| c5 TINYINT UNSIGNED, | ||
| c6 SMALLINT UNSIGNED, | ||
| c7 INT UNSIGNED, | ||
| c8 BIGINT UNSIGNED, | ||
| c1 TINYINT, | ||
| c2 SMALLINT, | ||
| c3 INT, | ||
| c4 BIGINT, | ||
| c5 TINYINT UNSIGNED, | ||
| c6 SMALLINT UNSIGNED, | ||
| c7 INT UNSIGNED, | ||
| c8 BIGINT UNSIGNED, | ||
| dataset TEXT | ||
| ) | ||
| ) | ||
| AS VALUES | ||
| (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'nulls'), | ||
| (0, 0, 0, 0, 0, 0, 0, 0, 'zeros'), | ||
|
|
@@ -237,7 +237,7 @@ SELECT c8%0 FROM test_nullable_integer | |
|
|
||
| # abs: return type | ||
| query TTTTTTTT rowsort | ||
| select | ||
| select | ||
| arrow_typeof(abs(c1)), arrow_typeof(abs(c2)), arrow_typeof(abs(c3)), arrow_typeof(abs(c4)), | ||
| arrow_typeof(abs(c5)), arrow_typeof(abs(c6)), arrow_typeof(abs(c7)), arrow_typeof(abs(c8)) | ||
| from test_nullable_integer limit 1 | ||
|
|
@@ -285,13 +285,13 @@ drop table test_nullable_integer | |
|
|
||
| statement ok | ||
| CREATE TABLE test_non_nullable_integer( | ||
| c1 TINYINT NOT NULL, | ||
| c2 SMALLINT NOT NULL, | ||
| c3 INT NOT NULL, | ||
| c4 BIGINT NOT NULL, | ||
| c5 TINYINT UNSIGNED NOT NULL, | ||
| c6 SMALLINT UNSIGNED NOT NULL, | ||
| c7 INT UNSIGNED NOT NULL, | ||
| c1 TINYINT NOT NULL, | ||
| c2 SMALLINT NOT NULL, | ||
| c3 INT NOT NULL, | ||
| c4 BIGINT NOT NULL, | ||
| c5 TINYINT UNSIGNED NOT NULL, | ||
| c6 SMALLINT UNSIGNED NOT NULL, | ||
| c7 INT UNSIGNED NOT NULL, | ||
| c8 BIGINT UNSIGNED NOT NULL | ||
| ); | ||
|
|
||
|
|
@@ -363,7 +363,7 @@ CREATE TABLE test_nullable_float( | |
| c2 double | ||
| ) AS VALUES | ||
| (-1.0, -1.0), | ||
| (1.0, 1.0), | ||
| (1.0, 1.0), | ||
| (NULL, NULL), | ||
| (0., 0.), | ||
| ('NaN'::double, 'NaN'::double); | ||
|
|
@@ -412,14 +412,25 @@ Float32 Float64 | |
|
|
||
| # abs: floats | ||
| query RR rowsort | ||
| SELECT abs(c1), abs(c2) from test_nullable_float | ||
| SELECT abs(c1), abs(c2) from test_nullable_float | ||
| ---- | ||
| 0 0 | ||
| 1 1 | ||
| 1 1 | ||
| NULL NULL | ||
| NaN NaN | ||
|
|
||
| # f16 | ||
| query TR rowsort | ||
| SELECT arrow_typeof(abs(arrow_cast(c1, 'Float16'))), abs(arrow_cast(c1, 'Float16')) | ||
| FROM test_nullable_float | ||
| ---- | ||
| Float16 0 | ||
| Float16 1 | ||
| Float16 1 | ||
| Float16 NULL | ||
| Float16 NaN | ||
|
|
||
| statement ok | ||
| drop table test_nullable_float | ||
|
|
||
|
|
@@ -428,7 +439,7 @@ statement ok | |
| CREATE TABLE test_non_nullable_float( | ||
| c1 float NOT NULL, | ||
| c2 double NOT NULL | ||
| ); | ||
| ); | ||
|
|
||
| query I | ||
| INSERT INTO test_non_nullable_float VALUES | ||
|
|
@@ -478,27 +489,27 @@ drop table test_non_nullable_float | |
| statement ok | ||
| CREATE TABLE test_nullable_decimal( | ||
| c1 DECIMAL(10, 2), /* Decimal128 */ | ||
| c2 DECIMAL(38, 10), /* Decimal128 with max precision */ | ||
| c2 DECIMAL(38, 10), /* Decimal128 with max precision */ | ||
| c3 DECIMAL(40, 2), /* Decimal256 */ | ||
| c4 DECIMAL(76, 10) /* Decimal256 with max precision */ | ||
| ) AS VALUES | ||
| (0, 0, 0, 0), | ||
| c4 DECIMAL(76, 10) /* Decimal256 with max precision */ | ||
| ) AS VALUES | ||
| (0, 0, 0, 0), | ||
| (NULL, NULL, NULL, NULL); | ||
|
|
||
| query I | ||
| INSERT into test_nullable_decimal values | ||
| ( | ||
| -99999999.99, | ||
| '-9999999999999999999999999999.9999999999', | ||
| '-99999999999999999999999999999999999999.99', | ||
| -99999999.99, | ||
| '-9999999999999999999999999999.9999999999', | ||
| '-99999999999999999999999999999999999999.99', | ||
| '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' | ||
| ), | ||
| ), | ||
| ( | ||
| 99999999.99, | ||
| '9999999999999999999999999999.9999999999', | ||
| '99999999999999999999999999999999999999.99', | ||
| 99999999.99, | ||
| '9999999999999999999999999999.9999999999', | ||
| '99999999999999999999999999999999999999.99', | ||
| '999999999999999999999999999999999999999999999999999999999999999999.9999999999' | ||
| ) | ||
| ) | ||
| ---- | ||
| 2 | ||
|
|
||
|
|
@@ -533,9 +544,9 @@ SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; | |
|
|
||
| # abs: return type | ||
| query TTTT | ||
| SELECT | ||
| arrow_typeof(abs(c1)), | ||
| arrow_typeof(abs(c2)), | ||
| SELECT | ||
| arrow_typeof(abs(c1)), | ||
| arrow_typeof(abs(c2)), | ||
| arrow_typeof(abs(c3)), | ||
| arrow_typeof(abs(c4)) | ||
| FROM test_nullable_decimal limit 1 | ||
|
|
@@ -552,11 +563,11 @@ SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal | |
| NULL NULL NULL NULL | ||
|
|
||
| statement ok | ||
| drop table test_nullable_decimal | ||
| drop table test_nullable_decimal | ||
|
|
||
|
|
||
| statement ok | ||
| CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL); | ||
| CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL); | ||
|
|
||
| query I | ||
| INSERT INTO test_non_nullable_decimal VALUES(1) | ||
|
|
@@ -569,13 +580,13 @@ SELECT c1*0 FROM test_non_nullable_decimal | |
| 0 | ||
|
|
||
| query error DataFusion error: Arrow error: Divide by zero error | ||
| SELECT c1/0 FROM test_non_nullable_decimal | ||
| SELECT c1/0 FROM test_non_nullable_decimal | ||
|
|
||
| query error DataFusion error: Arrow error: Divide by zero error | ||
| SELECT c1%0 FROM test_non_nullable_decimal | ||
| SELECT c1%0 FROM test_non_nullable_decimal | ||
|
|
||
| statement ok | ||
| drop table test_non_nullable_decimal | ||
| drop table test_non_nullable_decimal | ||
|
|
||
| statement ok | ||
| CREATE TABLE signed_integers( | ||
|
|
@@ -615,7 +626,7 @@ NULL NULL NULL | |
|
|
||
| # scalar maxes and/or negative 1 | ||
| query III | ||
| select | ||
| select | ||
| gcd(9223372036854775807, -9223372036854775808), -- i64::MAX, i64::MIN | ||
| gcd(9223372036854775807, -1), -- i64::MAX, -1 | ||
| gcd(-9223372036854775808, -1); -- i64::MIN, -1 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need num_traits to get the abs implementation for f16 as it's only implemented for this trait: https://docs.rs/half/latest/half/struct.f16.html#method.abs-2