-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bitwise air constraints with currently supported syntax #58
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
def BitwiseAir | ||
|
||
public_inputs: | ||
stack_inputs: [16] | ||
|
||
trace_columns: | ||
main: [s, a, b, a0, a1, a2, a3, b0, b1, b2, b3, zp, z, dummy] | ||
|
||
periodic_columns: | ||
k0: [1, 0, 0, 0, 0, 0, 0, 0] | ||
k1: [1, 1, 1, 1, 1, 1, 1, 0] | ||
|
||
boundary_constraints: | ||
# This is a dummy trace column to satisfy requirement of at least one boundary constraint. | ||
enf dummy.first = 0 | ||
|
||
transition_constraints: | ||
# Enforce that selector must be binary | ||
enf s^2 - s = 0 | ||
|
||
# Enforce that selector should stay the same throughout the cycle. | ||
enf k1 * (s' - s) = 0 | ||
|
||
# Enforce that input is decomposed into valid bits | ||
enf a0^2 - a0 = 0 | ||
enf a1^2 - a1 = 0 | ||
enf a2^2 - a2 = 0 | ||
enf a3^2 - a3 = 0 | ||
enf b0^2 - b0 = 0 | ||
enf b1^2 - b1 = 0 | ||
enf b2^2 - b2 = 0 | ||
enf b3^2 - b3 = 0 | ||
|
||
# Enforce that the values in the column a in the first row should be the aggregation of the | ||
# decomposed bit columns a0..a3. | ||
enf k0 * (a - (2^0 * a0 + 2^1 * a1 + 2^2 * a2 + 2^3 * a3)) = 0 | ||
# Enforce that the values in the column b in the first row should be the aggregation of the | ||
# decomposed bit columns b0..b3. | ||
enf k0 * (b - (2^0 * b0 + 2^1 * b1 + 2^2 * b2 + 2^3 * b3)) = 0 | ||
|
||
# Enforce that for all rows in an 8-row cycle except for the last one, the values in a and b | ||
# columns are increased by the values contained in the individual bit columns a and b. | ||
enf k1 * (a' - (a * 16 + 2^0 * a0 + 2^1 * a1 + 2^2 * a2 + 2^3 * a3)) = 0 | ||
enf k1 * (b' - (b * 16 + 2^0 * b0 + 2^1 * b1 + 2^2 * b2 + 2^3 * b3)) = 0 | ||
|
||
# Enforce that in the first row, the aggregated output value of the previous row should be 0. | ||
enf k0 * zp = 0 | ||
|
||
# Enforce that for each row except the last, the aggregated output value must equal the | ||
# previous aggregated output value in the next row. | ||
enf k1 * (z - zp') = 0 | ||
|
||
# Enforce that for all rows the value in the z column is computed by multiplying the previous | ||
# output value (from the zp column in the current row) by 16 and then adding it to the bitwise | ||
# operation applied to the row's set of bits of a and b. The entire constraint must also be | ||
# multiplied by the operation selector flag to ensure it is only applied for the appropriate | ||
# operation. The constraint for AND is enforced when s = 0 and the constraint for XOR is | ||
# enforced when s = 1. Because the selectors for the AND and XOR operations are mutually | ||
# exclusive, the constraints for different operations can be aggregated into the same result | ||
# indices. | ||
enf (1 - s) * (z - (zp * 16 + 2^0 * a0 * b0 + 2^1 * a1 * b1 + 2^2 * a2 * b2 + 2^3 * a3 * b3)) + s * (z - (zp * 16 + 2^0 * (a0 + b0 - 2 * a0 * b0) + 2^1 * (a1 + b1 - 2 * a1 * b1) + 2^2 * (a2 + b2 - 2 * a2 * b2) + 2^3 * (a3 + b3 - 2 * a3 * b3))) = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; | ||
use winter_math::{fields, ExtensionOf, FieldElement}; | ||
use winter_utils::{collections, ByteWriter, Serializable}; | ||
|
||
pub struct PublicInputs { | ||
stack_inputs: [Felt; 16], | ||
} | ||
|
||
impl PublicInputs { | ||
pub fn new(stack_inputs: [Felt; 16]) -> Self { | ||
Self { stack_inputs } | ||
} | ||
} | ||
|
||
impl Serializable for PublicInputs { | ||
fn write_into<W: ByteWriter>(&self, target: &mut W) { | ||
target.write(self.stack_inputs.as_slice()); | ||
} | ||
} | ||
|
||
pub struct BitwiseAir { | ||
context: AirContext<Felt>, | ||
stack_inputs: [Felt; 16], | ||
} | ||
|
||
impl BitwiseAir { | ||
pub fn last_step(&self) -> usize { | ||
self.trace_length() - self.context().num_transition_exemptions() | ||
} | ||
} | ||
|
||
impl Air for BitwiseAir { | ||
type BaseField = Felt; | ||
type PublicInputs = PublicInputs; | ||
|
||
fn context(&self) -> &AirContext<Felt> { | ||
&self.context | ||
} | ||
|
||
fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { | ||
let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::with_cycles(1, vec![8]), TransitionConstraintDegree::new(3)]; | ||
let aux_degrees = Vec::new(); | ||
let num_main_assertions = 1; | ||
let num_aux_assertions = 0; | ||
|
||
let context = AirContext::new_multi_segment( | ||
trace_info, | ||
main_degrees, | ||
aux_degrees, | ||
num_main_assertions, | ||
num_aux_assertions, | ||
options, | ||
) | ||
.set_num_transition_exemptions(2); | ||
Self { context, stack_inputs: public_inputs.stack_inputs } | ||
} | ||
|
||
fn get_periodic_column_values(&self) -> Vec<Vec<Felt>> { | ||
vec![vec![Felt::new(1), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0), Felt::new(0)], vec![Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(1), Felt::new(0)]] | ||
} | ||
|
||
fn get_assertions(&self) -> Vec<Assertion<Felt>> { | ||
let mut result = Vec::new(); | ||
result.push(Assertion::single(13, 0, Felt::new(0))); | ||
result | ||
} | ||
|
||
fn get_aux_assertions<E: FieldElement<BaseField = Felt>>(&self, aux_rand_elements: &AuxTraceRandElements<E>) -> Vec<Assertion<E>> { | ||
let mut result = Vec::new(); | ||
result | ||
} | ||
|
||
fn evaluate_transition<E: FieldElement<BaseField = Felt>>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], result: &mut [E]) { | ||
let current = frame.current(); | ||
let next = frame.next(); | ||
result[0] = (current[0]).exp(E::PositiveInteger::from(2_u64)) - (current[0]) - (E::from(0_u64)); | ||
result[1] = (periodic_values[1]) * (next[0] - (current[0])) - (E::from(0_u64)); | ||
result[2] = (current[3]).exp(E::PositiveInteger::from(2_u64)) - (current[3]) - (E::from(0_u64)); | ||
result[3] = (current[4]).exp(E::PositiveInteger::from(2_u64)) - (current[4]) - (E::from(0_u64)); | ||
result[4] = (current[5]).exp(E::PositiveInteger::from(2_u64)) - (current[5]) - (E::from(0_u64)); | ||
result[5] = (current[6]).exp(E::PositiveInteger::from(2_u64)) - (current[6]) - (E::from(0_u64)); | ||
result[6] = (current[7]).exp(E::PositiveInteger::from(2_u64)) - (current[7]) - (E::from(0_u64)); | ||
result[7] = (current[8]).exp(E::PositiveInteger::from(2_u64)) - (current[8]) - (E::from(0_u64)); | ||
result[8] = (current[9]).exp(E::PositiveInteger::from(2_u64)) - (current[9]) - (E::from(0_u64)); | ||
result[9] = (current[10]).exp(E::PositiveInteger::from(2_u64)) - (current[10]) - (E::from(0_u64)); | ||
result[10] = (periodic_values[0]) * (current[1] - (((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6]))) - (E::from(0_u64)); | ||
result[11] = (periodic_values[0]) * (current[2] - (((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[7]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[8]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[9]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[10]))) - (E::from(0_u64)); | ||
result[12] = (periodic_values[1]) * (next[1] - ((current[1]) * (E::from(16_u64)) + ((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6]))) - (E::from(0_u64)); | ||
result[13] = (periodic_values[1]) * (next[2] - ((current[2]) * (E::from(16_u64)) + ((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[7]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[8]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[9]) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[10]))) - (E::from(0_u64)); | ||
result[14] = (periodic_values[0]) * (current[11]) - (E::from(0_u64)); | ||
result[15] = (periodic_values[1]) * (current[12] - (next[11])) - (E::from(0_u64)); | ||
result[16] = (E::from(1_u64) - (current[0])) * (current[12] - ((current[11]) * (E::from(16_u64)) + (((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3])) * (current[7]) + (((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4])) * (current[8]) + (((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5])) * (current[9]) + (((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6])) * (current[10]))) + (current[0]) * (current[12] - ((current[11]) * (E::from(16_u64)) + ((E::from(2_u64)).exp(E::PositiveInteger::from(0_u64))) * (current[3] + current[7] - (((E::from(2_u64)) * (current[3])) * (current[7]))) + ((E::from(2_u64)).exp(E::PositiveInteger::from(1_u64))) * (current[4] + current[8] - (((E::from(2_u64)) * (current[4])) * (current[8]))) + ((E::from(2_u64)).exp(E::PositiveInteger::from(2_u64))) * (current[5] + current[9] - (((E::from(2_u64)) * (current[5])) * (current[9]))) + ((E::from(2_u64)).exp(E::PositiveInteger::from(3_u64))) * (current[6] + current[10] - (((E::from(2_u64)) * (current[6])) * (current[10]))))) - (E::from(0_u64)); | ||
} | ||
|
||
fn evaluate_aux_transition<F, E>(&self, main_frame: &EvaluationFrame<F>, aux_frame: &EvaluationFrame<E>, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements<E>, result: &mut [E]) | ||
where F: FieldElement<BaseField = Felt>, | ||
E: FieldElement<BaseField = Felt> + ExtensionOf<F>, | ||
{ | ||
let current = aux_frame.current(); | ||
let next = aux_frame.next(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment for the future: exponentiations could be really expensive, especially for constant-time implementations (like the one we are currently using). So, whenever possible, we should replace them with multiplications or more specialized operations.
For example, here, instead of doing
(current[0]).exp(E::PositiveInteger::from(2_u64))
we should be doing(current[0] * current[0])
orcurrent[0].square()
.Also, if we know that a constant value is smaller than
u32
we should try to use conversions fromu32
. For example:E::from(2_u32)
instead ofE::from(2_u64)
as we may come up with a more efficient reduction for smaller values later on.