Skip to content

Commit

Permalink
fix(expr): do not const-eval impure expressions (#9616)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <i@bugenzhao.com>
  • Loading branch information
BugenZhao authored May 6, 2023
1 parent 5161643 commit 445ca8e
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 62 deletions.
39 changes: 38 additions & 1 deletion e2e_test/ddl/table/generated_columns.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ drop table t2;
statement error
create table t2 (v1 int as v2+1, v2 int, v3 int as v1-1);

# Create a table with proctime.
# Test table with proctime.
statement ok
create table t3 (v1 int, v2 Timestamptz as proctime());

Expand Down Expand Up @@ -76,3 +76,40 @@ t

statement ok
drop table t3;

# Test materialized view on source with proctime.
statement ok
create source t4 (
v int,
t timestamptz as proctime()
) with (
connector = 'datagen',
fields.v.kind = 'sequence',
fields.v.start = '1',
fields.v.end = '5',
datagen.rows.per.second='10000',
datagen.split.num = '1'
) row format json;

statement ok
CREATE MATERIALIZED VIEW mv AS SELECT * FROM t4;

sleep 2s

statement ok
flush;

query TT
select v, t >= date '2021-01-01' as later_than_2021 from mv;
----
1 t
2 t
3 t
4 t
5 t

statement ok
drop materialized view mv;

statement ok
drop source t4;
3 changes: 0 additions & 3 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,6 @@ message ExprNode {
JSONB_TYPEOF = 602;
JSONB_ARRAY_LENGTH = 603;

// Functions that return a constant value
PI = 610;

// Non-pure functions below (> 1000)
// ------------------------
// Internal functions
Expand Down
75 changes: 52 additions & 23 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,26 @@ impl ExprImpl {
Ok(backend_expr.eval_row(input).await?)
}

/// Evaluate a constant expression.
pub fn eval_row_const(&self) -> RwResult<Datum> {
assert!(self.is_const());
self.eval_row(&OwnedRow::empty())
.now_or_never()
.expect("constant expression should not be async")
/// Try to evaluate an expression if it's a constant expression by `ExprImpl::is_const`.
///
/// Returns...
/// - `None` if it's not a constant expression,
/// - `Some(Ok(_))` if constant evaluation succeeds,
/// - `Some(Err(_))` if there's an error while evaluating a constant expression.
pub fn try_fold_const(&self) -> Option<RwResult<Datum>> {
if self.is_const() {
self.eval_row(&OwnedRow::empty())
.now_or_never()
.expect("constant expression should not be async")
.into()
} else {
None
}
}

/// Similar to `ExprImpl::try_fold_const`, but panics if the expression is not constant.
pub fn fold_const(&self) -> RwResult<Datum> {
self.try_fold_const().expect("expression is not constant")
}
}

Expand Down Expand Up @@ -539,26 +553,41 @@ impl ExprImpl {
}

/// Checks whether this is a constant expr that can be evaluated over a dummy chunk.
/// Equivalent to `!has_input_ref && !has_agg_call && !has_subquery &&
/// !has_correlated_input_ref` but checks them in one pass.
///
/// The expression tree should only consist of literals and **pure** function calls.
pub fn is_const(&self) -> bool {
struct Has {
has: bool,
}
impl ExprVisitor<()> for Has {
fn merge(_: (), _: ()) {}

fn visit_expr(&mut self, expr: &ExprImpl) {
match expr {
ExprImpl::Literal(_inner) => {}
ExprImpl::FunctionCall(inner) => self.visit_function_call(inner),
_ => self.has = true,
let only_literal_and_func = {
struct HasOthers {
has_others: bool,
}
impl ExprVisitor<()> for HasOthers {
fn merge(_: (), _: ()) {}

fn visit_expr(&mut self, expr: &ExprImpl) {
match expr {
ExprImpl::Literal(_inner) => {}
ExprImpl::FunctionCall(inner) => self.visit_function_call(inner),
ExprImpl::CorrelatedInputRef(_)
| ExprImpl::InputRef(_)
| ExprImpl::AggCall(_)
| ExprImpl::Subquery(_)
| ExprImpl::TableFunction(_)
| ExprImpl::WindowFunction(_)
| ExprImpl::UserDefinedFunction(_)
| ExprImpl::Parameter(_)
| ExprImpl::Now(_) => self.has_others = true,
}
}
}
}
let mut visitor = Has { has: false };
visitor.visit_expr(self);
!visitor.has

let mut visitor = HasOthers { has_others: false };
visitor.visit_expr(self);
!visitor.has_others
};

let is_pure = self.is_pure();

only_literal_and_func && is_pure
}

/// Returns the `InputRefs` of an Equality predicate if it matches
Expand Down
1 change: 0 additions & 1 deletion src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ impl ExprVisitor<bool> for ImpureAnalyzer {
| expr_node::Type::JsonbAccessStr
| expr_node::Type::JsonbTypeof
| expr_node::Type::JsonbArrayLength
| expr_node::Type::Pi
| expr_node::Type::Sind
| expr_node::Type::Cosd
| expr_node::Type::Cotd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ impl ExprRewriter for ConstEvalRewriter {
if self.error.is_some() {
return expr;
}
if expr.is_const() {
let data_type = expr.return_type();
match expr.eval_row_const() {
Ok(datum) => Literal::new(datum, data_type).into(),
if let Some(result) = expr.try_fold_const() {
match result {
Ok(datum) => Literal::new(datum, expr.return_type()).into(),
Err(e) => {
self.error = Some(e);
expr
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/optimizer/plan_node/logical_over_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ impl LogicalOverAgg {
}
offset_expr
.cast_implicit(DataType::Int64)?
.eval_row_const()?
.try_fold_const()
.transpose()?
.flatten()
.map(|v| *v.as_int64() as usize)
.unwrap_or(1usize)
} else {
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/optimizer/plan_node/logical_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,24 +371,24 @@ fn expr_to_kafka_timestamp_range(
ExprImpl::FunctionCall(function_call) if function_call.inputs().len() == 2 => {
match (&function_call.inputs()[0], &function_call.inputs()[1]) {
(ExprImpl::InputRef(input_ref), literal)
if literal.is_const()
if let Some(datum) = literal.try_fold_const().transpose()?
&& schema.fields[input_ref.index].name
== KAFKA_TIMESTAMP_COLUMN_NAME
&& literal.return_type() == DataType::Timestamptz =>
{
Ok(Some((
literal.eval_row_const()?.unwrap().into_int64() / 1000,
datum.unwrap().into_int64() / 1000,
false,
)))
}
(literal, ExprImpl::InputRef(input_ref))
if literal.is_const()
if let Some(datum) = literal.try_fold_const().transpose()?
&& schema.fields[input_ref.index].name
== KAFKA_TIMESTAMP_COLUMN_NAME
&& literal.return_type() == DataType::Timestamptz =>
{
Ok(Some((
literal.eval_row_const()?.unwrap().into_int64() / 1000,
datum.unwrap().into_int64() / 1000,
true,
)))
}
Expand Down
12 changes: 1 addition & 11 deletions src/frontend/src/optimizer/rule/always_false_filter_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ impl Rule for AlwaysFalseFilterRule {
.predicate()
.conjunctions
.iter()
.filter_map(|e| {
if e.is_const() {
if let Ok(v) = e.eval_row_const() {
Some(v)
} else {
None
}
} else {
None
}
})
.filter_map(|e| e.try_fold_const().transpose().ok().flatten())
.any(|s| s.unwrap_or(ScalarImpl::Bool(true)) == ScalarImpl::Bool(false));
if always_false {
Some(LogicalValues::create(
Expand Down
12 changes: 2 additions & 10 deletions src/frontend/src/optimizer/rule/over_agg_to_topn_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,7 @@ fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<
for cond in rank_preds {
if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
assert_eq!(input_ref.index, window_func_pos);
let v = v
.cast_implicit(DataType::Int64)
.ok()?
.eval_row_const()
.ok()??;
let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
let v = *v.as_int64();
match cmp {
ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
Expand All @@ -135,11 +131,7 @@ fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<
}
} else if let Some((input_ref, v)) = cond.as_eq_const() {
assert_eq!(input_ref.index, window_func_pos);
let v = v
.cast_implicit(DataType::Int64)
.ok()?
.eval_row_const()
.ok()??;
let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
let v = *v.as_int64();
if let Some(eq) = eq && eq != v {
tracing::error!(
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/utils/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ impl Condition {
}
};

let Some(new_cond) = new_expr.eval_row_const()? else {
let Some(new_cond) = new_expr.fold_const()? else {
// column = NULL, the result is always NULL.
return Ok(false_cond());
};
Expand All @@ -479,7 +479,7 @@ impl Condition {
let const_expr = const_expr
.cast_implicit(input_ref.data_type.clone())
.unwrap();
let value = const_expr.eval_row_const()?;
let value = const_expr.fold_const()?;
let Some(value) = value else {
continue;
};
Expand Down Expand Up @@ -537,7 +537,7 @@ impl Condition {
}
}
};
let Some(value) = new_expr.eval_row_const()? else {
let Some(value) = new_expr.fold_const()? else {
// column compare with NULL, the result is always NULL.
return Ok(false_cond());
};
Expand Down Expand Up @@ -849,7 +849,7 @@ mod cast_compare {
}
_ => unreachable!(),
};
match const_expr.eval_row_const().map_err(|_| ())? {
match const_expr.fold_const().map_err(|_| ())? {
Some(scalar) => {
let value = scalar.as_integral();
if value > upper_bound {
Expand Down

0 comments on commit 445ca8e

Please sign in to comment.