@@ -20,9 +20,9 @@ extern crate datafusion_functions;
2020use crate :: function:: error_utils:: {
2121 invalid_arg_count_exec_err, unsupported_data_type_exec_err,
2222} ;
23- use crate :: function:: math:: hex:: spark_hex ;
23+ use crate :: function:: math:: hex:: spark_sha2_hex ;
2424use arrow:: array:: { ArrayRef , AsArray , StringArray } ;
25- use arrow:: datatypes:: { DataType , UInt32Type } ;
25+ use arrow:: datatypes:: { DataType , Int32Type } ;
2626use datafusion_common:: { exec_err, internal_datafusion_err, Result , ScalarValue } ;
2727use datafusion_expr:: Signature ;
2828use datafusion_expr:: { ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Volatility } ;
@@ -121,7 +121,7 @@ impl ScalarUDFImpl for SparkSha2 {
121121 ) ) ,
122122 } ?;
123123 let bit_length_type = if arg_types[ 1 ] . is_numeric ( ) {
124- Ok ( DataType :: UInt32 )
124+ Ok ( DataType :: Int32 )
125125 } else if arg_types[ 1 ] . is_null ( ) {
126126 Ok ( DataType :: Null )
127127 } else {
@@ -138,39 +138,24 @@ impl ScalarUDFImpl for SparkSha2 {
138138
139139pub fn sha2 ( args : [ ColumnarValue ; 2 ] ) -> Result < ColumnarValue > {
140140 match args {
141- [ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( expr_arg) ) , ColumnarValue :: Scalar ( ScalarValue :: UInt32 ( Some ( bit_length_arg) ) ) ] => {
142- match bit_length_arg {
143- 0 | 256 => sha256 ( & [ ColumnarValue :: from ( ScalarValue :: Utf8 ( expr_arg) ) ] ) ,
144- 224 => sha224 ( & [ ColumnarValue :: from ( ScalarValue :: Utf8 ( expr_arg) ) ] ) ,
145- 384 => sha384 ( & [ ColumnarValue :: from ( ScalarValue :: Utf8 ( expr_arg) ) ] ) ,
146- 512 => sha512 ( & [ ColumnarValue :: from ( ScalarValue :: Utf8 ( expr_arg) ) ] ) ,
147- _ => exec_err ! (
148- "sha2 function only supports 224, 256, 384, and 512 bit lengths."
149- ) ,
150- }
151- . map ( |hashed| spark_hex ( & [ hashed] ) . unwrap ( ) )
141+ [ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( expr_arg) ) , ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( bit_length_arg) ) ) ] => {
142+ compute_sha2 (
143+ bit_length_arg,
144+ & [ ColumnarValue :: from ( ScalarValue :: Utf8 ( expr_arg) ) ] ,
145+ )
152146 }
153- [ ColumnarValue :: Array ( expr_arg) , ColumnarValue :: Scalar ( ScalarValue :: UInt32 ( Some ( bit_length_arg) ) ) ] => {
154- match bit_length_arg {
155- 0 | 256 => sha256 ( & [ ColumnarValue :: from ( expr_arg) ] ) ,
156- 224 => sha224 ( & [ ColumnarValue :: from ( expr_arg) ] ) ,
157- 384 => sha384 ( & [ ColumnarValue :: from ( expr_arg) ] ) ,
158- 512 => sha512 ( & [ ColumnarValue :: from ( expr_arg) ] ) ,
159- _ => exec_err ! (
160- "sha2 function only supports 224, 256, 384, and 512 bit lengths."
161- ) ,
162- }
163- . map ( |hashed| spark_hex ( & [ hashed] ) . unwrap ( ) )
147+ [ ColumnarValue :: Array ( expr_arg) , ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( bit_length_arg) ) ) ] => {
148+ compute_sha2 ( bit_length_arg, & [ ColumnarValue :: from ( expr_arg) ] )
164149 }
165150 [ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( expr_arg) ) , ColumnarValue :: Array ( bit_length_arg) ] =>
166151 {
167152 let arr: StringArray = bit_length_arg
168- . as_primitive :: < UInt32Type > ( )
153+ . as_primitive :: < Int32Type > ( )
169154 . iter ( )
170155 . map ( |bit_length| {
171156 match sha2 ( [
172157 ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( expr_arg. clone ( ) ) ) ,
173- ColumnarValue :: Scalar ( ScalarValue :: UInt32 ( bit_length) ) ,
158+ ColumnarValue :: Scalar ( ScalarValue :: Int32 ( bit_length) ) ,
174159 ] )
175160 . unwrap ( )
176161 {
@@ -188,15 +173,15 @@ pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
188173 }
189174 [ ColumnarValue :: Array ( expr_arg) , ColumnarValue :: Array ( bit_length_arg) ] => {
190175 let expr_iter = expr_arg. as_string :: < i32 > ( ) . iter ( ) ;
191- let bit_length_iter = bit_length_arg. as_primitive :: < UInt32Type > ( ) . iter ( ) ;
176+ let bit_length_iter = bit_length_arg. as_primitive :: < Int32Type > ( ) . iter ( ) ;
192177 let arr: StringArray = expr_iter
193178 . zip ( bit_length_iter)
194179 . map ( |( expr, bit_length) | {
195180 match sha2 ( [
196181 ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some (
197182 expr. unwrap ( ) . to_string ( ) ,
198183 ) ) ) ,
199- ColumnarValue :: Scalar ( ScalarValue :: UInt32 ( bit_length) ) ,
184+ ColumnarValue :: Scalar ( ScalarValue :: Int32 ( bit_length) ) ,
200185 ] )
201186 . unwrap ( )
202187 {
@@ -215,3 +200,21 @@ pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
215200 _ => exec_err ! ( "Unsupported argument types for sha2 function" ) ,
216201 }
217202}
203+
204+ fn compute_sha2 (
205+ bit_length_arg : i32 ,
206+ expr_arg : & [ ColumnarValue ] ,
207+ ) -> Result < ColumnarValue > {
208+ match bit_length_arg {
209+ 0 | 256 => sha256 ( expr_arg) ,
210+ 224 => sha224 ( expr_arg) ,
211+ 384 => sha384 ( expr_arg) ,
212+ 512 => sha512 ( expr_arg) ,
213+ _ => {
214+ // Return null for unsupported bit lengths instead of error, because spark sha2 does not
215+ // error out for this.
216+ return Ok ( ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) ) ;
217+ }
218+ }
219+ . map ( |hashed| spark_sha2_hex ( & [ hashed] ) . unwrap ( ) )
220+ }
0 commit comments