diff --git a/src/tests/sqlsmith/src/sql_gen/expr.rs b/src/tests/sqlsmith/src/sql_gen/expr.rs index 19ab43463f378..a9a85cf68fbf7 100644 --- a/src/tests/sqlsmith/src/sql_gen/expr.rs +++ b/src/tests/sqlsmith/src/sql_gen/expr.rs @@ -26,7 +26,10 @@ use risingwave_sqlparser::ast::{ TrimWhereField, UnaryOperator, Value, }; -use crate::sql_gen::types::{data_type_to_ast_data_type, AGG_FUNC_TABLE, CAST_TABLE, FUNC_TABLE}; +use crate::sql_gen::types::{ + data_type_to_ast_data_type, AGG_FUNC_TABLE, EXPLICIT_CAST_TABLE, FUNC_TABLE, + IMPLICIT_CAST_TABLE, INVARIANT_FUNC_SET, +}; use crate::sql_gen::{SqlGenerator, SqlGeneratorContext}; static STRUCT_FIELD_NAMES: [&str; 26] = [ @@ -87,7 +90,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { match self.rng.gen_range(0..=range) { 0..=70 => self.gen_func(typ, context), 71..=80 => self.gen_exists(typ, context), - 81..=90 => self.gen_cast(typ, context), + 81..=90 => self.gen_explicit_cast(typ, context), 91..=99 => self.gen_agg(typ), _ => unreachable!(), } @@ -151,7 +154,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { } fn gen_struct_data_type(&mut self, depth: usize) -> DataType { - let num_fields = self.rng.gen_range(1..10); + let num_fields = self.rng.gen_range(1..4); let fields = (0..num_fields) .map(|_| self.gen_data_type_inner(depth)) .collect(); @@ -199,15 +202,19 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { } } - fn gen_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr { - self.gen_cast_inner(ret, context) + fn gen_explicit_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr { + self.gen_explicit_cast_inner(ret, context) .unwrap_or_else(|| self.gen_simple_scalar(ret)) } /// Generate casts from a cast map. /// TODO: Assign casts have to be tested via `INSERT`. - fn gen_cast_inner(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Option { - let casts = CAST_TABLE.get(ret)?; + fn gen_explicit_cast_inner( + &mut self, + ret: &DataType, + context: SqlGeneratorContext, + ) -> Option { + let casts = EXPLICIT_CAST_TABLE.get(ret)?; let cast_sig = casts.choose(&mut self.rng).unwrap(); use CastContext as T; @@ -220,32 +227,18 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { Some(Expr::Cast { expr, data_type }) } - // TODO: Re-enable implicit casts - // Currently these implicit cast expressions may surface in: - // select items, functions and so on. - // Type-inference could result in different type from what SQLGenerator expects. - // For example: - // Suppose we had implicit cast expr from smallint->int. - // We then generated 1::smallint with implicit type int. - // If it was part of this expression: - // SELECT 1::smallint as col0; - // Then, when generating other expressions, SqlGenerator sees `col0` with type `int`, - // but its type will be inferred as `smallint` actually in the frontend. - // - // Functions also encounter problems, and could infer to the wrong type. - // May refer to type inference rules: - // https://github.com/risingwavelabs/risingwave/blob/650810a5a9b86028036cb3b51eec5b18d8f814d5/src/frontend/src/expr/type_inference/func.rs#L445-L464 - // Therefore it is disabled for now. - // T::Implicit if context.can_implicit_cast() => { - // self.gen_expr(cast_sig.from_type, context).into() - // } - // TODO: Generate this when e2e inserts are generated. // T::Assign - _ => None, + _ => unreachable!(), } } + /// NOTE: This can result in ambiguous expressions. + /// Should only be used in unambiguous context. + fn gen_implicit_cast(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr { + self.gen_expr(ret, context) + } + fn gen_func(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr { match self.rng.gen_bool(0.1) { true => self.gen_variadic_func(ret, context), @@ -275,7 +268,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { } fn gen_case(&mut self, ret: &DataType, context: SqlGeneratorContext) -> Expr { - let n = self.rng.gen_range(1..10); + let n = self.rng.gen_range(1..4); Expr::Case { operand: None, conditions: self.gen_n_exprs_with_type(n, &DataType::Boolean, context), @@ -304,8 +297,16 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { } fn gen_concat_args(&mut self, context: SqlGeneratorContext) -> Vec { - let n = self.rng.gen_range(1..10); - self.gen_n_exprs_with_type(n, &DataType::Varchar, context) + let n = self.rng.gen_range(1..4); + (0..n) + .map(|_| { + if self.rng.gen_bool(0.1) { + self.gen_explicit_cast(&DataType::Varchar, context) + } else { + self.gen_expr(&DataType::Varchar, context) + } + }) + .collect() } /// Generates `n` expressions of type `ret`. @@ -324,10 +325,19 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { Some(funcs) => funcs, }; let func = funcs.choose(&mut self.rng).unwrap(); + let can_implicit_cast = INVARIANT_FUNC_SET.contains(&func.func); let exprs: Vec = func .inputs_type .iter() - .map(|t| self.gen_expr(t, context)) + .map(|t| { + if let Some(from_tys) = IMPLICIT_CAST_TABLE.get(t) + && can_implicit_cast && self.flip_coin() { + let from_ty = &from_tys.choose(&mut self.rng).unwrap().from_type; + self.gen_implicit_cast(from_ty, context) + } else { + self.gen_expr(t, context) + } + }) .collect(); let expr = if exprs.len() == 1 { make_unary_op(func.func, &exprs[0]) diff --git a/src/tests/sqlsmith/src/sql_gen/scalar.rs b/src/tests/sqlsmith/src/sql_gen/scalar.rs index 6d657385c67e4..1a0a26b5d0af8 100644 --- a/src/tests/sqlsmith/src/sql_gen/scalar.rs +++ b/src/tests/sqlsmith/src/sql_gen/scalar.rs @@ -92,7 +92,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { value: self.gen_temporal_scalar(typ), })), T::List { datatype: ref ty } => { - let n = self.rng.gen_range(1..=100); // Avoid ambiguous type + let n = self.rng.gen_range(1..=4); // Avoid ambiguous type Expr::Array(self.gen_simple_scalar_list(ty, n)) } // ENABLE: https://github.com/risingwavelabs/risingwave/issues/6934 diff --git a/src/tests/sqlsmith/src/sql_gen/types.rs b/src/tests/sqlsmith/src/sql_gen/types.rs index 904163d6a964a..9019fd250ae25 100644 --- a/src/tests/sqlsmith/src/sql_gen/types.rs +++ b/src/tests/sqlsmith/src/sql_gen/types.rs @@ -14,7 +14,7 @@ //! This module contains datatypes and functions which can be generated by sqlsmith. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::LazyLock; use itertools::Itertools; @@ -169,6 +169,18 @@ pub(crate) static FUNC_TABLE: LazyLock>> = LazyLo funcs }); +/// Set of invariant functions +// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826 +pub(crate) static INVARIANT_FUNC_SET: LazyLock> = LazyLock::new(|| { + func_sigs() + .map(|sig| sig.func) + .counts() + .into_iter() + .filter(|(_key, count)| *count == 1) + .map(|(key, _)| key) + .collect() +}); + /// Table which maps aggregate functions' return types to possible function signatures. // ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826 pub(crate) static AGG_FUNC_TABLE: LazyLock>> = @@ -191,14 +203,28 @@ pub(crate) static AGG_FUNC_TABLE: LazyLock>> = /// NOTE: We avoid cast from varchar to other datatypes apart from itself. /// This is because arbitrary strings may not be able to cast, /// creating large number of invalid queries. -pub(crate) static CAST_TABLE: LazyLock>> = LazyLock::new(|| { - let mut casts = HashMap::>::new(); - cast_sigs() - .filter_map(|cast| cast.try_into().ok()) - .filter(|cast: &CastSig| { - cast.context == CastContext::Explicit || cast.context == CastContext::Implicit - }) - .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar) - .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast)); - casts -}); +pub(crate) static EXPLICIT_CAST_TABLE: LazyLock>> = + LazyLock::new(|| { + let mut casts = HashMap::>::new(); + cast_sigs() + .filter_map(|cast| cast.try_into().ok()) + .filter(|cast: &CastSig| cast.context == CastContext::Explicit) + .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar) + .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast)); + casts + }); + +/// Build a cast map from return types to viable cast-signatures. +/// NOTE: We avoid cast from varchar to other datatypes apart from itself. +/// This is because arbitrary strings may not be able to cast, +/// creating large number of invalid queries. +pub(crate) static IMPLICIT_CAST_TABLE: LazyLock>> = + LazyLock::new(|| { + let mut casts = HashMap::>::new(); + cast_sigs() + .filter_map(|cast| cast.try_into().ok()) + .filter(|cast: &CastSig| cast.context == CastContext::Implicit) + .filter(|cast| cast.from_type != DataType::Varchar || cast.to_type == DataType::Varchar) + .for_each(|cast| casts.entry(cast.to_type.clone()).or_default().push(cast)); + casts + });