From bcd389b17332fe83ee53300e5c7024244f38536d Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Tue, 31 Jan 2023 06:48:55 +0000 Subject: [PATCH 1/4] feat(core): add list folding structures to air-script-core --- air-script-core/src/expression.rs | 12 +++++++++++- air-script-core/src/lib.rs | 2 +- air-script-core/src/variable.rs | 4 ++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/air-script-core/src/expression.rs b/air-script-core/src/expression.rs index adef0d89..b9a46278 100644 --- a/air-script-core/src/expression.rs +++ b/air-script-core/src/expression.rs @@ -1,4 +1,6 @@ -use super::{Identifier, IndexedTraceAccess, MatrixAccess, NamedTraceAccess, VectorAccess}; +use super::{ + Identifier, IndexedTraceAccess, ListComprehension, MatrixAccess, NamedTraceAccess, VectorAccess, +}; /// Arithmetic expressions for evaluation of constraints. #[derive(Debug, Eq, PartialEq, Clone)] @@ -21,4 +23,12 @@ pub enum Expression { Sub(Box, Box), Mul(Box, Box), Exp(Box, Box), + ListFolding(ListFoldingType), +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ListFoldingType { + Expr(ListComprehension), + Sum(ListComprehension), + Prod(ListComprehension), } diff --git a/air-script-core/src/lib.rs b/air-script-core/src/lib.rs index f42ee7ed..8cd6c495 100644 --- a/air-script-core/src/lib.rs +++ b/air-script-core/src/lib.rs @@ -5,7 +5,7 @@ mod constant; pub use constant::{Constant, ConstantType}; mod expression; -pub use expression::Expression; +pub use expression::{Expression, ListFoldingType}; mod identifier; pub use identifier::Identifier; diff --git a/air-script-core/src/variable.rs b/air-script-core/src/variable.rs index ea6d4953..00706227 100644 --- a/air-script-core/src/variable.rs +++ b/air-script-core/src/variable.rs @@ -30,7 +30,7 @@ pub enum VariableType { #[derive(Debug, Clone, Eq, PartialEq)] pub struct ListComprehension { - expression: Expression, + expression: Box, context: Vec<(Identifier, Iterable)>, } @@ -38,7 +38,7 @@ impl ListComprehension { /// Creates a new list comprehension. pub fn new(expression: Expression, context: Vec<(Identifier, Iterable)>) -> Self { Self { - expression, + expression: Box::new(expression), context, } } From b0b0c97a9d3f97adea0f3ceb23ee59ef69a5221d Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Tue, 31 Jan 2023 06:49:20 +0000 Subject: [PATCH 2/4] feat(parser): add list folding parsing rules to grammar --- parser/src/parser/grammar.lalrpop | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index e5f04f8f..71324b38 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -2,9 +2,9 @@ use crate::{ ast::{ boundary_constraints::{Boundary, BoundaryConstraint, BoundaryStmt}, integrity_constraints::{IntegrityConstraint, IntegrityStmt}, - Constant, ConstantType, Expression, Identifier, Variable, VariableType, ListComprehension, - Iterable, Source, SourceSection, Trace, TraceCols, - PublicInput, PeriodicColumn, IndexedTraceAccess, NamedTraceAccess, Range, MatrixAccess, + Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, Iterable, + ListComprehension, ListFoldingType, MatrixAccess, NamedTraceAccess, PeriodicColumn, + PublicInput, Range, Source, SourceSection, Trace, TraceCols, Variable, VariableType, VectorAccess }, error::{Error, ParseError::{InvalidInt, InvalidTraceCols, MissingMainTraceCols, InvalidConst, InvalidListComprehension, MissingBoundaryConstraint, @@ -195,6 +195,7 @@ BoundaryAtom: Expression = { => Expression::Elem(ident), => Expression::VectorAccess(vector_access), => Expression::MatrixAccess(matrix_access), + > => Expression::ListFolding(list_folding_type), } // INTEGRITY CONSTRAINTS @@ -269,6 +270,8 @@ IntegrityAtom: Expression = { => Expression::VectorAccess(vector_access), => Expression::MatrixAccess(matrix_access), => Expression::NamedTraceAccess(trace_access), + > => + Expression::ListFolding(list_folding_type), } // ATOMS @@ -333,6 +336,13 @@ ListComprehension: ListComprehension = { } } +ListFoldingType: ListFoldingType = { + "sum" "(" "[" > "]" ")" => + ListFoldingType::Sum(list_comprehension), + "prod" "(" "[" > "]" ")" => + ListFoldingType::Prod(list_comprehension), +} + Members: Vec = { => vec![member], "(" > ")" => members @@ -378,6 +388,8 @@ extern { "let" => Token::Let, "for" => Token::For, "in" => Token::In, + "sum" => Token::Sum, + "prod" => Token::Prod, "const" => Token::Const, "trace_columns" => Token::TraceColumns, "main" => Token::MainDecl, From 08763f98facb78a3bae950fc6e912ec7f54027b9 Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Tue, 31 Jan 2023 06:49:57 +0000 Subject: [PATCH 3/4] feat(parser): add list folding support to parser --- ir/src/constraints/graph.rs | 1 + parser/src/ast/mod.rs | 3 ++- parser/src/lexer/mod.rs | 8 ++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/ir/src/constraints/graph.rs b/ir/src/constraints/graph.rs index bd4f1a31..91a706e6 100644 --- a/ir/src/constraints/graph.rs +++ b/ir/src/constraints/graph.rs @@ -228,6 +228,7 @@ impl AlgebraicGraph { let lhs_base = self.accumulate_degree(cycles, lhs); lhs_base * rhs } + Expression::ListFolding(_) => todo!(), } } diff --git a/parser/src/ast/mod.rs b/parser/src/ast/mod.rs index 0408d650..fedb7e54 100644 --- a/parser/src/ast/mod.rs +++ b/parser/src/ast/mod.rs @@ -1,6 +1,7 @@ pub(crate) use air_script_core::{ Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, Iterable, - ListComprehension, MatrixAccess, NamedTraceAccess, Range, Variable, VariableType, VectorAccess, + ListComprehension, ListFoldingType, MatrixAccess, NamedTraceAccess, Range, Variable, + VariableType, VectorAccess, }; pub mod pub_inputs; diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index c0441254..06b748de 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -99,6 +99,14 @@ pub enum Token { #[token("in")] In, + /// Used to declare sum list folding operation in the AIR constraints module. + #[token("sum")] + Sum, + + /// Used to declare prod list folding operation in the AIR constraints module. + #[token("prod")] + Prod, + // GENERAL KEYWORDS // -------------------------------------------------------------------------------------------- /// Keyword to signify that a constraint needs to be enforced From 0d3b94e43495a9b823abd78c815bdae506fb0b2c Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Tue, 31 Jan 2023 06:50:14 +0000 Subject: [PATCH 4/4] test: add list folding lexer and parser tests --- air-script-core/src/expression.rs | 1 - ir/src/constraints/graph.rs | 2 +- parser/src/lexer/tests/list_comprehension.rs | 76 +++ parser/src/parser/tests/list_comprehension.rs | 516 +++++++++++++++++- 4 files changed, 592 insertions(+), 3 deletions(-) diff --git a/air-script-core/src/expression.rs b/air-script-core/src/expression.rs index b9a46278..0bb286b6 100644 --- a/air-script-core/src/expression.rs +++ b/air-script-core/src/expression.rs @@ -28,7 +28,6 @@ pub enum Expression { #[derive(Debug, Clone, Eq, PartialEq)] pub enum ListFoldingType { - Expr(ListComprehension), Sum(ListComprehension), Prod(ListComprehension), } diff --git a/ir/src/constraints/graph.rs b/ir/src/constraints/graph.rs index 91a706e6..6d721ff5 100644 --- a/ir/src/constraints/graph.rs +++ b/ir/src/constraints/graph.rs @@ -161,6 +161,7 @@ impl AlgebraicGraph { lhs.domain(), )) } + Expression::ListFolding(_) => todo!(), } } @@ -228,7 +229,6 @@ impl AlgebraicGraph { let lhs_base = self.accumulate_degree(cycles, lhs); lhs_base * rhs } - Expression::ListFolding(_) => todo!(), } } diff --git a/parser/src/lexer/tests/list_comprehension.rs b/parser/src/lexer/tests/list_comprehension.rs index fcbb0f9d..c4627121 100644 --- a/parser/src/lexer/tests/list_comprehension.rs +++ b/parser/src/lexer/tests/list_comprehension.rs @@ -72,3 +72,79 @@ fn multiple_iterables_comprehension() { ]; expect_valid_tokenization(source, tokens); } + +#[test] +fn one_iterable_folding() { + let source = "let y = sum([x for x in x])"; + let tokens = vec![ + Token::Let, + Token::Ident("y".to_string()), + Token::Equal, + Token::Sum, + Token::Lparen, + Token::Lsqb, + Token::Ident("x".to_string()), + Token::For, + Token::Ident("x".to_string()), + Token::In, + Token::Ident("x".to_string()), + Token::Rsqb, + Token::Rparen, + ]; + expect_valid_tokenization(source, tokens); +} + +#[test] +fn multiple_iterables_list_folding() { + let source = "let a = sum([w + x - y - z for (w, x, y, z) in (0..3, x, y[0..3], z[0..3])])"; + let tokens = vec![ + Token::Let, + Token::Ident("a".to_string()), + Token::Equal, + Token::Sum, + Token::Lparen, + Token::Lsqb, + Token::Ident("w".to_string()), + Token::Plus, + Token::Ident("x".to_string()), + Token::Minus, + Token::Ident("y".to_string()), + Token::Minus, + Token::Ident("z".to_string()), + Token::For, + Token::Lparen, + Token::Ident("w".to_string()), + Token::Comma, + Token::Ident("x".to_string()), + Token::Comma, + Token::Ident("y".to_string()), + Token::Comma, + Token::Ident("z".to_string()), + Token::Rparen, + Token::In, + Token::Lparen, + Token::Num("0".to_string()), + Token::Range, + Token::Num("3".to_string()), + Token::Comma, + Token::Ident("x".to_string()), + Token::Comma, + Token::Ident("y".to_string()), + Token::Lsqb, + Token::Num("0".to_string()), + Token::Range, + Token::Num("3".to_string()), + Token::Rsqb, + Token::Comma, + Token::Ident("z".to_string()), + Token::Lsqb, + Token::Num("0".to_string()), + Token::Range, + Token::Num("3".to_string()), + Token::Rsqb, + Token::Rparen, + Token::Rsqb, + Token::Rparen, + ]; + expect_valid_tokenization(source, tokens); +} diff --git a/parser/src/parser/tests/list_comprehension.rs b/parser/src/parser/tests/list_comprehension.rs index a08b7b85..0f4458d3 100644 --- a/parser/src/parser/tests/list_comprehension.rs +++ b/parser/src/parser/tests/list_comprehension.rs @@ -1,4 +1,4 @@ -use air_script_core::{Iterable, ListComprehension, Range}; +use air_script_core::{Iterable, ListComprehension, ListFoldingType, Range}; use super::{build_parse_test, Identifier, IntegrityConstraint, Source}; use crate::{ @@ -814,3 +814,517 @@ fn err_ic_lc_two_members_one_iterable() { )); build_parse_test!(source).expect_error(error); } + +// LIST FOLDING +// ================================================================================================ + +#[test] +fn bc_one_iterable_identifier_lf() { + let source = " + trace_columns: + main: [a, b, c[4]] + boundary_constraints: + let x = sum([col^7 for col in c]) + let y = prod([col^7 for col in c]) + enf a.first = x + y"; + + let expected = Source(vec![ + Trace(Trace { + main_cols: vec![ + TraceCols::new(Identifier("a".to_string()), 1), + TraceCols::new(Identifier("b".to_string()), 1), + TraceCols::new(Identifier("c".to_string()), 4), + ], + aux_cols: vec![], + }), + BoundaryConstraints(vec![ + BoundaryStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Exp( + Box::new(Elem(Identifier("col".to_string()))), + Box::new(Const(7)), + ), + vec![( + Identifier("col".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + )], + )))), + )), + BoundaryStmt::Variable(Variable::new( + Identifier("y".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Prod(ListComprehension::new( + Exp( + Box::new(Elem(Identifier("col".to_string()))), + Box::new(Const(7)), + ), + vec![( + Identifier("col".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + )], + )))), + )), + BoundaryStmt::Constraint(BoundaryConstraint::new( + NamedTraceAccess::new(Identifier("a".to_string()), 0, 0), + Boundary::First, + Add( + Box::new(Elem(Identifier("x".to_string()))), + Box::new(Elem(Identifier("y".to_string()))), + ), + )), + ]), + ]); + + build_parse_test!(source).expect_ast(expected); +} + +#[test] +fn bc_two_iterable_identifier_lf() { + let source = " + trace_columns: + main: [a, b, c[4], d[4]] + boundary_constraints: + let x = sum([c * d for (c, d) in (c, d)]) + let y = prod([c + d for (c, d) in (c, d)]) + enf a.first = x + y"; + + let expected = Source(vec![ + Trace(Trace { + main_cols: vec![ + TraceCols::new(Identifier("a".to_string()), 1), + TraceCols::new(Identifier("b".to_string()), 1), + TraceCols::new(Identifier("c".to_string()), 4), + TraceCols::new(Identifier("d".to_string()), 4), + ], + aux_cols: vec![], + }), + BoundaryConstraints(vec![ + BoundaryStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Mul( + Box::new(Elem(Identifier("c".to_string()))), + Box::new(Elem(Identifier("d".to_string()))), + ), + vec![ + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ( + Identifier("d".to_string()), + Iterable::Identifier(Identifier("d".to_string())), + ), + ], + )))), + )), + BoundaryStmt::Variable(Variable::new( + Identifier("y".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Prod(ListComprehension::new( + Add( + Box::new(Elem(Identifier("c".to_string()))), + Box::new(Elem(Identifier("d".to_string()))), + ), + vec![ + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ( + Identifier("d".to_string()), + Iterable::Identifier(Identifier("d".to_string())), + ), + ], + )))), + )), + BoundaryStmt::Constraint(BoundaryConstraint::new( + NamedTraceAccess::new(Identifier("a".to_string()), 0, 0), + Boundary::First, + Add( + Box::new(Elem(Identifier("x".to_string()))), + Box::new(Elem(Identifier("y".to_string()))), + ), + )), + ]), + ]); + + build_parse_test!(source).expect_ast(expected); +} + +#[test] +fn bc_two_iterables_identifier_range_lf() { + let source = " + trace_columns: + main: [a, b, c[4]] + boundary_constraints: + let x = sum([i * c for (i, c) in (0..4, c)]) + let y = prod([i + c for (i, c) in (0..4, c)]) + enf a.first = x + y"; + + let expected = Source(vec![ + Trace(Trace { + main_cols: vec![ + TraceCols::new(Identifier("a".to_string()), 1), + TraceCols::new(Identifier("b".to_string()), 1), + TraceCols::new(Identifier("c".to_string()), 4), + ], + aux_cols: vec![], + }), + BoundaryConstraints(vec![ + BoundaryStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Mul( + Box::new(Elem(Identifier("i".to_string()))), + Box::new(Elem(Identifier("c".to_string()))), + ), + vec![ + ( + Identifier("i".to_string()), + Iterable::Range(Range::new(0, 4)), + ), + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ], + )))), + )), + BoundaryStmt::Variable(Variable::new( + Identifier("y".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Prod(ListComprehension::new( + Add( + Box::new(Elem(Identifier("i".to_string()))), + Box::new(Elem(Identifier("c".to_string()))), + ), + vec![ + ( + Identifier("i".to_string()), + Iterable::Range(Range::new(0, 4)), + ), + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ], + )))), + )), + BoundaryStmt::Constraint(BoundaryConstraint::new( + NamedTraceAccess::new(Identifier("a".to_string()), 0, 0), + Boundary::First, + Add( + Box::new(Elem(Identifier("x".to_string()))), + Box::new(Elem(Identifier("y".to_string()))), + ), + )), + ]), + ]); + + build_parse_test!(source).expect_ast(expected); +} + +#[test] +fn ic_one_iterable_identifier_lf() { + let source = " + trace_columns: + main: [a, b, c[4]] + integrity_constraints: + let x = sum([col^7 for col in c]) + let y = prod([col^7 for col in c]) + enf a = x + y"; + + let expected = Source(vec![ + Trace(Trace { + main_cols: vec![ + TraceCols::new(Identifier("a".to_string()), 1), + TraceCols::new(Identifier("b".to_string()), 1), + TraceCols::new(Identifier("c".to_string()), 4), + ], + aux_cols: vec![], + }), + IntegrityConstraints(vec![ + IntegrityStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Exp( + Box::new(Elem(Identifier("col".to_string()))), + Box::new(Const(7)), + ), + vec![( + Identifier("col".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + )], + )))), + )), + IntegrityStmt::Variable(Variable::new( + Identifier("y".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Prod(ListComprehension::new( + Exp( + Box::new(Elem(Identifier("col".to_string()))), + Box::new(Const(7)), + ), + vec![( + Identifier("col".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + )], + )))), + )), + IntegrityStmt::Constraint(IntegrityConstraint::new( + Elem(Identifier("a".to_string())), + Add( + Box::new(Elem(Identifier("x".to_string()))), + Box::new(Elem(Identifier("y".to_string()))), + ), + )), + ]), + ]); + + build_parse_test!(source).expect_ast(expected); +} + +#[test] +fn ic_two_iterable_identifier_lf() { + let source = " + trace_columns: + main: [a, b, c[4], d[4]] + integrity_constraints: + let x = sum([c * d for (c, d) in (c, d)]) + let y = prod([c + d for (c, d) in (c, d)]) + enf a = x + y"; + + let expected = Source(vec![ + Trace(Trace { + main_cols: vec![ + TraceCols::new(Identifier("a".to_string()), 1), + TraceCols::new(Identifier("b".to_string()), 1), + TraceCols::new(Identifier("c".to_string()), 4), + TraceCols::new(Identifier("d".to_string()), 4), + ], + aux_cols: vec![], + }), + IntegrityConstraints(vec![ + IntegrityStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Mul( + Box::new(Elem(Identifier("c".to_string()))), + Box::new(Elem(Identifier("d".to_string()))), + ), + vec![ + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ( + Identifier("d".to_string()), + Iterable::Identifier(Identifier("d".to_string())), + ), + ], + )))), + )), + IntegrityStmt::Variable(Variable::new( + Identifier("y".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Prod(ListComprehension::new( + Add( + Box::new(Elem(Identifier("c".to_string()))), + Box::new(Elem(Identifier("d".to_string()))), + ), + vec![ + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ( + Identifier("d".to_string()), + Iterable::Identifier(Identifier("d".to_string())), + ), + ], + )))), + )), + IntegrityStmt::Constraint(IntegrityConstraint::new( + Elem(Identifier("a".to_string())), + Add( + Box::new(Elem(Identifier("x".to_string()))), + Box::new(Elem(Identifier("y".to_string()))), + ), + )), + ]), + ]); + + build_parse_test!(source).expect_ast(expected); +} + +#[test] +fn ic_two_iterables_identifier_range_lf() { + let source = " + trace_columns: + main: [a, b, c[4]] + integrity_constraints: + let x = sum([i * c for (i, c) in (0..4, c)]) + let y = prod([i + c for (i, c) in (0..4, c)]) + enf a = x + y"; + + let expected = Source(vec![ + Trace(Trace { + main_cols: vec![ + TraceCols::new(Identifier("a".to_string()), 1), + TraceCols::new(Identifier("b".to_string()), 1), + TraceCols::new(Identifier("c".to_string()), 4), + ], + aux_cols: vec![], + }), + IntegrityConstraints(vec![ + IntegrityStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Mul( + Box::new(Elem(Identifier("i".to_string()))), + Box::new(Elem(Identifier("c".to_string()))), + ), + vec![ + ( + Identifier("i".to_string()), + Iterable::Range(Range::new(0, 4)), + ), + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ], + )))), + )), + IntegrityStmt::Variable(Variable::new( + Identifier("y".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Prod(ListComprehension::new( + Add( + Box::new(Elem(Identifier("i".to_string()))), + Box::new(Elem(Identifier("c".to_string()))), + ), + vec![ + ( + Identifier("i".to_string()), + Iterable::Range(Range::new(0, 4)), + ), + ( + Identifier("c".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ], + )))), + )), + IntegrityStmt::Constraint(IntegrityConstraint::new( + Elem(Identifier("a".to_string())), + Add( + Box::new(Elem(Identifier("x".to_string()))), + Box::new(Elem(Identifier("y".to_string()))), + ), + )), + ]), + ]); + + build_parse_test!(source).expect_ast(expected); +} + +#[test] +fn ic_three_iterables_slice_identifier_range_lf() { + let source = " + trace_columns: + main: [a, b[6], c[4]] + integrity_constraints: + let x = sum([m * n * i for (m, n, i) in (b[1..5], c, 0..4)]) + let x = sum([m * n * i for (m, n, i) in (b[1..5], c, 0..4)]) + enf a = x + y"; + + let expected = Source(vec![ + Trace(Trace { + main_cols: vec![ + TraceCols::new(Identifier("a".to_string()), 1), + TraceCols::new(Identifier("b".to_string()), 6), + TraceCols::new(Identifier("c".to_string()), 4), + ], + aux_cols: vec![], + }), + IntegrityConstraints(vec![ + IntegrityStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Mul( + Box::new(Mul( + Box::new(Elem(Identifier("m".to_string()))), + Box::new(Elem(Identifier("n".to_string()))), + )), + Box::new(Elem(Identifier("i".to_string()))), + ), + vec![ + ( + Identifier("m".to_string()), + Iterable::Slice(Identifier("b".to_string()), Range::new(1, 5)), + ), + ( + Identifier("n".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ( + Identifier("i".to_string()), + Iterable::Range(Range::new(0, 4)), + ), + ], + )))), + )), + IntegrityStmt::Variable(Variable::new( + Identifier("x".to_string()), + VariableType::Scalar(ListFolding(ListFoldingType::Sum(ListComprehension::new( + Mul( + Box::new(Mul( + Box::new(Elem(Identifier("m".to_string()))), + Box::new(Elem(Identifier("n".to_string()))), + )), + Box::new(Elem(Identifier("i".to_string()))), + ), + vec![ + ( + Identifier("m".to_string()), + Iterable::Slice(Identifier("b".to_string()), Range::new(1, 5)), + ), + ( + Identifier("n".to_string()), + Iterable::Identifier(Identifier("c".to_string())), + ), + ( + Identifier("i".to_string()), + Iterable::Range(Range::new(0, 4)), + ), + ], + )))), + )), + IntegrityStmt::Constraint(IntegrityConstraint::new( + Elem(Identifier("a".to_string())), + Add( + Box::new(Elem(Identifier("x".to_string()))), + Box::new(Elem(Identifier("y".to_string()))), + ), + )), + ]), + ]); + + build_parse_test!(source).expect_ast(expected); +} + +// INVALID LIST FOLDING +// ================================================================================================ + +#[test] +fn err_ic_lf_single_members_double_iterables() { + let source = " + trace_columns: + main: [a, b, c[4]] + + integrity_constraints: + let x = sum([c for c in (c, d)]) + enf a = x"; + + let error = Error::ParseError(ParseError::InvalidListComprehension( + "Number of members and iterables must match".to_string(), + )); + build_parse_test!(source).expect_error(error); +}