Skip to content

Commit 24e3d38

Browse files
committed
More work reduction in case of early exit
1 parent 39ab973 commit 24e3d38

File tree

1 file changed

+70
-66
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+70
-66
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 70 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,14 @@ impl CaseExpr {
269269
)
270270
};
271271

272-
let (mut remainder_batch, mut base_value, mut current_value) = initial_values;
272+
let (mut remainder, mut base_value, mut current_value) = initial_values;
273273

274274
for i in 0..self.when_then_expr.len() {
275275
// Evaluate the 'when' predicate for the remainder batch
276276
// This results in a boolean array with the same length as the remaining number of rows
277277
let when_expression = &self.when_then_expr[i].0;
278-
let when_value = when_expression.evaluate(&remainder_batch)?;
279-
let when_value = when_value.to_array(remainder_batch.num_rows())?;
278+
let when_value = when_expression.evaluate(&remainder)?;
279+
let when_value = when_value.to_array(remainder.num_rows())?;
280280
let when_value = compare_with_eq(
281281
&base_value,
282282
&when_value,
@@ -294,43 +294,45 @@ impl CaseExpr {
294294
continue;
295295
}
296296

297-
// Make sure 'NULL' is treated as false
298-
let when_value = match when_value.null_count() {
299-
0 => Cow::Borrowed(when_value),
300-
_ => Cow::Owned(prep_null_mask_filter(when_value)),
301-
};
302-
303-
// Filter the remainder batch based on the 'when' value
304-
// This results in a batch containing only the rows that need to be evaluated
305-
// for the current branch
306-
let then_batch = filter_record_batch(&remainder_batch, &when_value)?;
307-
308297
// Evaluate the then expression for the matching rows
309298
let then_expression = &self.when_then_expr[i].1;
310-
let then_value = then_expression.evaluate(&then_batch)?;
311299

312-
// Expand the 'when' match array using the 'remainder scatter' array
313-
// This results in a truthy boolean array than we can use to merge the
314-
// 'then' values with the `current_value` array.
315-
let then_merge = scatter(&remainder_scatter, when_value.as_ref())?;
316-
let then_merge = then_merge.as_boolean();
300+
if when_match_count == remainder.num_rows() {
301+
// If the 'when' predicate matched all remaining row, there's nothing left to do so
302+
// we can return early
303+
let then_value = then_expression.evaluate(remainder.as_ref())?;
304+
current_value =
305+
merge_result(&current_value, then_value, &remainder_scatter)?;
306+
return Ok(ColumnarValue::Array(current_value));
307+
} else {
308+
// Make sure 'NULL' is treated as false
309+
let when_value = match when_value.null_count() {
310+
0 => Cow::Borrowed(when_value),
311+
_ => Cow::Owned(prep_null_mask_filter(when_value)),
312+
};
313+
314+
// Filter the remainder batch based on the 'when' value
315+
// This results in a batch containing only the rows that need to be evaluated
316+
// for the current branch
317+
let then_batch = filter_record_batch(&remainder, &when_value)?;
318+
let then_value = then_expression.evaluate(&then_batch)?;
317319

318-
// Merge the 'then' values with the `current_value` array
319-
current_value = merge_result(&current_value, then_value, then_merge)?;
320+
// Expand the 'when' match array using the 'remainder scatter' array
321+
// This results in a truthy boolean array than we can use to merge the
322+
// 'then' values with the `current_value` array.
323+
let then_merge = scatter(&remainder_scatter, when_value.as_ref())?;
324+
let then_merge = then_merge.as_boolean();
320325

321-
// If the 'when' predicate matched all remaining row, there's nothing left to do so
322-
// we can return early
323-
if remainder_batch.num_rows() == when_match_count {
324-
return Ok(ColumnarValue::Array(current_value));
325-
}
326+
// Merge the 'then' values with the `current_value` array
327+
current_value = merge_result(&current_value, then_value, then_merge)?;
326328

327-
// Clear the positions in 'remainder scatter' for which we just evaluated a value
328-
remainder_scatter = and_not(&remainder_scatter, then_merge)?;
329+
// Clear the positions in 'remainder scatter' for which we just evaluated a value
330+
remainder_scatter = and_not(&remainder_scatter, then_merge)?;
331+
};
329332

330333
// Finally, prepare the remainder batch for the next branch
331-
let next_selection = not(&when_value)?;
332-
remainder_batch =
333-
Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?);
334+
let next_selection = not(when_value)?;
335+
remainder = Cow::Owned(filter_record_batch(&remainder, &next_selection)?);
334336
base_value = filter(base_value.as_ref(), &next_selection)?;
335337
}
336338

