Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlsmith): gen implicit cast #7629

Merged
merged 5 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 42 additions & 32 deletions src/tests/sqlsmith/src/sql_gen/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ use risingwave_sqlparser::ast::{
TrimWhereField, UnaryOperator, Value,
};

use crate::sql_gen::types::{data_type_to_ast_data_type, AGG_FUNC_TABLE, CAST_TABLE, FUNC_TABLE};
use crate::sql_gen::types::{
data_type_to_ast_data_type, AGG_FUNC_TABLE, EXPLICIT_CAST_TABLE, FUNC_TABLE,
IMPLICIT_CAST_TABLE, INVARIANT_FUNC_SET,
};
use crate::sql_gen::{SqlGenerator, SqlGeneratorContext};

static STRUCT_FIELD_NAMES: [&str; 26] = [
Expand Down Expand Up @@ -87,7 +90,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
match self.rng.gen_range(0..=range) {
0..=70 => self.gen_func(typ, context),
71..=80 => self.gen_exists(typ, context),
81..=90 => self.gen_cast(typ, context),
81..=90 => self.gen_explicit_cast(typ, context),
91..=99 => self.gen_agg(typ),
_ => unreachable!(),
}
Expand Down Expand Up @@ -151,7 +154,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}

fn gen_struct_data_type(&mut self, depth: usize) -> DataType {
let num_fields = self.rng.gen_range(1..10);
let num_fields = self.rng.gen_range(1..4);
let fields = (0..num_fields)
.map(|_| self.gen_data_type_inner(depth))
.collect();
Expand Down Expand Up @@ -199,15 +202,19 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}
}

fn gen_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
self.gen_cast_inner(ret, context)
fn gen_explicit_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
self.gen_explicit_cast_inner(ret, context)
.unwrap_or_else(|| self.gen_simple_scalar(ret))
}

/// Generate casts from a cast map.
/// TODO: Assign casts have to be tested via `INSERT`.
fn gen_cast_inner(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Option<Expr> {
let casts = CAST_TABLE.get(ret)?;
fn gen_explicit_cast_inner(
&mut self,
ret: &DataType,
context: SqlGeneratorContext,
) -> Option<Expr> {
let casts = EXPLICIT_CAST_TABLE.get(ret)?;
let cast_sig = casts.choose(&mut self.rng).unwrap();

use CastContext as T;
Expand All @@ -220,32 +227,18 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
Some(Expr::Cast { expr, data_type })
}

// TODO: Re-enable implicit casts
// Currently these implicit cast expressions may surface in:
// select items, functions and so on.
// Type-inference could result in different type from what SQLGenerator expects.
// For example:
// Suppose we had implicit cast expr from smallint->int.
// We then generated 1::smallint with implicit type int.
// If it was part of this expression:
// SELECT 1::smallint as col0;
// Then, when generating other expressions, SqlGenerator sees `col0` with type `int`,
// but its type will be inferred as `smallint` actually in the frontend.
//
// Functions also encounter problems, and could infer to the wrong type.
// May refer to type inference rules:
// https://github.com/risingwavelabs/risingwave/blob/650810a5a9b86028036cb3b51eec5b18d8f814d5/src/frontend/src/expr/type_inference/func.rs#L445-L464
// Therefore it is disabled for now.
// T::Implicit if context.can_implicit_cast() => {
// self.gen_expr(cast_sig.from_type, context).into()
// }

// TODO: Generate this when e2e inserts are generated.
// T::Assign
_ => None,
_ => unreachable!(),
}
}

/// NOTE: This can result in ambiguous expressions.
/// Should only be used in unambiguous context.
fn gen_implicit_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
self.gen_expr(ret, context)
}

