Skip to content

Commit

Permalink
test: add integration test for variables
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Dec 22, 2022
1 parent b7e0acf commit 678c234
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 8 deletions.
10 changes: 10 additions & 0 deletions air-script/tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,13 @@ fn constants() {
let expected = expect_file!["constants/constants.rs"];
expected.assert_eq(&generated_air);
}

#[test]
fn variables() {
let generated_air = Test::new("tests/variables/variables.air".to_string())
.transpile()
.unwrap();

let expected = expect_file!["variables/variables.rs"];
expected.assert_eq(&generated_air);
}
49 changes: 49 additions & 0 deletions air-script/tests/variables/variables.air
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Variables AIR in AirScript

def VariablesAir

trace_columns:
main: [s, a, b, c]
aux: [p]

public_inputs:
stack_inputs: [16]
stack_outputs: [16]

periodic_columns:
k0: [1, 1, 1, 1, 1, 1, 1, 0]

boundary_constraints:
# define boundary constraints against the main trace at the first row of the trace.
let x = 1
enf a.first = stack_inputs[0]
let y = [x, 4 - 2]
enf b.first = x
enf c.first = y[0]
let z = [[x, 3], [4 - 2, 8 + 8]]
# define boundary constraints against the main trace at the last row of the trace.
enf a.last = stack_outputs[0]
enf b.last = z[0][0]
enf c.last = stack_outputs[2]

# set the first row of the auxiliary column p to 1
enf p.first = 1

transition_constraints:
let m = 0
# the selector must be binary.
enf s^2 = s

let n = [2 * 3, s]
let o = [[s', 3], [4 - 2, 8 + 8]]
# selector should stay the same for all rows of an 8-row cycle.
enf k0 * (s' - s) = m

# c = a + b when s = 0.
enf (1 - s) * (c - a + b) = n[0] - n[1]

# c = a * b when s = 1.
enf s * (c - a * b) = o[0][0] - o[0][1] - o[1][0]

# the auxiliary column contains the product of values of c offset by a random value.
enf p' = p * (c + $rand[0])
114 changes: 114 additions & 0 deletions air-script/tests/variables/variables.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo};
use winter_math::fields::f64::BaseElement as Felt;
use winter_math::{ExtensionOf, FieldElement};
use winter_utils::collections::Vec;
use winter_utils::{ByteWriter, Serializable};

pub struct PublicInputs {
stack_inputs: [Felt; 16],
stack_outputs: [Felt; 16],
}

impl PublicInputs {
pub fn new(stack_inputs: [Felt; 16], stack_outputs: [Felt; 16]) -> Self {
Self { stack_inputs, stack_outputs }
}
}

impl Serializable for PublicInputs {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.stack_inputs.as_slice());
target.write(self.stack_outputs.as_slice());
}
}

pub struct VariablesAir {
context: AirContext<Felt>,
stack_inputs: [Felt; 16],
stack_outputs: [Felt; 16],
}

impl VariablesAir {
pub fn last_step(&self) -> usize {
self.trace_length() - self.context().num_transition_exemptions()
}
}

impl Air for VariablesAir {
type BaseField = Felt;
type PublicInputs = PublicInputs;

fn context(&self) -> &AirContext<Felt> {
&self.context
}

fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self {
let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(3)];
let aux_degrees = vec![TransitionConstraintDegree::new(2)];
let num_main_assertions = 6;
let num_aux_assertions = 1;

let context = AirContext::new_multi_segment(
trace_info,
main_degrees,
aux_degrees,
num_main_assertions,
num_aux_assertions,
options,
)
.set_num_transition_exemptions(2);
Self { context, stack_inputs: public_inputs.stack_inputs, stack_outputs: public_inputs.stack_outputs }
}

fn get_periodic_column_values(&self) -> Vec<Vec<Felt>> {
vec![vec![Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(0)]]
}

fn get_assertions(&self) -> Vec<Assertion<Felt>> {
let x = Felt::new(1);
let y = [x, (Felt::new(4)) - (Felt::new(2))];
let z = [[x, Felt::new(3)], [(Felt::new(4)) - (Felt::new(2)), (Felt::new(8)) + (Felt::new(8))]];
let mut result = Vec::new();
result.push(Assertion::single(1, 0, self.stack_inputs[0]));
result.push(Assertion::single(2, 0, x));
result.push(Assertion::single(3, 0, y[0]));
let last_step = self.last_step();
result.push(Assertion::single(1, last_step, self.stack_outputs[0]));
result.push(Assertion::single(2, last_step, z[0][0]));
result.push(Assertion::single(3, last_step, self.stack_outputs[2]));
result
}

