Skip to content

Commit

Permalink
feat: Function dot syntax (#3369)
Browse files Browse the repository at this point in the history
```
>> CREATE TEMP TABLE t (a INT, b TEXT);
┌─────────────────────┐
│ Query success       │
│ No columns returned │
└─────────────────────┘

>> INSERT INTO t VALUES (3, 'cat'), (4, 'dog'), (NULL, 'mouse'), (5, NULL);
┌───────────────┐
│ rows_inserted │
│ UInt64        │
├───────────────┤
│             4 │
└───────────────┘

>> SELECT b.upper() FROM t;
┌───────┐
│ upper │
│ Utf8  │
├───────┤
│ CAT   │
│ NULL  │
│ MOUSE │
│ DOG   │
└───────┘

>> SELECT t.b.upper() AS my_upper, my_upper.repeat(t.a) AS my_repeat, my_repeat.lower() AS my_lower FROM t;
┌──────────┬───────────┬──────────┐
│ my_upper │ my_repeat │ my_lower │
│ Utf8     │ Utf8      │ Utf8     │
├──────────┼───────────┼──────────┤
│ MOUSE    │ NULL      │ NULL     │
│ NULL     │ NULL      │ NULL     │
│ CAT      │ CATCATCAT │ catcatc… │
│ DOG      │ DOGDOGDO… │ dogdogd… │
└──────────┴───────────┴──────────┘
```

Actual chaining later.
  • Loading branch information
scsmithr authored Dec 18, 2024
1 parent 24dba98 commit 7eff2d8
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 27 deletions.
21 changes: 21 additions & 0 deletions crates/rayexec_execution/src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct SessionConfig {
pub partitions: u64,
pub batch_size: u64,
pub verify_optimized_plan: bool,
pub enable_function_chaining: bool,
}

impl SessionConfig {
Expand All @@ -30,6 +31,7 @@ impl SessionConfig {
partitions: executor.default_partitions() as u64,
batch_size: 4096,
verify_optimized_plan: false,
enable_function_chaining: true,
}
}

Expand Down Expand Up @@ -103,6 +105,7 @@ static GET_SET_FUNCTIONS: LazyLock<HashMap<&'static str, SettingFunctions>> = La
insert_setting::<AllowNestedLoopJoin>(&mut map);
insert_setting::<Partitions>(&mut map);
insert_setting::<BatchSize>(&mut map);
insert_setting::<EnableFunctionChaining>(&mut map);

map
});
Expand Down Expand Up @@ -218,6 +221,23 @@ impl SessionSetting for VerifyOptimizedPlan {
}
}

pub struct EnableFunctionChaining;

impl SessionSetting for EnableFunctionChaining {
const NAME: &'static str = "enable_function_chaining";
const DESCRIPTION: &'static str = "If function chaining syntax is enabled.";

fn set_from_scalar(scalar: ScalarValue, conf: &mut SessionConfig) -> Result<()> {
let val = scalar.try_as_bool()?;
conf.enable_function_chaining = val;
Ok(())
}

fn get_as_scalar(conf: &SessionConfig) -> OwnedScalarValue {
conf.enable_function_chaining.into()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -230,6 +250,7 @@ mod tests {
partitions: 8,
batch_size: 4096,
verify_optimized_plan: false,
enable_function_chaining: true,
}
}

Expand Down
5 changes: 4 additions & 1 deletion crates/rayexec_execution/src/engine/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::logical::logical_set::VariableOrAll;
use crate::logical::operator::{LogicalOperator, Node};
use crate::logical::planner::plan_statement::StatementPlanner;
use crate::logical::resolver::resolve_context::ResolveContext;
use crate::logical::resolver::{ResolveMode, ResolvedStatement, Resolver};
use crate::logical::resolver::{ResolveConfig, ResolveMode, ResolvedStatement, Resolver};
use crate::optimizer::Optimizer;
use crate::runtime::time::Timer;
use crate::runtime::{PipelineExecutor, Runtime};
Expand Down Expand Up @@ -243,6 +243,9 @@ where
&tx,
&self.context,
self.registry.get_file_handlers(),
ResolveConfig {
enable_function_chaining: self.config.enable_function_chaining,
},
)
.resolve_statement(stmt.statement.clone())
.await?;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use std::collections::HashMap;

use rayexec_error::Result;
use rayexec_error::{RayexecError, Result};
use rayexec_parser::ast;

use super::select_expr_expander::ExpandedSelectExpr;
use super::select_list::SelectList;
use crate::expr::column_expr::ColumnExpr;
use crate::expr::Expression;
use crate::logical::binder::bind_context::{BindContext, BindScopeRef};
use crate::logical::binder::column_binder::DefaultColumnBinder;
use crate::logical::binder::column_binder::{DefaultColumnBinder, ExpressionColumnBinder};
use crate::logical::binder::expr_binder::{BaseExpressionBinder, RecursionContext};
use crate::logical::binder::table_list::TableRef;
use crate::logical::resolver::resolve_context::ResolveContext;
use crate::logical::resolver::ResolvedMeta;

