Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(expr): respect overload order for function factories #10313

Merged
merged 5 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/query/codegen/src/writes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod arithmetics_type_v2;
mod arithmetics_type;
mod register;

pub use arithmetics_type_v2::*;
pub use arithmetics_type::*;
pub use register::*;
22 changes: 12 additions & 10 deletions src/query/codegen/src/writes/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,19 +342,21 @@ pub fn codegen_register() {
F: Fn({arg_f_closure_sig}) -> FunctionDomain<O> + 'static + Clone + Copy + Send + Sync,
G: for <'a> Fn({arg_g_closure_sig} &mut EvalContext) -> Value<O> + 'static + Clone + Copy + Send + Sync,
{{
let func = Arc::new(Function {{
signature: FunctionSignature {{
name: name.to_string(),
args_type: vec![{arg_sig_type}],
return_type: O::data_type(),
property,
}},
calc_domain: Box::new(erase_calc_domain_generic_{n_args}_arg::<{arg_generics} O>(calc_domain)),
eval: Box::new(erase_function_generic_{n_args}_arg(func)),
}});
let id = self.next_function_id(name);
self.funcs
.entry(name.to_string())
.or_insert_with(Vec::new)
.push(Arc::new(Function {{
signature: FunctionSignature {{
name: name.to_string(),
args_type: vec![{arg_sig_type}],
return_type: O::data_type(),
property,
}},
calc_domain: Box::new(erase_calc_domain_generic_{n_args}_arg::<{arg_generics} O>(calc_domain)),
eval: Box::new(erase_function_generic_{n_args}_arg(func)),
}}));
.push((func, id));
}}
"
)
Expand Down
6 changes: 3 additions & 3 deletions src/query/expression/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,13 +676,13 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
let mut old_expr = expr.clone();
let mut old_domain = None;
for _ in 0..MAX_ITERATIONS {
let (new_expr, domain) = self.fold_once(&old_expr);
let (new_expr, new_domain) = self.fold_once(&old_expr);

if new_expr == old_expr {
return (new_expr, domain);
return (new_expr, new_domain);
}
old_expr = new_expr;
old_domain = domain;
old_domain = new_domain;
}

error!("maximum iterations reached while folding expression");
Expand Down
111 changes: 78 additions & 33 deletions src/query/expression/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use serde::Serialize;
use crate::date_helper::TzLUT;
use crate::property::Domain;
use crate::property::FunctionProperty;
use crate::type_check::can_auto_cast_to;
use crate::type_check::try_unify_signature;
use crate::types::nullable::NullableColumn;
use crate::types::*;
use crate::utils::arrow::constant_bitmap;
Expand Down Expand Up @@ -166,17 +166,17 @@ impl Function {
}
}

/// A function to build function depending on the const parameters and the type of arguments (before coercion).
///
/// The first argument is the const parameters and the second argument is the types of arguments.
pub type FunctionFactory =
Box<dyn Fn(&[usize], &[DataType]) -> Option<Arc<Function>> + Send + Sync + 'static>;

