Skip to content

Commit 4dd6923

Browse files
authored
fix: Fix SparkSha2 to be compliant with Spark response and add support for Int32 (#16350)
* fix: Fix SparkSha2 to be compliant with Spark response and add support for Int32. * Fixed test cases. * Addressed comments. * Fixed missed test case. * Minor cosmetic changes.
1 parent 79f5c8d commit 4dd6923

File tree

3 files changed

+85
-55
lines changed

3 files changed

+85
-55
lines changed

datafusion/spark/src/function/hash/sha2.rs

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ extern crate datafusion_functions;
2020
use crate::function::error_utils::{
2121
invalid_arg_count_exec_err, unsupported_data_type_exec_err,
2222
};
23-
use crate::function::math::hex::spark_hex;
23+
use crate::function::math::hex::spark_sha2_hex;
2424
use arrow::array::{ArrayRef, AsArray, StringArray};
25-
use arrow::datatypes::{DataType, UInt32Type};
25+
use arrow::datatypes::{DataType, Int32Type};
2626
use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
2727
use datafusion_expr::Signature;
2828
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
@@ -121,7 +121,7 @@ impl ScalarUDFImpl for SparkSha2 {
121121
)),
122122
}?;
123123
let bit_length_type = if arg_types[1].is_numeric() {
124-
Ok(DataType::UInt32)
124+
Ok(DataType::Int32)
125125
} else if arg_types[1].is_null() {
126126
Ok(DataType::Null)
127127
} else {
@@ -138,39 +138,24 @@ impl ScalarUDFImpl for SparkSha2 {
138138

139139
pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
140140
match args {
141-
[ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => {
142-
match bit_length_arg {
143-
0 | 256 => sha256(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
144-
224 => sha224(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
145-
384 => sha384(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
146-
512 => sha512(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
147-
_ => exec_err!(
148-
"sha2 function only supports 224, 256, 384, and 512 bit lengths."
149-
),
150-
}
151-
.map(|hashed| spark_hex(&[hashed]).unwrap())
141+
[ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
142+
compute_sha2(
143+
bit_length_arg,
144+
&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
145+
)
152146
}
153-
[ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => {
154-
match bit_length_arg {
155-
0 | 256 => sha256(&[ColumnarValue::from(expr_arg)]),
156-
224 => sha224(&[ColumnarValue::from(expr_arg)]),
157-
384 => sha384(&[ColumnarValue::from(expr_arg)]),
158-
512 => sha512(&[ColumnarValue::from(expr_arg)]),
159-
_ => exec_err!(
160-
"sha2 function only supports 224, 256, 384, and 512 bit lengths."
161-
),
162-
}
163-
.map(|hashed| spark_hex(&[hashed]).unwrap())
147+
[ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
148+
compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)])
164149
}
165150
[ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] =>
166151
{
167152
let arr: StringArray = bit_length_arg
168-
.as_primitive::<UInt32Type>()
153+
.as_primitive::<Int32Type>()
169154
.iter()
170155
.map(|bit_length| {
171156
match sha2([
172157
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
173-
ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)),
158+
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
174159
])
175160
.unwrap()
176161
{
@@ -188,15 +173,15 @@ pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
188173
}
189174
[ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => {
190175
let expr_iter = expr_arg.as_string::<i32>().iter();
191-
let bit_length_iter = bit_length_arg.as_primitive::<UInt32Type>().iter();
176+
let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter();
192177
let arr: StringArray = expr_iter
193178
.zip(bit_length_iter)
194179
.map(|(expr, bit_length)| {
195180
match sha2([
196181
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
197182
expr.unwrap().to_string(),
198183
))),
199-
ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)),
184+
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
200185
])
201186
.unwrap()
202187
{
@@ -215,3 +200,21 @@ pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
215200
_ => exec_err!("Unsupported argument types for sha2 function"),
216201
}
217202
}
203+
204+
fn compute_sha2(
205+
bit_length_arg: i32,
206+
expr_arg: &[ColumnarValue],
207+
) -> Result<ColumnarValue> {
208+
match bit_length_arg {
209+
0 | 256 => sha256(expr_arg),
210+
224 => sha224(expr_arg),
211+
384 => sha384(expr_arg),
212+
512 => sha512(expr_arg),
213+
_ => {
214+
// Return null for unsupported bit lengths instead of error, because spark sha2 does not
215+
// error out for this.
216+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
217+
}
218+
}
219+
.map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
220+
}

datafusion/spark/src/function/math/hex.rs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,28 @@ fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
159159
}
160160

161161
#[inline(always)]
162-
fn hex_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<String, std::fmt::Error> {
163-
let hex_string = hex_encode(bytes, false);
162+
fn hex_bytes<T: AsRef<[u8]>>(
163+
bytes: T,
164+
lowercase: bool,
165+
) -> Result<String, std::fmt::Error> {
166+
let hex_string = hex_encode(bytes, lowercase);
164167
Ok(hex_string)
165168
}
166169

167170
/// Spark-compatible `hex` function
168171
pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
172+
compute_hex(args, false)
173+
}
174+
175+
/// Spark-compatible `sha2` function
176+
pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
177+
compute_hex(args, true)
178+
}
179+
180+
pub fn compute_hex(
181+
args: &[ColumnarValue],
182+
lowercase: bool,
183+
) -> Result<ColumnarValue, DataFusionError> {
169184
if args.len() != 1 {
170185
return Err(DataFusionError::Internal(
171186
"hex expects exactly one argument".to_string(),
@@ -192,7 +207,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
192207

193208
let hexed: StringArray = array
194209
.iter()
195-
.map(|v| v.map(hex_bytes).transpose())
210+
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
196211
.collect::<Result<_, _>>()?;
197212

198213
Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -202,7 +217,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
202217

203218
let hexed: StringArray = array
204219
.iter()
205-
.map(|v| v.map(hex_bytes).transpose())
220+
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
206221
.collect::<Result<_, _>>()?;
207222

208223
Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -212,7 +227,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
212227

213228
let hexed: StringArray = array
214229
.iter()
215-
.map(|v| v.map(hex_bytes).transpose())
230+
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
216231
.collect::<Result<_, _>>()?;
217232

218233
Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -222,7 +237,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
222237

223238
let hexed: StringArray = array
224239
.iter()
225-
.map(|v| v.map(hex_bytes).transpose())
240+
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
226241
.collect::<Result<_, _>>()?;
227242

228243
Ok(ColumnarValue::Array(Arc::new(hexed)))
@@ -237,11 +252,11 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
237252
.collect::<Vec<_>>(),
238253
DataType::Utf8 => as_string_array(dict.values())
239254
.iter()
240-
.map(|v| v.map(hex_bytes).transpose())
255+
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
241256
.collect::<Result<_, _>>()?,
242257
DataType::Binary => as_binary_array(dict.values())?
243258
.iter()
244-
.map(|v| v.map(hex_bytes).transpose())
259+
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
245260
.collect::<Result<_, _>>()?,
246261
_ => exec_err!(
247262
"hex got an unexpected argument type: {:?}",

datafusion/sqllogictest/test_files/spark/hash/sha2.slt

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,48 +18,60 @@
1818
query T
1919
SELECT sha2('Spark', 0::INT);
2020
----
21-
529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B
21+
529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b
2222

2323
query T
2424
SELECT sha2('Spark', 256::INT);
2525
----
26-
529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B
26+
529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b
2727

2828
query T
2929
SELECT sha2('Spark', 224::INT);
3030
----
31-
DBEAB94971678D36AF2195851C0F7485775A2A7C60073D62FC04549C
31+
dbeab94971678d36af2195851c0f7485775a2a7c60073d62fc04549c
3232

3333
query T
3434
SELECT sha2('Spark', 384::INT);
3535
----
36-
1E40B8D06C248A1CC32428C22582B6219D072283078FA140D9AD297ECADF2CABEFC341B857AD36226AA8D6D79F2AB67D
36+
1e40b8d06c248a1cc32428c22582b6219d072283078fa140d9ad297ecadf2cabefc341b857ad36226aa8d6d79f2ab67d
3737

3838
query T
3939
SELECT sha2('Spark', 512::INT);
4040
----
41-
44844A586C54C9A212DA1DBFE05C5F1705DE1AF5FDA1F0D36297623249B279FD8F0CCEC03F888F4FB13BF7CD83FDAD58591C797F81121A23CFDD5E0897795238
41+
44844a586c54c9a212da1dbfe05c5f1705de1af5fda1f0d36297623249b279fd8f0ccec03f888f4fb13bf7cd83fdad58591c797f81121a23cfdd5e0897795238
42+
43+
query T
44+
SELECT sha2('Spark', 128::INT);
45+
----
46+
NULL
4247

4348
query T
4449
SELECT sha2(expr, 256::INT) FROM VALUES ('foo'), ('bar') AS t(expr);
4550
----
46-
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
47-
FCDE2B2EDBA56BF408601FB721FE9B5C338D10EE429EA04FAE5511B68FBF8FB9
51+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
52+
fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9
4853

4954
query T
50-
SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), (384::INT), (512::INT) AS t(bit_length);
55+
SELECT sha2(expr, 128::INT) FROM VALUES ('foo'), ('bar') AS t(expr);
5156
----
52-
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
53-
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
54-
0808F64E60D58979FCB676C96EC938270DEA42445AEEFCD3A4E6F8DB
55-
98C11FFDFDD540676B1A137CB1A22B2A70350C9A44171D6B1180C6BE5CBB2EE3F79D532C8A1DD9EF2E8E08E752A3BABB
56-
F7FBBA6E0636F890E56FBBF3283E524C6FA3204AE298382D624741D0DC6638326E282C41BE5E4254D8820772C5518A2C5A8C0C7F7EDA19594A7EB539453E1ED7
57+
NULL
58+
NULL
5759

60+
query T
61+
SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), (384::INT), (512::INT), (128::INT) AS t(bit_length);
62+
----
63+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
64+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
65+
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
66+
98c11ffdfdd540676b1a137cb1a22b2a70350c9a44171d6b1180c6be5cbb2ee3f79d532c8a1dd9ef2e8e08e752a3babb
67+
f7fbba6e0636f890e56fbbf3283e524c6fa3204ae298382d624741d0dc6638326e282c41be5e4254d8820772c5518a2c5a8c0c7f7eda19594a7eb539453e1ed7
68+
NULL
5869

5970
query T
60-
SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('baz',384::INT), ('qux',512::INT) AS t(expr, bit_length);
71+
SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('baz',384::INT), ('qux',512::INT), ('qux',128::INT) AS t(expr, bit_length);
6172
----
62-
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
63-
07DAF010DE7F7F0D8D76A76EB8D1EB40182C8D1E7A3877A6686C9BF0
64-
967004D25DE4ABC1BD6A7C9A216254A5AC0733E8AD96DC9F1EA0FAD9619DA7C32D654EC8AD8BA2F9B5728FED6633BD91
65-
8C6BE9ED448A34883A13A13F4EAD4AEFA036B67DCDA59020C01E57EA075EA8A4792D428F2C6FD0C09D1C49994D6C22789336E062188DF29572ED07E7F9779C52
73+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
74+
07daf010de7f7f0d8d76a76eb8d1eb40182c8d1e7a3877a6686c9bf0
75+
967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91
76+
8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52
77+
NULL

0 commit comments

Comments
 (0)