Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ indexmap = "2.12.0"
insta = { version = "1.43.2", features = ["glob", "filters"] }
itertools = "0.14"
log = "^0.4"
num-traits = { version = "0.2" }
object_store = { version = "0.12.4", default-features = false }
parking_lot = "0.12"
parquet = { version = "57.0.0", default-features = false, features = [
Expand Down
2 changes: 1 addition & 1 deletion datafusion/datasource-avro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ datafusion-physical-expr-common = { workspace = true }
datafusion-physical-plan = { workspace = true }
datafusion-session = { workspace = true }
futures = { workspace = true }
num-traits = { version = "0.2" }
num-traits = { workspace = true }
object_store = { workspace = true }

[dev-dependencies]
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ hex = { version = "0.4", optional = true }
itertools = { workspace = true }
log = { workspace = true }
md-5 = { version = "^0.10.0", optional = true }
num-traits = { workspace = true }
rand = { workspace = true }
regex = { workspace = true, optional = true }
sha2 = { version = "^0.10.9", optional = true }
Expand Down
36 changes: 6 additions & 30 deletions datafusion/functions/src/math/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use std::sync::Arc;

use arrow::array::{
ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array,
};
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
Expand All @@ -34,6 +35,7 @@ use datafusion_expr::{
Volatility,
};
use datafusion_macros::user_doc;
use num_traits::sign::Signed;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need num_traits to get the abs implementation for f16 as it's only implemented for this trait: https://docs.rs/half/latest/half/struct.f16.html#method.abs-2


type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;

Expand Down Expand Up @@ -81,6 +83,7 @@ macro_rules! make_decimal_abs_function {
/// Return different implementations based on input datatype to reduce branches during execution
fn create_abs_function(input_data_type: &DataType) -> Result<MathArrayFunction> {
match input_data_type {
DataType::Float16 => Ok(make_abs_function!(Float16Array)),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix here

DataType::Float32 => Ok(make_abs_function!(Float32Array)),
DataType::Float64 => Ok(make_abs_function!(Float64Array)),

Expand Down Expand Up @@ -143,6 +146,7 @@ impl ScalarUDFImpl for AbsFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"abs"
}
Expand All @@ -152,35 +156,7 @@ impl ScalarUDFImpl for AbsFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types[0] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a little cleanup

DataType::Float32 => Ok(DataType::Float32),
DataType::Float64 => Ok(DataType::Float64),
DataType::Int8 => Ok(DataType::Int8),
DataType::Int16 => Ok(DataType::Int16),
DataType::Int32 => Ok(DataType::Int32),
DataType::Int64 => Ok(DataType::Int64),
DataType::Null => Ok(DataType::Null),
DataType::UInt8 => Ok(DataType::UInt8),
DataType::UInt16 => Ok(DataType::UInt16),
DataType::UInt32 => Ok(DataType::UInt32),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Decimal32(precision, scale) => {
Ok(DataType::Decimal32(precision, scale))
}
DataType::Decimal64(precision, scale) => {
Ok(DataType::Decimal64(precision, scale))
}
DataType::Decimal128(precision, scale) => {
Ok(DataType::Decimal128(precision, scale))
}
DataType::Decimal256(precision, scale) => {
Ok(DataType::Decimal256(precision, scale))
}
_ => not_impl_err!(
"Unsupported data type {} for function abs",
arg_types[0].to_string()
),
}
Ok(arg_types[0].clone())
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand Down
93 changes: 52 additions & 41 deletions datafusion/sqllogictest/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,16 @@ select abs(arrow_cast('-1.2', 'Utf8'));

statement ok
CREATE TABLE test_nullable_integer(
c1 TINYINT,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed my editor to remove trailing whitespaces on write hence these changes; I wonder if we should have a lint for SLT files? 🤔

c2 SMALLINT,
c3 INT,
c4 BIGINT,
c5 TINYINT UNSIGNED,
c6 SMALLINT UNSIGNED,
c7 INT UNSIGNED,
c8 BIGINT UNSIGNED,
c1 TINYINT,
c2 SMALLINT,
c3 INT,
c4 BIGINT,
c5 TINYINT UNSIGNED,
c6 SMALLINT UNSIGNED,
c7 INT UNSIGNED,
c8 BIGINT UNSIGNED,
dataset TEXT
)
)
AS VALUES
(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'nulls'),
(0, 0, 0, 0, 0, 0, 0, 0, 'zeros'),
Expand Down Expand Up @@ -237,7 +237,7 @@ SELECT c8%0 FROM test_nullable_integer

# abs: return type
query TTTTTTTT rowsort
select
select
arrow_typeof(abs(c1)), arrow_typeof(abs(c2)), arrow_typeof(abs(c3)), arrow_typeof(abs(c4)),
arrow_typeof(abs(c5)), arrow_typeof(abs(c6)), arrow_typeof(abs(c7)), arrow_typeof(abs(c8))
from test_nullable_integer limit 1
Expand Down Expand Up @@ -285,13 +285,13 @@ drop table test_nullable_integer

statement ok
CREATE TABLE test_non_nullable_integer(
c1 TINYINT NOT NULL,
c2 SMALLINT NOT NULL,
c3 INT NOT NULL,
c4 BIGINT NOT NULL,
c5 TINYINT UNSIGNED NOT NULL,
c6 SMALLINT UNSIGNED NOT NULL,
c7 INT UNSIGNED NOT NULL,
c1 TINYINT NOT NULL,
c2 SMALLINT NOT NULL,
c3 INT NOT NULL,
c4 BIGINT NOT NULL,
c5 TINYINT UNSIGNED NOT NULL,
c6 SMALLINT UNSIGNED NOT NULL,
c7 INT UNSIGNED NOT NULL,
c8 BIGINT UNSIGNED NOT NULL
);

Expand Down Expand Up @@ -363,7 +363,7 @@ CREATE TABLE test_nullable_float(
c2 double
) AS VALUES
(-1.0, -1.0),
(1.0, 1.0),
(1.0, 1.0),
(NULL, NULL),
(0., 0.),
('NaN'::double, 'NaN'::double);
Expand Down Expand Up @@ -412,14 +412,25 @@ Float32 Float64

# abs: floats
query RR rowsort
SELECT abs(c1), abs(c2) from test_nullable_float
SELECT abs(c1), abs(c2) from test_nullable_float
----
0 0
1 1
1 1
NULL NULL
NaN NaN

# f16
query TR rowsort
SELECT arrow_typeof(abs(arrow_cast(c1, 'Float16'))), abs(arrow_cast(c1, 'Float16'))
FROM test_nullable_float
----
Float16 0
Float16 1
Float16 1
Float16 NULL
Float16 NaN

statement ok
drop table test_nullable_float

Expand All @@ -428,7 +439,7 @@ statement ok
CREATE TABLE test_non_nullable_float(
c1 float NOT NULL,
c2 double NOT NULL
);
);

query I
INSERT INTO test_non_nullable_float VALUES
Expand Down Expand Up @@ -478,27 +489,27 @@ drop table test_non_nullable_float
statement ok
CREATE TABLE test_nullable_decimal(
c1 DECIMAL(10, 2), /* Decimal128 */
c2 DECIMAL(38, 10), /* Decimal128 with max precision */
c2 DECIMAL(38, 10), /* Decimal128 with max precision */
c3 DECIMAL(40, 2), /* Decimal256 */
c4 DECIMAL(76, 10) /* Decimal256 with max precision */
) AS VALUES
(0, 0, 0, 0),
c4 DECIMAL(76, 10) /* Decimal256 with max precision */
) AS VALUES
(0, 0, 0, 0),
(NULL, NULL, NULL, NULL);