#[derive(Default)]
pub struct FunctionRegistry {
pub funcs: HashMap<String, Vec<Arc<Function>>>,
/// A function to build function depending on the const parameters and the type of arguments (before coercion).
///
/// The first argument is the const parameters and the second argument is the number of arguments.
#[allow(clippy::type_complexity)]
pub factories: HashMap<
String,
Vec<Box<dyn Fn(&[usize], &[DataType]) -> Option<Arc<Function>> + Send + Sync + 'static>>,
>,
pub funcs: HashMap<String, Vec<(Arc<Function>, usize)>>,
pub factories: HashMap<String, Vec<(FunctionFactory, usize)>>,

/// Aliases map from alias function name to original function name.
pub aliases: HashMap<String, String>,

Expand Down Expand Up @@ -211,14 +211,24 @@ impl FunctionRegistry {

pub fn get(&self, id: &FunctionID) -> Option<Arc<Function>> {
match id {
FunctionID::Builtin { name, id } => self.funcs.get(name.as_str())?.get(*id).cloned(),
FunctionID::Builtin { name, id } => self
.funcs
.get(name.as_str())?
.iter()
.find(|(_, func_id)| func_id == id)
.map(|(func, _)| func.clone()),
FunctionID::Factory {
name,
id,
params,
args_type,
} => {
let factory = self.factories.get(name.as_str())?.get(*id)?;
let factory = self
.factories
.get(name.as_str())?
.iter()
.find(|(_, func_id)| func_id == id)
.map(|(func, _)| func)?;
factory(params, args_type)
}
}
Expand All @@ -235,12 +245,12 @@ impl FunctionRegistry {
let mut candidates = Vec::new();

if let Some(funcs) = self.funcs.get(&name) {
candidates.extend(funcs.iter().enumerate().filter_map(|(id, func)| {
candidates.extend(funcs.iter().filter_map(|(func, id)| {
if func.signature.name == name && func.signature.args_type.len() == args.len() {
Some((
FunctionID::Builtin {
name: name.to_string(),
id,
id: *id,
},
func.clone(),
))
Expand All @@ -256,12 +266,12 @@ impl FunctionRegistry {
.map(Expr::data_type)
.cloned()
.collect::<Vec<_>>();
candidates.extend(factories.iter().enumerate().filter_map(|(id, factory)| {
candidates.extend(factories.iter().filter_map(|(factory, id)| {
factory(params, &args_type).map(|func| {
(
FunctionID::Factory {
name: name.to_string(),
id,
id: *id,
params: params.to_vec(),
args_type: args_type.clone(),
},
Expand All @@ -271,6 +281,8 @@ impl FunctionRegistry {
}));
}

candidates.sort_by_key(|(id, _)| id.id());

candidates
}

Expand All @@ -291,10 +303,11 @@ impl FunctionRegistry {
name: &str,
factory: impl Fn(&[usize], &[DataType]) -> Option<Arc<Function>> + 'static + Send + Sync,
) {
let id = self.next_function_id(name);
self.factories
.entry(name.to_string())
.or_insert_with(Vec::new)
.push(Box::new(factory));
.push((Box::new(factory), id));
}

pub fn register_aliases(&mut self, fn_name: &str, aliases: &[&str]) {
Expand Down Expand Up @@ -329,33 +342,65 @@ impl FunctionRegistry {
self.auto_try_cast_rules.extend(auto_try_cast_rules);
}

pub fn next_function_id(&self, name: &str) -> usize {
self.funcs.get(name).map(|funcs| funcs.len()).unwrap_or(0)
+ self.factories.get(name).map(|f| f.len()).unwrap_or(0)
}

pub fn check_ambiguity(&self) {
for (name, funcs) in &self.funcs {
let auto_cast_rules = self.get_auto_cast_rules(name);
for (i, former) in funcs.iter().enumerate() {
for latter in funcs.iter().skip(i + 1) {
if former.signature.args_type.len() == latter.signature.args_type.len()
&& former
.signature
.args_type
for (former, former_id) in funcs {
for latter in funcs
.iter()
.filter(|(_, id)| id > former_id)
.map(|(func, _)| func.clone())
.chain(
self.factories
.get(name)
.map(Vec::as_slice)
.unwrap_or(&[])
.iter()
.zip(latter.signature.args_type.iter())
.all(|(former_arg, latter_arg)| {
can_auto_cast_to(latter_arg, former_arg, auto_cast_rules)
})
{
panic!(
"Ambiguous signatures for function:\n- {}\n- {}\n\
Suggestion: swap the order of the overloads.",
former.signature, latter.signature
);
.filter(|(_, id)| id > former_id)
.filter_map(|(factory, _)| factory(&[], &former.signature.args_type)),
)
{
if former.signature.args_type.len() == latter.signature.args_type.len() {
if let Ok(subst) = try_unify_signature(
latter.signature.args_type.iter(),
former.signature.args_type.iter(),
auto_cast_rules,
) {
if subst.apply(&former.signature.return_type).is_ok()
&& former
.signature
.args_type
.iter()
.all(|sig_ty| subst.apply(sig_ty).is_ok())
{
panic!(
"Ambiguous signatures for function:\n- {}\n- {}\n\
Suggestion: swap the order of the overloads.",
former.signature, latter.signature
);
}
}
}
}
}
}
}
}

impl FunctionID {
pub fn id(&self) -> usize {
match self {
FunctionID::Builtin { id, .. } => *id,
FunctionID::Factory { id, .. } => *id,
}
}
}

pub fn wrap_nullable<F>(f: F) -> impl Fn(&[ValueRef<AnyType>], &mut EvalContext) -> Value<AnyType>
where F: Fn(&[ValueRef<AnyType>], &mut EvalContext) -> Value<AnyType> {
move |args, ctx| {
Expand Down
Loading