diff --git a/crates/rayexec_execution/src/config/session.rs b/crates/rayexec_execution/src/config/session.rs index 325697497..da2be9443 100644 --- a/crates/rayexec_execution/src/config/session.rs +++ b/crates/rayexec_execution/src/config/session.rs @@ -15,6 +15,7 @@ pub struct SessionConfig { pub partitions: u64, pub batch_size: u64, pub verify_optimized_plan: bool, + pub enable_function_chaining: bool, } impl SessionConfig { @@ -30,6 +31,7 @@ impl SessionConfig { partitions: executor.default_partitions() as u64, batch_size: 4096, verify_optimized_plan: false, + enable_function_chaining: true, } } @@ -103,6 +105,7 @@ static GET_SET_FUNCTIONS: LazyLock> = La insert_setting::(&mut map); insert_setting::(&mut map); insert_setting::(&mut map); + insert_setting::(&mut map); map }); @@ -218,6 +221,23 @@ impl SessionSetting for VerifyOptimizedPlan { } } +pub struct EnableFunctionChaining; + +impl SessionSetting for EnableFunctionChaining { + const NAME: &'static str = "enable_function_chaining"; + const DESCRIPTION: &'static str = "If function chaining syntax is enabled."; + + fn set_from_scalar(scalar: ScalarValue, conf: &mut SessionConfig) -> Result<()> { + let val = scalar.try_as_bool()?; + conf.enable_function_chaining = val; + Ok(()) + } + + fn get_as_scalar(conf: &SessionConfig) -> OwnedScalarValue { + conf.enable_function_chaining.into() + } +} + #[cfg(test)] mod tests { use super::*; @@ -230,6 +250,7 @@ mod tests { partitions: 8, batch_size: 4096, verify_optimized_plan: false, + enable_function_chaining: true, } } diff --git a/crates/rayexec_execution/src/engine/session.rs b/crates/rayexec_execution/src/engine/session.rs index d4d63f6ea..89a012fc1 100644 --- a/crates/rayexec_execution/src/engine/session.rs +++ b/crates/rayexec_execution/src/engine/session.rs @@ -30,7 +30,7 @@ use crate::logical::logical_set::VariableOrAll; use crate::logical::operator::{LogicalOperator, Node}; use crate::logical::planner::plan_statement::StatementPlanner; use crate::logical::resolver::resolve_context::ResolveContext; -use crate::logical::resolver::{ResolveMode, ResolvedStatement, Resolver}; +use crate::logical::resolver::{ResolveConfig, ResolveMode, ResolvedStatement, Resolver}; use crate::optimizer::Optimizer; use crate::runtime::time::Timer; use crate::runtime::{PipelineExecutor, Runtime}; @@ -243,6 +243,9 @@ where &tx, &self.context, self.registry.get_file_handlers(), + ResolveConfig { + enable_function_chaining: self.config.enable_function_chaining, + }, ) .resolve_statement(stmt.statement.clone()) .await?; diff --git a/crates/rayexec_execution/src/logical/binder/bind_query/bind_select_list.rs b/crates/rayexec_execution/src/logical/binder/bind_query/bind_select_list.rs index e417cb345..e3ffe1ded 100644 --- a/crates/rayexec_execution/src/logical/binder/bind_query/bind_select_list.rs +++ b/crates/rayexec_execution/src/logical/binder/bind_query/bind_select_list.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use rayexec_error::Result; +use rayexec_error::{RayexecError, Result}; use rayexec_parser::ast; use super::select_expr_expander::ExpandedSelectExpr; @@ -8,10 +8,11 @@ use super::select_list::SelectList; use crate::expr::column_expr::ColumnExpr; use crate::expr::Expression; use crate::logical::binder::bind_context::{BindContext, BindScopeRef}; -use crate::logical::binder::column_binder::DefaultColumnBinder; +use crate::logical::binder::column_binder::{DefaultColumnBinder, ExpressionColumnBinder}; use crate::logical::binder::expr_binder::{BaseExpressionBinder, RecursionContext}; use crate::logical::binder::table_list::TableRef; use crate::logical::resolver::resolve_context::ResolveContext; +use crate::logical::resolver::ResolvedMeta; #[derive(Debug)] pub struct SelectListBinder<'a> { @@ -77,13 +78,19 @@ impl<'a> SelectListBinder<'a> { // Bind the expressions. let expr_binder = BaseExpressionBinder::new(self.current, self.resolve_context); let mut exprs = Vec::with_capacity(projections.len()); - for proj in projections { + for (idx, proj) in projections.into_iter().enumerate() { match proj { ExpandedSelectExpr::Expr { expr, .. } => { + let mut col_binder = SelectAliasColumnBinder { + current_idx: idx, + alias_map: &alias_map, + previous_exprs: &exprs, + }; + let expr = expr_binder.bind_expression( bind_context, &expr, - &mut DefaultColumnBinder, + &mut col_binder, RecursionContext { allow_windows: true, allow_aggregates: true, @@ -242,3 +249,76 @@ impl<'a> SelectListBinder<'a> { Ok(()) } } + +/// Column binder that allows binding to previously defined user aliases. +/// +/// If an ident isn't found in the alias map, then default column binding is +/// used. +/// +/// Aliases are only checked if normal column binding cannot find a column. +#[derive(Debug, Clone, Copy)] +struct SelectAliasColumnBinder<'a> { + /// Index of the expression we're currently planning in the select list. + /// + /// Used to determine if an alias is valid to use. + current_idx: usize, + /// User provided aliases. + alias_map: &'a HashMap, + /// Previously planned expressions. + previous_exprs: &'a [Expression], +} + +impl ExpressionColumnBinder for SelectAliasColumnBinder<'_> { + fn bind_from_root_literal( + &mut self, + bind_scope: BindScopeRef, + bind_context: &mut BindContext, + literal: &ast::Literal, + ) -> Result> { + DefaultColumnBinder.bind_from_root_literal(bind_scope, bind_context, literal) + } + + fn bind_from_ident( + &mut self, + bind_scope: BindScopeRef, + bind_context: &mut BindContext, + ident: &ast::Ident, + _recur: RecursionContext, + ) -> Result> { + let col = ident.as_normalized_string(); + + match DefaultColumnBinder.bind_column(bind_scope, bind_context, None, &col)? { + Some(expr) => Ok(Some(expr)), + None => { + match self.alias_map.get(&col) { + Some(&col_idx) => { + if col_idx < self.current_idx { + // Valid alias reference, use the existing expression. + let aliased_expr = + self.previous_exprs.get(col_idx).ok_or_else(|| { + RayexecError::new("Missing select expression?") + .with_field("idx", col_idx) + })?; + + Ok(Some(aliased_expr.clone())) + } else { + // Not a valid alias expression. + Err(RayexecError::new(format!("'{col}' can only be referenced after it's been defined in the SELECT list"))) + } + } + None => Ok(None), + } + } + } + } + + fn bind_from_idents( + &mut self, + bind_scope: BindScopeRef, + bind_context: &mut BindContext, + idents: &[ast::Ident], + recur: RecursionContext, + ) -> Result> { + DefaultColumnBinder.bind_from_idents(bind_scope, bind_context, idents, recur) + } +} diff --git a/crates/rayexec_execution/src/logical/binder/column_binder.rs b/crates/rayexec_execution/src/logical/binder/column_binder.rs index a80c38d19..549435272 100644 --- a/crates/rayexec_execution/src/logical/binder/column_binder.rs +++ b/crates/rayexec_execution/src/logical/binder/column_binder.rs @@ -89,7 +89,11 @@ impl ExpressionColumnBinder for DefaultColumnBinder { } impl DefaultColumnBinder { - fn bind_column( + /// Binds a column with the given name and optional table alias. + /// + /// This will handle appending correlated columns to the bind context as + /// necessary. + pub fn bind_column( &self, bind_scope: BindScopeRef, bind_context: &mut BindContext, diff --git a/crates/rayexec_execution/src/logical/resolver/expr_resolver.rs b/crates/rayexec_execution/src/logical/resolver/expr_resolver.rs index 70ae225c5..a1f28779a 100644 --- a/crates/rayexec_execution/src/logical/resolver/expr_resolver.rs +++ b/crates/rayexec_execution/src/logical/resolver/expr_resolver.rs @@ -12,6 +12,7 @@ use crate::database::catalog_entry::CatalogEntryType; use crate::logical::binder::expr_binder::BaseExpressionBinder; use crate::logical::operator::LocationRequirement; +#[derive(Debug)] pub struct ExpressionResolver<'a> { resolver: &'a Resolver<'a>, } @@ -237,7 +238,10 @@ impl<'a> ExpressionResolver<'a> { op, right: Box::new(Box::pin(self.resolve_expression(*right, resolve_context)).await?), }), - ast::Expr::Function(func) => self.resolve_function(func, resolve_context).await, + ast::Expr::Function(func) => { + self.resolve_scalar_or_aggregate_function(func, resolve_context) + .await + } ast::Expr::Subquery(subquery) => self.resolve_subquery(subquery, resolve_context).await, ast::Expr::Exists { subquery, @@ -556,20 +560,82 @@ impl<'a> ExpressionResolver<'a> { } } - async fn resolve_function( + async fn resolve_scalar_or_aggregate_function( &self, - func: Box>, + mut func: Box>, resolve_context: &mut ResolveContext, ) -> Result> { // TODO: Search path (with system being the first to check) - if func.reference.0.len() != 1 { - return Err(RayexecError::new( - "Qualified function names not yet supported", - )); + let (catalog, schema, func_name) = match func.reference.0.len() { + 0 => return Err(RayexecError::new("Missing idents for function reference")), // Shouldn't happen. + 1 => ( + "system".to_string(), + "glare_catalog".to_string(), + func.reference.0[0].as_normalized_string(), + ), + 2 => ( + "system".to_string(), + func.reference.0[0].as_normalized_string(), + func.reference.0[1].as_normalized_string(), + ), + 3 => ( + func.reference.0[0].as_normalized_string(), + func.reference.0[1].as_normalized_string(), + func.reference.0[2].as_normalized_string(), + ), + _ => { + // TODO: This could technically be from chained syntax on a + // fully qualified column. + return Err(RayexecError::new("Too many idents for function reference") + .with_field("idents", func.reference.to_string())); + } + }; + + let context = self.resolver.context; + + // See if we can resolve the catalog & schema. If we can't assume we're + // using chained function syntax. + // + // TODO: Make `get_database` return Option. + // TODO: We should be exhaustive about what's part of the qualified + // function call vs what's part of the column. + let is_qualified = func.reference.0.len() > 1; + if self.resolver.config.enable_function_chaining + && is_qualified + && (!context.database_exists(&catalog) + || context + .get_database(&catalog)? + .catalog + .get_schema(self.resolver.tx, &schema)? + .is_none()) + { + let unqualified_name = func.reference.0.pop().unwrap(); // Length checked above. + let unqualified_ref = ast::ObjectReference(vec![unqualified_name]); + + let mut prefix_ref = std::mem::replace(&mut func.reference, unqualified_ref); + + // Now add the prefix we took from the reference as the first + // argument to the function. + + // TODO: Expr binder should probably take of this for us. + let arg_expr = match prefix_ref.0.len() { + 1 => ast::Expr::Ident(prefix_ref.0.pop().unwrap()), + _ => ast::Expr::CompoundIdent(prefix_ref.0), + }; + + func.args.insert( + 0, + ast::FunctionArg::Unnamed { + arg: ast::FunctionArgExpr::Expr(arg_expr), + }, + ); + + // Now try to resolve with just the unqualified reference. + let resolved = + Box::pin(self.resolve_scalar_or_aggregate_function(func, resolve_context)).await?; + + return Ok(resolved); } - let func_name = &func.reference.0[0].as_normalized_string(); - let catalog = "system"; - let schema = "glare_catalog"; let filter = self .resolve_optional_expression(func.filter.map(|e| *e), resolve_context) @@ -582,16 +648,14 @@ impl<'a> ExpressionResolver<'a> { }; let args = Box::pin(self.resolve_function_args(func.args, resolve_context)).await?; - let schema_ent = self - .resolver - .context - .get_database(catalog)? + let schema_ent = context + .get_database(&catalog)? .catalog - .get_schema(self.resolver.tx, schema)? + .get_schema(self.resolver.tx, &schema)? .ok_or_else(|| RayexecError::new(format!("Missing schema: {schema}")))?; // Check if this is a special function. - if let Some(special) = SpecialBuiltinFunction::try_from_name(func_name) { + if let Some(special) = SpecialBuiltinFunction::try_from_name(&func_name) { let resolve_idx = resolve_context .functions .push_resolved(ResolvedFunction::Special(special), LocationRequirement::Any); @@ -606,7 +670,7 @@ impl<'a> ExpressionResolver<'a> { } // Now check scalars. - if let Some(scalar) = schema_ent.get_scalar_function(self.resolver.tx, func_name)? { + if let Some(scalar) = schema_ent.get_scalar_function(self.resolver.tx, &func_name)? { // TODO: Allow unresolved scalars? // TODO: This also assumes scalars (and aggs) are the same everywhere, which // they probably should be for now. @@ -624,7 +688,7 @@ impl<'a> ExpressionResolver<'a> { } // Now check aggregates. - if let Some(aggregate) = schema_ent.get_aggregate_function(self.resolver.tx, func_name)? { + if let Some(aggregate) = schema_ent.get_aggregate_function(self.resolver.tx, &func_name)? { // TODO: Allow unresolved aggregates? let resolve_idx = resolve_context.functions.push_resolved( ResolvedFunction::Aggregate( @@ -651,7 +715,7 @@ impl<'a> ExpressionResolver<'a> { CatalogEntryType::ScalarFunction, CatalogEntryType::AggregateFunction, ], - func_name, + &func_name, )) } diff --git a/crates/rayexec_execution/src/logical/resolver/mod.rs b/crates/rayexec_execution/src/logical/resolver/mod.rs index 75208f435..df2c71a70 100644 --- a/crates/rayexec_execution/src/logical/resolver/mod.rs +++ b/crates/rayexec_execution/src/logical/resolver/mod.rs @@ -112,6 +112,11 @@ impl ResolveMode { } } +#[derive(Debug)] +pub struct ResolveConfig { + pub enable_function_chaining: bool, +} + /// Resolves references in a raw SQL AST with entries in the catalog. #[derive(Debug)] pub struct Resolver<'a> { @@ -119,6 +124,7 @@ pub struct Resolver<'a> { pub tx: &'a CatalogTx, pub context: &'a DatabaseContext, pub file_handlers: &'a FileHandlers, + pub config: ResolveConfig, } impl<'a> Resolver<'a> { @@ -127,12 +133,14 @@ impl<'a> Resolver<'a> { tx: &'a CatalogTx, context: &'a DatabaseContext, file_handlers: &'a FileHandlers, + config: ResolveConfig, ) -> Self { Resolver { resolve_mode, tx, context, file_handlers, + config, } } diff --git a/crates/rayexec_execution/src/logical/resolver/resolve_hybrid.rs b/crates/rayexec_execution/src/logical/resolver/resolve_hybrid.rs index d4cfaf1a0..e07ec7230 100644 --- a/crates/rayexec_execution/src/logical/resolver/resolve_hybrid.rs +++ b/crates/rayexec_execution/src/logical/resolver/resolve_hybrid.rs @@ -13,7 +13,7 @@ use crate::database::{Database, DatabaseContext}; use crate::datasource::{DataSourceRegistry, FileHandlers}; use crate::functions::table::TableFunctionPlanner; use crate::logical::operator::LocationRequirement; -use crate::logical::resolver::ResolveMode; +use crate::logical::resolver::{ResolveConfig, ResolveMode}; /// Extends a context by attaching additional databases using information /// provided by partially bound objects supplied by the client. @@ -139,7 +139,15 @@ impl<'a> HybridResolver<'a> { // Note we're using bindmode normal here since everything we attempt to // bind in this resolver should succeed. HybridResolver { - resolver: Resolver::new(ResolveMode::Normal, tx, context, EMPTY_FILE_HANDLER_REF), + resolver: Resolver::new( + ResolveMode::Normal, + tx, + context, + EMPTY_FILE_HANDLER_REF, + ResolveConfig { + enable_function_chaining: true, // TODO: We'll need to get this from the client. + }, + ), } } diff --git a/slt/standard/functions/chaining.slt b/slt/standard/functions/chaining.slt new file mode 100644 index 000000000..99b013fc9 --- /dev/null +++ b/slt/standard/functions/chaining.slt @@ -0,0 +1,89 @@ +# Function chaining. + +statement ok +CREATE TEMP TABLE t (a INT, b TEXT); + +statement ok +INSERT INTO t VALUES + (3, 'cat'), + (4, 'dog'), + (NULL, 'mouse'), + (5, NULL); + +query T +SELECT b.upper() FROM t ORDER BY 1; +---- +CAT +DOG +MOUSE +NULL + +query T +SELECT t.b.upper() FROM t ORDER BY 1; +---- +CAT +DOG +MOUSE +NULL + +statement error Cannot resolve scalar function or aggregate function with name 'missing_function' +SELECT t.b.missing_function() FROM t ORDER BY 1; + +statement error Invalid inputs to 'upper' +SELECT t.b.upper(4) FROM t ORDER BY 1; + +query T +SELECT b.repeat(a) FROM t ORDER BY 1; +---- +catcatcat +dogdogdogdog +NULL +NULL + +query TTT +SELECT + t.b.upper() AS my_upper, + my_upper.repeat(t.a) AS my_repeat, + my_repeat.lower() AS my_lower + FROM t + ORDER BY 1; +---- +CAT CATCATCAT catcatcat +DOG DOGDOGDOGDOG dogdogdogdog +MOUSE NULL NULL +NULL NULL NULL + +# Should work on aggregates too. +query I +SELECT a.sum() FROM t; +---- +12 + +# TODO: Literals +# query T +# SELECT 'hello'.repeat(3); +# ---- + +query T +SELECT 'hello' AS s, s.repeat(3); +---- +hello hellohellohello + +# Ensure we can disable the behavior. +statement ok +SET enable_function_chaining TO false; + +statement error Missing schema +SELECT b.upper() FROM t ORDER BY 1; + +statement error Missing catalog +SELECT t.b.upper() FROM t ORDER BY 1; + +query T +SELECT system.glare_catalog.upper(t.b) FROM t ORDER BY 1; +---- +CAT +DOG +MOUSE +NULL + diff --git a/slt/standard/functions/qualified.slt b/slt/standard/functions/qualified.slt new file mode 100644 index 000000000..0f906a0f8 --- /dev/null +++ b/slt/standard/functions/qualified.slt @@ -0,0 +1,11 @@ +# Ensure we can call functions partially and fully qualified. + +query I +SELECT glare_catalog.abs(-48); +---- +48 + +query I +SELECT system.glare_catalog.abs(-48); +---- +48 diff --git a/slt/standard/select/reference_alias_in_select.slt b/slt/standard/select/reference_alias_in_select.slt new file mode 100644 index 000000000..09aba6aba --- /dev/null +++ b/slt/standard/select/reference_alias_in_select.slt @@ -0,0 +1,38 @@ +# Allow referencing previously defined aliases in later select items. + +query II +SELECT 1 AS a, a + 2; +---- +1 3 + +statement error 'a' can only be referenced after it's been defined in the SELECT list +SELECT a + 2, 1 AS a; + +# TODO: Unsure if this is even wanted. +# query II +# SELECT 3 AS a, (SELECT a + 5); +# ---- + +query III +SELECT 1 AS a, 2 AS a, a + 3; +---- +1 2 5 + +# Prefer unaliased columns. +query TI rowsort +SELECT 'select' AS a, a FROM (VALUES (1), (2)) v(a); +---- +select 1 +select 2 + +# Prefer unaliased columns. +query TI rowsort +SELECT a, 'select' AS a FROM (VALUES (1), (2)) v(a); +---- +1 select +2 select + +query TT +SELECT 'select' AS a, upper(a); +---- +select SELECT