fn get_aux_assertions<E: FieldElement<BaseField = Felt>>(&self, aux_rand_elements: &AuxTraceRandElements<E>) -> Vec<Assertion<E>> {
let x = E::from(1_u64);
let y = [E::from(x), (E::from(4_u64)) - (E::from(2_u64))];
let z = [[E::from(x), E::from(3_u64)], [(E::from(4_u64)) - (E::from(2_u64)), (E::from(8_u64)) + (E::from(8_u64))]];
let mut result = Vec::new();
result.push(Assertion::single(0, 0, E::from(1_u64)));
result
}

fn evaluate_transition<E: FieldElement<BaseField = Felt>>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], result: &mut [E]) {
let m = E::from(0_u64);
let n = [(E::from(2_u64)) * (E::from(3_u64)), current[0]];
let o = [[next[0], E::from(3_u64)], [E::from(4_u64) - (E::from(2_u64)), E::from(8_u64) + E::from(8_u64)]];
let current = frame.current();
let next = frame.next();
result[0] = (current[0]).exp(E::PositiveInteger::from(2_u64)) - (current[0]);
result[1] = (periodic_values[0]) * (next[0] - (current[0])) - (m);
result[2] = (E::from(1_u64) - (current[0])) * (current[3] - (current[1]) + current[2]) - (n[0] - (n[1]));
result[3] = (current[0]) * (current[3] - ((current[1]) * (current[2]))) - (o[0][0] - (o[0][1]) - (o[1][0]));
}

fn evaluate_aux_transition<F, E>(&self, main_frame: &EvaluationFrame<F>, aux_frame: &EvaluationFrame<E>, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements<E>, result: &mut [E])
where F: FieldElement<BaseField = Felt>,
E: FieldElement<BaseField = Felt> + ExtensionOf<F>,
{
let m = E::from(0_u64);
let n = [(E::from(2_u64)) * (E::from(3_u64)), current[0]];
let o = [[next[0], E::from(3_u64)], [E::from(4_u64) - (E::from(2_u64)), E::from(8_u64) + E::from(8_u64)]];
let current = aux_frame.current();
let next = aux_frame.next();
result[0] = next[0] - ((current[0]) * (current[3] + aux_rand_elements.get_segment_elements(0)[0]));
}
}
1 change: 0 additions & 1 deletion ir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ impl AirIR {
}

pub fn num_aux_assertions(&self) -> usize {
println!("{}", self.num_trace_segments);
if self.num_trace_segments == 2 {
self.boundary_stmts.num_boundary_constraints(1)
} else {
Expand Down
24 changes: 20 additions & 4 deletions ir/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,6 @@ fn err_variable_vector_invalid_access() {

let parsed = parse(source).expect("Parsing failed");
let result = AirIR::from_source(&parsed);
println!("{:?}", result);
assert!(result.is_err());
}

Expand All @@ -594,14 +593,31 @@ fn err_variable_matrix_invalid_access() {
public_inputs:
stack_inputs: [16]
transition_constraints:
let a = [1, 2]
enf clk' = clk + a[1]
let a = [[1, 2, 3], [4, 5, 6]]
enf clk' = clk + a[1][3]
boundary_constraints:
enf clk.first = 0
enf clk.last = 1";

let parsed = parse(source).expect("Parsing failed");
let result = AirIR::from_source(&parsed);
assert!(result.is_err());

let source = "
constants:
A: [[2, 3], [1, 0]]
trace_columns:
main: [clk]
public_inputs:
stack_inputs: [16]
transition_constraints:
let a = [[1, 2, 3], [4, 5, 6]]
enf clk' = clk + a[2][0]
boundary_constraints:
enf clk.first = 0
enf clk.last = 1";

let parsed = parse(source).expect("Parsing failed");
let result = AirIR::from_source(&parsed);
println!("{:?}", result);
assert!(result.is_err());
}
1 change: 0 additions & 1 deletion ir/src/transition_stmts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,3 @@ impl TransitionStmts {
Ok(())
}
}

4 changes: 2 additions & 2 deletions parser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use transition_constraints::*;
// ================================================================================================

/// [Source] is the root node of the AST representing the AIR constraints file.
#[derive(Debug, PartialEq)]
#[derive(Debug, Eq, PartialEq)]
pub struct Source(pub Vec<SourceSection>);

/// Source is divided into SourceSections.
Expand All @@ -36,7 +36,7 @@ pub struct Source(pub Vec<SourceSection>);
/// representing the first and last rows of the column.
/// - TransitionConstraints: Transition Constraints to be enforced on the trace columns defined
/// in the TraceCols section.
#[derive(Debug, PartialEq)]
#[derive(Debug, Eq, PartialEq)]
pub enum SourceSection {
AirDef(Identifier),
Constant(Constant),
Expand Down

0 comments on commit 678c234

Please sign in to comment.