Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bitwise air constraints with currently supported syntax #58

Merged
merged 3 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# will have compiled files and executables
/target/

# Remove system config files generated by mac os
**/.DS_Store

# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
Expand Down
61 changes: 61 additions & 0 deletions air-script/tests/bitwise/bitwise.air
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
def BitwiseAir

public_inputs:
stack_inputs: [16]

trace_columns:
main: [s, a, b, a0, a1, a2, a3, b0, b1, b2, b3, zp, z, dummy]

periodic_columns:
k0: [1, 0, 0, 0, 0, 0, 0, 0]
k1: [1, 1, 1, 1, 1, 1, 1, 0]

boundary_constraints:
# This is a dummy trace column to satisfy requirement of at least one boundary constraint.
enf dummy.first = 0

transition_constraints:
# Enforce that selector must be binary
enf s^2 - s = 0

# Enforce that selector should stay the same throughout the cycle.
enf k1 * (s' - s) = 0

# Enforce that input is decomposed into valid bits
enf a0^2 - a0 = 0
enf a1^2 - a1 = 0
enf a2^2 - a2 = 0
enf a3^2 - a3 = 0
enf b0^2 - b0 = 0
enf b1^2 - b1 = 0
enf b2^2 - b2 = 0
enf b3^2 - b3 = 0

# Enforce that the values in the column a in the first row should be the aggregation of the
# decomposed bit columns a0..a3.
enf k0 * (a - (2^0 * a0 + 2^1 * a1 + 2^2 * a2 + 2^3 * a3)) = 0
# Enforce that the values in the column b in the first row should be the aggregation of the
# decomposed bit columns b0..b3.
enf k0 * (b - (2^0 * b0 + 2^1 * b1 + 2^2 * b2 + 2^3 * b3)) = 0

# Enforce that for all rows in an 8-row cycle except for the last one, the values in a and b
# columns are increased by the values contained in the individual bit columns a and b.
enf k1 * (a' - (a * 16 + 2^0 * a0 + 2^1 * a1 + 2^2 * a2 + 2^3 * a3)) = 0
enf k1 * (b' - (b * 16 + 2^0 * b0 + 2^1 * b1 + 2^2 * b2 + 2^3 * b3)) = 0

# Enforce that in the first row, the aggregated output value of the previous row should be 0.
enf k0 * zp = 0

# Enforce that for each row except the last, the aggregated output value must equal the
# previous aggregated output value in the next row.
enf k1 * (z - zp') = 0

# Enforce that for all rows the value in the z column is computed by multiplying the previous
# output value (from the zp column in the current row) by 16 and then adding it to the bitwise
# operation applied to the row's set of bits of a and b. The entire constraint must also be
# multiplied by the operation selector flag to ensure it is only applied for the appropriate
# operation. The constraint for AND is enforced when s = 0 and the constraint for XOR is
# enforced when s = 1. Because the selectors for the AND and XOR operations are mutually
# exclusive, the constraints for different operations can be aggregated into the same result
# indices.
enf (1 - s) * (z - (zp * 16 + 2^0 * a0 * b0 + 2^1 * a1 * b1 + 2^2 * a2 * b2 + 2^3 * a3 * b3)) + s * (z - (zp * 16 + 2^0 * (a0 + b0 - 2 * a0 * b0) + 2^1 * (a1 + b1 - 2 * a1 * b1) + 2^2 * (a2 + b2 - 2 * a2 * b2) + 2^3 * (a3 + b3 - 2 * a3 * b3))) = 0
102 changes: 102 additions & 0 deletions air-script/tests/bitwise/bitwise.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo};
use winter_math::{fields, ExtensionOf, FieldElement};
use winter_utils::{collections, ByteWriter, Serializable};

pub struct PublicInputs {
stack_inputs: [Felt; 16],
}

impl PublicInputs {
pub fn new(stack_inputs: [Felt; 16]) -> Self {
Self { stack_inputs }
}
}

impl Serializable for PublicInputs {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.stack_inputs.as_slice());
}
}

pub struct BitwiseAir {
context: AirContext<Felt>,
stack_inputs: [Felt; 16],
}

impl BitwiseAir {
pub fn last_step(&self) -> usize {
self.trace_length() - self.context().num_transition_exemptions()
}
}

impl Air for BitwiseAir {
type BaseField = Felt;
type PublicInputs = PublicInputs;

fn context(&self) -> &AirContext<Felt> {
&self.context
}

fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self {
let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::new(3)];
let aux_degrees = Vec::new();
let num_main_assertions = 1;
let num_aux_assertions = 0;

let context = AirContext::new_multi_segment(
trace_info,
main_degrees,
aux_degrees,
num_main_assertions,
num_aux_assertions,
options,
)
.set_num_transition_exemptions(2);
Self { context, stack_inputs: public_inputs.stack_inputs }
}

fn get_periodic_column_values(&self) -> Vec<Vec<Felt>> {
vec![vec![Felt::new(1), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0)], vec![Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(0)]]
}

