From 0424320469b526ddc79287b27fc025c5fbfef306 Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Wed, 14 Dec 2022 17:08:56 +0000 Subject: [PATCH] refactor(ir): Refactor IR - Remove variables for boundary constraints - Remove variables graph - Remove variables vector --- air-script/tests/variables/variables.air | 18 +--- air-script/tests/variables/variables.rs | 29 ++--- .../src/air/boundary_constraints.rs | 60 +---------- .../src/air/transition_constraints.rs | 86 +-------------- ir/src/boundary_stmts.rs | 79 +++++--------- ir/src/lib.rs | 24 +---- ir/src/symbol_table.rs | 59 +--------- ir/src/tests/mod.rs | 52 +-------- ir/src/transition_stmts/graph.rs | 91 +++++++++------- ir/src/transition_stmts/mod.rs | 101 +++--------------- 10 files changed, 113 insertions(+), 486 deletions(-) diff --git a/air-script/tests/variables/variables.air b/air-script/tests/variables/variables.air index a41b12c80..4f37a2ee4 100644 --- a/air-script/tests/variables/variables.air +++ b/air-script/tests/variables/variables.air @@ -14,28 +14,18 @@ periodic_columns: k0: [1, 1, 1, 1, 1, 1, 1, 0] boundary_constraints: - # define boundary constraints against the main trace at the first row of the trace. - let x = 1 enf a.first = stack_inputs[0] - let y = [x, 4 - 2] - enf b.first = x - enf c.first = y[0] - let z = [[x, 3], [4 - 2, 8 + 8]] - # define boundary constraints against the main trace at the last row of the trace. - enf a.last = stack_outputs[0] - enf b.last = z[0][0] - enf c.last = stack_outputs[2] - - # set the first row of the auxiliary column p to 1 - enf p.first = 1 - + enf a.last = 1 + transition_constraints: let m = 0 + # the selector must be binary. enf s^2 = s let n = [2 * 3, s] let o = [[s', 3], [4 - 2, 8 + 8]] + # selector should stay the same for all rows of an 8-row cycle. enf k0 * (s' - s) = m diff --git a/air-script/tests/variables/variables.rs b/air-script/tests/variables/variables.rs index dc01eee59..4583eedd3 100644 --- a/air-script/tests/variables/variables.rs +++ b/air-script/tests/variables/variables.rs @@ -45,8 +45,8 @@ impl Air for VariablesAir { fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(3)]; let aux_degrees = vec![TransitionConstraintDegree::new(2)]; - let num_main_assertions = 6; - let num_aux_assertions = 1; + let num_main_assertions = 2; + let num_aux_assertions = 0; let context = AirContext::new_multi_segment( trace_info, @@ -65,48 +65,31 @@ impl Air for VariablesAir { } fn get_assertions(&self) -> Vec> { - let x = Felt::new(1); - let y = [x, (Felt::new(4)) - (Felt::new(2))]; - let z = [[x, Felt::new(3)], [(Felt::new(4)) - (Felt::new(2)), (Felt::new(8)) + (Felt::new(8))]]; let mut result = Vec::new(); result.push(Assertion::single(1, 0, self.stack_inputs[0])); - result.push(Assertion::single(2, 0, x)); - result.push(Assertion::single(3, 0, y[0])); let last_step = self.last_step(); - result.push(Assertion::single(1, last_step, self.stack_outputs[0])); - result.push(Assertion::single(2, last_step, z[0][0])); - result.push(Assertion::single(3, last_step, self.stack_outputs[2])); + result.push(Assertion::single(1, last_step, Felt::new(1))); result } fn get_aux_assertions>(&self, aux_rand_elements: &AuxTraceRandElements) -> Vec> { - let x = E::from(1_u64); - let y = [E::from(x), (E::from(4_u64)) - (E::from(2_u64))]; - let z = [[E::from(x), E::from(3_u64)], [(E::from(4_u64)) - (E::from(2_u64)), (E::from(8_u64)) + (E::from(8_u64))]]; let mut result = Vec::new(); - result.push(Assertion::single(0, 0, E::from(1_u64))); result } fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { - let m = E::from(0_u64); - let n = [(E::from(2_u64)) * (E::from(3_u64)), current[0]]; - let o = [[next[0], E::from(3_u64)], [E::from(4_u64) - (E::from(2_u64)), E::from(8_u64) + E::from(8_u64)]]; let current = frame.current(); let next = frame.next(); result[0] = (current[0]).exp(E::PositiveInteger::from(2_u64)) - (current[0]); - result[1] = (periodic_values[0]) * (next[0] - (current[0])) - (m); - result[2] = (E::from(1_u64) - (current[0])) * (current[3] - (current[1]) + current[2]) - (n[0] - (n[1])); - result[3] = (current[0]) * (current[3] - ((current[1]) * (current[2]))) - (o[0][0] - (o[0][1]) - (o[1][0])); + result[1] = (periodic_values[0]) * (next[0] - (current[0])) - (E::from(0_u64)); + result[2] = (E::from(1_u64) - (current[0])) * (current[3] - (current[1]) + current[2]) - ((E::from(2_u64)) * (E::from(3_u64)) - (current[0])); + result[3] = (current[0]) * (current[3] - ((current[1]) * (current[2]))) - (next[0] - (E::from(3_u64)) - (E::from(4_u64) - (E::from(2_u64)))); } fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements, result: &mut [E]) where F: FieldElement, E: FieldElement + ExtensionOf, { - let m = E::from(0_u64); - let n = [(E::from(2_u64)) * (E::from(3_u64)), current[0]]; - let o = [[next[0], E::from(3_u64)], [E::from(4_u64) - (E::from(2_u64)), E::from(8_u64) + E::from(8_u64)]]; let current = aux_frame.current(); let next = aux_frame.next(); result[0] = next[0] - ((current[0]) * (current[3] + aux_rand_elements.get_segment_elements(0)[0])); diff --git a/codegen/winterfell/src/air/boundary_constraints.rs b/codegen/winterfell/src/air/boundary_constraints.rs index a42e72ca0..472ea85d2 100644 --- a/codegen/winterfell/src/air/boundary_constraints.rs +++ b/codegen/winterfell/src/air/boundary_constraints.rs @@ -1,5 +1,5 @@ use super::{AirIR, Codegen, Impl}; -use ir::{ast::BoundaryVariableType, BoundaryExpr}; +use ir::BoundaryExpr; // HELPERS TO GENERATE THE WINTERFELL BOUNDARY CONSTRAINT METHODS // ================================================================================================ @@ -13,12 +13,6 @@ pub(super) fn add_fn_get_assertions(impl_ref: &mut Impl, ir: &AirIR) { .arg_ref_self() .ret("Vec>"); - // TODO: Only add variables used in main trace assertions - let variables = add_variables(ir, false); - for variable in variables { - get_assertions.line(variable); - } - // declare the result vector to be returned. get_assertions.line("let mut result = Vec::new();"); @@ -62,12 +56,6 @@ pub(super) fn add_fn_get_aux_assertions(impl_ref: &mut Impl, ir: &AirIR) { .arg("aux_rand_elements", "&AuxTraceRandElements") .ret("Vec>"); - // TODO: Only add variables used in aux trace assertions - let variables = add_variables(ir, true); - for variable in variables { - get_aux_assertions.line(variable); - } - // declare the result vector to be returned. get_aux_assertions.line("let mut result = Vec::new();"); @@ -100,52 +88,6 @@ pub(super) fn add_fn_get_aux_assertions(impl_ref: &mut Impl, ir: &AirIR) { get_aux_assertions.line("result"); } -/// A helper function to add variable definitions to the enforce_assertions and -/// enforce_aux_assertions functions. -fn add_variables(ir: &AirIR, is_aux_constraint: bool) -> Vec { - let mut vars = Vec::new(); - let variables = ir.boundary_variables(); - for variable in variables { - let variable_name = variable.name(); - let variable_def = match variable.value() { - BoundaryVariableType::Scalar(expr) => { - format!( - "let {} = {};", - variable_name, - expr.to_string(ir, is_aux_constraint) - ) - } - BoundaryVariableType::Vector(vector) => format!( - "let {} = [{}];", - variable_name, - vector - .iter() - .map(|expr| expr.to_string(ir, is_aux_constraint)) - .collect::>() - .join(", ") - ), - BoundaryVariableType::Matrix(matrix) => { - let variable_value = { - let mut rows = vec![]; - for row in matrix { - rows.push(format!( - "[{}]", - row.iter() - .map(|expr| expr.to_string(ir, is_aux_constraint)) - .collect::>() - .join(", "), - )) - } - format!("[{}]", rows.join(", ")) - }; - format!("let {} = {};", variable_name, variable_value) - } - }; - vars.push(variable_def); - } - vars -} - // RUST STRING GENERATION // ================================================================================================ diff --git a/codegen/winterfell/src/air/transition_constraints.rs b/codegen/winterfell/src/air/transition_constraints.rs index 4f985c1fd..51c3f61d8 100644 --- a/codegen/winterfell/src/air/transition_constraints.rs +++ b/codegen/winterfell/src/air/transition_constraints.rs @@ -1,8 +1,7 @@ use super::{AirIR, Impl}; use ir::{ - ast::{MatrixAccess, TransitionVariableType, VectorAccess}, - transition_stmts::{AlgebraicGraph, ConstantValue, Operation, VariableValue}, - Identifier, NodeIndex, + transition_stmts::{AlgebraicGraph, ConstantValue, Operation}, + NodeIndex, }; // HELPERS TO GENERATE THE WINTERFELL TRANSITION CONSTRAINT METHODS @@ -20,12 +19,6 @@ pub(super) fn add_fn_evaluate_transition(impl_ref: &mut Impl, ir: &AirIR) { .arg("periodic_values", "&[E]") .arg("result", "&mut [E]"); - // TODO: Only add variables used in main trace assertions - let variables = add_variables(ir); - for variable in variables { - evaluate_transition.line(variable); - } - // declare current and next trace row arrays. evaluate_transition.line("let current = frame.current();"); evaluate_transition.line("let next = frame.next();"); @@ -57,12 +50,6 @@ pub(super) fn add_fn_evaluate_aux_transition(impl_ref: &mut Impl, ir: &AirIR) { .bound("F", "FieldElement") .bound("E", "FieldElement + ExtensionOf"); - // TODO: Only add variables used in aux trace assertions. - let variables = add_variables(ir); - for variable in variables { - evaluate_aux_transition.line(variable); - } - // declare current and next trace row arrays. evaluate_aux_transition.line("let current = aux_frame.current();"); evaluate_aux_transition.line("let next = aux_frame.next();"); @@ -78,65 +65,6 @@ pub(super) fn add_fn_evaluate_aux_transition(impl_ref: &mut Impl, ir: &AirIR) { } } -/// A helper function to add variable definitions to the evaluate_transition and -/// evaluate_aux_transition functions. -fn add_variables(ir: &AirIR) -> Vec { - let mut vars = Vec::new(); - let variables = ir.transition_variables(); - let variables_graph = ir.variables_graph(); - for variable in variables { - let variable_name = variable.name(); - let variable_def = match variable.value() { - TransitionVariableType::Scalar(_) => { - let key = VariableValue::Scalar(variable_name.to_string()); - let variable_value = ir.variable_roots().get(&key).unwrap_or_else(|| { - panic!("Variable {} not found in variable_roots map", variable_name) - }); - format!( - "let {} = {};", - variable_name, - variable_value.to_string(variables_graph) - ) - } - TransitionVariableType::Vector(vector) => { - let mut vector_str = Vec::new(); - for idx in 0..vector.len() { - let key = VariableValue::Vector(VectorAccess::new( - Identifier(variable_name.to_string()), - idx, - )); - let variable_value = ir.variable_roots().get(&key).unwrap_or_else(|| { - panic!("Variable {} not found in variable_roots map", variable_name) - }); - vector_str.push(variable_value.to_string(variables_graph)); - } - format!("let {} = [{}];", variable_name, vector_str.join(", ")) - } - TransitionVariableType::Matrix(matrix) => { - let mut rows = Vec::new(); - for row_idx in 0..matrix.len() { - let mut cols = Vec::new(); - for col_idx in 0..matrix[0].len() { - let key = VariableValue::Matrix(MatrixAccess::new( - Identifier(variable_name.to_string()), - row_idx, - col_idx, - )); - let variable_value = ir.variable_roots().get(&key).unwrap_or_else(|| { - panic!("Variable {} not found in variable_roots map", variable_name) - }); - cols.push(variable_value.to_string(variables_graph)); - } - rows.push(format!("[{}]", cols.join(", "))); - } - format!("let {} = [{}];", variable_name, rows.join(", ")) - } - }; - vars.push(variable_def); - } - vars -} - /// Code generation trait for generating Rust code strings from [AlgebraicGraph] types. trait Codegen { fn to_string(&self, graph: &AlgebraicGraph) -> String; @@ -164,16 +92,6 @@ impl Codegen for Operation { matrix_access.row_idx(), matrix_access.col_idx() ), - Operation::Variable(VariableValue::Scalar(ident), _) => ident.to_string(), - Operation::Variable(VariableValue::Vector(vector_access), _) => { - format!("{}[{}]", vector_access.name(), vector_access.idx()) - } - Operation::Variable(VariableValue::Matrix(matrix_access), _) => format!( - "{}[{}][{}]", - matrix_access.name(), - matrix_access.row_idx(), - matrix_access.col_idx() - ), Operation::TraceElement(trace_access) => match trace_access.row_offset() { 0 => { format!("current[{}]", trace_access.col_idx()) diff --git a/ir/src/boundary_stmts.rs b/ir/src/boundary_stmts.rs index e74157b6f..8f9423390 100644 --- a/ir/src/boundary_stmts.rs +++ b/ir/src/boundary_stmts.rs @@ -1,7 +1,7 @@ use crate::TraceSegment; use super::{BTreeMap, BoundaryExpr, IdentifierType, SemanticError, SymbolTable}; -use parser::ast::{self, BoundaryStmt, BoundaryVariable, BoundaryVariableType}; +use parser::ast::{self, BoundaryStmt}; // BOUNDARY CONSTRAINTS // ================================================================================================ @@ -10,14 +10,18 @@ use parser::ast::{self, BoundaryStmt, BoundaryVariable, BoundaryVariableType}; /// boundaries (first row and last row). For ease of code generation and evaluation, constraints are /// sorted into maps by the boundary. This also simplifies ensuring that there are no conflicting /// constraints sharing a boundary and column index. -/// TODO: generalize the way we store boundary constraints for more trace segments. #[derive(Default, Debug)] pub(crate) struct BoundaryStmts { boundary_constraints: Vec<(BTreeMap, BTreeMap)>, - variables: Vec, } impl BoundaryStmts { + pub fn new(num_trace_segments: usize) -> Self { + Self { + boundary_constraints: vec![(BTreeMap::new(), BTreeMap::new()); num_trace_segments], + } + } + // --- ACCESSORS ------------------------------------------------------------------------------ pub fn num_boundary_constraints(&self, trace_segment: TraceSegment) -> usize { @@ -47,10 +51,6 @@ impl BoundaryStmts { .collect() } - pub fn variables(&self) -> &Vec { - &self.variables - } - // --- MUTATORS ------------------------------------------------------------------------------- /// Add a boundary statement from the AST. The statement can be either be a variable or a @@ -70,26 +70,7 @@ impl BoundaryStmts { stmt: &BoundaryStmt, ) -> Result<(), SemanticError> { match stmt { - BoundaryStmt::Variable(boundary_variable) => { - // validate the expressions inside the variable's values - match boundary_variable.value() { - BoundaryVariableType::Scalar(expr) => validate_expression(symbol_table, expr)?, - BoundaryVariableType::Vector(vector) => { - for expr in vector { - validate_expression(symbol_table, expr)?; - } - } - BoundaryVariableType::Matrix(matrix) => { - for row in matrix { - for expr in row { - validate_expression(symbol_table, expr)?; - } - } - } - } - symbol_table.insert_boundary_variable(boundary_variable)?; - self.variables.push(boundary_variable.clone()); - } + BoundaryStmt::Variable(_) => unimplemented!(), BoundaryStmt::Constraint(constraint) => { // validate the expression let expr = constraint.value(); @@ -99,34 +80,22 @@ impl BoundaryStmts { let col_type = symbol_table.get_type(constraint.column())?; let result = match col_type { IdentifierType::TraceColumn(column) => match column.trace_segment() { - 0 => { - if self.boundary_constraints.is_empty() { - self.boundary_constraints - .push((BTreeMap::default(), BTreeMap::default())); - } - match constraint.boundary() { - ast::Boundary::First => self.boundary_constraints[0] - .0 - .insert(column.col_idx(), expr), - ast::Boundary::Last => self.boundary_constraints[0] - .1 - .insert(column.col_idx(), expr), - } - } - 1 => { - if self.boundary_constraints.len() == 1 { - self.boundary_constraints - .push((BTreeMap::default(), BTreeMap::default())); - } - match constraint.boundary() { - ast::Boundary::First => self.boundary_constraints[1] - .0 - .insert(column.col_idx(), expr), - ast::Boundary::Last => self.boundary_constraints[1] - .1 - .insert(column.col_idx(), expr), - } - } + 0 => match constraint.boundary() { + ast::Boundary::First => self.boundary_constraints[0] + .0 + .insert(column.col_idx(), expr), + ast::Boundary::Last => self.boundary_constraints[0] + .1 + .insert(column.col_idx(), expr), + }, + 1 => match constraint.boundary() { + ast::Boundary::First => self.boundary_constraints[1] + .0 + .insert(column.col_idx(), expr), + ast::Boundary::Last => self.boundary_constraints[1] + .1 + .insert(column.col_idx(), expr), + }, _ => unimplemented!(), }, _ => { diff --git a/ir/src/lib.rs b/ir/src/lib.rs index d6270ed0d..900c95fb5 100644 --- a/ir/src/lib.rs +++ b/ir/src/lib.rs @@ -1,7 +1,7 @@ pub use parser::ast::{ self, boundary_constraints::BoundaryExpr, constants::Constant, Identifier, PublicInput, + TransitionVariable, }; -use parser::ast::{BoundaryVariable, TransitionVariable}; use std::collections::BTreeMap; mod symbol_table; @@ -11,7 +11,7 @@ pub mod boundary_stmts; use boundary_stmts::BoundaryStmts; pub mod transition_stmts; -use transition_stmts::{AlgebraicGraph, TransitionStmts, VariableValue, MIN_CYCLE_LENGTH}; +use transition_stmts::{AlgebraicGraph, TransitionStmts, MIN_CYCLE_LENGTH}; pub use transition_stmts::{NodeIndex, TransitionConstraintDegree}; mod error; @@ -34,8 +34,6 @@ pub type PeriodicColumns = Vec>; #[derive(Default, Debug)] pub struct AirIR { air_name: String, - //TODO: remove dead code attribute - #[allow(dead_code)] constants: Constants, num_trace_segments: usize, public_inputs: PublicInputs, @@ -93,7 +91,7 @@ impl AirIR { let num_trace_segments = symbol_table.num_trace_segments(); // then process the constraints & validate them against the symbol table. - let mut boundary_stmts = BoundaryStmts::default(); + let mut boundary_stmts = BoundaryStmts::new(num_trace_segments); let mut transition_stmts = TransitionStmts::new(num_trace_segments); for section in source { match section { @@ -185,10 +183,6 @@ impl AirIR { } } - pub fn boundary_variables(&self) -> &Vec { - self.boundary_stmts.variables() - } - // --- PUBLIC ACCESSORS FOR TRANSITION CONSTRAINTS -------------------------------------------- pub fn constraint_degrees( @@ -205,16 +199,4 @@ impl AirIR { pub fn transition_graph(&self) -> &AlgebraicGraph { self.transition_stmts.graph() } - - pub fn transition_variables(&self) -> &Vec { - self.transition_stmts.variables() - } - - pub fn variable_roots(&self) -> &BTreeMap { - self.transition_stmts.variable_roots() - } - - pub fn variables_graph(&self) -> &AlgebraicGraph { - self.transition_stmts.variables_graph() - } } diff --git a/ir/src/symbol_table.rs b/ir/src/symbol_table.rs index 3f6698c21..72621bc2e 100644 --- a/ir/src/symbol_table.rs +++ b/ir/src/symbol_table.rs @@ -4,8 +4,8 @@ use super::{ }; use parser::ast::{ constants::{Constant, ConstantType}, - BoundaryVariable, BoundaryVariableType, Identifier, MatrixAccess, PeriodicColumn, PublicInput, - TransitionVariable, TransitionVariableType, VectorAccess, + Identifier, MatrixAccess, PeriodicColumn, PublicInput, TransitionVariable, + TransitionVariableType, VectorAccess, }; use std::fmt::Display; @@ -21,7 +21,7 @@ pub(super) enum IdentifierType { /// an identifier for a periodic column, containing its index out of all periodic columns and /// its cycle length in that order. PeriodicColumn(usize, usize), - BoundaryVariable(BoundaryVariable), + /// an identifier for a transition variable, containing its name and value TransitionVariable(TransitionVariable), } @@ -34,7 +34,6 @@ impl Display for IdentifierType { Self::TraceColumn(column) => { write!(f, "TraceColumn in segment {}", column.trace_segment()) } - Self::BoundaryVariable(_) => write!(f, "BoundaryVariable"), Self::TransitionVariable(_) => write!(f, "TransitionVariable"), } } @@ -183,17 +182,6 @@ impl SymbolTable { Ok(()) } - pub(super) fn insert_boundary_variable( - &mut self, - variable: &BoundaryVariable, - ) -> Result<(), SemanticError> { - self.insert_symbol( - variable.name(), - IdentifierType::BoundaryVariable(variable.clone()), - )?; - Ok(()) - } - pub(super) fn insert_transition_variable( &mut self, variable: &TransitionVariable, @@ -268,23 +256,6 @@ impl SymbolTable { )) } } - IdentifierType::BoundaryVariable(boundary_variable) => { - if let BoundaryVariableType::Vector(vector) = boundary_variable.value() { - if vector_access.idx() < vector.len() { - Ok(symbol_type) - } else { - Err(SemanticError::vector_access_out_of_bounds( - vector_access, - vector.len(), - )) - } - } else { - Err(SemanticError::invalid_vector_access( - vector_access, - symbol_type, - )) - } - } IdentifierType::TransitionVariable(transition_variable) => { if let TransitionVariableType::Vector(vector) = transition_variable.value() { if vector_access.idx() < vector.len() { @@ -340,30 +311,6 @@ impl SymbolTable { } Ok(symbol_type) } - IdentifierType::BoundaryVariable(boundary_variable) => { - if let BoundaryVariableType::Matrix(matrix) = boundary_variable.value() { - if matrix_access.row_idx() >= matrix.len() { - return Err(SemanticError::matrix_access_out_of_bounds( - matrix_access, - matrix.len(), - matrix[0].len(), - )); - } - if matrix_access.col_idx() >= matrix[0].len() { - return Err(SemanticError::matrix_access_out_of_bounds( - matrix_access, - matrix.len(), - matrix[0].len(), - )); - } - Ok(symbol_type) - } else { - Err(SemanticError::invalid_matrix_access( - matrix_access, - symbol_type, - )) - } - } IdentifierType::TransitionVariable(transition_variable) => { if let TransitionVariableType::Matrix(matrix) = transition_variable.value() { if matrix_access.row_idx() >= matrix.len() { diff --git a/ir/src/tests/mod.rs b/ir/src/tests/mod.rs index 33aee830c..3f96be389 100644 --- a/ir/src/tests/mod.rs +++ b/ir/src/tests/mod.rs @@ -41,31 +41,6 @@ fn boundary_constraints_with_constants() { assert!(result.is_ok()); } -#[test] -fn boundary_constraints_with_variables() { - let source = " - constants: - A: 123 - B: [1, 2, 3] - C: [[1, 2, 3], [4, 5, 6]] - trace_columns: - main: [clk] - public_inputs: - stack_inputs: [16] - transition_constraints: - enf clk' = clk - 1 - boundary_constraints: - let a = 1 - let b = [a, a*a] - let c = [[b[0] - clk, clk - a], [1 + 8, 2^2]] - enf clk.first = A + a - b[0] - enf clk.last = B[0] + C[0][1] - c[0][1]"; - - let parsed = parse(source).expect("Parsing failed"); - let result = AirIR::from_source(&parsed); - assert!(result.is_ok()); -} - #[test] fn err_tc_invalid_vector_access() { let source = " @@ -270,11 +245,12 @@ fn transition_constraints_with_variables() { let c = [[clk' - clk, clk - a], [1 + 8, 2^2]] enf c[0][0] = 1 boundary_constraints: - enf clk.first = A + a - b[0] - enf clk.last = B[0] + C[0][1] - c[0][1]"; + enf clk.first = A + enf clk.last = B[0] + C[0][1]"; let parsed = parse(source).expect("Parsing failed"); let result = AirIR::from_source(&parsed); + println!("{:?}", result); assert!(result.is_ok()); } @@ -549,28 +525,6 @@ fn err_tc_variable_access_before_declaration() { assert!(result.is_err()); } -#[test] -fn err_variable_declaration_in_bc_and_tc() { - let source = " - constants: - A: [[2, 3], [1, 0]] - trace_columns: - main: [clk] - public_inputs: - stack_inputs: [16] - transition_constraints: - let a = 1 - enf clk' = clk + a - boundary_constraints: - let a = 0 - enf clk.first = a - enf clk.last = 1"; - - let parsed = parse(source).expect("Parsing failed"); - let result = AirIR::from_source(&parsed); - assert!(result.is_err()); -} - #[test] fn err_variable_vector_invalid_access() { let source = " diff --git a/ir/src/transition_stmts/graph.rs b/ir/src/transition_stmts/graph.rs index 36c854109..34ceff902 100644 --- a/ir/src/transition_stmts/graph.rs +++ b/ir/src/transition_stmts/graph.rs @@ -53,11 +53,6 @@ impl AlgebraicGraph { // recursively walk the subgraph and compute the degree from the operation and child nodes match self.node(index).op() { Operation::Constant(_) | Operation::RandomValue(_) => 0, - // TODO: Get the degree from variables graph instead of adding them again in constraint - // graph. - Operation::Variable(_, expr_node_index) => { - self.accumulate_degree(cycles, expr_node_index) - } Operation::TraceElement(_) => 1, Operation::PeriodicColumn(index, cycle_len) => { cycles.insert(*index, *cycle_len); @@ -95,6 +90,7 @@ impl AlgebraicGraph { &mut self, symbol_table: &mut SymbolTable, expr: TransitionExpr, + variable_roots: &mut BTreeMap, ) -> Result<(TraceSegment, NodeIndex), SemanticError> { match expr { TransitionExpr::Const(value) => { @@ -104,13 +100,13 @@ impl AlgebraicGraph { Ok((trace_segment, node_index)) } TransitionExpr::Elem(Identifier(ident)) => { - self.insert_symbol_access(symbol_table, &ident) + self.insert_symbol_access(symbol_table, &ident, variable_roots) } TransitionExpr::VectorAccess(vector_access) => { - self.insert_vector_access(symbol_table, &vector_access) + self.insert_vector_access(symbol_table, &vector_access, variable_roots) } TransitionExpr::MatrixAccess(matrix_access) => { - self.insert_matrix_access(symbol_table, &matrix_access) + self.insert_matrix_access(symbol_table, &matrix_access, variable_roots) } TransitionExpr::Next(Identifier(ident)) => self.insert_next(symbol_table, &ident), TransitionExpr::Rand(index) => { @@ -124,8 +120,8 @@ impl AlgebraicGraph { } TransitionExpr::Add(lhs, rhs) => { // add both subexpressions. - let (lhs_segment, lhs) = self.insert_expr(symbol_table, *lhs)?; - let (rhs_segment, rhs) = self.insert_expr(symbol_table, *rhs)?; + let (lhs_segment, lhs) = self.insert_expr(symbol_table, *lhs, variable_roots)?; + let (rhs_segment, rhs) = self.insert_expr(symbol_table, *rhs, variable_roots)?; // add the expression. let trace_segment = lhs_segment.max(rhs_segment); let node_index = self.insert_op(Operation::Add(lhs, rhs)); @@ -133,8 +129,8 @@ impl AlgebraicGraph { } TransitionExpr::Sub(lhs, rhs) => { // add both subexpressions. - let (lhs_segment, lhs) = self.insert_expr(symbol_table, *lhs)?; - let (rhs_segment, rhs) = self.insert_expr(symbol_table, *rhs)?; + let (lhs_segment, lhs) = self.insert_expr(symbol_table, *lhs, variable_roots)?; + let (rhs_segment, rhs) = self.insert_expr(symbol_table, *rhs, variable_roots)?; // add the expression. let trace_segment = lhs_segment.max(rhs_segment); let node_index = self.insert_op(Operation::Sub(lhs, rhs)); @@ -142,8 +138,8 @@ impl AlgebraicGraph { } TransitionExpr::Mul(lhs, rhs) => { // add both subexpressions. - let (lhs_segment, lhs) = self.insert_expr(symbol_table, *lhs)?; - let (rhs_segment, rhs) = self.insert_expr(symbol_table, *rhs)?; + let (lhs_segment, lhs) = self.insert_expr(symbol_table, *lhs, variable_roots)?; + let (rhs_segment, rhs) = self.insert_expr(symbol_table, *rhs, variable_roots)?; // add the expression. let trace_segment = lhs_segment.max(rhs_segment); let node_index = self.insert_op(Operation::Mul(lhs, rhs)); @@ -151,7 +147,7 @@ impl AlgebraicGraph { } TransitionExpr::Exp(lhs, rhs) => { // add base subexpression. - let (trace_segment, lhs) = self.insert_expr(symbol_table, *lhs)?; + let (trace_segment, lhs) = self.insert_expr(symbol_table, *lhs, variable_roots)?; // add exponent subexpression. let node_index = self.insert_op(Operation::Exp(lhs, rhs as usize)); Ok((trace_segment, node_index)) @@ -186,6 +182,7 @@ impl AlgebraicGraph { &mut self, symbol_table: &mut SymbolTable, ident: &str, + variable_roots: &mut BTreeMap, ) -> Result<(TraceSegment, NodeIndex), SemanticError> { let elem_type = symbol_table.get_type(ident)?; match elem_type { @@ -210,13 +207,19 @@ impl AlgebraicGraph { } IdentifierType::TransitionVariable(transition_variable) => { if let TransitionVariableType::Scalar(expr) = transition_variable.value() { - let (trace_segment, expr_node_index) = - self.insert_expr(symbol_table, expr.clone())?; - let node_index = self.insert_op(Operation::Variable( - VariableValue::Scalar(ident.to_string()), - expr_node_index, - )); - Ok((trace_segment, node_index)) + if let Some((trace_segment, node_index)) = + variable_roots.get(&VariableValue::Scalar(ident.to_string())) + { + Ok((*trace_segment, *node_index)) + } else { + let (trace_segment, node_index) = + self.insert_expr(symbol_table, expr.clone(), variable_roots)?; + variable_roots.insert( + VariableValue::Scalar(ident.to_string()), + (trace_segment, node_index), + ); + Ok((trace_segment, node_index)) + } } else { Err(SemanticError::InvalidUsage(format!( "Identifier {} was declared as a {} which is not a supported type.", @@ -237,6 +240,7 @@ impl AlgebraicGraph { &mut self, symbol_table: &mut SymbolTable, vector_access: &VectorAccess, + variable_roots: &mut BTreeMap, ) -> Result<(TraceSegment, NodeIndex), SemanticError> { let symbol_type = symbol_table.access_vector_element(vector_access)?; match symbol_type { @@ -250,13 +254,19 @@ impl AlgebraicGraph { IdentifierType::TransitionVariable(transition_variable) => { if let TransitionVariableType::Vector(vector) = transition_variable.value() { let expr = &vector[vector_access.idx()]; - let (trace_segment, expr_node_index) = - self.insert_expr(symbol_table, expr.clone())?; - let node_index = self.insert_op(Operation::Variable( - VariableValue::Vector(vector_access.clone()), - expr_node_index, - )); - Ok((trace_segment, node_index)) + if let Some((trace_segment, node_index)) = + variable_roots.get(&VariableValue::Vector(vector_access.clone())) + { + Ok((*trace_segment, *node_index)) + } else { + let (trace_segment, node_index) = + self.insert_expr(symbol_table, expr.clone(), variable_roots)?; + variable_roots.insert( + VariableValue::Vector(vector_access.clone()), + (trace_segment, node_index), + ); + Ok((trace_segment, node_index)) + } } else { Err(SemanticError::InvalidUsage(format!( "Identifier {} was declared as a {} which is not a supported type.", @@ -278,6 +288,7 @@ impl AlgebraicGraph { &mut self, symbol_table: &mut SymbolTable, matrix_access: &MatrixAccess, + variable_roots: &mut BTreeMap, ) -> Result<(TraceSegment, NodeIndex), SemanticError> { let symbol_type = symbol_table.access_matrix_element(matrix_access)?; match symbol_type { @@ -291,13 +302,19 @@ impl AlgebraicGraph { IdentifierType::TransitionVariable(transition_variable) => { if let TransitionVariableType::Matrix(matrix) = transition_variable.value() { let expr = &matrix[matrix_access.row_idx()][matrix_access.col_idx()]; - let (trace_segment, expr_node_index) = - self.insert_expr(symbol_table, expr.clone())?; - let node_index = self.insert_op(Operation::Variable( - VariableValue::Matrix(matrix_access.clone()), - expr_node_index, - )); - Ok((trace_segment, node_index)) + if let Some((trace_segment, node_index)) = + variable_roots.get(&VariableValue::Matrix(matrix_access.clone())) + { + Ok((*trace_segment, *node_index)) + } else { + let (trace_segment, node_index) = + self.insert_expr(symbol_table, expr.clone(), variable_roots)?; + variable_roots.insert( + VariableValue::Matrix(matrix_access.clone()), + (trace_segment, node_index), + ); + Ok((trace_segment, node_index)) + } } else { Err(SemanticError::invalid_matrix_access( matrix_access, @@ -351,8 +368,6 @@ impl Node { pub enum Operation { /// An inlined or named constant with identifier and access indices. Constant(ConstantValue), - /// A variable with [VariableValue] and root index of the variable's expression value. - Variable(VariableValue, NodeIndex), /// An identifier for an element in the trace segment, column, and row offset specified by the /// [TraceAccess] TraceElement(TraceAccess), diff --git a/ir/src/transition_stmts/mod.rs b/ir/src/transition_stmts/mod.rs index 7e8c4955d..f4f8e7668 100644 --- a/ir/src/transition_stmts/mod.rs +++ b/ir/src/transition_stmts/mod.rs @@ -1,10 +1,6 @@ -use std::collections::BTreeMap; - use super::{SemanticError, SymbolTable, TraceSegment}; -use parser::ast::{ - Identifier, MatrixAccess, TransitionExpr, TransitionStmt, TransitionVariable, - TransitionVariableType, VectorAccess, -}; +use parser::ast::TransitionStmt; +use std::collections::BTreeMap; mod degree; pub use degree::TransitionConstraintDegree; @@ -32,16 +28,10 @@ pub(super) struct TransitionStmts { /// A directed acyclic graph which represents all of the transition constraints. constraints_graph: AlgebraicGraph, - /// A vector containing all the variables defined in the transition constraints section. - variables: Vec, - - /// Variable roots for all the variables in the variables graph. For each element in a vector - /// or a matrix, a new root is added with a key equal to the [VariableValue] of the element. - variable_roots: BTreeMap, - - /// A directed acyclic graph which represents all of the variables defined in the transition - /// constraints section. - variables_graph: AlgebraicGraph, + /// Variable roots for the variables used in transition constraints. For each element in a + /// vector or a matrix, a new root is added with a key equal to the [VariableValue] of the + /// element. + variable_roots: BTreeMap, } impl TransitionStmts { @@ -51,9 +41,7 @@ impl TransitionStmts { Self { constraint_roots: vec![Vec::new(); num_trace_segments], constraints_graph: AlgebraicGraph::default(), - variables: Vec::new(), variable_roots: BTreeMap::new(), - variables_graph: AlgebraicGraph::default(), } } @@ -91,24 +79,6 @@ impl TransitionStmts { &self.constraints_graph } - /// Returns all the variables defined in the transition constraints section. - pub fn variables(&self) -> &Vec { - &self.variables - } - - /// Returns variable roots map for the variables defined in the transition constraints section. - /// The value of the map contains the tip of the subgraph representing the variable within the - /// variables [AlgebraicGraph]. - pub fn variable_roots(&self) -> &BTreeMap { - &self.variable_roots - } - - /// Returns the [AlgebraicGraph] representing all variables defined in the transition - /// constraints section. - pub fn variables_graph(&self) -> &AlgebraicGraph { - &self.variables_graph - } - // --- MUTATORS ------------------------------------------------------------------------------- /// Adds the provided parsed transition statement to the graph. The statement can either be a @@ -127,44 +97,15 @@ impl TransitionStmts { stmt: &TransitionStmt, ) -> Result<(), SemanticError> { match stmt { - TransitionStmt::Variable(variable) => { - symbol_table.insert_transition_variable(variable)?; - match variable.value() { - TransitionVariableType::Scalar(expr) => { - let variable_value = VariableValue::Scalar(variable.name().to_string()); - self.insert_variable_expr(symbol_table, variable_value, expr)?; - } - TransitionVariableType::Vector(vector) => { - for (idx, expr) in vector.iter().enumerate() { - let variable_value = VariableValue::Vector(VectorAccess::new( - Identifier(variable.name().to_string()), - idx, - )); - self.insert_variable_expr(symbol_table, variable_value, expr)?; - } - } - TransitionVariableType::Matrix(matrix) => { - for (row_idx, row) in matrix.iter().enumerate() { - for (col_idx, expr) in row.iter().enumerate() { - let variable_value = VariableValue::Matrix(MatrixAccess::new( - Identifier(variable.name().to_string()), - row_idx, - col_idx, - )); - self.insert_variable_expr(symbol_table, variable_value, expr)?; - } - } - } - } - - self.variables.push(variable.clone()) - } TransitionStmt::Constraint(constraint) => { let expr = constraint.expr(); // add it to the transition constraints graph and get its entry index. - let (trace_segment, root_index) = - self.constraints_graph.insert_expr(symbol_table, expr)?; + let (trace_segment, root_index) = self.constraints_graph.insert_expr( + symbol_table, + expr, + &mut self.variable_roots, + )?; // the constraint should not be against an undeclared trace segment. if symbol_table.num_trace_segments() <= trace_segment.into() { @@ -176,25 +117,11 @@ impl TransitionStmts { // add the transition constraint to the appropriate set of constraints. self.constraint_roots[trace_segment as usize].push(root_index); } + TransitionStmt::Variable(variable) => { + symbol_table.insert_transition_variable(variable)? + } } Ok(()) } - - /// A helper function to insert variable expression in the variables graph as a subgraph and - /// add its root to the variable_roots map. - fn insert_variable_expr( - &mut self, - symbol_table: &mut SymbolTable, - variable_value: VariableValue, - expr: &TransitionExpr, - ) -> Result<(), SemanticError> { - // add it to the transition constraints graph and get its entry index. - let (_, root_index) = self - .variables_graph - .insert_expr(symbol_table, expr.clone())?; - - self.variable_roots.insert(variable_value, root_index); - Ok(()) - } }