From c85c41f5a79ccd15c72403b6775fe6c80a06d912 Mon Sep 17 00:00:00 2001 From: Andy Lok Date: Thu, 2 Mar 2023 23:14:09 +0800 Subject: [PATCH 1/2] chore(expr): respect overload order for function factories --- ...hmetics_type_v2.rs => arithmetics_type.rs} | 0 src/query/codegen/src/writes/mod.rs | 4 +- src/query/codegen/src/writes/register.rs | 22 +-- src/query/expression/src/evaluator.rs | 6 +- src/query/expression/src/function.rs | 111 ++++++++---- src/query/expression/src/register.rs | 166 ++++++++++-------- src/query/expression/src/type_check.rs | 58 +++--- src/query/expression/src/types.rs | 11 ++ .../expression/src/utils/arithmetics_type.rs | 2 +- src/query/functions/src/scalars/array.rs | 3 + src/query/functions/tests/it/scalars/mod.rs | 2 +- 11 files changed, 237 insertions(+), 148 deletions(-) rename src/query/codegen/src/writes/{arithmetics_type_v2.rs => arithmetics_type.rs} (100%) diff --git a/src/query/codegen/src/writes/arithmetics_type_v2.rs b/src/query/codegen/src/writes/arithmetics_type.rs similarity index 100% rename from src/query/codegen/src/writes/arithmetics_type_v2.rs rename to src/query/codegen/src/writes/arithmetics_type.rs diff --git a/src/query/codegen/src/writes/mod.rs b/src/query/codegen/src/writes/mod.rs index eb81bd8455722..8f29f144eb688 100644 --- a/src/query/codegen/src/writes/mod.rs +++ b/src/query/codegen/src/writes/mod.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod arithmetics_type_v2; +mod arithmetics_type; mod register; -pub use arithmetics_type_v2::*; +pub use arithmetics_type::*; pub use register::*; diff --git a/src/query/codegen/src/writes/register.rs b/src/query/codegen/src/writes/register.rs index 8c04e92111328..e13ec15f8efc0 100644 --- a/src/query/codegen/src/writes/register.rs +++ b/src/query/codegen/src/writes/register.rs @@ -342,19 +342,21 @@ pub fn codegen_register() { F: Fn({arg_f_closure_sig}) -> FunctionDomain + 'static + Clone + Copy + Send + Sync, G: for <'a> Fn({arg_g_closure_sig} &mut EvalContext) -> Value + 'static + Clone + Copy + Send + Sync, {{ + let func = Arc::new(Function {{ + signature: FunctionSignature {{ + name: name.to_string(), + args_type: vec![{arg_sig_type}], + return_type: O::data_type(), + property, + }}, + calc_domain: Box::new(erase_calc_domain_generic_{n_args}_arg::<{arg_generics} O>(calc_domain)), + eval: Box::new(erase_function_generic_{n_args}_arg(func)), + }}); + let id = self.next_function_id(name); self.funcs .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Arc::new(Function {{ - signature: FunctionSignature {{ - name: name.to_string(), - args_type: vec![{arg_sig_type}], - return_type: O::data_type(), - property, - }}, - calc_domain: Box::new(erase_calc_domain_generic_{n_args}_arg::<{arg_generics} O>(calc_domain)), - eval: Box::new(erase_function_generic_{n_args}_arg(func)), - }})); + .push((func, id)); }} " ) diff --git a/src/query/expression/src/evaluator.rs b/src/query/expression/src/evaluator.rs index 62667a3fe424e..a465212a4c696 100644 --- a/src/query/expression/src/evaluator.rs +++ b/src/query/expression/src/evaluator.rs @@ -676,13 +676,13 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> { let mut old_expr = expr.clone(); let mut old_domain = None; for _ in 0..MAX_ITERATIONS { - let (new_expr, domain) = self.fold_once(&old_expr); + let (new_expr, new_domain) = self.fold_once(&old_expr); if new_expr == old_expr { - return (new_expr, domain); + return (new_expr, new_domain); } old_expr = new_expr; - old_domain = domain; + old_domain = new_domain; } error!("maximum iterations reached while folding expression"); diff --git a/src/query/expression/src/function.rs b/src/query/expression/src/function.rs index c56ff96e1866e..d82185aba5c0f 100755 --- a/src/query/expression/src/function.rs +++ b/src/query/expression/src/function.rs @@ -28,7 +28,7 @@ use serde::Serialize; use crate::date_helper::TzLUT; use crate::property::Domain; use crate::property::FunctionProperty; -use crate::type_check::can_auto_cast_to; +use crate::type_check::try_unify_signature; use crate::types::nullable::NullableColumn; use crate::types::*; use crate::utils::arrow::constant_bitmap; @@ -166,17 +166,17 @@ impl Function { } } +/// A function to build function depending on the const parameters and the type of arguments (before coercion). +/// +/// The first argument is the const parameters and the second argument is the types of arguments. +pub type FunctionFactory = + Box Option> + Send + Sync + 'static>; + #[derive(Default)] pub struct FunctionRegistry { - pub funcs: HashMap>>, - /// A function to build function depending on the const parameters and the type of arguments (before coercion). - /// - /// The first argument is the const parameters and the second argument is the number of arguments. - #[allow(clippy::type_complexity)] - pub factories: HashMap< - String, - Vec Option> + Send + Sync + 'static>>, - >, + pub funcs: HashMap, usize)>>, + pub factories: HashMap>, + /// Aliases map from alias function name to original function name. pub aliases: HashMap, @@ -211,14 +211,24 @@ impl FunctionRegistry { pub fn get(&self, id: &FunctionID) -> Option> { match id { - FunctionID::Builtin { name, id } => self.funcs.get(name.as_str())?.get(*id).cloned(), + FunctionID::Builtin { name, id } => self + .funcs + .get(name.as_str())? + .iter() + .find(|(_, func_id)| func_id == id) + .map(|(func, _)| func.clone()), FunctionID::Factory { name, id, params, args_type, } => { - let factory = self.factories.get(name.as_str())?.get(*id)?; + let factory = self + .factories + .get(name.as_str())? + .iter() + .find(|(_, func_id)| func_id == id) + .map(|(func, _)| func)?; factory(params, args_type) } } @@ -235,12 +245,12 @@ impl FunctionRegistry { let mut candidates = Vec::new(); if let Some(funcs) = self.funcs.get(&name) { - candidates.extend(funcs.iter().enumerate().filter_map(|(id, func)| { + candidates.extend(funcs.iter().filter_map(|(func, id)| { if func.signature.name == name && func.signature.args_type.len() == args.len() { Some(( FunctionID::Builtin { name: name.to_string(), - id, + id: *id, }, func.clone(), )) @@ -256,12 +266,12 @@ impl FunctionRegistry { .map(Expr::data_type) .cloned() .collect::>(); - candidates.extend(factories.iter().enumerate().filter_map(|(id, factory)| { + candidates.extend(factories.iter().filter_map(|(factory, id)| { factory(params, &args_type).map(|func| { ( FunctionID::Factory { name: name.to_string(), - id, + id: *id, params: params.to_vec(), args_type: args_type.clone(), }, @@ -271,6 +281,8 @@ impl FunctionRegistry { })); } + candidates.sort_by_key(|(id, _)| id.id()); + candidates } @@ -291,10 +303,11 @@ impl FunctionRegistry { name: &str, factory: impl Fn(&[usize], &[DataType]) -> Option> + 'static + Send + Sync, ) { + let id = self.next_function_id(name); self.factories .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Box::new(factory)); + .push((Box::new(factory), id)); } pub fn register_aliases(&mut self, fn_name: &str, aliases: &[&str]) { @@ -329,26 +342,49 @@ impl FunctionRegistry { self.auto_try_cast_rules.extend(auto_try_cast_rules); } + pub fn next_function_id(&self, name: &str) -> usize { + self.funcs.get(name).map(|funcs| funcs.len()).unwrap_or(0) + + self.factories.get(name).map(|f| f.len()).unwrap_or(0) + } + pub fn check_ambiguity(&self) { for (name, funcs) in &self.funcs { let auto_cast_rules = self.get_auto_cast_rules(name); - for (i, former) in funcs.iter().enumerate() { - for latter in funcs.iter().skip(i + 1) { - if former.signature.args_type.len() == latter.signature.args_type.len() - && former - .signature - .args_type + for (former, former_id) in funcs { + for latter in funcs + .iter() + .filter(|(_, id)| id > former_id) + .map(|(func, _)| func.clone()) + .chain( + self.factories + .get(name) + .map(Vec::as_slice) + .unwrap_or(&[]) .iter() - .zip(latter.signature.args_type.iter()) - .all(|(former_arg, latter_arg)| { - can_auto_cast_to(latter_arg, former_arg, auto_cast_rules) - }) - { - panic!( - "Ambiguous signatures for function:\n- {}\n- {}\n\ - Suggestion: swap the order of the overloads.", - former.signature, latter.signature - ); + .filter(|(_, id)| id > former_id) + .filter_map(|(factory, _)| factory(&[], &former.signature.args_type)), + ) + { + if former.signature.args_type.len() == latter.signature.args_type.len() { + if let Ok(subst) = try_unify_signature( + latter.signature.args_type.iter(), + former.signature.args_type.iter(), + auto_cast_rules, + ) { + if subst.apply(&former.signature.return_type).is_ok() + && former + .signature + .args_type + .iter() + .all(|sig_ty| subst.apply(sig_ty).is_ok()) + { + panic!( + "Ambiguous signatures for function:\n- {}\n- {}\n\ + Suggestion: swap the order of the overloads.", + former.signature, latter.signature + ); + } + } } } } @@ -356,6 +392,15 @@ impl FunctionRegistry { } } +impl FunctionID { + pub fn id(&self) -> usize { + match self { + FunctionID::Builtin { id, .. } => *id, + FunctionID::Factory { id, .. } => *id, + } + } +} + pub fn wrap_nullable(f: F) -> impl Fn(&[ValueRef], &mut EvalContext) -> Value where F: Fn(&[ValueRef], &mut EvalContext) -> Value { move |args, ctx| { diff --git a/src/query/expression/src/register.rs b/src/query/expression/src/register.rs index 7eb5d01fa1d76..7f846c4169c89 100755 --- a/src/query/expression/src/register.rs +++ b/src/query/expression/src/register.rs @@ -935,19 +935,21 @@ impl FunctionRegistry { F: Fn() -> FunctionDomain + 'static + Clone + Copy + Send + Sync, G: for<'a> Fn(&mut EvalContext) -> Value + 'static + Clone + Copy + Send + Sync, { + let func = Arc::new(Function { + signature: FunctionSignature { + name: name.to_string(), + args_type: vec![], + return_type: O::data_type(), + property, + }, + calc_domain: Box::new(erase_calc_domain_generic_0_arg::(calc_domain)), + eval: Box::new(erase_function_generic_0_arg(func)), + }); + let id = self.next_function_id(name); self.funcs .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Arc::new(Function { - signature: FunctionSignature { - name: name.to_string(), - args_type: vec![], - return_type: O::data_type(), - property, - }, - calc_domain: Box::new(erase_calc_domain_generic_0_arg::(calc_domain)), - eval: Box::new(erase_function_generic_0_arg(func)), - })); + .push((func, id)); } pub fn register_1_arg_core( @@ -965,19 +967,21 @@ impl FunctionRegistry { + Send + Sync, { + let func = Arc::new(Function { + signature: FunctionSignature { + name: name.to_string(), + args_type: vec![I1::data_type()], + return_type: O::data_type(), + property, + }, + calc_domain: Box::new(erase_calc_domain_generic_1_arg::(calc_domain)), + eval: Box::new(erase_function_generic_1_arg(func)), + }); + let id = self.next_function_id(name); self.funcs .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Arc::new(Function { - signature: FunctionSignature { - name: name.to_string(), - args_type: vec![I1::data_type()], - return_type: O::data_type(), - property, - }, - calc_domain: Box::new(erase_calc_domain_generic_1_arg::(calc_domain)), - eval: Box::new(erase_function_generic_1_arg(func)), - })); + .push((func, id)); } pub fn register_2_arg_core( @@ -995,19 +999,21 @@ impl FunctionRegistry { + Send + Sync, { + let func = Arc::new(Function { + signature: FunctionSignature { + name: name.to_string(), + args_type: vec![I1::data_type(), I2::data_type()], + return_type: O::data_type(), + property, + }, + calc_domain: Box::new(erase_calc_domain_generic_2_arg::(calc_domain)), + eval: Box::new(erase_function_generic_2_arg(func)), + }); + let id = self.next_function_id(name); self.funcs .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Arc::new(Function { - signature: FunctionSignature { - name: name.to_string(), - args_type: vec![I1::data_type(), I2::data_type()], - return_type: O::data_type(), - property, - }, - calc_domain: Box::new(erase_calc_domain_generic_2_arg::(calc_domain)), - eval: Box::new(erase_function_generic_2_arg(func)), - })); + .push((func, id)); } pub fn register_3_arg_core( @@ -1035,21 +1041,23 @@ impl FunctionRegistry { + Send + Sync, { + let func = Arc::new(Function { + signature: FunctionSignature { + name: name.to_string(), + args_type: vec![I1::data_type(), I2::data_type(), I3::data_type()], + return_type: O::data_type(), + property, + }, + calc_domain: Box::new(erase_calc_domain_generic_3_arg::( + calc_domain, + )), + eval: Box::new(erase_function_generic_3_arg(func)), + }); + let id = self.next_function_id(name); self.funcs .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Arc::new(Function { - signature: FunctionSignature { - name: name.to_string(), - args_type: vec![I1::data_type(), I2::data_type(), I3::data_type()], - return_type: O::data_type(), - property, - }, - calc_domain: Box::new(erase_calc_domain_generic_3_arg::( - calc_domain, - )), - eval: Box::new(erase_function_generic_3_arg(func)), - })); + .push((func, id)); } pub fn register_4_arg_core< @@ -1086,26 +1094,28 @@ impl FunctionRegistry { + Send + Sync, { + let func = Arc::new(Function { + signature: FunctionSignature { + name: name.to_string(), + args_type: vec![ + I1::data_type(), + I2::data_type(), + I3::data_type(), + I4::data_type(), + ], + return_type: O::data_type(), + property, + }, + calc_domain: Box::new(erase_calc_domain_generic_4_arg::( + calc_domain, + )), + eval: Box::new(erase_function_generic_4_arg(func)), + }); + let id = self.next_function_id(name); self.funcs .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Arc::new(Function { - signature: FunctionSignature { - name: name.to_string(), - args_type: vec![ - I1::data_type(), - I2::data_type(), - I3::data_type(), - I4::data_type(), - ], - return_type: O::data_type(), - property, - }, - calc_domain: Box::new(erase_calc_domain_generic_4_arg::( - calc_domain, - )), - eval: Box::new(erase_function_generic_4_arg(func)), - })); + .push((func, id)); } pub fn register_5_arg_core< @@ -1144,27 +1154,29 @@ impl FunctionRegistry { + Send + Sync, { + let func = Arc::new(Function { + signature: FunctionSignature { + name: name.to_string(), + args_type: vec![ + I1::data_type(), + I2::data_type(), + I3::data_type(), + I4::data_type(), + I5::data_type(), + ], + return_type: O::data_type(), + property, + }, + calc_domain: Box::new(erase_calc_domain_generic_5_arg::( + calc_domain, + )), + eval: Box::new(erase_function_generic_5_arg(func)), + }); + let id = self.next_function_id(name); self.funcs .entry(name.to_string()) .or_insert_with(Vec::new) - .push(Arc::new(Function { - signature: FunctionSignature { - name: name.to_string(), - args_type: vec![ - I1::data_type(), - I2::data_type(), - I3::data_type(), - I4::data_type(), - I5::data_type(), - ], - return_type: O::data_type(), - property, - }, - calc_domain: Box::new(erase_calc_domain_generic_5_arg::( - calc_domain, - )), - eval: Box::new(erase_function_generic_5_arg(func)), - })); + .push((func, id)); } } diff --git a/src/query/expression/src/type_check.rs b/src/query/expression/src/type_check.rs index 0b0582740da0f..bd3d532e3a8d4 100755 --- a/src/query/expression/src/type_check.rs +++ b/src/query/expression/src/type_check.rs @@ -310,9 +310,9 @@ impl Substitution { Ok(self) } - pub fn apply(&self, ty: DataType) -> Result { + pub fn apply(&self, ty: &DataType) -> Result { match ty { - DataType::Generic(idx) => self.0.get(&idx).cloned().ok_or_else(|| { + DataType::Generic(idx) => self.0.get(idx).cloned().ok_or_else(|| { ErrorCode::from_string_no_backtrace(format!("unbound generic type `T{idx}`")) }), DataType::Nullable(box ty) => Ok(DataType::Nullable(Box::new(self.apply(ty)?))), @@ -320,7 +320,7 @@ impl Substitution { DataType::Map(box ty) => match ty { DataType::Tuple(fields_ty) => { let fields_ty = fields_ty - .into_iter() + .iter() .map(|field_ty| self.apply(field_ty)) .collect::>()?; let inner_ty = DataType::Tuple(fields_ty); @@ -330,12 +330,12 @@ impl Substitution { }, DataType::Tuple(fields_ty) => { let fields_ty = fields_ty - .into_iter() + .iter() .map(|field_ty| self.apply(field_ty)) .collect::>()?; Ok(DataType::Tuple(fields_ty)) } - ty => Ok(ty), + ty => Ok(ty.clone()), } } } @@ -348,31 +348,23 @@ pub fn try_check_function( auto_cast_rules: AutoCastRules, fn_registry: &FunctionRegistry, ) -> Result<(Vec>, DataType, Vec)> { - assert_eq!(args.len(), sig.args_type.len()); - - let substs = args - .iter() - .map(Expr::data_type) - .zip(&sig.args_type) - .map(|(src_ty, dest_ty)| unify(src_ty, dest_ty, auto_cast_rules)) - .collect::>>()?; - - let subst = substs - .into_iter() - .try_reduce(|subst1, subst2| subst1.merge(subst2, auto_cast_rules))? - .unwrap_or_else(Substitution::empty); + let subst = try_unify_signature( + args.iter().map(Expr::data_type), + sig.args_type.iter(), + auto_cast_rules, + )?; let checked_args = args .iter() .zip(&sig.args_type) .map(|(arg, sig_type)| { - let sig_type = subst.apply(sig_type.clone())?; + let sig_type = subst.apply(sig_type)?; let is_try = fn_registry.is_auto_try_cast_rule(arg.data_type(), &sig_type); check_cast(span, is_try, arg.clone(), &sig_type, fn_registry) }) .collect::>>()?; - let return_type = subst.apply(sig.return_type.clone())?; + let return_type = subst.apply(&sig.return_type)?; let generics = subst .0 @@ -394,13 +386,37 @@ pub fn try_check_function( Ok((checked_args, return_type, generics)) } +pub fn try_unify_signature( + src_tys: impl IntoIterator + ExactSizeIterator, + dest_tys: impl IntoIterator + ExactSizeIterator, + auto_cast_rules: AutoCastRules, +) -> Result { + assert_eq!(src_tys.len(), dest_tys.len()); + + let substs = src_tys + .into_iter() + .zip(dest_tys) + .map(|(src_ty, dest_ty)| unify(src_ty, dest_ty, auto_cast_rules)) + .collect::>>()?; + + Ok(substs + .into_iter() + .try_reduce(|subst1, subst2| subst1.merge(subst2, auto_cast_rules))? + .unwrap_or_else(Substitution::empty)) +} + pub fn unify( src_ty: &DataType, dest_ty: &DataType, auto_cast_rules: AutoCastRules, ) -> Result { match (src_ty, dest_ty) { - (DataType::Generic(_), _) => unreachable!("source type must not contain generic type"), + (DataType::Generic(_), _) => Err(ErrorCode::from_string_no_backtrace( + "source type {src_ty} must not contain generic type".to_string(), + )), + (ty, DataType::Generic(_)) if ty.has_generic() => Err(ErrorCode::from_string_no_backtrace( + "source type {src_ty} must not contain generic type".to_string(), + )), (ty, DataType::Generic(idx)) => Ok(Substitution::equation(*idx, ty.clone())), (src_ty, dest_ty) if can_auto_cast_to(src_ty, dest_ty, auto_cast_rules) => { Ok(Substitution::empty()) diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index 5b654afdf55ad..23b45ed3255ca 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -119,6 +119,17 @@ impl DataType { } } + pub fn has_generic(&self) -> bool { + match self { + DataType::Generic(_) => true, + DataType::Nullable(ty) => ty.has_generic(), + DataType::Array(ty) => ty.has_generic(), + DataType::Map(ty) => ty.has_generic(), + DataType::Tuple(tys) => tys.iter().any(|ty| ty.has_generic()), + _ => false, + } + } + pub fn is_unsigned_numeric(&self) -> bool { match self { DataType::Number(ty) => ALL_UNSIGNED_INTEGER_TYPES.contains(ty), diff --git a/src/query/expression/src/utils/arithmetics_type.rs b/src/query/expression/src/utils/arithmetics_type.rs index a56710f8637a1..db209721be6e2 100644 --- a/src/query/expression/src/utils/arithmetics_type.rs +++ b/src/query/expression/src/utils/arithmetics_type.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This code is generated by src/query/codegen/src/writes/arithmetics_type_v2.rs. DO NOT EDIT. +// This code is generated by src/query/codegen/src/writes/arithmetics_type.rs. DO NOT EDIT. use crate::types::number::Number; use crate::types::number::F32; diff --git a/src/query/functions/src/scalars/array.rs b/src/query/functions/src/scalars/array.rs index a3fa8091c7987..16c47d6944e6d 100644 --- a/src/query/functions/src/scalars/array.rs +++ b/src/query/functions/src/scalars/array.rs @@ -102,6 +102,9 @@ pub fn register(registry: &mut FunctionRegistry) { ); registry.register_function_factory("array", |_, args_type| { + if args_type.len() == 0 { + return None; + } Some(Arc::new(Function { signature: FunctionSignature { name: "array".to_string(), diff --git a/src/query/functions/tests/it/scalars/mod.rs b/src/query/functions/tests/it/scalars/mod.rs index d08965994bdea..c1beb32e9d5c1 100644 --- a/src/query/functions/tests/it/scalars/mod.rs +++ b/src/query/functions/tests/it/scalars/mod.rs @@ -228,7 +228,7 @@ fn list_all_builtin_functions() { let fn_registry = &BUILTIN_FUNCTIONS; writeln!(file, "Simple functions:").unwrap(); - for func in fn_registry + for (func, _) in fn_registry .funcs .iter() .sorted_by_key(|(name, _)| name.to_string()) From 7e3e642ffae37aa63abd1dc7b2a18729d3c51113 Mon Sep 17 00:00:00 2001 From: Yang Xiufeng Date: Fri, 3 Mar 2023 09:47:48 +0800 Subject: [PATCH 2/2] fix clippy. --- src/query/functions/src/scalars/array.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/query/functions/src/scalars/array.rs b/src/query/functions/src/scalars/array.rs index 16c47d6944e6d..9d3a0ac7171ec 100644 --- a/src/query/functions/src/scalars/array.rs +++ b/src/query/functions/src/scalars/array.rs @@ -102,7 +102,7 @@ pub fn register(registry: &mut FunctionRegistry) { ); registry.register_function_factory("array", |_, args_type| { - if args_type.len() == 0 { + if args_type.is_empty() { return None; } Some(Arc::new(Function {