@@ -269,14 +269,14 @@ impl CaseExpr {
269269 )
270270 } ;
271271
272- let ( mut remainder_batch , mut base_value, mut current_value) = initial_values;
272+ let ( mut remainder , mut base_value, mut current_value) = initial_values;
273273
274274 for i in 0 ..self . when_then_expr . len ( ) {
275275 // Evaluate the 'when' predicate for the remainder batch
276276 // This results in a boolean array with the same length as the remaining number of rows
277277 let when_expression = & self . when_then_expr [ i] . 0 ;
278- let when_value = when_expression. evaluate ( & remainder_batch ) ?;
279- let when_value = when_value. to_array ( remainder_batch . num_rows ( ) ) ?;
278+ let when_value = when_expression. evaluate ( & remainder ) ?;
279+ let when_value = when_value. to_array ( remainder . num_rows ( ) ) ?;
280280 let when_value = compare_with_eq (
281281 & base_value,
282282 & when_value,
@@ -294,43 +294,45 @@ impl CaseExpr {
294294 continue ;
295295 }
296296
297- // Make sure 'NULL' is treated as false
298- let when_value = match when_value. null_count ( ) {
299- 0 => Cow :: Borrowed ( when_value) ,
300- _ => Cow :: Owned ( prep_null_mask_filter ( when_value) ) ,
301- } ;
302-
303- // Filter the remainder batch based on the 'when' value
304- // This results in a batch containing only the rows that need to be evaluated
305- // for the current branch
306- let then_batch = filter_record_batch ( & remainder_batch, & when_value) ?;
307-
308297 // Evaluate the then expression for the matching rows
309298 let then_expression = & self . when_then_expr [ i] . 1 ;
310- let then_value = then_expression. evaluate ( & then_batch) ?;
311299
312- // Expand the 'when' match array using the 'remainder scatter' array
313- // This results in a truthy boolean array than we can use to merge the
314- // 'then' values with the `current_value` array.
315- let then_merge = scatter ( & remainder_scatter, when_value. as_ref ( ) ) ?;
316- let then_merge = then_merge. as_boolean ( ) ;
300+ if when_match_count == remainder. num_rows ( ) {
301+ // If the 'when' predicate matched all remaining row, there's nothing left to do so
302+ // we can return early
303+ let then_value = then_expression. evaluate ( remainder. as_ref ( ) ) ?;
304+ current_value =
305+ merge_result ( & current_value, then_value, & remainder_scatter) ?;
306+ return Ok ( ColumnarValue :: Array ( current_value) ) ;
307+ } else {
308+ // Make sure 'NULL' is treated as false
309+ let when_value = match when_value. null_count ( ) {
310+ 0 => Cow :: Borrowed ( when_value) ,
311+ _ => Cow :: Owned ( prep_null_mask_filter ( when_value) ) ,
312+ } ;
313+
314+ // Filter the remainder batch based on the 'when' value
315+ // This results in a batch containing only the rows that need to be evaluated
316+ // for the current branch
317+ let then_batch = filter_record_batch ( & remainder, & when_value) ?;
318+ let then_value = then_expression. evaluate ( & then_batch) ?;
317319
318- // Merge the 'then' values with the `current_value` array
319- current_value = merge_result ( & current_value, then_value, then_merge) ?;
320+ // Expand the 'when' match array using the 'remainder scatter' array
321+ // This results in a truthy boolean array than we can use to merge the
322+ // 'then' values with the `current_value` array.
323+ let then_merge = scatter ( & remainder_scatter, when_value. as_ref ( ) ) ?;
324+ let then_merge = then_merge. as_boolean ( ) ;
320325
321- // If the 'when' predicate matched all remaining row, there's nothing left to do so
322- // we can return early
323- if remainder_batch. num_rows ( ) == when_match_count {
324- return Ok ( ColumnarValue :: Array ( current_value) ) ;
325- }
326+ // Merge the 'then' values with the `current_value` array
327+ current_value = merge_result ( & current_value, then_value, then_merge) ?;
326328
327- // Clear the positions in 'remainder scatter' for which we just evaluated a value
328- remainder_scatter = and_not ( & remainder_scatter, then_merge) ?;
329+ // Clear the positions in 'remainder scatter' for which we just evaluated a value
330+ remainder_scatter = and_not ( & remainder_scatter, then_merge) ?;
331+ } ;
329332
330333 // Finally, prepare the remainder batch for the next branch
331- let next_selection = not ( & when_value) ?;
332- remainder_batch =
333- Cow :: Owned ( filter_record_batch ( & remainder_batch, & next_selection) ?) ;
334+ let next_selection = not ( when_value) ?;
335+ remainder = Cow :: Owned ( filter_record_batch ( & remainder, & next_selection) ?) ;
334336 base_value = filter ( base_value. as_ref ( ) , & next_selection) ?;
335337 }
336338
@@ -339,7 +341,7 @@ impl CaseExpr {
339341 if let Some ( e) = self . else_expr ( ) {
340342 // keep `else_expr`'s data type and return type consistent
341343 let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) ) ?;
342- let else_value = expr. evaluate ( & remainder_batch ) ?;
344+ let else_value = expr. evaluate ( & remainder ) ?;
343345 current_value = merge_result ( & current_value, else_value, & remainder_scatter) ?;
344346 }
345347
@@ -360,14 +362,14 @@ impl CaseExpr {
360362 let mut current_value = new_null_array ( & return_type, batch. num_rows ( ) ) ;
361363
362364 let mut remainder_scatter = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
363- let mut remainder_batch = Cow :: Borrowed ( batch) ;
365+ let mut remainder = Cow :: Borrowed ( batch) ;
364366
365367 for i in 0 ..self . when_then_expr . len ( ) {
366368 // Evaluate the 'when' predicate for the remainder batch
367369 // This results in a boolean array with the same length as the remaining number of rows
368370 let when_predicate = & self . when_then_expr [ i] . 0 ;
369- let when_value = when_predicate. evaluate ( & remainder_batch ) ?;
370- let when_value = when_value. to_array ( remainder_batch . num_rows ( ) ) ?;
371+ let when_value = when_predicate. evaluate ( & remainder ) ?;
372+ let when_value = when_value. to_array ( remainder . num_rows ( ) ) ?;
371373 let when_value = as_boolean_array ( & when_value) . map_err ( |_| {
372374 internal_datafusion_err ! ( "WHEN expression did not return a BooleanArray" )
373375 } ) ?;
@@ -378,51 +380,53 @@ impl CaseExpr {
378380 continue ;
379381 }
380382
381- // Make sure 'NULL' is treated as false
382- let when_value = match when_value. null_count ( ) {
383- 0 => Cow :: Borrowed ( when_value) ,
384- _ => Cow :: Owned ( prep_null_mask_filter ( when_value) ) ,
385- } ;
386-
387- // Filter the remainder batch based on the 'when' value
388- // This results in a batch containing only the rows that need to be evaluated
389- // for the current branch
390- let then_batch = filter_record_batch ( & remainder_batch, & when_value) ?;
391-
392383 // Evaluate the then expression for the matching rows
393384 let then_expression = & self . when_then_expr [ i] . 1 ;
394- let then_value = then_expression. evaluate ( & then_batch) ?;
395385
396- // Expand the 'when' match array using the 'remainder scatter' array
397- // This results in a truthy boolean array than we can use to merge the
398- // 'then' values with the `current_value` array.
399- let then_merge = scatter ( & remainder_scatter, when_value. as_ref ( ) ) ?;
400- let then_merge = then_merge. as_boolean ( ) ;
386+ if when_match_count == remainder. num_rows ( ) {
387+ // If the 'when' predicate matched all remaining row, there's nothing left to do so
388+ // we can return early
389+ let then_value = then_expression. evaluate ( remainder. as_ref ( ) ) ?;
390+ current_value =
391+ merge_result ( & current_value, then_value, & remainder_scatter) ?;
392+ return Ok ( ColumnarValue :: Array ( current_value) ) ;
393+ } else {
394+ // Make sure 'NULL' is treated as false
395+ let when_value = match when_value. null_count ( ) {
396+ 0 => Cow :: Borrowed ( when_value) ,
397+ _ => Cow :: Owned ( prep_null_mask_filter ( when_value) ) ,
398+ } ;
399+
400+ // Filter the remainder batch based on the 'when' value
401+ // This results in a batch containing only the rows that need to be evaluated
402+ // for the current branch
403+ let then_batch = filter_record_batch ( & remainder, & when_value) ?;
404+ let then_value = then_expression. evaluate ( & then_batch) ?;
401405
402- // Merge the 'then' values with the `current_value` array
403- current_value = merge_result ( & current_value, then_value, then_merge) ?;
406+ // Expand the 'when' match array using the 'remainder scatter' array
407+ // This results in a truthy boolean array than we can use to merge the
408+ // 'then' values with the `current_value` array.
409+ let then_merge = scatter ( & remainder_scatter, when_value. as_ref ( ) ) ?;
410+ let then_merge = then_merge. as_boolean ( ) ;
404411
405- // If the 'when' predicate matched all remaining row, there's nothing left to do so
406- // we can return early
407- if remainder_batch. num_rows ( ) == when_match_count {
408- return Ok ( ColumnarValue :: Array ( current_value) ) ;
409- }
412+ // Merge the 'then' values with the `current_value` array
413+ current_value = merge_result ( & current_value, then_value, then_merge) ?;
410414
411- // Clear the positions in 'remainder scatter' for which we just evaluated a value
412- remainder_scatter = and_not ( & remainder_scatter, then_merge) ?;
415+ // Clear the positions in 'remainder scatter' for which we just evaluated a value
416+ remainder_scatter = and_not ( & remainder_scatter, then_merge) ?;
417+ } ;
413418
414419 // Finally, prepare the remainder batch for the next branch
415- let next_selection = not ( & when_value) ?;
416- remainder_batch =
417- Cow :: Owned ( filter_record_batch ( & remainder_batch, & next_selection) ?) ;
420+ let next_selection = not ( when_value) ?;
421+ remainder = Cow :: Owned ( filter_record_batch ( & remainder, & next_selection) ?) ;
418422 }
419423
420424 // If we reached this point, some rows were left unmatched.
421425 // Check if those need to be evaluated using the 'else' expression.
422426 if let Some ( e) = self . else_expr ( ) {
423427 // keep `else_expr`'s data type and return type consistent
424428 let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) ) ?;
425- let else_value = expr. evaluate ( & remainder_batch ) ?;
429+ let else_value = expr. evaluate ( & remainder ) ?;
426430 current_value = merge_result ( & current_value, else_value, & remainder_scatter) ?;
427431 }
428432
0 commit comments