fn gen_func(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
match self.rng.gen_bool(0.1) {
true => self.gen_variadic_func(ret, context),
Expand Down Expand Up @@ -275,7 +268,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}

fn gen_case(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr {
let n = self.rng.gen_range(1..10);
let n = self.rng.gen_range(1..4);
Expr::Case {
operand: None,
conditions: self.gen_n_exprs_with_type(n, &DataType::Boolean, context),
Expand Down Expand Up @@ -304,8 +297,16 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
}

fn gen_concat_args(&mut self, context: SqlGeneratorContext) -> Vec<Expr> {
let n = self.rng.gen_range(1..10);
self.gen_n_exprs_with_type(n, &DataType::Varchar, context)
let n = self.rng.gen_range(1..4);
(0..n)
.map(|_| {
if self.rng.gen_bool(0.1) {
self.gen_explicit_cast(&DataType::Varchar, context)
} else {
self.gen_expr(&DataType::Varchar, context)
}
})
.collect()
}

/// Generates `n` expressions of type `ret`.
Expand All @@ -324,10 +325,19 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
Some(funcs) => funcs,
};
let func = funcs.choose(&mut self.rng).unwrap();
let can_implicit_cast = INVARIANT_FUNC_SET.contains(&func.func);
let exprs: Vec<Expr> = func
.inputs_type
.iter()
.map(|t| self.gen_expr(t, context))
.map(|t| {
if let Some(from_tys) = IMPLICIT_CAST_TABLE.get(t)
&& can_implicit_cast && self.flip_coin() {
let from_ty = &from_tys.choose(&mut self.rng).unwrap().from_type;
self.gen_implicit_cast(from_ty, context)
} else {
self.gen_expr(t, context)
}
})
.collect();
let expr = if exprs.len() == 1 {
make_unary_op(func.func, &exprs[0])
Expand Down
2 changes: 1 addition & 1 deletion src/tests/sqlsmith/src/sql_gen/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> {
value: self.gen_temporal_scalar(typ),
})),
T::List { datatype: ref ty } => {
let n = self.rng.gen_range(1..=100); // Avoid ambiguous type
let n = self.rng.gen_range(1..=4); // Avoid ambiguous type
Expr::Array(self.gen_simple_scalar_list(ty, n))
}
// ENABLE: https://github.com/risingwavelabs/risingwave/issues/6934
Expand Down
50 changes: 38 additions & 12 deletions src/tests/sqlsmith/src/sql_gen/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

//! This module contains datatypes and functions which can be generated by sqlsmith.

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::LazyLock;

use itertools::Itertools;
Expand Down Expand Up @@ -169,6 +169,18 @@ pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<FuncSig>>> = LazyLo
funcs
});

/// Set of invariant functions
// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
func_sigs()
.map(|sig| sig.func)
.counts()
.into_iter()
.filter(|(_key, count)| *count == 1)
.map(|(key, _)| key)
.collect()
});

/// Table which maps aggregate functions' return types to possible function signatures.
// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<AggFuncSig>>> =
Expand All @@ -191,14 +203,28 @@ pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<AggFuncSig>>> =
/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
/// This is because arbitrary strings may not be able to cast,
/// creating large number of invalid queries.
pub(crate) static CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> = LazyLock::new(|| {
let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
cast_sigs()
.filter_map(|cast| cast.try_into().ok())
.filter(|cast: &CastSig| {
cast.context == CastContext::Explicit || cast.context == CastContext::Implicit
})
.filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
.for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
casts
});
pub(crate) static EXPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
LazyLock::new(|| {
let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
cast_sigs()
.filter_map(|cast| cast.try_into().ok())
.filter(|cast: &CastSig| cast.context == CastContext::Explicit)
.filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
.for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
casts
});

/// Build a cast map from return types to viable cast-signatures.
/// NOTE: We avoid cast from varchar to other datatypes apart from itself.
/// This is because arbitrary strings may not be able to cast,
/// creating large number of invalid queries.
pub(crate) static IMPLICIT_CAST_TABLE: LazyLock<HashMap<DataType, Vec<CastSig>>> =
LazyLock::new(|| {
let mut casts = HashMap::<DataType, Vec<CastSig>>::new();
cast_sigs()
.filter_map(|cast| cast.try_into().ok())
.filter(|cast: &CastSig| cast.context == CastContext::Implicit)
.filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar)
.for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast));
casts
});