Skip to content

Commit

Permalink
Optimize constraints for r1cs (#64)
Browse files Browse the repository at this point in the history
* Optimize R1CS Backend
  • Loading branch information
katat authored May 29, 2024
1 parent 05c6c93 commit 222e575
Show file tree
Hide file tree
Showing 27 changed files with 365 additions and 441 deletions.
3 changes: 1 addition & 2 deletions examples/fixture/asm/r1cs/arithmetic.asm
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
@ noname.0.7.0

v_3 == (v_1 + v_2) * (1)
2 == (v_3) * (1)
2 == (v_1 + v_2) * (1)
14 changes: 4 additions & 10 deletions examples/fixture/asm/r1cs/assignment.asm
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
@ noname.0.7.0

v_2 == (v_1 + 1) * (1)
v_3 == (v_2 + 3) * (1)
v_4 == (v_1 + 2) * (1)
v_5 == (v_1 + 3) * (1)
v_6 == (2 * v_1) * (1)
v_7 == (2 * v_1) * (1)
v_7 == (v_6) * (1)
4 == (v_6) * (1)
4 == (v_4) * (1)
5 == (v_5) * (1)
2 * v_1 == (2 * v_1) * (1)
4 == (2 * v_1) * (1)
4 == (v_1 + 2) * (1)
5 == (v_1 + 3) * (1)
14 changes: 5 additions & 9 deletions examples/fixture/asm/r1cs/bool.asm
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
@ noname.0.7.0

v_2 == (v_1 + -1) * (1)
v_3 == (v_1) * (v_2)
0 == (v_3) * (1)
v_5 == (v_4 + -1) * (1)
v_6 == (v_4) * (v_5)
0 == (v_6) * (1)
0 == (v_4 + v_7) * (1)
v_8 == (v_7 + 1) * (1)
1 == (v_8) * (1)
v_2 == (v_1) * (v_1 + -1)
0 == (v_2) * (1)
v_4 == (v_3) * (v_3 + -1)
0 == (v_4) * (1)
1 == (-1 * v_3 + 1) * (1)
1 == (v_1) * (1)
5 changes: 2 additions & 3 deletions examples/fixture/asm/r1cs/const.asm
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
@ noname.0.7.0

1 == (v_1) * (1)
v_3 == (v_1 + 1) * (1)
2 == (v_3) * (1)
v_3 == (v_2) * (1)
2 == (v_1 + 1) * (1)
v_1 + 1 == (v_2) * (1)
30 changes: 11 additions & 19 deletions examples/fixture/asm/r1cs/equals.asm
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
@ noname.0.7.0

0 == (v_1 + v_3) * (1)
v_4 == (v_2 + v_3) * (1)
v_6 == (v_5) * (v_4)
0 == (v_7 + v_8) * (1)
v_9 == (v_8 + 1) * (1)
v_9 == (v_6) * (1)
v_10 == (v_7) * (v_4)
0 == (v_10) * (1)
1 == (v_7) * (1)
3 == (v_11) * (1)
0 == (v_1 + v_12) * (1)
v_13 == (v_11 + v_12) * (1)
v_15 == (v_14) * (v_13)
0 == (v_16 + v_17) * (1)
v_18 == (v_17 + 1) * (1)
v_18 == (v_15) * (1)
v_19 == (v_16) * (v_13)
0 == (v_19) * (1)
1 == (v_16) * (1)
v_4 == (v_3) * (-1 * v_1 + v_2)
-1 * v_5 + 1 == (v_4) * (1)
v_6 == (v_5) * (-1 * v_1 + v_2)
0 == (v_6) * (1)
1 == (v_5) * (1)
3 == (v_7) * (1)
v_9 == (v_8) * (-1 * v_1 + v_7)
-1 * v_10 + 1 == (v_9) * (1)
v_11 == (v_10) * (-1 * v_1 + v_7)
0 == (v_11) * (1)
1 == (v_10) * (1)
4 changes: 1 addition & 3 deletions examples/fixture/asm/r1cs/for_loop.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
@ noname.0.7.0

v_5 == (v_2 + v_3) * (1)
v_6 == (v_4 + v_5) * (1)
v_1 == (v_6) * (1)
v_1 == (v_2 + v_3 + v_4) * (1)
6 changes: 2 additions & 4 deletions examples/fixture/asm/r1cs/functions.asm
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
@ noname.0.7.0

v_2 == (v_1 + 3) * (1)
4 == (v_2) * (1)
v_3 == (2 * v_2) * (1)
8 == (v_3) * (1)
4 == (v_1 + 3) * (1)
8 == (2 * v_1 + 6) * (1)
24 changes: 8 additions & 16 deletions examples/fixture/asm/r1cs/if_else.asm
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
@ noname.0.7.0

v_2 == (v_1 + 1) * (1)
1 == (v_3) * (1)
0 == (v_1 + v_4) * (1)
v_5 == (v_3 + v_4) * (1)
v_7 == (v_6) * (v_5)
0 == (v_8 + v_9) * (1)
v_10 == (v_9 + 1) * (1)
v_10 == (v_7) * (1)
v_11 == (v_8) * (v_5)
0 == (v_11) * (1)
v_12 == (v_2) * (v_8)
0 == (v_8 + v_13) * (1)
v_14 == (v_13 + 1) * (1)
v_15 == (v_14) * (v_1)
v_16 == (v_12 + v_15) * (1)
2 == (v_16) * (1)
1 == (v_2) * (1)
v_4 == (v_3) * (-1 * v_1 + v_2)
-1 * v_5 + 1 == (v_4) * (1)
v_6 == (v_5) * (-1 * v_1 + v_2)
0 == (v_6) * (1)
v_7 == (v_1 + 1) * (v_5)
v_8 == (-1 * v_5 + 1) * (v_1)
2 == (v_7 + v_8) * (1)
4 changes: 1 addition & 3 deletions examples/fixture/asm/r1cs/iterate.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
@ noname.0.7.0

v_3 == (v_1 + -1) * (1)
v_4 == (v_3 + 3) * (1)
v_4 == (v_2) * (1)
v_1 + 2 == (v_2) * (1)
6 changes: 2 additions & 4 deletions examples/fixture/asm/r1cs/methods.asm
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
@ noname.0.7.0

v_2 == (v_1 + 1) * (1)
1 == (v_1) * (1)
2 == (v_2) * (1)
v_3 == (v_1 + v_2) * (1)
3 == (v_3) * (1)
2 == (v_1 + 1) * (1)
3 == (2 * v_1 + 1) * (1)
6 changes: 2 additions & 4 deletions examples/fixture/asm/r1cs/mutable.asm
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@

2 == (v_1) * (1)
2 == (v_1) * (1)
v_3 == (v_1 + v_2) * (1)
5 == (v_3) * (1)
v_4 == (2 * v_3) * (1)
10 == (v_4) * (1)
5 == (v_1 + v_2) * (1)
10 == (2 * v_1 + 2 * v_2) * (1)
6 changes: 2 additions & 4 deletions examples/fixture/asm/r1cs/public_output.asm
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
@ noname.0.7.0

v_4 == (v_1 + v_2) * (1)
2 == (v_4) * (1)
v_5 == (v_4 + 6) * (1)
v_5 == (v_3) * (1)
2 == (v_1 + v_2) * (1)
v_1 + v_2 + 6 == (v_3) * (1)
6 changes: 3 additions & 3 deletions src/backends/kimchi/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ pub fn poseidon(
// create each variable
let var = compiler.backend.new_internal_var(
Value::Hint(Arc::new(move |backend, env| {
let x1 = backend.compute_var(env, prev_0)?;
let x2 = backend.compute_var(env, prev_1)?;
let x3 = backend.compute_var(env, prev_2)?;
let x1 = backend.compute_var(env, &prev_0)?;
let x2 = backend.compute_var(env, &prev_1)?;
let x3 = backend.compute_var(env, &prev_2)?;

let mut acc = vec![x1, x2, x3];

Expand Down
16 changes: 8 additions & 8 deletions src/backends/kimchi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub type VestaField = kimchi::mina_curves::pasta::Fp;
/// Number of columns in the execution trace.
pub const NUM_REGISTERS: usize = kimchi::circuits::wires::COLUMNS;

use super::{Backend, BackendField, CellVar};
use super::{Backend, BackendField, BackendVar};

impl BackendField for VestaField {}

Expand Down Expand Up @@ -249,7 +249,7 @@ pub struct KimchiCellVar {
pub span: Span,
}

impl CellVar for KimchiCellVar {}
impl BackendVar for KimchiCellVar {}

impl KimchiCellVar {
fn new(index: usize, span: Span) -> Self {
Expand All @@ -259,7 +259,7 @@ impl KimchiCellVar {

impl Backend for KimchiVesta {
type Field = VestaField;
type CellVar = KimchiCellVar;
type Var = KimchiCellVar;
type GeneratedWitness = GeneratedWitness;

fn poseidon() -> crate::imports::FnHandle<Self> {
Expand Down Expand Up @@ -304,7 +304,7 @@ impl Backend for KimchiVesta {

fn finalize_circuit(
&mut self,
public_output: Option<Var<Self::Field, Self::CellVar>>,
public_output: Option<Var<Self::Field, Self::Var>>,
returned_cells: Option<Vec<KimchiCellVar>>,
main_span: Span,
) -> Result<()> {
Expand Down Expand Up @@ -371,7 +371,7 @@ impl Backend for KimchiVesta {
fn compute_var(
&self,
env: &mut crate::witness::WitnessEnv<Self::Field>,
var: Self::CellVar,
var: &Self::Var,
) -> crate::error::Result<Self::Field> {
let val = self.witness_vars.get(&var.index).unwrap();
self.compute_val(env, val, var.index)
Expand Down Expand Up @@ -407,7 +407,7 @@ impl Backend for KimchiVesta {
.push((row, col));
Self::Field::zero()
} else {
self.compute_var(witness_env, *var)?
self.compute_var(witness_env, var)?
}
} else {
Self::Field::zero()
Expand All @@ -422,7 +422,7 @@ impl Backend for KimchiVesta {
let mut public_outputs = vec![];

for (var, rows_cols) in public_outputs_vars {
let val = self.compute_var(witness_env, var)?;
let val = self.compute_var(witness_env, &var)?;
for (row, col) in rows_cols {
witness[row][col] = val;
}
Expand Down Expand Up @@ -726,7 +726,7 @@ impl Backend for KimchiVesta {
cvar
}

fn add_private_input(&mut self, val: Value<Self>, span: Span) -> Self::CellVar {
fn add_private_input(&mut self, val: Value<Self>, span: Span) -> Self::Var {
let cvar = self.new_internal_var(val, span);
self.private_input_indices.push(cvar.index);

Expand Down
52 changes: 29 additions & 23 deletions src/backends/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub trait BackendField:

/// This trait allows different backends to have different cell var types.
/// It is intended to make it opaque to the frondend.
pub trait CellVar: Default + Clone + Copy + Debug + PartialEq + Eq + Hash {}
pub trait BackendVar: Clone + Debug + PartialEq + Eq {}

pub enum BackendKind {
KimchiVesta(KimchiVesta),
Expand Down Expand Up @@ -58,7 +58,7 @@ pub trait Backend: Clone {

/// The CellVar type for the backend.
/// Different backend is allowed to have different CellVar types.
type CellVar: CellVar;
type Var: BackendVar;

/// The generated witness type for the backend. Each backend may define its own witness format to be generated.
type GeneratedWitness;
Expand All @@ -69,37 +69,43 @@ pub trait Backend: Clone {

/// Create a new cell variable and record it.
/// It increments the variable index for look up later.
fn new_internal_var(&mut self, val: Value<Self>, span: Span) -> Self::CellVar;
fn new_internal_var(&mut self, val: Value<Self>, span: Span) -> Self::Var;

/// negate a var
fn neg(&mut self, var: &Self::CellVar, span: Span) -> Self::CellVar;
fn neg(&mut self, var: &Self::Var, span: Span) -> Self::Var;

/// add two vars
fn add(&mut self, lhs: &Self::CellVar, rhs: &Self::CellVar, span: Span) -> Self::CellVar;
fn add(&mut self, lhs: &Self::Var, rhs: &Self::Var, span: Span) -> Self::Var;

/// sub two vars
fn sub(&mut self, lhs: &Self::Var, rhs: &Self::Var, span: Span) -> Self::Var {
let rhs_neg = self.neg(rhs, span);
self.add(lhs, &rhs_neg, span)
}

/// add a var with a constant
fn add_const(&mut self, var: &Self::CellVar, cst: &Self::Field, span: Span) -> Self::CellVar;
fn add_const(&mut self, var: &Self::Var, cst: &Self::Field, span: Span) -> Self::Var;

/// multiply a var with another var
fn mul(&mut self, lhs: &Self::CellVar, rhs: &Self::CellVar, span: Span) -> Self::CellVar;
fn mul(&mut self, lhs: &Self::Var, rhs: &Self::Var, span: Span) -> Self::Var;

/// multiply a var with a constant
fn mul_const(&mut self, var: &Self::CellVar, cst: &Self::Field, span: Span) -> Self::CellVar;
fn mul_const(&mut self, var: &Self::Var, cst: &Self::Field, span: Span) -> Self::Var;

/// add a constraint to assert a var equals a constant
fn assert_eq_const(&mut self, var: &Self::CellVar, cst: Self::Field, span: Span);
fn assert_eq_const(&mut self, var: &Self::Var, cst: Self::Field, span: Span);

/// add a constraint to assert a var equals another var
fn assert_eq_var(&mut self, lhs: &Self::CellVar, rhs: &Self::CellVar, span: Span);
fn assert_eq_var(&mut self, lhs: &Self::Var, rhs: &Self::Var, span: Span);

/// Process a public input
fn add_public_input(&mut self, val: Value<Self>, span: Span) -> Self::CellVar;
fn add_public_input(&mut self, val: Value<Self>, span: Span) -> Self::Var;

/// Process a private input
fn add_private_input(&mut self, val: Value<Self>, span: Span) -> Self::CellVar;
fn add_private_input(&mut self, val: Value<Self>, span: Span) -> Self::Var;

/// Process a public output
fn add_public_output(&mut self, val: Value<Self>, span: Span) -> Self::CellVar;
fn add_public_output(&mut self, val: Value<Self>, span: Span) -> Self::Var;

/// This should be called only when you want to constrain a constant for real.
/// Gates that handle constants should always make sure to call this function when they want them constrained.
Expand All @@ -108,13 +114,13 @@ pub trait Backend: Clone {
label: Option<&'static str>,
value: Self::Field,
span: Span,
) -> Self::CellVar;
) -> Self::Var;

/// Backends should implement this function to load and compute the value of a CellVar.
fn compute_var(
&self,
env: &mut WitnessEnv<Self::Field>,
var: Self::CellVar,
var: &Self::Var,
) -> Result<Self::Field>;

/// Compute the value of the symbolic cell variables.
Expand Down Expand Up @@ -142,20 +148,20 @@ pub trait Backend: Clone {
Value::LinearCombination(lc, cst) => {
let mut res = *cst;
for (coeff, var) in lc {
res += self.compute_var(env, *var)? * *coeff;
res += self.compute_var(env, var)? * *coeff;
}
env.cached_values.insert(cache_key, res); // cache
Ok(res)
}
Value::Mul(lhs, rhs) => {
let lhs = self.compute_var(env, *lhs)?;
let rhs = self.compute_var(env, *rhs)?;
let lhs = self.compute_var(env, lhs)?;
let rhs = self.compute_var(env, rhs)?;
let res = lhs * rhs;
env.cached_values.insert(cache_key, res); // cache
Ok(res)
}
Value::Inverse(v) => {
let v = self.compute_var(env, *v)?;
let v = self.compute_var(env, v)?;
let res = v.inverse().unwrap_or_else(Self::Field::zero);
env.cached_values.insert(cache_key, res); // cache
Ok(res)
Expand All @@ -164,13 +170,13 @@ pub trait Backend: Clone {
Value::PublicOutput(var) => {
// var can be none. what could be the better way to pass in the span in that case?
// let span = self.main_info().span;
let var = var.ok_or_else(|| {
let var = var.as_ref().ok_or_else(|| {
Error::new("runtime", ErrorKind::MissingReturn, Span::default())
})?;
self.compute_var(env, var)
}
Value::Scale(scalar, var) => {
let var = self.compute_var(env, *var)?;
let var = self.compute_var(env, var)?;
Ok(*scalar * var)
}
}
Expand All @@ -179,8 +185,8 @@ pub trait Backend: Clone {
/// Finalize the circuit by doing some sanitizing checks.
fn finalize_circuit(
&mut self,
public_output: Option<Var<Self::Field, Self::CellVar>>,
returned_cells: Option<Vec<Self::CellVar>>,
public_output: Option<Var<Self::Field, Self::Var>>,
returned_cells: Option<Vec<Self::Var>>,
main_span: Span,
) -> Result<()>;

Expand Down
Loading

0 comments on commit 222e575

Please sign in to comment.