Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add list folding to parser #127

Merged
merged 4 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion air-script-core/src/expression.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::{Identifier, IndexedTraceAccess, MatrixAccess, NamedTraceAccess, VectorAccess};
use super::{
Identifier, IndexedTraceAccess, ListComprehension, MatrixAccess, NamedTraceAccess, VectorAccess,
};

/// Arithmetic expressions for evaluation of constraints.
#[derive(Debug, Eq, PartialEq, Clone)]
Expand All @@ -21,4 +23,11 @@ pub enum Expression {
Sub(Box<Expression>, Box<Expression>),
Mul(Box<Expression>, Box<Expression>),
Exp(Box<Expression>, Box<Expression>),
ListFolding(ListFoldingType),
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ListFoldingType {
Sum(ListComprehension),
Prod(ListComprehension),
}
2 changes: 1 addition & 1 deletion air-script-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod constant;
pub use constant::{Constant, ConstantType};

mod expression;
pub use expression::Expression;
pub use expression::{Expression, ListFoldingType};

mod identifier;
pub use identifier::Identifier;
Expand Down
4 changes: 2 additions & 2 deletions air-script-core/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ pub enum VariableType {

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ListComprehension {
expression: Expression,
expression: Box<Expression>,
context: Vec<(Identifier, Iterable)>,
}

impl ListComprehension {
/// Creates a new list comprehension.
pub fn new(expression: Expression, context: Vec<(Identifier, Iterable)>) -> Self {
Self {
expression,
expression: Box::new(expression),
context,
}
}
Expand Down
1 change: 1 addition & 0 deletions ir/src/constraints/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ impl AlgebraicGraph {
lhs.domain(),
))
}
Expression::ListFolding(_) => todo!(),
}
}

Expand Down
3 changes: 2 additions & 1 deletion parser/src/ast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub(crate) use air_script_core::{
Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, Iterable,
ListComprehension, MatrixAccess, NamedTraceAccess, Range, Variable, VariableType, VectorAccess,
ListComprehension, ListFoldingType, MatrixAccess, NamedTraceAccess, Range, Variable,
VariableType, VectorAccess,
};

pub mod pub_inputs;
Expand Down
8 changes: 8 additions & 0 deletions parser/src/lexer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ pub enum Token {
#[token("in")]
In,

/// Used to declare sum list folding operation in the AIR constraints module.
#[token("sum")]
Sum,

/// Used to declare prod list folding operation in the AIR constraints module.
#[token("prod")]
Prod,

// GENERAL KEYWORDS
// --------------------------------------------------------------------------------------------
/// Keyword to signify that a constraint needs to be enforced
Expand Down
76 changes: 76 additions & 0 deletions parser/src/lexer/tests/list_comprehension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,79 @@ fn multiple_iterables_comprehension() {
];
expect_valid_tokenization(source, tokens);
}

#[test]
fn one_iterable_folding() {
let source = "let y = sum([x for x in x])";
let tokens = vec![
Token::Let,
Token::Ident("y".to_string()),
Token::Equal,
Token::Sum,
Token::Lparen,
Token::Lsqb,
Token::Ident("x".to_string()),
Token::For,
Token::Ident("x".to_string()),
Token::In,
Token::Ident("x".to_string()),
Token::Rsqb,
Token::Rparen,
];
expect_valid_tokenization(source, tokens);
}

