1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: Array ;
18+ use arrow:: array:: { Array , ArrayBuilder } ;
1919use arrow:: datatypes:: DataType ;
2020use datafusion_common:: { Result , ScalarValue } ;
2121use datafusion_expr:: {
2222 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
2323 Volatility ,
2424} ;
2525use datafusion_functions:: string:: concat:: ConcatFunc ;
26- use std:: { any:: Any , sync:: Arc } ;
26+ use std:: any:: Any ;
27+ use std:: sync:: Arc ;
2728
2829/// Spark-compatible `concat` expression
2930/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
@@ -97,44 +98,64 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
9798 ) ) ) ;
9899 }
99100
101+ // Step 1: Check for NULL mask in incoming args
102+ let null_mask = compute_null_mask ( & arg_values, number_rows) ?;
103+
104+ // If all scalars and any is NULL, return NULL immediately
105+ if null_mask. is_none ( ) {
106+ return Ok ( ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) ) ;
107+ }
108+
109+ // Step 2: Delegate to DataFusion's concat
110+ let concat_func = ConcatFunc :: new ( ) ;
111+ let func_args = ScalarFunctionArgs {
112+ args : arg_values,
113+ arg_fields,
114+ number_rows,
115+ return_field,
116+ config_options,
117+ } ;
118+ let result = concat_func. invoke_with_args ( func_args) ?;
119+
120+ // Step 3: Apply NULL mask to result
121+ apply_null_mask ( result, null_mask)
122+ }
123+
124+ /// Compute NULL mask for the arguments
125+ /// Returns None if all scalars and any is NULL, or a Vec<bool> for arrays
126+ fn compute_null_mask (
127+ args : & [ ColumnarValue ] ,
128+ number_rows : usize ,
129+ ) -> Result < Option < Vec < bool > > > {
100130 // Check if all arguments are scalars
101- let all_scalars = arg_values
131+ let all_scalars = args
102132 . iter ( )
103133 . all ( |arg| matches ! ( arg, ColumnarValue :: Scalar ( _) ) ) ;
104134
105135 if all_scalars {
106136 // For scalars, check if any is NULL
107- for arg in & arg_values {
137+ for arg in args {
108138 if let ColumnarValue :: Scalar ( scalar) = arg {
109139 if scalar. is_null ( ) {
110- // Return NULL if any argument is NULL (Spark behavior)
111- return Ok ( ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) ) ;
140+ // Return None to indicate all values should be NULL
141+ return Ok ( None ) ;
112142 }
113143 }
114144 }
115- // No NULLs found, delegate to DataFusion's concat
116- let concat_func = ConcatFunc :: new ( ) ;
117- let func_args = ScalarFunctionArgs {
118- args : arg_values,
119- arg_fields,
120- number_rows,
121- return_field,
122- config_options,
123- } ;
124- concat_func. invoke_with_args ( func_args)
145+ // No NULLs in scalars
146+ Ok ( Some ( vec ! [ ] ) )
125147 } else {
126- // For arrays, we need to check each row for NULLs and return NULL for that row
127- // Get array length
128- let array_len = arg_values
148+ // For arrays, compute NULL mask for each row
149+ let array_len = args
129150 . iter ( )
130151 . find_map ( |arg| match arg {
131152 ColumnarValue :: Array ( array) => Some ( array. len ( ) ) ,
132153 _ => None ,
133154 } )
134155 . unwrap_or ( number_rows) ;
135156
136- // Convert all scalars to arrays
137- let arrays: Result < Vec < _ > > = arg_values
157+ // Convert all scalars to arrays for uniform processing
158+ let arrays: Result < Vec < _ > > = args
138159 . iter ( )
139160 . map ( |arg| match arg {
140161 ColumnarValue :: Array ( array) => Ok ( Arc :: clone ( array) ) ,
@@ -143,7 +164,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
143164 . collect ( ) ;
144165 let arrays = arrays?;
145166
146- // Check for NULL values in each row
167+ // Compute NULL mask
147168 let mut null_mask = vec ! [ false ; array_len] ;
148169 for array in & arrays {
149170 for ( i, null_flag) in null_mask. iter_mut ( ) . enumerate ( ) . take ( array_len) {
@@ -153,86 +174,90 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
153174 }
154175 }
155176
156- // Delegate to DataFusion's concat
157- let concat_func = ConcatFunc :: new ( ) ;
158- let func_args = ScalarFunctionArgs {
159- args : arg_values,
160- arg_fields,
161- number_rows,
162- return_field,
163- config_options,
164- } ;
177+ Ok ( Some ( null_mask) )
178+ }
179+ }
165180
166- let result = concat_func. invoke_with_args ( func_args) ?;
181+ /// Apply NULL mask to the result
182+ fn apply_null_mask (
183+ result : ColumnarValue ,
184+ null_mask : Option < Vec < bool > > ,
185+ ) -> Result < ColumnarValue > {
186+ match ( result, null_mask) {
187+ // Scalar with NULL mask means return NULL
188+ ( ColumnarValue :: Scalar ( _) , None ) => {
189+ Ok ( ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) )
190+ }
191+ // Scalar without NULL mask, return as-is
192+ ( scalar @ ColumnarValue :: Scalar ( _) , Some ( mask) ) if mask. is_empty ( ) => Ok ( scalar) ,
193+ // Array with NULL mask
194+ ( ColumnarValue :: Array ( array) , Some ( null_mask) ) if !null_mask. is_empty ( ) => {
195+ let array_len = array. len ( ) ;
196+ let return_type = array. data_type ( ) ;
167197
168- // Apply NULL mask to the result
169- match result {
170- ColumnarValue :: Array ( array) => {
171- let return_type = array. data_type ( ) ;
172- let mut builder: Box < dyn arrow:: array:: ArrayBuilder > = match return_type {
173- DataType :: Utf8 => {
174- let string_array = array
175- . as_any ( )
176- . downcast_ref :: < arrow:: array:: StringArray > ( )
177- . unwrap ( ) ;
178- let mut builder =
179- arrow:: array:: StringBuilder :: with_capacity ( array_len, 0 ) ;
180- for ( i, & is_null) in null_mask. iter ( ) . enumerate ( ) . take ( array_len)
181- {
182- if is_null || string_array. is_null ( i) {
183- builder. append_null ( ) ;
184- } else {
185- builder. append_value ( string_array. value ( i) ) ;
186- }
198+ let mut builder: Box < dyn ArrayBuilder > = match return_type {
199+ DataType :: Utf8 => {
200+ let string_array = array
201+ . as_any ( )
202+ . downcast_ref :: < arrow:: array:: StringArray > ( )
203+ . unwrap ( ) ;
204+ let mut builder =
205+ arrow:: array:: StringBuilder :: with_capacity ( array_len, 0 ) ;
206+ for ( i, & is_null) in null_mask. iter ( ) . enumerate ( ) . take ( array_len) {
207+ if is_null || string_array. is_null ( i) {
208+ builder. append_null ( ) ;
209+ } else {
210+ builder. append_value ( string_array. value ( i) ) ;
187211 }
188- Box :: new ( builder)
189212 }
190- DataType :: LargeUtf8 => {
191- let string_array = array
192- . as_any ( )
193- . downcast_ref :: < arrow :: array:: LargeStringArray > ( )
194- . unwrap ( ) ;
195- let mut builder =
196- arrow :: array :: LargeStringBuilder :: with_capacity ( array_len , 0 ) ;
197- for ( i , & is_null ) in null_mask . iter ( ) . enumerate ( ) . take ( array_len )
198- {
199- if is_null || string_array . is_null ( i ) {
200- builder . append_null ( ) ;
201- } else {
202- builder . append_value ( string_array . value ( i ) ) ;
203- }
213+ Box :: new ( builder )
214+ }
215+ DataType :: LargeUtf8 => {
216+ let string_array = array
217+ . as_any ( )
218+ . downcast_ref :: < arrow :: array :: LargeStringArray > ( )
219+ . unwrap ( ) ;
220+ let mut builder =
221+ arrow :: array :: LargeStringBuilder :: with_capacity ( array_len , 0 ) ;
222+ for ( i , & is_null) in null_mask . iter ( ) . enumerate ( ) . take ( array_len ) {
223+ if is_null || string_array . is_null ( i ) {
224+ builder . append_null ( ) ;
225+ } else {
226+ builder . append_value ( string_array . value ( i ) ) ;
204227 }
205- Box :: new ( builder)
206228 }
207- DataType :: Utf8View => {
208- let string_array = array
209- . as_any ( )
210- . downcast_ref :: < arrow :: array:: StringViewArray > ( )
211- . unwrap ( ) ;
212- let mut builder =
213- arrow :: array :: StringViewBuilder :: with_capacity ( array_len ) ;
214- for ( i , & is_null ) in null_mask . iter ( ) . enumerate ( ) . take ( array_len )
215- {
216- if is_null || string_array . is_null ( i ) {
217- builder . append_null ( ) ;
218- } else {
219- builder . append_value ( string_array . value ( i ) ) ;
220- }
229+ Box :: new ( builder )
230+ }
231+ DataType :: Utf8View => {
232+ let string_array = array
233+ . as_any ( )
234+ . downcast_ref :: < arrow :: array :: StringViewArray > ( )
235+ . unwrap ( ) ;
236+ let mut builder =
237+ arrow :: array :: StringViewBuilder :: with_capacity ( array_len ) ;
238+ for ( i , & is_null) in null_mask . iter ( ) . enumerate ( ) . take ( array_len ) {
239+ if is_null || string_array . is_null ( i ) {
240+ builder . append_null ( ) ;
241+ } else {
242+ builder . append_value ( string_array . value ( i ) ) ;
221243 }
222- Box :: new ( builder)
223- }
224- _ => {
225- return datafusion_common:: exec_err!(
226- "Unsupported return type for concat: {:?}" ,
227- return_type
228- ) ;
229244 }
230- } ;
245+ Box :: new ( builder)
246+ }
247+ _ => {
248+ return datafusion_common:: exec_err!(
249+ "Unsupported return type for concat: {:?}" ,
250+ return_type
251+ ) ;
252+ }
253+ } ;
231254
232- Ok ( ColumnarValue :: Array ( builder. finish ( ) ) )
233- }
234- other => Ok ( other) ,
255+ Ok ( ColumnarValue :: Array ( builder. finish ( ) ) )
235256 }
257+ // Array without NULL mask, return as-is
258+ ( array @ ColumnarValue :: Array ( _) , _) => Ok ( array) ,
259+ // Shouldn't happen
260+ ( scalar, _) => Ok ( scalar) ,
236261 }
237262}
238263
@@ -243,7 +268,6 @@ mod tests {
243268 use arrow:: array:: StringArray ;
244269 use arrow:: datatypes:: DataType ;
245270 use datafusion_common:: Result ;
246- use datafusion_expr:: ColumnarValue ;
247271
248272 #[ test]
249273 fn test_concat_basic ( ) -> Result < ( ) > {
0 commit comments