Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change modeling of Literals in the AST remove ambiguity #517

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions partiql-ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,16 +454,41 @@ pub enum Lit {
#[visit(skip)]
HexStringLit(String),
#[visit(skip)]
StructLit(AstNode<Struct>),
StructLit(AstNode<StructLit>),
#[visit(skip)]
BagLit(AstNode<Bag>),
BagLit(AstNode<BagLit>),
#[visit(skip)]
ListLit(AstNode<List>),
ListLit(AstNode<ListLit>),
/// E.g. `TIME WITH TIME ZONE` in `SELECT TIME WITH TIME ZONE '12:00' FROM ...`
#[visit(skip)]
TypedLit(String, Type),
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct LitField {
pub first: String,
pub second: AstNode<Lit>,
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct StructLit {
pub fields: Vec<LitField>,
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct BagLit {
pub values: Vec<Lit>,
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ListLit {
pub values: Vec<Lit>,
}

#[derive(Visit, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct VarRef {
Expand Down
61 changes: 61 additions & 0 deletions partiql-ast/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,38 @@ impl PrettyDoc for StructExprPair {
}
}

impl PrettyDoc for StructLit {
fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A>
where
D: DocAllocator<'b, A>,
D::Doc: Clone,
A: Clone,
{
let wrapped = self.fields.iter().map(|p| unsafe {
let x: &'b StructLitField = std::mem::transmute(p);
x
});
pretty_seq(wrapped, "{", "}", ",", PRETTY_INDENT_MINOR_NEST, arena)
}
}

pub struct StructLitField(pub LitField);

impl PrettyDoc for StructLitField {
fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A>
where
D: DocAllocator<'b, A>,
D::Doc: Clone,
A: Clone,
{
let k = self.0.first.pretty_doc(arena);
let v = self.0.second.pretty_doc(arena);
let sep = arena.text(": ");

k.append(sep).group().append(v).group()
}
}

impl PrettyDoc for Bag {
fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A>
where
Expand Down Expand Up @@ -728,6 +760,35 @@ impl PrettyDoc for List {
}
}

impl PrettyDoc for BagLit {
fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A>
where
D: DocAllocator<'b, A>,
D::Doc: Clone,
A: Clone,
{
pretty_seq(
&self.values,
"<<",
">>",
",",
PRETTY_INDENT_MINOR_NEST,
arena,
)
}
}

impl PrettyDoc for ListLit {
fn pretty_doc<'b, D, A>(&'b self, arena: &'b D) -> DocBuilder<'b, D, A>
where
D: DocAllocator<'b, A>,
D::Doc: Clone,
A: Clone,
{
pretty_seq(&self.values, "[", "]", ",", PRETTY_INDENT_MINOR_NEST, arena)
}
}

impl PrettyDoc for Sexp {
fn pretty_doc<'b, D, A>(&'b self, _arena: &'b D) -> DocBuilder<'b, D, A>
where
Expand Down
69 changes: 21 additions & 48 deletions partiql-logical-planner/src/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1893,30 +1893,14 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
}

fn lit_to_value(lit: &Lit) -> Result<Value, AstTransformError> {
fn expect_lit(v: &Expr) -> Result<Value, AstTransformError> {
match v {
Expr::Lit(l) => lit_to_value(&l.node),
_ => Err(AstTransformError::IllegalState(
"non literal in literal aggregate".to_string(),
)),
}
}

fn tuple_pair(pair: &ast::ExprPair) -> Option<Result<(String, Value), AstTransformError>> {
let key = match expect_lit(pair.first.as_ref()) {
Ok(Value::String(s)) => s.as_ref().clone(),
Ok(_) => {
return Some(Err(AstTransformError::IllegalState(
"non string literal in literal struct key".to_string(),
)))
}
Err(e) => return Some(Err(e)),
};

match expect_lit(pair.second.as_ref()) {
Ok(Value::Missing) => None,
Ok(val) => Some(Ok((key, val))),
Err(e) => Some(Err(e)),
fn tuple_pair(field: &ast::LitField) -> Option<Result<(String, Value), AstTransformError>> {
let key = field.first.clone();
match &field.second.node {
Lit::Missing => None,
value => match lit_to_value(value) {
Ok(value) => Some(Ok((key, value))),
Err(e) => Some(Err(e)),
},
}
}

Expand Down Expand Up @@ -1947,21 +1931,13 @@ fn lit_to_value(lit: &Lit) -> Result<Value, AstTransformError> {
))
}
Lit::BagLit(b) => {
let bag: Result<partiql_value::Bag, _> = b
.node
.values
.iter()
.map(|l| expect_lit(l.as_ref()))
.collect();
let bag: Result<partiql_value::Bag, _> =
b.node.values.iter().map(lit_to_value).collect();
Value::from(bag?)
}
Lit::ListLit(l) => {
let l: Result<partiql_value::List, _> = l
.node
.values
.iter()
.map(|l| expect_lit(l.as_ref()))
.collect();
let l: Result<partiql_value::List, _> =
l.node.values.iter().map(lit_to_value).collect();
Value::from(l?)
}
Lit::StructLit(s) => {
Expand Down Expand Up @@ -2006,6 +1982,7 @@ fn parse_embedded_ion_str(contents: &str) -> Result<Value, AstTransformError> {
mod tests {
use super::*;
use crate::LogicalPlanner;
use assert_matches::assert_matches;
use partiql_catalog::catalog::{PartiqlCatalog, TypeEnvEntry};
use partiql_logical::BindingsOp::Project;
use partiql_logical::ValueExpr;
Expand All @@ -2023,13 +2000,13 @@ mod tests {
assert!(logical.is_err());
let lowering_errs = logical.expect_err("Expect errs").errors;
assert_eq!(lowering_errs.len(), 2);
assert_eq!(
assert_matches!(
lowering_errs.first(),
Some(&AstTransformError::UnsupportedFunction("foo".to_string()))
Some(AstTransformError::UnsupportedFunction(fnc)) if fnc == "foo"
);
assert_eq!(
assert_matches!(
lowering_errs.get(1),
Some(&AstTransformError::UnsupportedFunction("bar".to_string()))
Some(AstTransformError::UnsupportedFunction(fnc)) if fnc == "bar"
);
}

Expand All @@ -2045,17 +2022,13 @@ mod tests {
assert!(logical.is_err());
let lowering_errs = logical.expect_err("Expect errs").errors;
assert_eq!(lowering_errs.len(), 2);
assert_eq!(
assert_matches!(
lowering_errs.first(),
Some(&AstTransformError::InvalidNumberOfArguments(
"abs".to_string()
))
Some(AstTransformError::InvalidNumberOfArguments(fnc)) if fnc == "abs"
);
assert_eq!(
assert_matches!(
lowering_errs.get(1),
Some(&AstTransformError::InvalidNumberOfArguments(
"mod".to_string()
))
Some(AstTransformError::InvalidNumberOfArguments(fnc)) if fnc == "mod"
);
}

Expand Down
75 changes: 73 additions & 2 deletions partiql-parser/src/parse/parse_util.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use partiql_ast::ast;

use crate::parse::parser_state::ParserState;
use crate::ParseError;
use bitflags::bitflags;
use partiql_ast::ast::{AstNode, Expr, Lit, LitField};

Check warning on line 6 in partiql-parser/src/parse/parse_util.rs

View workflow job for this annotation

GitHub Actions / clippy

unused imports: `AstNode` and `LitField`

warning: unused imports: `AstNode` and `LitField` --> partiql-parser/src/parse/parse_util.rs:6:24 | 6 | use partiql_ast::ast::{AstNode, Expr, Lit, LitField}; | ^^^^^^^ ^^^^^^^^ | = note: `#[warn(unused_imports)]` on by default
jpschorr marked this conversation as resolved.
Show resolved Hide resolved
use partiql_common::node::NodeIdGenerator;
use partiql_common::syntax::location::ByteOffset;
use partiql_common::syntax::location::{ByteOffset, BytePosition};

bitflags! {
/// Set of AST node attributes to use as synthesized attributes.
Expand Down Expand Up @@ -33,14 +35,25 @@

impl<T> Synth<T> {
#[inline]
pub fn new(data: T, attrs: Attrs) -> Self {
fn new(data: T, attrs: Attrs) -> Self {
Synth { data, attrs }
}

#[inline]
pub fn empty(data: T) -> Self {
Self::new(data, Attrs::empty())
}

#[inline]
pub fn lit(data: T) -> Self {
Self::new(data, Attrs::LIT)
}

pub fn map_data<U>(self, f: impl FnOnce(T) -> U) -> Synth<U> {
let Self { data, attrs } = self;
let data = f(data);
Synth::new(data, attrs)
}
}

impl<T> FromIterator<Synth<T>> for Synth<Vec<T>> {
Expand Down Expand Up @@ -170,3 +183,61 @@
Box::new(ast::Expr::Query(q))
}
}

#[inline]
#[track_caller]
fn illegal_literal<'a, T>() -> Result<T, crate::error::ParseError<'a, BytePosition>> {
Err(ParseError::IllegalState("Expected literal".to_string()))
}

pub(crate) type LitFlattenResult<'a, T> = Result<T, ParseError<'a>>;
#[inline]
pub(crate) fn struct_to_lit<'a>(strct: ast::Struct) -> LitFlattenResult<'a, ast::StructLit> {
strct
.fields
.into_iter()
.map(exprpair_to_lit)
.collect::<LitFlattenResult<'_, Vec<_>>>()
.map(|fields| ast::StructLit { fields })
}

#[inline]
pub(crate) fn bag_to_lit<'a>(bag: ast::Bag) -> LitFlattenResult<'a, ast::BagLit> {
bag.values
.into_iter()
.map(|v| expr_to_lit(*v).map(|n| n.node))
.collect::<LitFlattenResult<'_, Vec<_>>>()
.map(|values| ast::BagLit { values })
}

#[inline]
pub(crate) fn list_to_lit<'a>(list: ast::List) -> LitFlattenResult<'a, ast::ListLit> {
list.values
.into_iter()
.map(|v| expr_to_lit(*v).map(|n| n.node))
.collect::<LitFlattenResult<'_, Vec<_>>>()
.map(|values| ast::ListLit { values })
}

#[inline]
pub(crate) fn exprpair_to_lit<'a>(pair: ast::ExprPair) -> LitFlattenResult<'a, ast::LitField> {
let ast::ExprPair { first, second } = pair;
let (first, second) = (expr_to_litstr(*first)?, expr_to_lit(*second)?);
Ok(ast::LitField { first, second })
}

#[inline]
pub(crate) fn expr_to_litstr<'a>(expr: ast::Expr) -> LitFlattenResult<'a, String> {
match expr_to_lit(expr)?.node {
Lit::CharStringLit(s) | Lit::NationalCharStringLit(s) => Ok(s),
_ => illegal_literal(),
}
}

#[inline]
pub(crate) fn expr_to_lit<'a>(expr: ast::Expr) -> LitFlattenResult<'a, ast::AstNode<ast::Lit>> {
match expr {
Expr::Lit(lit) => Ok(lit),
_ => illegal_literal(),
}
}
Loading
Loading