From 51529a6f0f40ec7dc6c9e70e671672a1ad040452 Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Wed, 15 Feb 2023 18:23:42 +0000 Subject: [PATCH] refactor: add iterable context --- ir/src/constraints/list_comprehension.rs | 320 +++++++++++++---------- ir/src/tests/mod.rs | 2 +- 2 files changed, 179 insertions(+), 143 deletions(-) diff --git a/ir/src/constraints/list_comprehension.rs b/ir/src/constraints/list_comprehension.rs index 2b09ca4d9..2f82dcf7f 100644 --- a/ir/src/constraints/list_comprehension.rs +++ b/ir/src/constraints/list_comprehension.rs @@ -1,12 +1,13 @@ -use super::{ - graph::CURRENT_ROW, - IdentifierType, SemanticError, SymbolTable, -}; +use std::collections::BTreeMap; + +use super::{graph::CURRENT_ROW, IdentifierType, SemanticError, SymbolTable}; use air_script_core::{ Expression, Identifier, Iterable, ListComprehension, ListFoldingType, NamedTraceAccess, VariableType, VectorAccess, }; +type IterableContext = BTreeMap; + /// Unfolds a list comprehension into a vector of expressions. /// /// Returns an error if there is an error while parsing any of the expressions in the expanded @@ -16,8 +17,19 @@ pub fn unfold_lc( symbol_table: &SymbolTable, ) -> Result, SemanticError> { let lc_length = get_num_iterations(lc, symbol_table)?; + let mut iterable_context = IterableContext::new(); + for (member, iterable) in lc.context() { + if iterable_context + .insert(member.clone(), iterable.clone()) + .is_some() + { + Err(SemanticError::InvalidListComprehension( + "Duplicate member in list comprehension".to_string(), + ))? + } + } let vector = (0..lc_length) - .map(|i| parse_lc_expr(lc.expression(), lc, symbol_table, i)) + .map(|i| parse_lc_expr(lc.expression(), &iterable_context, symbol_table, i)) .collect::, _>>()?; Ok(vector) } @@ -28,33 +40,37 @@ pub fn unfold_lc( /// Returns an error if there is an error while parsing the sub-expression. fn parse_lc_expr( expression: &Expression, - lc: &ListComprehension, + iterable_context: &IterableContext, symbol_table: &SymbolTable, i: usize, ) -> Result { match expression { - Expression::Elem(ident) => parse_elem(ident, expression, lc, symbol_table, i), - Expression::NamedTraceAccess(named_trace_access) => { - parse_named_trace_access(named_trace_access, expression, lc, symbol_table, i) - } + Expression::Elem(ident) => parse_elem(ident, expression, iterable_context, symbol_table, i), + Expression::NamedTraceAccess(named_trace_access) => parse_named_trace_access( + named_trace_access, + expression, + iterable_context, + symbol_table, + i, + ), Expression::Add(lhs, rhs) => { - let lhs = parse_lc_expr(lhs, lc, symbol_table, i)?; - let rhs = parse_lc_expr(rhs, lc, symbol_table, i)?; + let lhs = parse_lc_expr(lhs, iterable_context, symbol_table, i)?; + let rhs = parse_lc_expr(rhs, iterable_context, symbol_table, i)?; Ok(Expression::Add(Box::new(lhs), Box::new(rhs))) } Expression::Sub(lhs, rhs) => { - let lhs = parse_lc_expr(lhs, lc, symbol_table, i)?; - let rhs = parse_lc_expr(rhs, lc, symbol_table, i)?; + let lhs = parse_lc_expr(lhs, iterable_context, symbol_table, i)?; + let rhs = parse_lc_expr(rhs, iterable_context, symbol_table, i)?; Ok(Expression::Sub(Box::new(lhs), Box::new(rhs))) } Expression::Mul(lhs, rhs) => { - let lhs = parse_lc_expr(lhs, lc, symbol_table, i)?; - let rhs = parse_lc_expr(rhs, lc, symbol_table, i)?; + let lhs = parse_lc_expr(lhs, iterable_context, symbol_table, i)?; + let rhs = parse_lc_expr(rhs, iterable_context, symbol_table, i)?; Ok(Expression::Mul(Box::new(lhs), Box::new(rhs))) } Expression::Exp(lhs, rhs) => { - let lhs = parse_lc_expr(lhs, lc, symbol_table, i)?; - let rhs = parse_lc_expr(rhs, lc, symbol_table, i)?; + let lhs = parse_lc_expr(lhs, iterable_context, symbol_table, i)?; + let rhs = parse_lc_expr(rhs, iterable_context, symbol_table, i)?; Ok(Expression::Exp(Box::new(lhs), Box::new(rhs))) } Expression::ListFolding(lf_type) => { @@ -78,99 +94,98 @@ fn parse_lc_expr( fn parse_elem( ident: &Identifier, expression: &Expression, - lc: &ListComprehension, + iterable_context: &IterableContext, symbol_table: &SymbolTable, i: usize, ) -> Result { - for (member, iterable) in lc.context() { - if ident == member { - match iterable { - Iterable::Identifier(ident) => { - let ident_type = symbol_table.get_type(ident.name())?; - match ident_type { - IdentifierType::TraceColumns(trace_columns) => { - validate_access(i, trace_columns.size())?; - return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( - ident.clone(), - i, - CURRENT_ROW, - ))); - } - IdentifierType::IntegrityVariable(var_type) => { - match var_type.value() { - VariableType::Vector(vector) => { - validate_access(i, vector.len())?; - return Ok(vector[i].clone()); - } - // TODO: Handle matrix access - _ => Err(SemanticError::InvalidListComprehension( - "Iterables should be vectors".to_string(), - ))?, + let iterable = iterable_context.get(ident); + if let Some(iterable_type) = iterable { + match iterable_type { + Iterable::Identifier(ident) => { + let ident_type = symbol_table.get_type(ident.name())?; + match ident_type { + IdentifierType::TraceColumns(trace_columns) => { + validate_access(i, trace_columns.size())?; + return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( + ident.clone(), + i, + CURRENT_ROW, + ))); + } + IdentifierType::IntegrityVariable(var_type) => { + match var_type.value() { + VariableType::Vector(vector) => { + validate_access(i, vector.len())?; + return Ok(vector[i].clone()); } + // TODO: Handle matrix access + _ => Err(SemanticError::InvalidListComprehension( + "Iterables should be vectors".to_string(), + ))?, } - IdentifierType::PublicInput(size) => { - validate_access(i, *size)?; - return Ok(Expression::VectorAccess(VectorAccess::new( - ident.clone(), - i, - ))); - } - IdentifierType::RandomValuesBinding(_, size) => { - validate_access(i, *size)?; - return Ok(Expression::VectorAccess(VectorAccess::new( - ident.clone(), - i, - ))); - } - _ => Err(SemanticError::InvalidListComprehension( - "Invalid type for a vector".to_string(), - ))?, } + IdentifierType::PublicInput(size) => { + validate_access(i, *size)?; + return Ok(Expression::VectorAccess(VectorAccess::new( + ident.clone(), + i, + ))); + } + IdentifierType::RandomValuesBinding(_, size) => { + validate_access(i, *size)?; + return Ok(Expression::VectorAccess(VectorAccess::new( + ident.clone(), + i, + ))); + } + _ => Err(SemanticError::InvalidListComprehension( + "Invalid type for a vector".to_string(), + ))?, } - Iterable::Range(range) => { - return Ok(Expression::Const((range.start() + i) as u64)); - } - Iterable::Slice(ident, range) => { - let ident_type = symbol_table.get_type(ident.name())?; - match ident_type { - IdentifierType::TraceColumns(trace_columns) => { - validate_access(i, trace_columns.size())?; - return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( - ident.clone(), - range.start() + i, - CURRENT_ROW, - ))); - } - IdentifierType::IntegrityVariable(var_type) => { - match var_type.value() { - VariableType::Vector(vector) => { - validate_access(i, vector.len())?; - return Ok(vector[range.start() + i].clone()); - } - // TODO: Handle matrix access - _ => Err(SemanticError::InvalidListComprehension( - "Iterables should be vectors".to_string(), - ))?, + } + Iterable::Range(range) => { + return Ok(Expression::Const((range.start() + i) as u64)); + } + Iterable::Slice(ident, range) => { + let ident_type = symbol_table.get_type(ident.name())?; + match ident_type { + IdentifierType::TraceColumns(trace_columns) => { + validate_access(i, trace_columns.size())?; + return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( + ident.clone(), + range.start() + i, + CURRENT_ROW, + ))); + } + IdentifierType::IntegrityVariable(var_type) => { + match var_type.value() { + VariableType::Vector(vector) => { + validate_access(i, vector.len())?; + return Ok(vector[range.start() + i].clone()); } + // TODO: Handle matrix access + _ => Err(SemanticError::InvalidListComprehension( + "Iterables should be vectors".to_string(), + ))?, } - IdentifierType::PublicInput(size) => { - validate_access(i, *size)?; - return Ok(Expression::VectorAccess(VectorAccess::new( - ident.clone(), - range.start() + i, - ))); - } - IdentifierType::RandomValuesBinding(_, size) => { - validate_access(i, *size)?; - return Ok(Expression::VectorAccess(VectorAccess::new( - ident.clone(), - range.start() + i, - ))); - } - _ => Err(SemanticError::InvalidListComprehension( - "Invalid type for a vector".to_string(), - ))?, } + IdentifierType::PublicInput(size) => { + validate_access(i, *size)?; + return Ok(Expression::VectorAccess(VectorAccess::new( + ident.clone(), + range.start() + i, + ))); + } + IdentifierType::RandomValuesBinding(_, size) => { + validate_access(i, *size)?; + return Ok(Expression::VectorAccess(VectorAccess::new( + ident.clone(), + range.start() + i, + ))); + } + _ => Err(SemanticError::InvalidListComprehension( + "Invalid type for a vector".to_string(), + ))?, } } } @@ -189,49 +204,48 @@ fn parse_elem( fn parse_named_trace_access( named_trace_access: &NamedTraceAccess, expression: &Expression, - lc: &ListComprehension, + iterable_context: &IterableContext, symbol_table: &SymbolTable, i: usize, ) -> Result { - for (member, iterable) in lc.context() { - if named_trace_access.name() == member.name() { - match iterable { - Iterable::Identifier(ident) => { - let ident_type = symbol_table.get_type(ident.name())?; - match ident_type { - IdentifierType::TraceColumns(size) => { - validate_access(i, size.size())?; - return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( - ident.clone(), - i, - named_trace_access.row_offset(), - ))); - } - _ => Err(SemanticError::InvalidListComprehension( - "Iterable should be a trace column".to_string(), - ))?, + let iterable = iterable_context.get(&Identifier(named_trace_access.name().to_string())); + if let Some(iterable_type) = iterable { + match iterable_type { + Iterable::Identifier(ident) => { + let ident_type = symbol_table.get_type(ident.name())?; + match ident_type { + IdentifierType::TraceColumns(size) => { + validate_access(i, size.size())?; + return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( + ident.clone(), + i, + named_trace_access.row_offset(), + ))); } + _ => Err(SemanticError::InvalidListComprehension( + "Iterable should be a trace column".to_string(), + ))?, } - Iterable::Range(_) => { - return Err(SemanticError::InvalidListComprehension( - "Iterable cannot be of range type here".to_string(), - )); - } - Iterable::Slice(ident, range) => { - let ident_type = symbol_table.get_type(ident.name())?; - match ident_type { - IdentifierType::TraceColumns(trace_columns) => { - validate_access(i, trace_columns.size())?; - return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( - ident.clone(), - range.start() + i, - named_trace_access.row_offset(), - ))); - } - _ => Err(SemanticError::InvalidListComprehension( - "Iterable should be a trace column".to_string(), - ))?, + } + Iterable::Range(_) => { + return Err(SemanticError::InvalidListComprehension( + "Iterable cannot be of range type here".to_string(), + )); + } + Iterable::Slice(ident, range) => { + let ident_type = symbol_table.get_type(ident.name())?; + match ident_type { + IdentifierType::TraceColumns(trace_columns) => { + validate_access(i, trace_columns.size())?; + return Ok(Expression::NamedTraceAccess(NamedTraceAccess::new( + ident.clone(), + range.start() + i, + named_trace_access.row_offset(), + ))); } + _ => Err(SemanticError::InvalidListComprehension( + "Iterable should be a trace column".to_string(), + ))?, } } } @@ -248,19 +262,41 @@ fn parse_list_folding( ) -> Result { match lf_type { ListFoldingType::Sum(lc) => { + let mut iterable_context = IterableContext::new(); + for (member, iterable) in lc.context() { + if iterable_context + .insert(member.clone(), iterable.clone()) + .is_some() + { + Err(SemanticError::InvalidListComprehension( + "Duplicate member in list comprehension".to_string(), + ))? + } + } let list = unfold_lc(lc, symbol_table)?; - let mut sum = parse_lc_expr(expression, lc, symbol_table, i)?; + let mut sum = parse_lc_expr(expression, &iterable_context, symbol_table, i)?; for elem in list.iter().skip(1) { - let expr = parse_lc_expr(elem, lc, symbol_table, i)?; + let expr = parse_lc_expr(elem, &iterable_context, symbol_table, i)?; sum = Expression::Add(Box::new(sum), Box::new(expr)); } Ok(sum) } ListFoldingType::Prod(lc) => { + let mut iterable_context = IterableContext::new(); + for (member, iterable) in lc.context() { + if iterable_context + .insert(member.clone(), iterable.clone()) + .is_some() + { + Err(SemanticError::InvalidListComprehension( + "Duplicate member in list comprehension".to_string(), + ))? + } + } let list = unfold_lc(lc, symbol_table)?; - let mut prod = parse_lc_expr(expression, lc, symbol_table, i)?; + let mut prod = parse_lc_expr(expression, &iterable_context, symbol_table, i)?; for elem in list.iter().skip(1) { - let expr = parse_lc_expr(elem, lc, symbol_table, i)?; + let expr = parse_lc_expr(elem, &iterable_context, symbol_table, i)?; prod = Expression::Mul(Box::new(prod), Box::new(expr)); } Ok(prod) diff --git a/ir/src/tests/mod.rs b/ir/src/tests/mod.rs index c8ebae8f8..1a138839b 100644 --- a/ir/src/tests/mod.rs +++ b/ir/src/tests/mod.rs @@ -1049,7 +1049,7 @@ fn lf_in_lc() { let parsed = parse(source).expect("Parsing failed"); let result = AirIR::from_source(&parsed); - + println!("{:?}", result); assert!(result.is_ok()); }