Skip to content

Commit

Permalink
test: add integration tests for constants codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Dec 6, 2022
1 parent 6c680fb commit aaf85a3
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 9 deletions.
26 changes: 26 additions & 0 deletions air-script/tests/constants/constants.air
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
def PubInputsAir

constants:
A: 1
B: [0, 1]
C: [[1, 2], [2, 0]]

trace_columns:
main: [a, b, c, d]

public_inputs:
program_hash: [4]
stack_inputs: [4]
stack_outputs: [20]
overflow_addrs: [4]

boundary_constraints:
enf a.first = A
enf b.first = A + B[0] * C[0][1]
enf c.first = (B[0] - C[1][1]) * A
enf d.first = A + B[0] - B[1] + C[0][0] - C[0][1] + C[1][0] - C[1][1]

transition_constraints:
enf a' = a + A
enf b' = B[0] * b
enf c' = (C[0][0] + B[0]) * c
104 changes: 104 additions & 0 deletions air-script/tests/constants/constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo};
use winter_math::{fields, ExtensionOf, FieldElement};
use winter_utils::{collections, ByteWriter, Serializable};

const A: u64 = 1;
const B: Vec<u64> = vec![0, 1];
const C: Vec<Vec<u64>> = vec![vec![1, 2], vec![2, 0]];

pub struct PublicInputs {
program_hash: [Felt; 4],
stack_inputs: [Felt; 4],
stack_outputs: [Felt; 20],
overflow_addrs: [Felt; 4],
}

impl PublicInputs {
pub fn new(program_hash: [Felt; 4], stack_inputs: [Felt; 4], stack_outputs: [Felt; 20], overflow_addrs: [Felt; 4]) -> Self {
Self { program_hash, stack_inputs, stack_outputs, overflow_addrs }
}
}

impl Serializable for PublicInputs {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.program_hash.as_slice());
target.write(self.stack_inputs.as_slice());
target.write(self.stack_outputs.as_slice());
target.write(self.overflow_addrs.as_slice());
}
}

pub struct PubInputsAir {
context: AirContext<Felt>,
program_hash: [Felt; 4],
stack_inputs: [Felt; 4],
stack_outputs: [Felt; 20],
overflow_addrs: [Felt; 4],
}

impl PubInputsAir {
pub fn last_step(&self) -> usize {
self.trace_length() - self.context().num_transition_exemptions()
}
}

impl Air for PubInputsAir {
type BaseField = Felt;
type PublicInputs = PublicInputs;

fn context(&self) -> &AirContext<Felt> {
&self.context
}

fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self {
let main_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)];
let aux_degrees = vec![];
let num_main_assertions = 4;
let num_aux_assertions = 0;

let context = AirContext::new_multi_segment(
trace_info,
main_degrees,
aux_degrees,
num_main_assertions,
num_aux_assertions,
options,
)
.set_num_transition_exemptions(2);
Self { context, program_hash: public_inputs.program_hash, stack_inputs: public_inputs.stack_inputs, stack_outputs: public_inputs.stack_outputs, overflow_addrs: public_inputs.overflow_addrs }
}

fn get_periodic_column_values(&self) -> Vec<Vec<Felt>> {
vec![]
}

fn get_assertions(&self) -> Vec<Assertion<Felt>> {
let mut result = Vec::new();
result.push(Assertion::single(0, 0, E::from(A)));
result.push(Assertion::single(1, 0, (E::from(A)) + ((E::from(B[0])) * (E::from(C[0][1])))));
result.push(Assertion::single(2, 0, ((E::from(B[0])) - (E::from(C[1][1]))) * (E::from(A))));
result.push(Assertion::single(3, 0, ((((((E::from(A)) + (E::from(B[0]))) - (E::from(B[1]))) + (E::from(C[0][0]))) - (E::from(C[0][1]))) + (E::from(C[1][0]))) - (E::from(C[1][1]))));
result
}

fn get_aux_assertions<E: FieldElement<BaseField = Felt>>(&self, aux_rand_elements: &AuxTraceRandElements<E>) -> Vec<Assertion<E>> {
let mut result = Vec::new();
result
}