#[test]
fn multiple_iterables_list_folding() {
let source = "let a = sum([w + x - y - z for (w, x, y, z) in (0..3, x, y[0..3], z[0..3])])";
let tokens = vec![
Token::Let,
Token::Ident("a".to_string()),
Token::Equal,
Token::Sum,
Token::Lparen,
Token::Lsqb,
Token::Ident("w".to_string()),
Token::Plus,
Token::Ident("x".to_string()),
Token::Minus,
Token::Ident("y".to_string()),
Token::Minus,
Token::Ident("z".to_string()),
Token::For,
Token::Lparen,
Token::Ident("w".to_string()),
Token::Comma,
Token::Ident("x".to_string()),
Token::Comma,
Token::Ident("y".to_string()),
Token::Comma,
Token::Ident("z".to_string()),
Token::Rparen,
Token::In,
Token::Lparen,
Token::Num("0".to_string()),
Token::Range,
Token::Num("3".to_string()),
Token::Comma,
Token::Ident("x".to_string()),
Token::Comma,
Token::Ident("y".to_string()),
Token::Lsqb,
Token::Num("0".to_string()),
Token::Range,
Token::Num("3".to_string()),
Token::Rsqb,
Token::Comma,
Token::Ident("z".to_string()),
Token::Lsqb,
Token::Num("0".to_string()),
Token::Range,
Token::Num("3".to_string()),
Token::Rsqb,
Token::Rparen,
Token::Rsqb,
Token::Rparen,
];
expect_valid_tokenization(source, tokens);
}
18 changes: 15 additions & 3 deletions parser/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use crate::{
ast::{
boundary_constraints::{Boundary, BoundaryConstraint, BoundaryStmt},
integrity_constraints::{IntegrityConstraint, IntegrityStmt},
Constant, ConstantType, Expression, Identifier, Variable, VariableType, ListComprehension,
Iterable, Source, SourceSection, Trace, TraceCols,
PublicInput, PeriodicColumn, IndexedTraceAccess, NamedTraceAccess, Range, MatrixAccess,
Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, Iterable,
ListComprehension, ListFoldingType, MatrixAccess, NamedTraceAccess, PeriodicColumn,
PublicInput, Range, Source, SourceSection, Trace, TraceCols, Variable, VariableType,
VectorAccess
}, error::{Error, ParseError::{InvalidInt, InvalidTraceCols, MissingMainTraceCols,
InvalidConst, InvalidListComprehension, MissingBoundaryConstraint,
Expand Down Expand Up @@ -195,6 +195,7 @@ BoundaryAtom: Expression = {
<ident: Identifier> => Expression::Elem(ident),
<vector_access: VectorAccess> => Expression::VectorAccess(vector_access),
<matrix_access: MatrixAccess> => Expression::MatrixAccess(matrix_access),
<list_folding_type: ListFoldingType<BoundaryExpr>> => Expression::ListFolding(list_folding_type),
}

// INTEGRITY CONSTRAINTS
Expand Down Expand Up @@ -269,6 +270,8 @@ IntegrityAtom: Expression = {
<vector_access: VectorAccess> => Expression::VectorAccess(vector_access),
<matrix_access: MatrixAccess> => Expression::MatrixAccess(matrix_access),
<trace_access: NamedTraceAccessWithOffset> => Expression::NamedTraceAccess(trace_access),
<list_folding_type: ListFoldingType<IntegrityExpr>> =>
Expression::ListFolding(list_folding_type),
}

// ATOMS
Expand Down Expand Up @@ -333,6 +336,13 @@ ListComprehension<T>: ListComprehension = {
}
}

ListFoldingType<T>: ListFoldingType = {
"sum" "(" "[" <list_comprehension: ListComprehension<T>> "]" ")" =>
ListFoldingType::Sum(list_comprehension),
"prod" "(" "[" <list_comprehension: ListComprehension<T>> "]" ")" =>
ListFoldingType::Prod(list_comprehension),
}

Members: Vec<Identifier> = {
<member: Identifier> => vec![member],
"(" <members: CommaElems<Identifier>> ")" => members
Expand Down Expand Up @@ -378,6 +388,8 @@ extern {
"let" => Token::Let,
"for" => Token::For,
"in" => Token::In,
"sum" => Token::Sum,
"prod" => Token::Prod,
"const" => Token::Const,
"trace_columns" => Token::TraceColumns,
"main" => Token::MainDecl,
Expand Down
Loading