Skip to content

Commit

Permalink
feat(sqlsmith): gen implicit cast (risingwavelabs#7629)
Browse files Browse the repository at this point in the history
- [x] Gen for fixed func
- [x] Gen for concat (note that this is implicit cast but in explicit context...)

Approved-By: lmatz

Co-Authored-By: Noel Kwan <noelkwan1998@gmail.com>
Co-Authored-By: Noel Kwan <47273164+kwannoel@users.noreply.github.com>
  • Loading branch information
kwannoel and kwannoel authored Feb 3, 2023
1 parent 1f6b063 commit b720f19
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 45 deletions.
74 changes: 42 additions & 32 deletions src/tests/sqlsmith/src/sql_gen/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ use risingwave_sqlparser::ast::{
TrimWhereField, UnaryOperator, Value,
};

use crate::sql_gen::types::{data_type_to_ast_data_type, AGG_FUNC_TABLE, CAST_TABLE, FUNC_TABLE};
use crate::sql_gen::types::{
data_type_to_ast_data_type, AGG_FUNC_TABLE, EXPLICIT_CAST_TABLE, FUNC_TABLE,
IMPLICIT_CAST_TABLE, INVARIANT_FUNC_SET,
};
use crate::sql_gen::{SqlGenerator, SqlGeneratorContext};

static STRUCT_FIELD_NAMES: [&str; 26] = [
Expand Down Expand Up @@ -87,7 +90,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
match self.rng.gen_range(0..=range) {
0..=70 => self.gen_func(typ, context),
71..=80 => self.gen_exists(typ, context),
81..=90 => self.gen_cast(typ, context),
81..=90 => self.gen_explicit_cast(typ, context),
91..=99 => self.gen_agg(typ),
_ => unreachable!(),
}
Expand Down Expand Up @@ -151,7 +154,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}

fn gen_struct_data_type(&mut self, depth: usize) -> DataType {
let num_fields = self.rng.gen_range(1..10);
let num_fields = self.rng.gen_range(1..4);
let fields = (0..num_fields)
.map(|_| self.gen_data_type_inner(depth))
.collect();
Expand Down Expand Up @@ -199,15 +202,19 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}
}

fn gen_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
self.gen_cast_inner(ret, context)
fn gen_explicit_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
self.gen_explicit_cast_inner(ret, context)
.unwrap_or_else(|| self.gen_simple_scalar(ret))
}

/// Generate casts from a cast map.
/// TODO: Assign casts have to be tested via `INSERT`.
fn gen_cast_inner(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Option<Expr> {
let casts = CAST_TABLE.get(ret)?;
fn gen_explicit_cast_inner(
&mut self,
ret: &DataType,
context: SqlGeneratorContext,
) -> Option<Expr> {
let casts = EXPLICIT_CAST_TABLE.get(ret)?;
let cast_sig = casts.choose(&mut self.rng).unwrap();

use CastContext as T;
Expand All @@ -220,32 +227,18 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
Some(Expr::Cast { expr, data_type })
}

// TODO: Re-enable implicit casts
// Currently these implicit cast expressions may surface in:
// select items, functions and so on.
// Type-inference could result in different type from what SQLGenerator expects.
// For example:
// Suppose we had implicit cast expr from smallint->int.
// We then generated 1::smallint with implicit type int.
// If it was part of this expression:
// SELECT 1::smallint as col0;
// Then, when generating other expressions, SqlGenerator sees `col0` with type `int`,
// but its type will be inferred as `smallint` actually in the frontend.
//
// Functions also encounter problems, and could infer to the wrong type.
// May refer to type inference rules:
// https://github.com/risingwavelabs/risingwave/blob/650810a5a9b86028036cb3b51eec5b18d8f814d5/src/frontend/src/expr/type_inference/func.rs#L445-L464
// Therefore it is disabled for now.
// T::Implicit if context.can_implicit_cast() => {
// self.gen_expr(cast_sig.from_type, context).into()
// }

// TODO: Generate this when e2e inserts are generated.
// T::Assign
_ => None,
_ => unreachable!(),
}
}

/// NOTE: This can result in ambiguous expressions.
/// Should only be used in unambiguous context.
fn gen_implicit_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
self.gen_expr(ret, context)
}

