Skip to content

Commit a1add0a

Browse files
committed
Short circuit case evaluation as soon as all rows have been evaluated
1 parent 057583d commit a1add0a

File tree

1 file changed

+44
-30
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+44
-30
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,15 @@ impl CaseExpr {
205205
let mut current_value = new_null_array(&return_type, batch.num_rows());
206206
// We only consider non-null values while comparing with whens
207207
let mut remainder = not(&base_nulls)?;
208+
let mut remainder_count = remainder.true_count();
208209
for i in 0..self.when_then_expr.len() {
209-
let when_value = self.when_then_expr[i]
210-
.0
211-
.evaluate_selection(batch, &remainder)?;
210+
// If there are no rows left to process, break out of the loop early
211+
if remainder_count == 0 {
212+
break;
213+
}
214+
215+
let when_predicate = &self.when_then_expr[i].0;
216+
let when_value = when_predicate.evaluate_selection(batch, &remainder)?;
212217
let when_value = when_value.into_array(batch.num_rows())?;
213218
// build boolean array representing which rows match the "when" value
214219
let when_match = compare_with_eq(
@@ -224,41 +229,44 @@ impl CaseExpr {
224229
_ => Cow::Owned(prep_null_mask_filter(&when_match)),
225230
};
226231
// Make sure we only consider rows that have not been matched yet
227-
let when_match = and(&when_match, &remainder)?;
232+
let when_value = and(&when_match, &remainder)?;
228233

229-
// When no rows available for when clause, skip then clause
230-
if when_match.true_count() == 0 {
234+
// If the predicate did not match any rows, continue to the next branch immediately
235+
let when_match_count = when_value.true_count();
236+
if when_match_count == 0 {
231237
continue;
232238
}
233239

234-
let then_value = self.when_then_expr[i]
235-
.1
236-
.evaluate_selection(batch, &when_match)?;
240+
let then_expression = &self.when_then_expr[i].1;
241+
let then_value = then_expression.evaluate_selection(batch, &when_value)?;
237242

238243
current_value = match then_value {
239244
ColumnarValue::Scalar(ScalarValue::Null) => {
240-
nullif(current_value.as_ref(), &when_match)?
245+
nullif(current_value.as_ref(), &when_value)?
241246
}
242247
ColumnarValue::Scalar(then_value) => {
243-
zip(&when_match, &then_value.to_scalar()?, &current_value)?
248+
zip(&when_value, &then_value.to_scalar()?, &current_value)?
244249
}
245250
ColumnarValue::Array(then_value) => {
246-
zip(&when_match, &then_value, &current_value)?
251+
zip(&when_value, &then_value, &current_value)?
247252
}
248253
};
249254

250-
remainder = and_not(&remainder, &when_match)?;
255+
remainder = and_not(&remainder, &when_value)?;
256+
remainder_count -= when_match_count;
251257
}
252258

253259
if let Some(e) = self.else_expr() {
254-
// keep `else_expr`'s data type and return type consistent
255-
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
256-
// null and unmatched tuples should be assigned else value
257-
remainder = or(&base_nulls, &remainder)?;
258-
let else_ = expr
259-
.evaluate_selection(batch, &remainder)?
260-
.into_array(batch.num_rows())?;
261-
current_value = zip(&remainder, &else_, &current_value)?;
260+
if remainder_count > 0 {
261+
// keep `else_expr`'s data type and return type consistent
262+
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
263+
// null and unmatched tuples should be assigned else value
264+
remainder = or(&base_nulls, &remainder)?;
265+
let else_ = expr
266+
.evaluate_selection(batch, &remainder)?
267+
.into_array(batch.num_rows())?;
268+
current_value = zip(&remainder, &else_, &current_value)?;
269+
}
262270
}
263271

264272
Ok(ColumnarValue::Array(current_value))
@@ -277,10 +285,15 @@ impl CaseExpr {
277285
// start with nulls as default output
278286
let mut current_value = new_null_array(&return_type, batch.num_rows());
279287
let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
288+
let mut remainder_count = batch.num_rows();
280289
for i in 0..self.when_then_expr.len() {
281-
let when_value = self.when_then_expr[i]
282-
.0
283-
.evaluate_selection(batch, &remainder)?;
290+
// If there are no rows left to process, break out of the loop early
291+
if remainder_count == 0 {
292+
break;
293+
}
294+
295+
let when_predicate = &self.when_then_expr[i].0;
296+
let when_value = when_predicate.evaluate_selection(batch, &remainder)?;
284297
let when_value = when_value.into_array(batch.num_rows())?;
285298
let when_value = as_boolean_array(&when_value).map_err(|_| {
286299
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
@@ -293,14 +306,14 @@ impl CaseExpr {
293306
// Make sure we only consider rows that have not been matched yet
294307
let when_value = and(&when_value, &remainder)?;
295308

296-
// When no rows available for when clause, skip then clause
297-
if when_value.true_count() == 0 {
309+
// If the predicate did not match any rows, continue to the next branch immediately
310+
let when_match_count = when_value.true_count();
311+
if when_match_count == 0 {
298312
continue;
299313
}
300314

301-
let then_value = self.when_then_expr[i]
302-
.1
303-
.evaluate_selection(batch, &when_value)?;
315+
let then_expression = &self.when_then_expr[i].1;
316+
let then_value = then_expression.evaluate_selection(batch, &when_value)?;
304317

305318
current_value = match then_value {
306319
ColumnarValue::Scalar(ScalarValue::Null) => {
@@ -317,10 +330,11 @@ impl CaseExpr {
317330
// Succeed tuples should be filtered out for short-circuit evaluation,
318331
// null values for the current when expr should be kept
319332
remainder = and_not(&remainder, &when_value)?;
333+
remainder_count -= when_match_count;
320334
}
321335

322336
if let Some(e) = self.else_expr() {
323-
if remainder.true_count() > 0 {
337+
if remainder_count > 0 {
324338
// keep `else_expr`'s data type and return type consistent
325339
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
326340
let else_ = expr

0 commit comments

Comments
 (0)