fn get_assertions(&self) -> Vec<Assertion<Felt>> {
let mut result = Vec::new();
result.push(Assertion::single(13, 0, Felt::new(0)));
result
}

fn get_aux_assertions<E: FieldElement<BaseField = Felt>>(&self, aux_rand_elements: &AuxTraceRandElements<E>) -> Vec<Assertion<E>> {
let mut result = Vec::new();
result
}

fn evaluate_transition<E: FieldElement<BaseField = Felt>>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], result: &mut [E]) {
let current = frame.current();
let next = frame.next();
result[0] = (current[0]).exp(E::PositiveInteger::from(2_u64)) - (current[0]) - (E::from(0_u64));
result[1] = (periodic_values[1]) * (next[0] - (current[0])) - (E::from(0_u64));
result[2] = (current[3]).exp(E::PositiveInteger::from(2_u64)) - (current[3]) - (E::from(0_u64));
Copy link
Contributor

@bobbinth bobbinth Nov 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment for the future: exponentiations could be really expensive, especially for constant-time implementations (like the one we are currently using). So, whenever possible, we should replace them with multiplications or more specialized operations.

For example, here, instead of doing (current[0]).exp(E::PositiveInteger::from(2_u64)) we should be doing (current[0] * current[0]) or current[0].square().

Also, if we know that a constant value is smaller than u32 we should try to use conversions from u32. For example: E::from(2_u32) instead of E::from(2_u64) as we may come up with a more efficient reduction for smaller values later on.

result[3] = (current[4]).exp(E::PositiveInteger::from(2_u64)) - (current[4]) - (E::from(0_u64));
result[4] = (current[5]).exp(E::PositiveInteger::from(2_u64)) - (current[5]) - (E::from(0_u64));
result[5] = (current[6]).exp(E::PositiveInteger::from(2_u64)) - (current[6]) - (E::from(0_u64));
result[6] = (current[7]).exp(E::PositiveInteger::from(2_u64)) - (current[7]) - (E::from(0_u64));
result[7] = (current[8]).exp(E::PositiveInteger::from(2_u64)) - (current[8]) - (E::from(0_u64));
result[8] = (current[9]).exp(E::PositiveInteger::from(2_u64)) - (current[9]) - (E::from(0_u64));
result[9] = (current[10]).exp(E::PositiveInteger::from(2_u64)) - (current[10]) - (E::from(0_u64));
result[10] = (periodic_values[0]) * (current[1] - (((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6]))) - (E::from(0_u64));
result[11] = (periodic_values[0]) * (current[2] - (((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[7]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[8]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[9]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[10]))) - (E::from(0_u64));
result[12] = (periodic_values[1]) * (next[1] - ((current[1]) * (E::from(16_u64)) + ((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6]))) - (E::from(0_u64));
result[13] = (periodic_values[1]) * (next[2] - ((current[2]) * (E::from(16_u64)) + ((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[7]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[8]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[9]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[10]))) - (E::from(0_u64));
result[14] = (periodic_values[0]) * (current[11]) - (E::from(0_u64));
result[15] = (periodic_values[1]) * (current[12] - (next[11])) - (E::from(0_u64));
result[16] = (E::from(1_u64) - (current[0])) * (current[12] - ((current[11]) * (E::from(16_u64)) + (((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3])) * (current[7]) + (((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4])) * (current[8]) + (((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5])) * (current[9]) + (((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6])) * (current[10]))) + (current[0]) * (current[12] - ((current[11]) * (E::from(16_u64)) + ((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3] + current[7] - (((E::from(2_u64)) * (current[3])) * (current[7]))) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4] + current[8] - (((E::from(2_u64)) * (current[4])) * (current[8]))) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5] + current[9] - (((E::from(2_u64)) * (current[5])) * (current[9]))) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6] + current[10] - (((E::from(2_u64)) * (current[6])) * (current[10]))))) - (E::from(0_u64));
}

fn evaluate_aux_transition<F, E>(&self, main_frame: &EvaluationFrame<F>, aux_frame: &EvaluationFrame<E>, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements<E>, result: &mut [E])
where F: FieldElement<BaseField = Felt>,
E: FieldElement<BaseField = Felt> + ExtensionOf<F>,
{
let current = aux_frame.current();
let next = aux_frame.next();
}
}
10 changes: 10 additions & 0 deletions air-script/tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ fn system() {
let expected = expect_file!["system/system.rs"];
expected.assert_eq(&generated_air);
}

#[test]
fn bitwise() {
let generated_air = Test::new("tests/bitwise/bitwise.air".to_string())
.transpile()
.unwrap();

let expected = expect_file!["bitwise/bitwise.rs"];
expected.assert_eq(&generated_air);
}