fn evaluate_transition<E: FieldElement<BaseField = Felt>>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], result: &mut [E]) {
let current = frame.current();
let next = frame.next();
result[0] = next[0] - (current[0] + E::from(A));
result[1] = next[1] - ((E::from(B[0])) * (current[1]));
result[2] = next[2] - ((E::from(C[0][0]) + E::from(B[0])) * (current[2]));
}

fn evaluate_aux_transition<F, E>(&self, main_frame: &EvaluationFrame<F>, aux_frame: &EvaluationFrame<E>, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements<E>, result: &mut [E])
where F: FieldElement<BaseField = Felt>,
E: FieldElement<BaseField = Felt> + ExtensionOf<F>,
{
let current = aux_frame.current();
let next = aux_frame.next();
}
}
10 changes: 10 additions & 0 deletions air-script/tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,13 @@ fn bitwise() {
let expected = expect_file!["bitwise/bitwise.rs"];
expected.assert_eq(&generated_air);
}

#[test]
fn constants() {
let generated_air = Test::new("tests/constants/constants.air".to_string())
.transpile()
.unwrap();

let expected = expect_file!["constants/constants.rs"];
expected.assert_eq(&generated_air);
}
9 changes: 5 additions & 4 deletions codegen/winterfell/src/air/boundary_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,16 @@ impl Codegen for BoundaryExpr {
// TODO: Check element type and cast accordingly.
Self::Elem(ident) => format!("E::from({})", ident),
Self::VectorAccess(vector_access) => {
// check if vector_access is not a public input
// check if vector_access is a public input
// TODO: figure out a better way to handle this lookup.
if ir
.public_inputs()
.iter()
.all(|input| input.0 != vector_access.name())
.any(|input| input.0 == vector_access.name())
{
format!("E::from({}[{}])", vector_access.name(), vector_access.idx())
format!("self.{}[{}]", vector_access.name(), vector_access.idx())
} else {
format!("self.{}[{}]", vector_access.name(), vector_access.idx())
format!("E::from({}[{}])", vector_access.name(), vector_access.idx())
}
}
Self::MatrixAccess(matrix_access) => format!(
Expand Down
2 changes: 1 addition & 1 deletion codegen/winterfell/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use transition_constraints::{add_fn_evaluate_aux_transition, add_fn_evaluate_tra
/// Updates the provided scope with a new Air struct and Winterfell Air trait implementation
/// which are equivalent the provided AirIR.
pub(super) fn add_air(scope: &mut Scope, ir: &AirIR) {
// add constant declarations
// add constant declarations. Check required to avoid adding extra line during codegen.
if !ir.constants().is_empty() {
add_constants(scope, ir);
}
Expand Down
11 changes: 9 additions & 2 deletions codegen/winterfell/src/air/transition_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,15 @@ impl Codegen for Operation {
match self {
Operation::Constant(ConstantValue::Inline(value)) => format!("E::from({}_u64)", value),
Operation::Constant(ConstantValue::Scalar(ident)) => format!("E::from({})", ident),
Operation::Constant(ConstantValue::Vector(vector_access)) => format!("E::from({}[{}])", vector_access.name(), vector_access.idx()),
Operation::Constant(ConstantValue::Matrix(matrix_access)) => format!("E::from({}[{}][{}])", matrix_access.name(), matrix_access.row_idx(), matrix_access.col_idx()),
Operation::Constant(ConstantValue::Vector(vector_access)) => {
format!("E::from({}[{}])", vector_access.name(), vector_access.idx())
}
Operation::Constant(ConstantValue::Matrix(matrix_access)) => format!(
"E::from({}[{}][{}])",
matrix_access.name(),
matrix_access.row_idx(),
matrix_access.col_idx()
),
Operation::TraceElement(trace_access) => match trace_access.row_offset() {
0 => {
format!("current[{}]", trace_access.col_idx())
Expand Down
5 changes: 3 additions & 2 deletions ir/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

pub use parser::ast::{
self, boundary_constraints::BoundaryExpr, constants::Constant, Identifier, PublicInput,
};
use parser::ast::{BoundaryStmt, TransitionStmt};
pub use parser::ast::{self, constants::Constant, boundary_constraints::BoundaryExpr, Identifier, PublicInput};
use std::collections::BTreeMap;

mod symbol_table;
Expand Down

0 comments on commit aaf85a3

Please sign in to comment.