fn gen_func(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
match self.rng.gen_bool(0.1) {
true => self.gen_variadic_func(ret, context),
Expand Down Expand Up @@ -275,7 +268,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}

fn gen_case(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
let n = self.rng.gen_range(1..10);
let n = self.rng.gen_range(1..4);
Expr::Case {
operand: None,
conditions: self.gen_n_exprs_with_type(n, &DataType::Boolean, context),
Expand Down Expand Up @@ -304,8 +297,16 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}

fn gen_concat_args(&mut self, context: SqlGeneratorContext) -> Vec<Expr> {
let n = self.rng.gen_range(1..10);
self.gen_n_exprs_with_type(n, &DataType::Varchar, context)
let n = self.rng.gen_range(1..4);
(0..n)
.map(|_| {
if self.rng.gen_bool(0.1) {
self.gen_explicit_cast(&DataType::Varchar, context)
} else {
self.gen_expr(&DataType::Varchar, context)
}
})
.collect()
}

/// Generates `n` expressions of type `ret`.
Expand All @@ -324,10 +325,19 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
Some(funcs) => funcs,
};
let func = funcs.choose(&mut self.rng).unwrap();
let can_implicit_cast = INVARIANT_FUNC_SET.contains(&func.func);
let exprs: Vec<Expr> = func
.inputs_type
.iter()
.map(|t| self.gen_expr(t, context))
.map(|t| {
if let Some(from_tys) = IMPLICIT_CAST_TABLE.get(t)
&& can_implicit_cast && self.flip_coin() {
let from_ty = &from_tys.choose(&mut self.rng).unwrap().from_type;
self.gen_implicit_cast(from_ty, context)
} else {
self.gen_expr(t, context)
}
})
.collect();
let expr = if exprs.len() == 1 {
make_unary_op(func.func, &exprs[0])
Expand Down
2 changes: 1 addition & 1 deletion src/tests/sqlsmith/src/sql_gen/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
value: self.gen_temporal_scalar(typ),
})),
T::List { datatype: ref ty } => {
let n = self.rng.gen_range(1..=100); // Avoid ambiguous type
let n = self.rng.gen_range(1..=4); // Avoid ambiguous type
Expr::Array(self.gen_simple_scalar_list(ty, n))
}
// ENABLE: https://github.com/risingwavelabs/risingwave/issues/6934
Expand Down
50 changes: 38 additions & 12 deletions src/tests/sqlsmith/src/sql_gen/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

//! This module contains datatypes and functions which can be generated by sqlsmith.
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::LazyLock;

use itertools::Itertools;
Expand Down Expand Up @@ -169,6 +169,18 @@ pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<FuncSig>>> = LazyLo
funcs
});

/// Set of invariant functions
// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
func_sigs()
.map(|sig| sig.func)
.counts()
.into_iter()
.filter(|(_key, count)| *count == 1)
.map(|(key, _)| key)
.collect()
});

/// Table which maps aggregate functions' return types to possible function signatures.
// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<AggFuncSig>>> =
Expand All @@ -191,14 +203,28 @@ pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<AggFuncSig>>> =
/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
/// This is because arbitrary strings may not be able to cast,
/// creating large number of invalid queries.
pub(crate) static CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> = LazyLock::new(|| {
let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
cast_sigs()
.filter_map(|cast| cast.try_into().ok())
.filter(|cast: &CastSig| {
cast.context == CastContext::Explicit || cast.context == CastContext::Implicit
})
.filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
.for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
casts
});
pub(crate) static EXPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
LazyLock::new(|| {
let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
cast_sigs()
.filter_map(|cast| cast.try_into().ok())
.filter(|cast: &CastSig| cast.context == CastContext::Explicit)
.filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
.for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
casts
});

/// Build a cast map from return types to viable cast-signatures.
/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
/// This is because arbitrary strings may not be able to cast,
/// creating large number of invalid queries.
pub(crate) static IMPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
LazyLock::new(|| {
let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
cast_sigs()
.filter_map(|cast| cast.try_into().ok())
.filter(|cast: &CastSig| cast.context == CastContext::Implicit)
.filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
.for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
casts
});

0 comments on commit b720f19

Please sign in to comment.