Skip to content

Commit

Permalink
refactor(ir): Refactor IR
Browse files Browse the repository at this point in the history
- Remove variables for boundary constraints
- Remove variables graph
- Remove variables vector
  • Loading branch information
tohrnii committed Dec 20, 2022
1 parent b6fddf7 commit d2591d6
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 532 deletions.
18 changes: 4 additions & 14 deletions air-script/tests/variables/variables.air
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,18 @@ periodic_columns:
k0: [1, 1, 1, 1, 1, 1, 1, 0]

boundary_constraints:
# define boundary constraints against the main trace at the first row of the trace.
let x = 1
enf a.first = stack_inputs[0]
let y = [x, 4 - 2]
enf b.first = x
enf c.first = y[0]
let z = [[x, 3], [4 - 2, 8 + 8]]
# define boundary constraints against the main trace at the last row of the trace.
enf a.last = stack_outputs[0]
enf b.last = z[0][0]
enf c.last = stack_outputs[2]

# set the first row of the auxiliary column p to 1
enf p.first = 1

enf a.last = 1

transition_constraints:
let m = 0

# the selector must be binary.
enf s^2 = s

let n = [2 * 3, s]
let o = [[s', 3], [4 - 2, 8 + 8]]

# selector should stay the same for all rows of an 8-row cycle.
enf k0 * (s' - s) = m

Expand Down
29 changes: 6 additions & 23 deletions air-script/tests/variables/variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ impl Air for VariablesAir {
fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self {
let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(3)];
let aux_degrees = vec![TransitionConstraintDegree::new(2)];
let num_main_assertions = 6;
let num_aux_assertions = 1;
let num_main_assertions = 2;
let num_aux_assertions = 0;

let context = AirContext::new_multi_segment(
trace_info,
Expand All @@ -65,48 +65,31 @@ impl Air for VariablesAir {
}

fn get_assertions(&self) -> Vec<Assertion<Felt>> {
let x = Felt::new(1);
let y = [x, (Felt::new(4)) - (Felt::new(2))];
let z = [[x, Felt::new(3)], [(Felt::new(4)) - (Felt::new(2)), (Felt::new(8)) + (Felt::new(8))]];
let mut result = Vec::new();
result.push(Assertion::single(1, 0, self.stack_inputs[0]));
result.push(Assertion::single(2, 0, x));
result.push(Assertion::single(3, 0, y[0]));
let last_step = self.last_step();
result.push(Assertion::single(1, last_step, self.stack_outputs[0]));
result.push(Assertion::single(2, last_step, z[0][0]));
result.push(Assertion::single(3, last_step, self.stack_outputs[2]));
result.push(Assertion::single(1, last_step, Felt::new(1)));
result
}

fn get_aux_assertions<E: FieldElement<BaseField = Felt>>(&self, aux_rand_elements: &AuxTraceRandElements<E>) -> Vec<Assertion<E>> {
let x = E::from(1_u64);
let y = [E::from(x), (E::from(4_u64)) - (E::from(2_u64))];
let z = [[E::from(x), E::from(3_u64)], [(E::from(4_u64)) - (E::from(2_u64)), (E::from(8_u64)) + (E::from(8_u64))]];
let mut result = Vec::new();
result.push(Assertion::single(0, 0, E::from(1_u64)));
result
}

fn evaluate_transition<E: FieldElement<BaseField = Felt>>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], result: &mut [E]) {
let m = E::from(0_u64);
let n = [(E::from(2_u64)) * (E::from(3_u64)), current[0]];
let o = [[next[0], E::from(3_u64)], [E::from(4_u64) - (E::from(2_u64)), E::from(8_u64) + E::from(8_u64)]];
let current = frame.current();
let next = frame.next();
result[0] = (current[0]).exp(E::PositiveInteger::from(2_u64)) - (current[0]);
result[1] = (periodic_values[0]) * (next[0] - (current[0])) - (m);
result[2] = (E::from(1_u64) - (current[0])) * (current[3] - (current[1]) + current[2]) - (n[0] - (n[1]));
result[3] = (current[0]) * (current[3] - ((current[1]) * (current[2]))) - (o[0][0] - (o[0][1]) - (o[1][0]));
result[1] = (periodic_values[0]) * (next[0] - (current[0])) - (E::from(0_u64));
result[2] = (E::from(1_u64) - (current[0])) * (current[3] - (current[1]) + current[2]) - ((E::from(2_u64)) * (E::from(3_u64)) - (current[0]));
result[3] = (current[0]) * (current[3] - ((current[1]) * (current[2]))) - (next[0] - (E::from(3_u64)) - (E::from(4_u64) - (E::from(2_u64))));
}

fn evaluate_aux_transition<F, E>(&self, main_frame: &EvaluationFrame<F>, aux_frame: &EvaluationFrame<E>, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements<E>, result: &mut [E])
where F: FieldElement<BaseField = Felt>,
E: FieldElement<BaseField = Felt> + ExtensionOf<F>,
{
let m = E::from(0_u64);
let n = [(E::from(2_u64)) * (E::from(3_u64)), current[0]];
let o = [[next[0], E::from(3_u64)], [E::from(4_u64) - (E::from(2_u64)), E::from(8_u64) + E::from(8_u64)]];
let current = aux_frame.current();
let next = aux_frame.next();
result[0] = next[0] - ((current[0]) * (current[3] + aux_rand_elements.get_segment_elements(0)[0]));
Expand Down
60 changes: 1 addition & 59 deletions codegen/winterfell/src/air/boundary_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{AirIR, Codegen, Impl};
use ir::{ast::BoundaryVariableType, BoundaryExpr};
use ir::BoundaryExpr;

// HELPERS TO GENERATE THE WINTERFELL BOUNDARY CONSTRAINT METHODS
// ================================================================================================
Expand All @@ -13,12 +13,6 @@ pub(super) fn add_fn_get_assertions(impl_ref: &mut Impl, ir: &AirIR) {
.arg_ref_self()
.ret("Vec<Assertion<Felt>>");

// TODO: Only add variables used in main trace assertions
let variables = add_variables(ir, false);
for variable in variables {
get_assertions.line(variable);
}

// declare the result vector to be returned.
get_assertions.line("let mut result = Vec::new();");

Expand Down Expand Up @@ -62,12 +56,6 @@ pub(super) fn add_fn_get_aux_assertions(impl_ref: &mut Impl, ir: &AirIR) {
.arg("aux_rand_elements", "&AuxTraceRandElements<E>")
.ret("Vec<Assertion<E>>");

// TODO: Only add variables used in aux trace assertions
let variables = add_variables(ir, true);
for variable in variables {
get_aux_assertions.line(variable);
}

// declare the result vector to be returned.
get_aux_assertions.line("let mut result = Vec::new();");

Expand Down Expand Up @@ -100,52 +88,6 @@ pub(super) fn add_fn_get_aux_assertions(impl_ref: &mut Impl, ir: &AirIR) {
get_aux_assertions.line("result");
}

/// A helper function to add variable definitions to the enforce_assertions and
/// enforce_aux_assertions functions.
fn add_variables(ir: &AirIR, is_aux_constraint: bool) -> Vec<String> {
let mut vars = Vec::new();
let variables = ir.boundary_variables();
for variable in variables {
let variable_name = variable.name();
let variable_def = match variable.value() {
BoundaryVariableType::Scalar(expr) => {
format!(
"let {} = {};",
variable_name,
expr.to_string(ir, is_aux_constraint)
)
}
BoundaryVariableType::Vector(vector) => format!(
"let {} = [{}];",
variable_name,
vector
.iter()
.map(|expr| expr.to_string(ir, is_aux_constraint))
.collect::<Vec<String>>()
.join(", ")
),
BoundaryVariableType::Matrix(matrix) => {
let variable_value = {
let mut rows = vec![];
for row in matrix {
rows.push(format!(
"[{}]",
row.iter()
.map(|expr| expr.to_string(ir, is_aux_constraint))
.collect::<Vec<String>>()
.join(", "),
))
}
format!("[{}]", rows.join(", "))
};
format!("let {} = {};", variable_name, variable_value)
}
};
vars.push(variable_def);
}
vars
}

// RUST STRING GENERATION
// ================================================================================================

Expand Down
86 changes: 2 additions & 84 deletions codegen/winterfell/src/air/transition_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::{AirIR, Impl};
use ir::{
ast::{MatrixAccess, TransitionVariableType, VectorAccess},
transition_stmts::{AlgebraicGraph, ConstantValue, Operation, VariableValue},
Identifier, NodeIndex,
transition_stmts::{AlgebraicGraph, ConstantValue, Operation},
NodeIndex,
};

// HELPERS TO GENERATE THE WINTERFELL TRANSITION CONSTRAINT METHODS
Expand All @@ -20,12 +19,6 @@ pub(super) fn add_fn_evaluate_transition(impl_ref: &mut Impl, ir: &AirIR) {
.arg("periodic_values", "&[E]")
.arg("result", "&mut [E]");

// TODO: Only add variables used in main trace assertions
let variables = add_variables(ir);
for variable in variables {
evaluate_transition.line(variable);
}

// declare current and next trace row arrays.
evaluate_transition.line("let current = frame.current();");
evaluate_transition.line("let next = frame.next();");
Expand Down Expand Up @@ -57,12 +50,6 @@ pub(super) fn add_fn_evaluate_aux_transition(impl_ref: &mut Impl, ir: &AirIR) {
.bound("F", "FieldElement<BaseField = Felt>")
.bound("E", "FieldElement<BaseField = Felt> + ExtensionOf<F>");

// TODO: Only add variables used in aux trace assertions.
let variables = add_variables(ir);
for variable in variables {
evaluate_aux_transition.line(variable);
}

// declare current and next trace row arrays.
evaluate_aux_transition.line("let current = aux_frame.current();");
evaluate_aux_transition.line("let next = aux_frame.next();");
Expand All @@ -78,65 +65,6 @@ pub(super) fn add_fn_evaluate_aux_transition(impl_ref: &mut Impl, ir: &AirIR) {
}
}

/// A helper function to add variable definitions to the evaluate_transition and
/// evaluate_aux_transition functions.
fn add_variables(ir: &AirIR) -> Vec<String> {
let mut vars = Vec::new();
let variables = ir.transition_variables();
let variables_graph = ir.variables_graph();
for variable in variables {
let variable_name = variable.name();
let variable_def = match variable.value() {
TransitionVariableType::Scalar(_) => {
let key = VariableValue::Scalar(variable_name.to_string());
let variable_value = ir.variable_roots().get(&key).unwrap_or_else(|| {
panic!("Variable {} not found in variable_roots map", variable_name)
});
format!(
"let {} = {};",
variable_name,
variable_value.to_string(variables_graph)
)
}
TransitionVariableType::Vector(vector) => {
let mut vector_str = Vec::new();
for idx in 0..vector.len() {
let key = VariableValue::Vector(VectorAccess::new(
Identifier(variable_name.to_string()),
idx,
));
let variable_value = ir.variable_roots().get(&key).unwrap_or_else(|| {
panic!("Variable {} not found in variable_roots map", variable_name)
});
vector_str.push(variable_value.to_string(variables_graph));
}
format!("let {} = [{}];", variable_name, vector_str.join(", "))
}
TransitionVariableType::Matrix(matrix) => {
let mut rows = Vec::new();
for row_idx in 0..matrix.len() {
let mut cols = Vec::new();
for col_idx in 0..matrix[0].len() {
let key = VariableValue::Matrix(MatrixAccess::new(
Identifier(variable_name.to_string()),
row_idx,
col_idx,
));
let variable_value = ir.variable_roots().get(&key).unwrap_or_else(|| {
panic!("Variable {} not found in variable_roots map", variable_name)
});
cols.push(variable_value.to_string(variables_graph));
}
rows.push(format!("[{}]", cols.join(", ")));
}
format!("let {} = [{}];", variable_name, rows.join(", "))
}
};
vars.push(variable_def);
}
vars
}

/// Code generation trait for generating Rust code strings from [AlgebraicGraph] types.
trait Codegen {
fn to_string(&self, graph: &AlgebraicGraph) -> String;
Expand Down Expand Up @@ -164,16 +92,6 @@ impl Codegen for Operation {
matrix_access.row_idx(),
matrix_access.col_idx()
),
Operation::Variable(VariableValue::Scalar(ident), _) => ident.to_string(),
Operation::Variable(VariableValue::Vector(vector_access), _) => {
format!("{}[{}]", vector_access.name(), vector_access.idx())
}
Operation::Variable(VariableValue::Matrix(matrix_access), _) => format!(
"{}[{}][{}]",
matrix_access.name(),
matrix_access.row_idx(),
matrix_access.col_idx()
),
Operation::TraceElement(trace_access) => match trace_access.row_offset() {
0 => {
format!("current[{}]", trace_access.col_idx())
Expand Down
Loading

0 comments on commit d2591d6

Please sign in to comment.