query I
INSERT into test_nullable_decimal values
(
-99999999.99,
'-9999999999999999999999999999.9999999999',
'-99999999999999999999999999999999999999.99',
-99999999.99,
'-9999999999999999999999999999.9999999999',
'-99999999999999999999999999999999999999.99',
'-999999999999999999999999999999999999999999999999999999999999999999.9999999999'
),
),
(
99999999.99,
'9999999999999999999999999999.9999999999',
'99999999999999999999999999999999999999.99',
99999999.99,
'9999999999999999999999999999.9999999999',
'99999999999999999999999999999999999999.99',
'999999999999999999999999999999999999999999999999999999999999999999.9999999999'
)
)
----
2

Expand Down Expand Up @@ -533,9 +544,9 @@ SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;

# abs: return type
query TTTT
SELECT
arrow_typeof(abs(c1)),
arrow_typeof(abs(c2)),
SELECT
arrow_typeof(abs(c1)),
arrow_typeof(abs(c2)),
arrow_typeof(abs(c3)),
arrow_typeof(abs(c4))
FROM test_nullable_decimal limit 1
Expand All @@ -552,11 +563,11 @@ SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal
NULL NULL NULL NULL

statement ok
drop table test_nullable_decimal
drop table test_nullable_decimal


statement ok
CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL);
CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL);

query I
INSERT INTO test_non_nullable_decimal VALUES(1)
Expand All @@ -569,13 +580,13 @@ SELECT c1*0 FROM test_non_nullable_decimal
0

query error DataFusion error: Arrow error: Divide by zero error
SELECT c1/0 FROM test_non_nullable_decimal
SELECT c1/0 FROM test_non_nullable_decimal

query error DataFusion error: Arrow error: Divide by zero error
SELECT c1%0 FROM test_non_nullable_decimal
SELECT c1%0 FROM test_non_nullable_decimal

statement ok
drop table test_non_nullable_decimal
drop table test_non_nullable_decimal

statement ok
CREATE TABLE signed_integers(
Expand Down Expand Up @@ -615,7 +626,7 @@ NULL NULL NULL

# scalar maxes and/or negative 1
query III
select
select
gcd(9223372036854775807, -9223372036854775808), -- i64::MAX, i64::MIN
gcd(9223372036854775807, -1), -- i64::MAX, -1
gcd(-9223372036854775808, -1); -- i64::MIN, -1
Expand Down