Skip to content

Commit 410d292

Browse files
authored
fix: support float16 for abs() (#18304)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> N/A ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Cover missing f16 type for `abs` ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Support `abs` on f16; also do some cleanup. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added SLT. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 6ee019e commit 410d292

File tree

6 files changed

+62
-72
lines changed

6 files changed

+62
-72
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ indexmap = "2.12.0"
159159
insta = { version = "1.43.2", features = ["glob", "filters"] }
160160
itertools = "0.14"
161161
log = "^0.4"
162+
num-traits = { version = "0.2" }
162163
object_store = { version = "0.12.4", default-features = false }
163164
parking_lot = "0.12"
164165
parquet = { version = "57.0.0", default-features = false, features = [

datafusion/datasource-avro/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ datafusion-physical-expr-common = { workspace = true }
4141
datafusion-physical-plan = { workspace = true }
4242
datafusion-session = { workspace = true }
4343
futures = { workspace = true }
44-
num-traits = { version = "0.2" }
44+
num-traits = { workspace = true }
4545
object_store = { workspace = true }
4646

4747
[dev-dependencies]

datafusion/functions/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ hex = { version = "0.4", optional = true }
7878
itertools = { workspace = true }
7979
log = { workspace = true }
8080
md-5 = { version = "^0.10.0", optional = true }
81+
num-traits = { workspace = true }
8182
rand = { workspace = true }
8283
regex = { workspace = true, optional = true }
8384
sha2 = { version = "^0.10.9", optional = true }

datafusion/functions/src/math/abs.rs

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ use std::sync::Arc;
2222

2323
use arrow::array::{
2424
ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
25-
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
25+
Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
26+
Int8Array,
2627
};
2728
use arrow::datatypes::DataType;
2829
use arrow::error::ArrowError;
@@ -34,6 +35,7 @@ use datafusion_expr::{
3435
Volatility,
3536
};
3637
use datafusion_macros::user_doc;
38+
use num_traits::sign::Signed;
3739

3840
type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;
3941

@@ -81,6 +83,7 @@ macro_rules! make_decimal_abs_function {
8183
/// Return different implementations based on input datatype to reduce branches during execution
8284
fn create_abs_function(input_data_type: &DataType) -> Result<MathArrayFunction> {
8385
match input_data_type {
86+
DataType::Float16 => Ok(make_abs_function!(Float16Array)),
8487
DataType::Float32 => Ok(make_abs_function!(Float32Array)),
8588
DataType::Float64 => Ok(make_abs_function!(Float64Array)),
8689

@@ -143,6 +146,7 @@ impl ScalarUDFImpl for AbsFunc {
143146
fn as_any(&self) -> &dyn Any {
144147
self
145148
}
149+
146150
fn name(&self) -> &str {
147151
"abs"
148152
}
@@ -152,35 +156,7 @@ impl ScalarUDFImpl for AbsFunc {
152156
}
153157

154158
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
155-
match arg_types[0] {
156-
DataType::Float32 => Ok(DataType::Float32),
157-
DataType::Float64 => Ok(DataType::Float64),
158-
DataType::Int8 => Ok(DataType::Int8),
159-
DataType::Int16 => Ok(DataType::Int16),
160-
DataType::Int32 => Ok(DataType::Int32),
161-
DataType::Int64 => Ok(DataType::Int64),
162-
DataType::Null => Ok(DataType::Null),
163-
DataType::UInt8 => Ok(DataType::UInt8),
164-
DataType::UInt16 => Ok(DataType::UInt16),
165-
DataType::UInt32 => Ok(DataType::UInt32),
166-
DataType::UInt64 => Ok(DataType::UInt64),
167-
DataType::Decimal32(precision, scale) => {
168-
Ok(DataType::Decimal32(precision, scale))
169-
}
170-
DataType::Decimal64(precision, scale) => {
171-
Ok(DataType::Decimal64(precision, scale))
172-
}
173-
DataType::Decimal128(precision, scale) => {
174-
Ok(DataType::Decimal128(precision, scale))
175-
}
176-
DataType::Decimal256(precision, scale) => {
177-
Ok(DataType::Decimal256(precision, scale))
178-
}
179-
_ => not_impl_err!(
180-
"Unsupported data type {} for function abs",
181-
arg_types[0].to_string()
182-
),
183-
}
159+
Ok(arg_types[0].clone())
184160
}
185161

186162
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {

datafusion/sqllogictest/test_files/math.slt

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,16 @@ select abs(arrow_cast('-1.2', 'Utf8'));
139139

140140
statement ok
141141
CREATE TABLE test_nullable_integer(
142-
c1 TINYINT,
143-
c2 SMALLINT,
144-
c3 INT,
145-
c4 BIGINT,
146-
c5 TINYINT UNSIGNED,
147-
c6 SMALLINT UNSIGNED,
148-
c7 INT UNSIGNED,
149-
c8 BIGINT UNSIGNED,
142+
c1 TINYINT,
143+
c2 SMALLINT,
144+
c3 INT,
145+
c4 BIGINT,
146+
c5 TINYINT UNSIGNED,
147+
c6 SMALLINT UNSIGNED,
148+
c7 INT UNSIGNED,
149+
c8 BIGINT UNSIGNED,
150150
dataset TEXT
151-
)
151+
)
152152
AS VALUES
153153
(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'nulls'),
154154
(0, 0, 0, 0, 0, 0, 0, 0, 'zeros'),
@@ -237,7 +237,7 @@ SELECT c8%0 FROM test_nullable_integer
237237

238238
# abs: return type
239239
query TTTTTTTT rowsort
240-
select
240+
select
241241
arrow_typeof(abs(c1)), arrow_typeof(abs(c2)), arrow_typeof(abs(c3)), arrow_typeof(abs(c4)),
242242
arrow_typeof(abs(c5)), arrow_typeof(abs(c6)), arrow_typeof(abs(c7)), arrow_typeof(abs(c8))
243243
from test_nullable_integer limit 1
@@ -285,13 +285,13 @@ drop table test_nullable_integer
285285

286286
statement ok
287287
CREATE TABLE test_non_nullable_integer(
288-
c1 TINYINT NOT NULL,
289-
c2 SMALLINT NOT NULL,
290-
c3 INT NOT NULL,
291-
c4 BIGINT NOT NULL,
292-
c5 TINYINT UNSIGNED NOT NULL,
293-
c6 SMALLINT UNSIGNED NOT NULL,
294-
c7 INT UNSIGNED NOT NULL,
288+
c1 TINYINT NOT NULL,
289+
c2 SMALLINT NOT NULL,
290+
c3 INT NOT NULL,
291+
c4 BIGINT NOT NULL,
292+
c5 TINYINT UNSIGNED NOT NULL,
293+
c6 SMALLINT UNSIGNED NOT NULL,
294+
c7 INT UNSIGNED NOT NULL,
295295
c8 BIGINT UNSIGNED NOT NULL
296296
);
297297

@@ -363,7 +363,7 @@ CREATE TABLE test_nullable_float(
363363
c2 double
364364
) AS VALUES
365365
(-1.0, -1.0),
366-
(1.0, 1.0),
366+
(1.0, 1.0),
367367
(NULL, NULL),
368368
(0., 0.),
369369
('NaN'::double, 'NaN'::double);
@@ -412,14 +412,25 @@ Float32 Float64
412412

413413
# abs: floats
414414
query RR rowsort
415-
SELECT abs(c1), abs(c2) from test_nullable_float
415+
SELECT abs(c1), abs(c2) from test_nullable_float
416416
----
417417
0 0
418418
1 1
419419
1 1
420420
NULL NULL
421421
NaN NaN
422422

423+
# f16
424+
query TR rowsort
425+
SELECT arrow_typeof(abs(arrow_cast(c1, 'Float16'))), abs(arrow_cast(c1, 'Float16'))
426+
FROM test_nullable_float
427+
----
428+
Float16 0
429+
Float16 1
430+
Float16 1
431+
Float16 NULL
432+
Float16 NaN
433+
423434
statement ok
424435
drop table test_nullable_float
425436

@@ -428,7 +439,7 @@ statement ok
428439
CREATE TABLE test_non_nullable_float(
429440
c1 float NOT NULL,
430441
c2 double NOT NULL
431-
);
442+
);
432443

433444
query I
434445
INSERT INTO test_non_nullable_float VALUES
@@ -478,27 +489,27 @@ drop table test_non_nullable_float
478489
statement ok
479490
CREATE TABLE test_nullable_decimal(
480491
c1 DECIMAL(10, 2), /* Decimal128 */
481-
c2 DECIMAL(38, 10), /* Decimal128 with max precision */
492+
c2 DECIMAL(38, 10), /* Decimal128 with max precision */
482493
c3 DECIMAL(40, 2), /* Decimal256 */
483-
c4 DECIMAL(76, 10) /* Decimal256 with max precision */
484-
) AS VALUES
485-
(0, 0, 0, 0),
494+
c4 DECIMAL(76, 10) /* Decimal256 with max precision */
495+
) AS VALUES
496+
(0, 0, 0, 0),
486497
(NULL, NULL, NULL, NULL);
487498

488499
query I
489500
INSERT into test_nullable_decimal values
490501
(
491-
-99999999.99,
492-
'-9999999999999999999999999999.9999999999',
493-
'-99999999999999999999999999999999999999.99',
502+
-99999999.99,
503+
'-9999999999999999999999999999.9999999999',
504+
'-99999999999999999999999999999999999999.99',
494505
'-999999999999999999999999999999999999999999999999999999999999999999.9999999999'
495-
),
506+
),
496507
(
497-
99999999.99,
498-
'9999999999999999999999999999.9999999999',
499-
'99999999999999999999999999999999999999.99',
508+
99999999.99,
509+
'9999999999999999999999999999.9999999999',
510+
'99999999999999999999999999999999999999.99',
500511
'999999999999999999999999999999999999999999999999999999999999999999.9999999999'
501-
)
512+
)
502513
----
503514
2
504515

@@ -533,9 +544,9 @@ SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;
533544

534545
# abs: return type
535546
query TTTT
536-
SELECT
537-
arrow_typeof(abs(c1)),
538-
arrow_typeof(abs(c2)),
547+
SELECT
548+
arrow_typeof(abs(c1)),
549+
arrow_typeof(abs(c2)),
539550
arrow_typeof(abs(c3)),
540551
arrow_typeof(abs(c4))
541552
FROM test_nullable_decimal limit 1
@@ -552,11 +563,11 @@ SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal
552563
NULL NULL NULL NULL
553564

554565
statement ok
555-
drop table test_nullable_decimal
566+
drop table test_nullable_decimal
556567

557568

558569
statement ok
559-
CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL);
570+
CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL);
560571

561572
query I
562573
INSERT INTO test_non_nullable_decimal VALUES(1)
@@ -569,13 +580,13 @@ SELECT c1*0 FROM test_non_nullable_decimal
569580
0
570581

571582
query error DataFusion error: Arrow error: Divide by zero error
572-
SELECT c1/0 FROM test_non_nullable_decimal
583+
SELECT c1/0 FROM test_non_nullable_decimal
573584

574585
query error DataFusion error: Arrow error: Divide by zero error
575-
SELECT c1%0 FROM test_non_nullable_decimal
586+
SELECT c1%0 FROM test_non_nullable_decimal
576587

577588
statement ok
578-
drop table test_non_nullable_decimal
589+
drop table test_non_nullable_decimal
579590

580591
statement ok
581592
CREATE TABLE signed_integers(
@@ -615,7 +626,7 @@ NULL NULL NULL
615626

616627
# scalar maxes and/or negative 1
617628
query III
618-
select
629+
select
619630
gcd(9223372036854775807, -9223372036854775808), -- i64::MAX, i64::MIN
620631
gcd(9223372036854775807, -1), -- i64::MAX, -1
621632
gcd(-9223372036854775808, -1); -- i64::MIN, -1

0 commit comments

Comments
 (0)