#[derive(Debug)]
pub struct SelectListBinder<'a> {
Expand Down Expand Up @@ -77,13 +78,19 @@ impl<'a> SelectListBinder<'a> {
// Bind the expressions.
let expr_binder = BaseExpressionBinder::new(self.current, self.resolve_context);
let mut exprs = Vec::with_capacity(projections.len());
for proj in projections {
for (idx, proj) in projections.into_iter().enumerate() {
match proj {
ExpandedSelectExpr::Expr { expr, .. } => {
let mut col_binder = SelectAliasColumnBinder {
current_idx: idx,
alias_map: &alias_map,
previous_exprs: &exprs,
};

let expr = expr_binder.bind_expression(
bind_context,
&expr,
&mut DefaultColumnBinder,
&mut col_binder,
RecursionContext {
allow_windows: true,
allow_aggregates: true,
Expand Down Expand Up @@ -242,3 +249,76 @@ impl<'a> SelectListBinder<'a> {
Ok(())
}
}

/// Column binder that allows binding to previously defined user aliases.
///
/// If an ident isn't found in the alias map, then default column binding is
/// used.
///
/// Aliases are only checked if normal column binding cannot find a column.
#[derive(Debug, Clone, Copy)]
struct SelectAliasColumnBinder<'a> {
/// Index of the expression we're currently planning in the select list.
///
/// Used to determine if an alias is valid to use.
current_idx: usize,
/// User provided aliases.
alias_map: &'a HashMap<String, usize>,
/// Previously planned expressions.
previous_exprs: &'a [Expression],
}

impl ExpressionColumnBinder for SelectAliasColumnBinder<'_> {
fn bind_from_root_literal(
&mut self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
literal: &ast::Literal<ResolvedMeta>,
) -> Result<Option<Expression>> {
DefaultColumnBinder.bind_from_root_literal(bind_scope, bind_context, literal)
}

fn bind_from_ident(
&mut self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
ident: &ast::Ident,
_recur: RecursionContext,
) -> Result<Option<Expression>> {
let col = ident.as_normalized_string();

match DefaultColumnBinder.bind_column(bind_scope, bind_context, None, &col)? {
Some(expr) => Ok(Some(expr)),
None => {
match self.alias_map.get(&col) {
Some(&col_idx) => {
if col_idx < self.current_idx {
// Valid alias reference, use the existing expression.
let aliased_expr =
self.previous_exprs.get(col_idx).ok_or_else(|| {
RayexecError::new("Missing select expression?")
.with_field("idx", col_idx)
})?;

Ok(Some(aliased_expr.clone()))
} else {
// Not a valid alias expression.
Err(RayexecError::new(format!("'{col}' can only be referenced after it's been defined in the SELECT list")))
}
}
None => Ok(None),
}
}
}
}

fn bind_from_idents(
&mut self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
idents: &[ast::Ident],
recur: RecursionContext,
) -> Result<Option<Expression>> {
DefaultColumnBinder.bind_from_idents(bind_scope, bind_context, idents, recur)
}
}
6 changes: 5 additions & 1 deletion crates/rayexec_execution/src/logical/binder/column_binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ impl ExpressionColumnBinder for DefaultColumnBinder {
}

impl DefaultColumnBinder {
fn bind_column(
/// Binds a column with the given name and optional table alias.
///
/// This will handle appending correlated columns to the bind context as
/// necessary.
pub fn bind_column(
&self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
Expand Down
102 changes: 83 additions & 19 deletions crates/rayexec_execution/src/logical/resolver/expr_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::database::catalog_entry::CatalogEntryType;
use crate::logical::binder::expr_binder::BaseExpressionBinder;
use crate::logical::operator::LocationRequirement;

#[derive(Debug)]
pub struct ExpressionResolver<'a> {
resolver: &'a Resolver<'a>,
}
Expand Down Expand Up @@ -237,7 +238,10 @@ impl<'a> ExpressionResolver<'a> {
op,
right: Box::new(Box::pin(self.resolve_expression(*right, resolve_context)).await?),
}),
ast::Expr::Function(func) => self.resolve_function(func, resolve_context).await,
ast::Expr::Function(func) => {
self.resolve_scalar_or_aggregate_function(func, resolve_context)
.await
}
ast::Expr::Subquery(subquery) => self.resolve_subquery(subquery, resolve_context).await,
ast::Expr::Exists {
subquery,
Expand Down Expand Up @@ -556,20 +560,82 @@ impl<'a> ExpressionResolver<'a> {
}
}

async fn resolve_function(
async fn resolve_scalar_or_aggregate_function(
&self,
func: Box<ast::Function<Raw>>,
mut func: Box<ast::Function<Raw>>,
resolve_context: &mut ResolveContext,
) -> Result<ast::Expr<ResolvedMeta>> {
// TODO: Search path (with system being the first to check)
if func.reference.0.len() != 1 {
return Err(RayexecError::new(
"Qualified function names not yet supported",
));
let (catalog, schema, func_name) = match func.reference.0.len() {
0 => return Err(RayexecError::new("Missing idents for function reference")), // Shouldn't happen.
1 => (
"system".to_string(),
"glare_catalog".to_string(),
func.reference.0[0].as_normalized_string(),
),
2 => (
"system".to_string(),
func.reference.0[0].as_normalized_string(),
func.reference.0[1].as_normalized_string(),
),
3 => (
func.reference.0[0].as_normalized_string(),
func.reference.0[1].as_normalized_string(),
func.reference.0[2].as_normalized_string(),
),
_ => {
// TODO: This could technically be from chained syntax on a
// fully qualified column.
return Err(RayexecError::new("Too many idents for function reference")
.with_field("idents", func.reference.to_string()));
}
};

let context = self.resolver.context;

// See if we can resolve the catalog & schema. If we can't assume we're
// using chained function syntax.
//
// TODO: Make `get_database` return Option.
// TODO: We should be exhaustive about what's part of the qualified
// function call vs what's part of the column.
let is_qualified = func.reference.0.len() > 1;
if self.resolver.config.enable_function_chaining
&& is_qualified
&& (!context.database_exists(&catalog)
|| context
.get_database(&catalog)?
.catalog
.get_schema(self.resolver.tx, &schema)?
.is_none())
{
let unqualified_name = func.reference.0.pop().unwrap(); // Length checked above.
let unqualified_ref = ast::ObjectReference(vec![unqualified_name]);

let mut prefix_ref = std::mem::replace(&mut func.reference, unqualified_ref);

// Now add the prefix we took from the reference as the first
// argument to the function.

// TODO: Expr binder should probably take of this for us.
let arg_expr = match prefix_ref.0.len() {
1 => ast::Expr::Ident(prefix_ref.0.pop().unwrap()),
_ => ast::Expr::CompoundIdent(prefix_ref.0),
};

func.args.insert(
0,
ast::FunctionArg::Unnamed {
arg: ast::FunctionArgExpr::Expr(arg_expr),
},
);

// Now try to resolve with just the unqualified reference.
let resolved =
Box::pin(self.resolve_scalar_or_aggregate_function(func, resolve_context)).await?;

return Ok(resolved);
}
let func_name = &func.reference.0[0].as_normalized_string();
let catalog = "system";
let schema = "glare_catalog";

let filter = self
.resolve_optional_expression(func.filter.map(|e| *e), resolve_context)
Expand All @@ -582,16 +648,14 @@ impl<'a> ExpressionResolver<'a> {
};
let args = Box::pin(self.resolve_function_args(func.args, resolve_context)).await?;

let schema_ent = self
.resolver
.context
.get_database(catalog)?
let schema_ent = context
.get_database(&catalog)?
.catalog
.get_schema(self.resolver.tx, schema)?
.get_schema(self.resolver.tx, &schema)?
.ok_or_else(|| RayexecError::new(format!("Missing schema: {schema}")))?;

// Check if this is a special function.
if let Some(special) = SpecialBuiltinFunction::try_from_name(func_name) {
if let Some(special) = SpecialBuiltinFunction::try_from_name(&func_name) {
let resolve_idx = resolve_context
.functions
.push_resolved(ResolvedFunction::Special(special), LocationRequirement::Any);
Expand All @@ -606,7 +670,7 @@ impl<'a> ExpressionResolver<'a> {
}

// Now check scalars.
if let Some(scalar) = schema_ent.get_scalar_function(self.resolver.tx, func_name)? {
if let Some(scalar) = schema_ent.get_scalar_function(self.resolver.tx, &func_name)? {
// TODO: Allow unresolved scalars?
// TODO: This also assumes scalars (and aggs) are the same everywhere, which
// they probably should be for now.
Expand All @@ -624,7 +688,7 @@ impl<'a> ExpressionResolver<'a> {
}

// Now check aggregates.
if let Some(aggregate) = schema_ent.get_aggregate_function(self.resolver.tx, func_name)? {
if let Some(aggregate) = schema_ent.get_aggregate_function(self.resolver.tx, &func_name)? {
// TODO: Allow unresolved aggregates?
let resolve_idx = resolve_context.functions.push_resolved(
ResolvedFunction::Aggregate(
Expand All @@ -651,7 +715,7 @@ impl<'a> ExpressionResolver<'a> {
CatalogEntryType::ScalarFunction,
CatalogEntryType::AggregateFunction,
],
func_name,
&func_name,
))
}

Expand Down
Loading

0 comments on commit 7eff2d8

Please sign in to comment.