@@ -339,7 +341,7 @@ impl CaseExpr {
339341
if let Some(e) = self.else_expr() {
340342
// keep `else_expr`'s data type and return type consistent
341343
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
342-
let else_value = expr.evaluate(&remainder_batch)?;
344+
let else_value = expr.evaluate(&remainder)?;
343345
current_value = merge_result(&current_value, else_value, &remainder_scatter)?;
344346
}
345347

@@ -360,14 +362,14 @@ impl CaseExpr {
360362
let mut current_value = new_null_array(&return_type, batch.num_rows());
361363

362364
let mut remainder_scatter = BooleanArray::from(vec![true; batch.num_rows()]);
363-
let mut remainder_batch = Cow::Borrowed(batch);
365+
let mut remainder = Cow::Borrowed(batch);
364366

365367
for i in 0..self.when_then_expr.len() {
366368
// Evaluate the 'when' predicate for the remainder batch
367369
// This results in a boolean array with the same length as the remaining number of rows
368370
let when_predicate = &self.when_then_expr[i].0;
369-
let when_value = when_predicate.evaluate(&remainder_batch)?;
370-
let when_value = when_value.to_array(remainder_batch.num_rows())?;
371+
let when_value = when_predicate.evaluate(&remainder)?;
372+
let when_value = when_value.to_array(remainder.num_rows())?;
371373
let when_value = as_boolean_array(&when_value).map_err(|_| {
372374
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
373375
})?;
@@ -378,51 +380,53 @@ impl CaseExpr {
378380
continue;
379381
}
380382

381-
// Make sure 'NULL' is treated as false
382-
let when_value = match when_value.null_count() {
383-
0 => Cow::Borrowed(when_value),
384-
_ => Cow::Owned(prep_null_mask_filter(when_value)),
385-
};
386-
387-
// Filter the remainder batch based on the 'when' value
388-
// This results in a batch containing only the rows that need to be evaluated
389-
// for the current branch
390-
let then_batch = filter_record_batch(&remainder_batch, &when_value)?;
391-
392383
// Evaluate the then expression for the matching rows
393384
let then_expression = &self.when_then_expr[i].1;
394-
let then_value = then_expression.evaluate(&then_batch)?;
395385

396-
// Expand the 'when' match array using the 'remainder scatter' array
397-
// This results in a truthy boolean array than we can use to merge the
398-
// 'then' values with the `current_value` array.
399-
let then_merge = scatter(&remainder_scatter, when_value.as_ref())?;
400-
let then_merge = then_merge.as_boolean();
386+
if when_match_count == remainder.num_rows() {
387+
// If the 'when' predicate matched all remaining row, there's nothing left to do so
388+
// we can return early
389+
let then_value = then_expression.evaluate(remainder.as_ref())?;
390+
current_value =
391+
merge_result(&current_value, then_value, &remainder_scatter)?;
392+
return Ok(ColumnarValue::Array(current_value));
393+
} else {
394+
// Make sure 'NULL' is treated as false
395+
let when_value = match when_value.null_count() {
396+
0 => Cow::Borrowed(when_value),
397+
_ => Cow::Owned(prep_null_mask_filter(when_value)),
398+
};
399+
400+
// Filter the remainder batch based on the 'when' value
401+
// This results in a batch containing only the rows that need to be evaluated
402+
// for the current branch
403+
let then_batch = filter_record_batch(&remainder, &when_value)?;
404+
let then_value = then_expression.evaluate(&then_batch)?;
401405

402-
// Merge the 'then' values with the `current_value` array
403-
current_value = merge_result(&current_value, then_value, then_merge)?;
406+
// Expand the 'when' match array using the 'remainder scatter' array
407+
// This results in a truthy boolean array than we can use to merge the
408+
// 'then' values with the `current_value` array.
409+
let then_merge = scatter(&remainder_scatter, when_value.as_ref())?;
410+
let then_merge = then_merge.as_boolean();
404411

405-
// If the 'when' predicate matched all remaining row, there's nothing left to do so
406-
// we can return early
407-
if remainder_batch.num_rows() == when_match_count {
408-
return Ok(ColumnarValue::Array(current_value));
409-
}
412+
// Merge the 'then' values with the `current_value` array
413+
current_value = merge_result(&current_value, then_value, then_merge)?;
410414

411-
// Clear the positions in 'remainder scatter' for which we just evaluated a value
412-
remainder_scatter = and_not(&remainder_scatter, then_merge)?;
415+
// Clear the positions in 'remainder scatter' for which we just evaluated a value
416+
remainder_scatter = and_not(&remainder_scatter, then_merge)?;
417+
};
413418

414419
// Finally, prepare the remainder batch for the next branch
415-
let next_selection = not(&when_value)?;
416-
remainder_batch =
417-
Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?);
420+
let next_selection = not(when_value)?;
421+
remainder = Cow::Owned(filter_record_batch(&remainder, &next_selection)?);
418422
}
419423

420424
// If we reached this point, some rows were left unmatched.
421425
// Check if those need to be evaluated using the 'else' expression.
422426
if let Some(e) = self.else_expr() {
423427
// keep `else_expr`'s data type and return type consistent
424428
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
425-
let else_value = expr.evaluate(&remainder_batch)?;
429+
let else_value = expr.evaluate(&remainder)?;
426430
current_value = merge_result(&current_value, else_value, &remainder_scatter)?;
427431
}
428432

0 commit comments

Comments
 (0)