From da6d69dd4e4441803dfded825ad3c2d3a5c5d850 Mon Sep 17 00:00:00 2001 From: Josh Pschorr Date: Fri, 15 Nov 2024 16:24:27 -0800 Subject: [PATCH] Change modeling of Literals in the AST remove ambiguity --- partiql-ast/src/ast.rs | 31 +++++++++- partiql-ast/src/pretty.rs | 61 +++++++++++++++++++ partiql-logical-planner/src/lower.rs | 69 +++++++--------------- partiql-parser/src/parse/parse_util.rs | 75 +++++++++++++++++++++++- partiql-parser/src/parse/partiql.lalrpop | 62 +++++++++++++++++--- 5 files changed, 238 insertions(+), 60 deletions(-) diff --git a/partiql-ast/src/ast.rs b/partiql-ast/src/ast.rs index b4db2608..25475279 100644 --- a/partiql-ast/src/ast.rs +++ b/partiql-ast/src/ast.rs @@ -454,16 +454,41 @@ pub enum Lit { #[visit(skip)] HexStringLit(String), #[visit(skip)] - StructLit(AstNode), + StructLit(AstNode), #[visit(skip)] - BagLit(AstNode), + BagLit(AstNode), #[visit(skip)] - ListLit(AstNode), + ListLit(AstNode), /// E.g. `TIME WITH TIME ZONE` in `SELECT TIME WITH TIME ZONE '12:00' FROM ...` #[visit(skip)] TypedLit(String, Type), } +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct LitField { + pub first: String, + pub second: AstNode, +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct StructLit { + pub fields: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct BagLit { + pub values: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ListLit { + pub values: Vec, +} + #[derive(Visit, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct VarRef { diff --git a/partiql-ast/src/pretty.rs b/partiql-ast/src/pretty.rs index e4913149..3f1c818f 100644 --- a/partiql-ast/src/pretty.rs +++ b/partiql-ast/src/pretty.rs @@ -699,6 +699,38 @@ impl PrettyDoc for StructExprPair { } } +impl PrettyDoc for StructLit { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + let wrapped = self.fields.iter().map(|p| unsafe { + let x: &'b StructLitField = std::mem::transmute(p); + x + }); + pretty_seq(wrapped, "{", "}", ",", PRETTY_INDENT_MINOR_NEST, arena) + } +} + +pub struct StructLitField(pub LitField); + +impl PrettyDoc for StructLitField { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + let k = self.0.first.pretty_doc(arena); + let v = self.0.second.pretty_doc(arena); + let sep = arena.text(": "); + + k.append(sep).group().append(v).group() + } +} + impl PrettyDoc for Bag { fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> where @@ -728,6 +760,35 @@ impl PrettyDoc for List { } } +impl PrettyDoc for BagLit { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + pretty_seq( + &self.values, + "<<", + ">>", + ",", + PRETTY_INDENT_MINOR_NEST, + arena, + ) + } +} + +impl PrettyDoc for ListLit { + fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + pretty_seq(&self.values, "[", "]", ",", PRETTY_INDENT_MINOR_NEST, arena) + } +} + impl PrettyDoc for Sexp { fn pretty_doc<'b, D, A>(&'b self, _arena: &'b D) -> DocBuilder<'b, D, A> where diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index 23a6a0bf..3f3d1284 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -1893,30 +1893,14 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { } fn lit_to_value(lit: &Lit) -> Result { - fn expect_lit(v: &Expr) -> Result { - match v { - Expr::Lit(l) => lit_to_value(&l.node), - _ => Err(AstTransformError::IllegalState( - "non literal in literal aggregate".to_string(), - )), - } - } - - fn tuple_pair(pair: &ast::ExprPair) -> Option> { - let key = match expect_lit(pair.first.as_ref()) { - Ok(Value::String(s)) => s.as_ref().clone(), - Ok(_) => { - return Some(Err(AstTransformError::IllegalState( - "non string literal in literal struct key".to_string(), - ))) - } - Err(e) => return Some(Err(e)), - }; - - match expect_lit(pair.second.as_ref()) { - Ok(Value::Missing) => None, - Ok(val) => Some(Ok((key, val))), - Err(e) => Some(Err(e)), + fn tuple_pair(field: &ast::LitField) -> Option> { + let key = field.first.clone(); + match &field.second.node { + Lit::Missing => None, + value => match lit_to_value(value) { + Ok(value) => Some(Ok((key, value))), + Err(e) => Some(Err(e)), + }, } } @@ -1947,21 +1931,13 @@ fn lit_to_value(lit: &Lit) -> Result { )) } Lit::BagLit(b) => { - let bag: Result = b - .node - .values - .iter() - .map(|l| expect_lit(l.as_ref())) - .collect(); + let bag: Result = + b.node.values.iter().map(lit_to_value).collect(); Value::from(bag?) } Lit::ListLit(l) => { - let l: Result = l - .node - .values - .iter() - .map(|l| expect_lit(l.as_ref())) - .collect(); + let l: Result = + l.node.values.iter().map(lit_to_value).collect(); Value::from(l?) } Lit::StructLit(s) => { @@ -2006,6 +1982,7 @@ fn parse_embedded_ion_str(contents: &str) -> Result { mod tests { use super::*; use crate::LogicalPlanner; + use assert_matches::assert_matches; use partiql_catalog::catalog::{PartiqlCatalog, TypeEnvEntry}; use partiql_logical::BindingsOp::Project; use partiql_logical::ValueExpr; @@ -2023,13 +2000,13 @@ mod tests { assert!(logical.is_err()); let lowering_errs = logical.expect_err("Expect errs").errors; assert_eq!(lowering_errs.len(), 2); - assert_eq!( + assert_matches!( lowering_errs.first(), - Some(&AstTransformError::UnsupportedFunction("foo".to_string())) + Some(AstTransformError::UnsupportedFunction(fnc)) if fnc == "foo" ); - assert_eq!( + assert_matches!( lowering_errs.get(1), - Some(&AstTransformError::UnsupportedFunction("bar".to_string())) + Some(AstTransformError::UnsupportedFunction(fnc)) if fnc == "bar" ); } @@ -2045,17 +2022,13 @@ mod tests { assert!(logical.is_err()); let lowering_errs = logical.expect_err("Expect errs").errors; assert_eq!(lowering_errs.len(), 2); - assert_eq!( + assert_matches!( lowering_errs.first(), - Some(&AstTransformError::InvalidNumberOfArguments( - "abs".to_string() - )) + Some(AstTransformError::InvalidNumberOfArguments(fnc)) if fnc == "abs" ); - assert_eq!( + assert_matches!( lowering_errs.get(1), - Some(&AstTransformError::InvalidNumberOfArguments( - "mod".to_string() - )) + Some(AstTransformError::InvalidNumberOfArguments(fnc)) if fnc == "mod" ); } diff --git a/partiql-parser/src/parse/parse_util.rs b/partiql-parser/src/parse/parse_util.rs index 2bb1da0f..13137f04 100644 --- a/partiql-parser/src/parse/parse_util.rs +++ b/partiql-parser/src/parse/parse_util.rs @@ -1,9 +1,11 @@ use partiql_ast::ast; use crate::parse::parser_state::ParserState; +use crate::ParseError; use bitflags::bitflags; +use partiql_ast::ast::{AstNode, Expr, Lit, LitField}; use partiql_common::node::NodeIdGenerator; -use partiql_common::syntax::location::ByteOffset; +use partiql_common::syntax::location::{ByteOffset, BytePosition}; bitflags! { /// Set of AST node attributes to use as synthesized attributes. @@ -33,7 +35,7 @@ pub(crate) struct Synth { impl Synth { #[inline] - pub fn new(data: T, attrs: Attrs) -> Self { + fn new(data: T, attrs: Attrs) -> Self { Synth { data, attrs } } @@ -41,6 +43,17 @@ impl Synth { pub fn empty(data: T) -> Self { Self::new(data, Attrs::empty()) } + + #[inline] + pub fn lit(data: T) -> Self { + Self::new(data, Attrs::LIT) + } + + pub fn map_data(self, f: impl FnOnce(T) -> U) -> Synth { + let Self { data, attrs } = self; + let data = f(data); + Synth::new(data, attrs) + } } impl FromIterator> for Synth> { @@ -170,3 +183,61 @@ pub(crate) fn strip_expr(q: ast::AstNode) -> Box { Box::new(ast::Expr::Query(q)) } } + +#[inline] +#[track_caller] +fn illegal_literal<'a, T>() -> Result> { + Err(ParseError::IllegalState("Expected literal".to_string())) +} + +pub(crate) type LitFlattenResult<'a, T> = Result>; +#[inline] +pub(crate) fn struct_to_lit<'a>(strct: ast::Struct) -> LitFlattenResult<'a, ast::StructLit> { + strct + .fields + .into_iter() + .map(exprpair_to_lit) + .collect::>>() + .map(|fields| ast::StructLit { fields }) +} + +#[inline] +pub(crate) fn bag_to_lit<'a>(bag: ast::Bag) -> LitFlattenResult<'a, ast::BagLit> { + bag.values + .into_iter() + .map(|v| expr_to_lit(*v).map(|n| n.node)) + .collect::>>() + .map(|values| ast::BagLit { values }) +} + +#[inline] +pub(crate) fn list_to_lit<'a>(list: ast::List) -> LitFlattenResult<'a, ast::ListLit> { + list.values + .into_iter() + .map(|v| expr_to_lit(*v).map(|n| n.node)) + .collect::>>() + .map(|values| ast::ListLit { values }) +} + +#[inline] +pub(crate) fn exprpair_to_lit<'a>(pair: ast::ExprPair) -> LitFlattenResult<'a, ast::LitField> { + let ast::ExprPair { first, second } = pair; + let (first, second) = (expr_to_litstr(*first)?, expr_to_lit(*second)?); + Ok(ast::LitField { first, second }) +} + +#[inline] +pub(crate) fn expr_to_litstr<'a>(expr: ast::Expr) -> LitFlattenResult<'a, String> { + match expr_to_lit(expr)?.node { + Lit::CharStringLit(s) | Lit::NationalCharStringLit(s) => Ok(s), + _ => illegal_literal(), + } +} + +#[inline] +pub(crate) fn expr_to_lit<'a>(expr: ast::Expr) -> LitFlattenResult<'a, ast::AstNode> { + match expr { + Expr::Lit(lit) => Ok(lit), + _ => illegal_literal(), + } +} diff --git a/partiql-parser/src/parse/partiql.lalrpop b/partiql-parser/src/parse/partiql.lalrpop index 7a9708c0..7bfc0175 100644 --- a/partiql-parser/src/parse/partiql.lalrpop +++ b/partiql-parser/src/parse/partiql.lalrpop @@ -9,7 +9,17 @@ use partiql_ast::ast; use partiql_common::syntax::location::{ByteOffset, BytePosition, Location, ToLocated}; -use crate::parse::parse_util::{strip_expr, strip_query, strip_query_set, CallSite, Attrs, Synth}; +use crate::parse::parse_util::{ + strip_expr, + strip_query, + strip_query_set, + struct_to_lit, + bag_to_lit, + list_to_lit, + CallSite, + Attrs, + Synth +}; use crate::parse::parser_state::ParserState; use partiql_common::node::NodeIdGenerator; @@ -568,8 +578,7 @@ ExprQuery: Box = { ExprQuerySynth: Synth> = { => { - let Synth{data, attrs} = e; - Synth::new(Box::new(data), attrs) + e.map_data(|e| Box::new(e)) } } @@ -865,13 +874,39 @@ ExprPrecedence01: Synth = { ExprTerm: Synth = { => Synth::empty(s), - => Synth::new(ast::Expr::Lit( state.node(lit, lo..hi) ), Attrs::LIT), + => Synth::lit(ast::Expr::Lit( state.node(lit, lo..hi) )), => Synth::empty(v), => { if c.attrs.contains(Attrs::LIT) { match c.data { - ast::Expr::List(l) => Synth::new(ast::Expr::Lit( state.node(ast::Lit::ListLit(l), lo..hi) ), Attrs::LIT), - ast::Expr::Bag(b) => Synth::new(ast::Expr::Lit( state.node(ast::Lit::BagLit(b), lo..hi) ), Attrs::LIT), + ast::Expr::List(l) => { + match list_to_lit(l.node) { + Ok(list_lit) => { + let list_lit = state.node(list_lit, lo..hi); + let lit = state.node(ast::Lit::ListLit(list_lit), lo..hi); + Synth::lit(ast::Expr::Lit( lit )) + }, + Err(e) => { + let err = lpop::ErrorRecovery{error: e.into(), dropped_tokens: Default::default()}; + state.errors.push(err); + Synth::empty(ast::Expr::Error) + } + } + }, + ast::Expr::Bag(b) => { + match bag_to_lit(b.node) { + Ok(bag_lit) => { + let bag_lit = state.node(bag_lit, lo..hi); + let lit = state.node(ast::Lit::BagLit(bag_lit), lo..hi); + Synth::lit(ast::Expr::Lit( lit )) + }, + Err(e) => { + let err = lpop::ErrorRecovery{error: e.into(), dropped_tokens: Default::default()}; + state.errors.push(err); + Synth::empty(ast::Expr::Error) + } + } + }, _ => unreachable!(), } } else { @@ -881,7 +916,20 @@ ExprTerm: Synth = { => { if t.attrs.contains(Attrs::LIT) { match t.data { - ast::Expr::Struct(s) => Synth::new(ast::Expr::Lit( state.node(ast::Lit::StructLit(s), lo..hi) ), Attrs::LIT), + ast::Expr::Struct(s) => { + match struct_to_lit(s.node) { + Ok(struct_lit) => { + let struct_lit = state.node(struct_lit, lo..hi); + let lit = state.node(ast::Lit::StructLit(struct_lit), lo..hi); + Synth::lit(ast::Expr::Lit( lit )) + }, + Err(e) => { + let err = lpop::ErrorRecovery{error: e.into(), dropped_tokens: Default::default()}; + state.errors.push(err); + Synth::empty(ast::Expr::Error) + } + } + }, _ => unreachable!(), } } else {