diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index 446a799fae..674e80a5a0 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -42,7 +42,7 @@ pub struct DegreeStatement { pub degree: BigUint, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum FunctionStatement { Assignment(AssignmentStatement), Instruction(InstructionStatement), @@ -74,7 +74,7 @@ impl From for FunctionStatement { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct AssignmentStatement { pub start: usize, pub lhs: Vec, @@ -82,20 +82,20 @@ pub struct AssignmentStatement { pub rhs: Box>, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct InstructionStatement { pub start: usize, pub instruction: String, pub inputs: Vec>, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct LabelStatement { pub start: usize, pub name: String, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct DebugDirective { pub start: usize, pub directive: crate::parsed::asm::DebugDirective, @@ -130,7 +130,7 @@ impl Machine { } } -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub struct Rom { pub statements: Vec>, pub batches: Option>, diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 16a1fee6dd..266c8abb10 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -5,9 +5,11 @@ edition = "2021" [features] halo2 = ["dep:halo2"] +nova = ["dep:nova"] [dependencies] halo2 = { path = "../halo2", optional = true } +nova = { path = "../nova", optional = true } pil_analyzer = { path = "../pil_analyzer" } number = { path = "../number" } strum = { version = "0.24.1", features = ["derive"] } diff --git a/backend/src/lib.rs b/backend/src/lib.rs index f346a642f1..df28c5801a 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,4 +1,4 @@ -use ast::analyzed::Analyzed; +use ast::{analyzed::Analyzed, asm_analysis::Machine}; use number::FieldElement; use std::io; use strum::{Display, EnumString, EnumVariantNames}; @@ -62,6 +62,9 @@ pub struct Halo2MockBackend; #[cfg(feature = "halo2")] pub struct Halo2AggregationBackend; +#[cfg(feature = "nova")] +pub struct NovaBackend; + #[cfg(feature = "halo2")] impl ProverWithParams for Halo2Backend { fn prove( @@ -102,3 +105,24 @@ impl ProverAggregationWithParams for Halo2AggregationBackend { halo2::prove_aggr_read_proof_params(pil, fixed, witness, proof, params) } } + +#[cfg(feature = "nova")] +// TODO implement ProverWithoutParams, and remove dependent on main_machine +impl NovaBackend { + pub fn prove( + pil: &Analyzed, + main_machine: &Machine, + fixed: Vec<(&str, Vec)>, + witness: Vec<(&str, Vec)>, + public_io: Vec, + ) -> Option { + Some(nova::prove_ast_read_params( + pil, + main_machine, + fixed, + witness, + public_io, + )); + Some(vec![]) + } +} diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 17fe2e6350..10807bd434 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -14,7 +14,8 @@ use json::JsonValue; pub mod util; mod verify; -use analysis::analyze; +// TODO should analyze be `pub`? +pub use analysis::analyze; pub use backend::{Backend, Proof}; use number::write_polys_file; use pil_analyzer::json_exporter; diff --git a/nova/Cargo.toml b/nova/Cargo.toml new file mode 100644 index 0000000000..2e4cfbc93f --- /dev/null +++ b/nova/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "nova" +version = "0.1.0" +edition = "2021" + +[dependencies] +number = { path = "../number" } +pil_analyzer = { path = "../pil_analyzer" } +num-traits = "0.2.15" +num-integer = "0.1.45" +itertools = "^0.10" +num-bigint = "^0.4" +log = "0.4.17" +rand = "0.8.5" +pasta_curves = { version = "0.5", features = ["repr-c", "serde"] } +ast = { version = "0.1.0", path = "../ast" } +nova-snark = { git = "https://github.com/hero78119/SuperNova.git", branch = "supernova", default-features = false, features = ["supernova"] } +bellpepper-core = { version="0.2.0", default-features = false } +bellpepper = { version="0.2.0", default-features = false } +ff = { version = "0.13.0", features = ["derive"] } +polyexen = { git = "https://github.com/Dhole/polyexen", branch = "feature/shuffles" } + +[dev-dependencies] +analysis = { path = "../analysis" } +executor = { path = "../executor" } +parser = { path = "../parser" } +asm_to_pil = { path = "../asm_to_pil" } +test-log = "0.2.12" +env_logger = "0.10.0" +linker = { path = "../linker" } + diff --git a/zero_nova/Makefile b/nova/Makefile similarity index 100% rename from zero_nova/Makefile rename to nova/Makefile diff --git a/nova/README.md b/nova/README.md new file mode 100644 index 0000000000..656c679757 --- /dev/null +++ b/nova/README.md @@ -0,0 +1,124 @@ +Nova +=========== +This implementation is based on Supernova https://eprint.iacr.org/2022/1758, a Non-Uniform IVC scheme to `solve a stateful machine with a particular instruction set (e.g., EVM, RISC-V)` under R1CS. + +This implemetation is based on https://github.com/microsoft/Nova with toggling "supernova" features + +> 07 Aug 2024 Update: supernova PR still under review https://github.com/microsoft/Nova/pull/204. Therefore temporarily depends on private repo https://github.com/hero78119/SuperNova/tree/supernova_trait_mut + +### folding schemes of powdr-asm +Recursive/folding logical boundary is cut on per instruction defined in powdr-asm. Each instruction will form a relaxed running instance in supernova. Therefore, there will be totally `#inst` (relaxed R1CS) running instance. More accurately, it's `#inst + 1` running instance, for extra `+1` is for folding secondary circuit on primary circuit. Detail is omitted and encourge to check Microsoft Nova repo https://github.com/microsoft/Nova for how 2 cycle of curve was implemented + +### augmented circuit +Each instruction is compiled into a step circuit, following Nova(Supernova) paper terminology, it's also called F circuit. +An augmented circuit := step circuit + nova folding verification circuit. +Furthermore, an augmented circuit has it own isolated constraints system, means there will be no shared circuit among different augmented circuits. Due to the fact, we can also call it instruction-circuit. There will be `#inst` instruction-circuit (More accurate, `#inst + 1` for 2 cycle curve implementation) + +### ROM encoding & label +For each statement in main.ROM, we will encode it as a linear combination under 2 power of LIMB_SIZE, and the LIMB_SIZE is configurable. +ROM array are attached at the end of Nova state. For input params in different type, the linear combination strategy will be adjust accordingly. + +- `reg index`, i.e. x2. `2` will be treat as unsigned index and put into the value +- `sign/unsigned` const. For unsigned value will be put in lc directly. While signed part, it will be convert to signed limb, and on circuit side, signed limb will be convert to negative field value accordingly. +- `label`. i.e. `LOOP_START`, it will be convert to respective pc index. + +Since each instruction circuit has it own params type definition, different constraints circuit will be compiled to handle above situation automatically. + +### Nova state & constraints +Nova state layout as z0 = `(pc, [writable register...] ++ public io ++ ROM)` + +- public io := array of public input input from prover/verifier +- ROM := array `[rom_value_pc1, rom_value_pc2, rom_value_pc3...]` +Each round an instruction is invoked, and in instruction-circuit it will constraints +1. sequence constraints => `rom_value_at_current_pc - linear-combination([opcode_index, input param1, input param2,...output param1, ...], 1 << limb_width) = 0` +2. writable register read/write are value are constraint and match. + +> While which instruction-circuit is invoked determined by prover, an maliculous prover can not invoke arbitrary any instruction-circuit, otherwise sequence constraints will be failed to pass `is_sat` check in the final stage. + +### Public IO +In powdr-asm, an example of psuedo instruction to query public input +``` +reg <=X= ${ ("input", 0) }; +``` +which means query public input at index `0` and store into reg + +In Nova, psuedo instruction is also represented as a individual augmented circuit with below check, similar as other instruction +- sequence constraints: same as other instruction circuit to check `zi[offset + pc] - linear-combination([opcode_index, input param1, input param2,...output param1, ...], 1 << limb_width) = 0` +- writable register read/write constraints => `z[public_io_offset] - writable register = 0` + +Public-io psuedo instruction is not declared in powdr-asm instruction body. To make the circuit compiler flow similar as other normal instruction, a dummy input/output params will be auto populated, so the constraints for `writable register read/write` will reuse the same flow as others. + +### R1CS constraints +An augmented circuit can be viewed as a individual constraint system. PIL in powdr-asm instruction definition body will be compile into respective R1CS constraints. More detail, constraints can be categorized into 2 group +1. sequence constraints + writable register RW (or signed/unsigned/label) => this constraints will be insert into R1CS circuit automatically and transparent to powdr-asm PIL. +2. Powdr-asm PIL constraints: will be compiled into R1CS constraints + +Giving simple example + +``` +machine NovaZero { + + ... + instr incr X -> Y { + Y = X + 1, + } + + instr decr X -> Y { + Y = X - 1 + } + + instr loop { + pc' = pc + } + + instr assert_zero X { + X = 0 + } + + // the main function assigns the first prover input to A, increments it, decrements it, and loops forever + function main { + x0 <=Z= incr(x0); // x0' = 1 + x0 <=Y= decr(x0) // x0' = 0 + assert_zero x0; // x0 == 0 + loop; + } +} +``` + +It will be compiled to below R1CS instance +``` +// in incr instruction circuit +(X*1 + 1*1) * 1 = Y +... + +// in decr instruction circuit +(x*1 + 1.inv()) * 1 = Y +... + +// in assert_zero circuit +(X*1) * 1 = 0 +... +``` + +Note that, the only way to pass data to next round is via writable register, so the `next` will be handled by instruction assignment automatically. There is forbidden to have `next` usage in instruction body. Only exception is `pc'` value, since the only place to mark next pc value is in instruction body. + +The circuit to constraints `Writable Register` will be generate automatically to constraints +- `zi[end_of_reg_offset + pc] - linear-combination([opcode_index, input reg1, input reg2,..., output reg1, output reg2,...], 1 << limb_width) = 0 // in R1CS` +- `zi[input reg1] - assignemnt_reg_1 = 0 // constraints equality with respective PIL polynomial reference` +- `zi[input reg2] - assignemnt_reg_2 = 0 // ...` +- ... +- `zi[output reg1] - assignemnt_reg_3 = 0 // output is optional` +- ... + +> For R1CS in the future, a more efficient `memory commitment` should be also part of the Nova state. For example, merkle tree root, or KZG vector commitment. This enable to constraints RAM Read/Write consistency. For KZG, potientially it can fit into R1CS folding scheme by random linear combination and defering the pairing to the last step then SNARK it. See `https://eprint.iacr.org/2019/1047.pdf` `9.2 Optimizations for the polynomial commitment scheme` for pairing linear combination on LHS/RHS which is potientially applicable to folding scheme. + + +### witness assignment +For 2 groups of constraints in a instruction circuit +- sequence constraints +- pil constrains + +Powdr framework can infer and provide pil constrains, so in witness assigment we just reuse it. +For sequence constraints, it's infer via bellpepper R1CS constraints system automatically, since it's transparent in PIL logic. + +> An fact is, there are also constraints related to (Super)Nova verifier circuit. Its also automatically infer via `bellpepper R1CS constraints system` automatically, and also encaptulated as blackbox \ No newline at end of file diff --git a/nova/src/circuit.rs b/nova/src/circuit.rs new file mode 100644 index 0000000000..bc9087292b --- /dev/null +++ b/nova/src/circuit.rs @@ -0,0 +1,522 @@ +use std::{ + collections::BTreeMap, + iter, + marker::PhantomData, + sync::{Arc, Mutex}, +}; + +use ast::{ + analyzed::{Analyzed, Expression, IdentityKind, PolynomialReference}, + parsed::{asm::Param, BinaryOperator}, +}; +use bellpepper_core::{ + ConstraintSystem, SynthesisError, + {boolean::Boolean, num::AllocatedNum}, +}; +use ff::PrimeField; +use itertools::Itertools; +use log::warn; +use nova_snark::traits::{circuit_supernova::StepCircuit, PrimeFieldExt}; +use number::FieldElement; + +use crate::{ + nonnative::{bignat::BigNat, util::Num}, + utils::{ + add_allocated_num, alloc_const, alloc_num_equals, alloc_one, conditionally_select, + evaluate_expr, find_pc_expression, get_num_at_index, signed_limb_to_neg, WitnessGen, + }, + FREE_INPUT_DUMMY_REG, FREE_INPUT_INSTR_NAME, FREE_INPUT_TY, LIMB_WIDTH, +}; + +/// this NovaStepCircuit can compile single instruction in PIL into R1CS constraints +#[derive(Clone, Debug)] +pub struct NovaStepCircuit<'a, F: PrimeField, T: FieldElement> { + _p: PhantomData, + augmented_circuit_index: usize, + pi_len: usize, + rom_len: usize, + identity_name: String, + io_params: &'a (Vec, Vec), // input,output index + analyzed: &'a Analyzed, + num_registers: usize, + witgen: Arc>>, +} + +impl<'a, F, T> NovaStepCircuit<'a, F, T> +where + F: PrimeField, + T: FieldElement, +{ + /// new + pub fn new( + pi_len: usize, + rom_len: usize, + augmented_circuit_index: usize, + identity_name: String, + analyzed: &'a Analyzed, + io_params: &'a (Vec, Vec), + num_registers: usize, + witgen: Arc>>, + ) -> Self { + NovaStepCircuit { + rom_len, + pi_len, + augmented_circuit_index, + identity_name, + analyzed, + io_params, + num_registers, + witgen, + _p: PhantomData, + } + } +} + +impl<'a, F, T> StepCircuit for NovaStepCircuit<'a, F, T> +where + F: PrimeFieldExt, + T: FieldElement, +{ + fn arity(&self) -> usize { + self.num_registers + self.pi_len + self.rom_len + } + + fn synthesize>( + &self, + cs: &mut CS, + _pc_counter: &AllocatedNum, + z: &[AllocatedNum], + ) -> Result<(AllocatedNum, Vec>), SynthesisError> { + // mapping to AllocatedNum + let mut poly_map = BTreeMap::new(); + + // process pc + poly_map.insert("pc".to_string(), _pc_counter.clone()); + + // process constants and build map for its reference + self.analyzed.constants.iter().try_for_each(|(k, v)| { + let mut v_le = v.to_bytes_le(); + v_le.resize(64, 0); + let v = alloc_const( + cs.namespace(|| format!("const {:?}", v)), + F::from_uniform(&v_le[..]), + )?; + poly_map.insert(k.clone(), v); + Ok::<(), SynthesisError>(()) + })?; + // add constant 1 + poly_map.insert("ONE".to_string(), alloc_one(cs.namespace(|| "constant 1"))?); + + // add constant 2^(LIMB_WIDTH + 1) + let mut max_limb_plus_one = [0u8; 64]; + max_limb_plus_one[LIMB_WIDTH / 8] = 1u8; + let max_limb_plus_one = F::from_uniform(&max_limb_plus_one[..]); + poly_map.insert( + "1 <<(LIMB_WIDTH + 1)".to_string(), + alloc_const( + cs.namespace(|| "constant 1 <<(LIMB_WIDTH + 1)"), + max_limb_plus_one, + )?, + ); + + // parse inst part to construct step circuit + // decompose ROM[pc] into linear combination lc(opcode_index, operand_index1, operand_index2, ... operand_output) + // Noted that here only support single output + // register value can be constrait via `multiple select + sum` with register index on nova `zi` state + // output register need to be constraints in the last of synthesize + + // NOTES: things that do not support + // 1. next query. only support pc as next, other value are all query on same rotation + // ... + + let rom_value = get_num_at_index( + cs.namespace(|| "rom value"), + _pc_counter, + &z[self.pi_len + self.num_registers..], + )?; + let (input_params, output_params) = self.io_params; + // ------------- + let rom_value_bignat = BigNat::from_num( + cs.namespace(|| "rom value bignat"), + &Num::from(rom_value), + LIMB_WIDTH, + 1 + input_params.len() + output_params.len(), // 1 is opcode_index + )?; + let input_output_params_allocnum = rom_value_bignat + .as_limbs() + .iter() + .zip_eq( + iter::once::>(None) + .chain(input_params.iter().map(Some)) + .chain(output_params.iter().map(Some)), + ) + .enumerate() + .map(|(limb_index, (limb, param))| { + match param { + // signed handling + Some(Param { + ty: Some(type_str), .. + }) if type_str == "signed" => signed_limb_to_neg( + cs.namespace(|| format!("limb index {}", limb_index)), + limb, + &poly_map["1 <<(LIMB_WIDTH + 1)"], + LIMB_WIDTH, + ), + _ => limb.as_allocated_num( + cs.namespace(|| format!("rom decompose index {}", limb_index)), + ), + } + }) + .collect::>, SynthesisError>>()?; + + let opcode_index_from_rom = &input_output_params_allocnum[0]; // opcode_index in 0th element + let mut offset = 1; + let input_params_allocnum = // fetch range belongs to input params + &input_output_params_allocnum[offset..offset + input_params.len()]; + offset += input_params.len(); + let output_params_allocnum = // fetch range belongs to output params + &input_output_params_allocnum[offset..offset + output_params.len()]; + + let expected_opcode_index = alloc_const( + cs.namespace(|| "expected_opcode_index"), + F::from(self.augmented_circuit_index as u64), + )?; + cs.enforce( + || "opcode equality", + |lc| lc + opcode_index_from_rom.get_variable(), + |lc| lc + CS::one(), + |lc| lc + expected_opcode_index.get_variable(), + ); + + // process identites + // all the shared constraints will be compiled into each augmented circuit + // TODO optimized to compile only nessesary shared constraints + for (index, id) in self.analyzed.identities.iter().enumerate() { + match id.kind { + IdentityKind::Polynomial => { + // everthing should be in left.selector only + assert_eq!(id.right.expressions.len(), 0); + assert_eq!(id.right.selector, None); + assert_eq!(id.left.expressions.len(), 0); + + let exp = id.expression_for_poly_id(); + + // identities process as below + // case 1: ((main.instr_XXXX * ) - 0) = 0 => `BinaryOperation(BinaryOperation... ` + // case 1.1: (() - 0) = 0 => `BinaryOperation(BinaryOperation... ` + // case 2: PolynomialReference - () = 0 => `BinaryOperation(PolynomialReference...` + match exp { + // case 1/1.1 + Expression::BinaryOperation( + box Expression::BinaryOperation( + box Expression::PolynomialReference( + PolynomialReference { name, next, .. }, + .., + ), + _, + box rhs, + ), + .., + ) => { + // skip next + if *next { + unimplemented!("not support column next in folding scheme") + } + // skip first_step, for it's not being used in folding scheme + if name == "main.first_step" { + continue; + } + // only skip if constraint is dedicated to other instruction + if name.starts_with("main.instr") + && !name.ends_with(&self.identity_name[..]) + { + continue; + } + + let contains_next_ref = exp.contains_next_ref(); + if contains_next_ref { + unimplemented!("not support column next in folding scheme") + } + let num_eval = if name.starts_with("main.instr") + && name.ends_with(&self.identity_name[..]) + { + evaluate_expr( + cs.namespace(|| format!("id index {} rhs {}", index, rhs)), + &mut poly_map, + rhs, + self.witgen.clone(), + )? + } else { + evaluate_expr( + cs.namespace(|| format!("id index {} exp {}", index, exp)), + &mut poly_map, + exp, + self.witgen.clone(), + )? + }; + cs.enforce( + || format!("id index {} constraint {} = 0", index, name), + |lc| lc + num_eval.get_variable(), + |lc| lc + CS::one(), + |lc| lc, + ); + } + // case 2 + Expression::BinaryOperation( + box Expression::PolynomialReference( + PolynomialReference { name, next, .. }, + .., + ), + _, + _, + ) => { + if *next { + continue; + } + let num_eval = evaluate_expr( + cs.namespace(|| format!("id index {} exp {}", index, exp)), + &mut poly_map, + exp, + self.witgen.clone(), + )?; + cs.enforce( + || format!("id index {} constraint {} = 0", index, name), + |lc| lc + num_eval.get_variable(), + |lc| lc + CS::one(), + |lc| lc, + ); + } + _ => unimplemented!("exp {:?} not support", exp), + } + } + _ => (), + } + } + + // special handling for free input + if self.identity_name == FREE_INPUT_INSTR_NAME { + assert_eq!( + input_params_allocnum.len(), + 1, + "free input must have exactly one input param" + ); + let public_input_index = &input_params_allocnum[0]; + // alloc a intermediate dummy intermediate witness + // here just to retrieve it's value + // it wont be under-constraints, as intermediate witness be constraints in below input/output parts + poly_map.insert( + FREE_INPUT_DUMMY_REG.to_string(), + AllocatedNum::alloc(cs.namespace(|| "dummy var"), || { + public_input_index + .get_value() + .and_then(|index| { + let repr = &index.to_repr(); + let index = u32::from_le_bytes(repr.as_ref()[0..4].try_into().unwrap()); + assert!(index < self.pi_len as u32); + z[self.num_registers..][index as usize].get_value() + }) + .ok_or(SynthesisError::AssignmentMissing) + })?, + ); + } + + // process pc_next + // TODO: very inefficient to go through all identities for each folding, need optimisation + for id in &self.analyzed.identities { + match id.kind { + IdentityKind::Polynomial => { + // everthing should be in left.selector only + assert_eq!(id.right.expressions.len(), 0); + assert_eq!(id.right.selector, None); + assert_eq!(id.left.expressions.len(), 0); + + let exp = id.expression_for_poly_id(); + + if let Expression::BinaryOperation( + box Expression::PolynomialReference( + PolynomialReference { name, next, .. }, + .., + ), + BinaryOperator::Sub, + box rhs, + .., + ) = exp + { + // lhs is `pc'` + if name == "main.pc" && *next { + let identity_name = format!("main.instr_{}", self.identity_name); + let exp = find_pc_expression::(rhs, &identity_name); + let pc_next = exp + .and_then(|expr| { + evaluate_expr( + // evaluate rhs pc bumping logic + cs.namespace(|| format!("pc eval on {}", expr)), + &mut poly_map, + &expr, + self.witgen.clone(), + ) + .ok() + }) + .unwrap_or_else(|| { + // by default pc + 1 + add_allocated_num( + cs.namespace(|| format!("instr {} pc + 1", identity_name)), + &poly_map["pc"], + &poly_map["ONE"], + ) + .unwrap() + }); + poly_map.insert("pc_next".to_string(), pc_next); + } + } + } + _ => (), + } + } + + // constraint input param to assigned reg value + input_params_allocnum + .iter() + .zip_eq(input_params.iter()) + .try_for_each(|(index, params)| { + let (name, value) = match params { + // register + Param { name, ty: None } => ( + name.clone(), + get_num_at_index( + cs.namespace(|| format!("regname {}", name)), + index, + z, + )?, + ), + // constant + Param { name, ty: Some(ty) } if ty == "signed" || ty == "unsigned" => { + ( + format!("instr_{}_param_{}", self.identity_name, name), + index.clone(), + ) + }, + // public io + Param { name, ty: Some(ty) } if ty == FREE_INPUT_TY => { + ( + name.clone(), + get_num_at_index( + cs.namespace(|| format!("regname {}", name)), + index, + &z[self.num_registers..], + )?, + ) + }, + // label + Param { name, ty: Some(ty) } if ty == "label" => { + ( + format!("instr_{}_param_{}", self.identity_name, name), + index.clone(), + ) + }, + s => { + unimplemented!("not support {}", s) + }, + }; + if let Some(reg) = poly_map.get(&name) { + cs.enforce( + || format!("input params {} - reg[{}_index] = 0", params, params), + |lc| lc + value.get_variable(), + |lc| lc + CS::one(), + |lc| lc + reg.get_variable(), + ); + Ok::<(), SynthesisError>(()) + } else { + warn!( + "missing input reg name {} in polymap with key {:?}, Probably instr {} defined but never used", + params, + poly_map.keys(), + self.identity_name, + ); + Ok::<(), SynthesisError>(()) + } + })?; + + // constraint zi_next[index] = (index == output_index) ? reg_assigned[output_reg_name]: zi[index] + let zi_next = output_params_allocnum.iter().zip_eq(output_params.iter()).try_fold( + z.to_vec(), + |acc, (output_index, param)| { + assert!(param.ty.is_none()); // output only accept register + (0..z.len()) + .map(|i| { + let i_alloc = alloc_const( + cs.namespace(|| format!("output reg i{} allocated", i)), + F::from(i as u64), + )?; + let equal_bit = Boolean::from(alloc_num_equals( + cs.namespace(|| format!("check reg {} equal bit", i)), + &i_alloc, + output_index, + )?); + if let Some(output) = poly_map.get(¶m.name) { + conditionally_select( + cs.namespace(|| { + format!( + "zi_next constraint with register index {} on reg name {}", + i, param + ) + }), + output, + &acc[i], + &equal_bit, + ) + } else { + warn!( + "missing output reg name {} in polymap with key {:?}, Probably instr {} defined but never used", + param, + poly_map.keys(), + self.identity_name, + ); + Ok(acc[i].clone()) + } + }) + .collect::>, SynthesisError>>() + }, + )?; + + Ok((poly_map["pc_next"].clone(), zi_next)) + } +} + +/// A trivial step circuit that simply returns the input +#[derive(Clone, Debug, Default)] +pub struct SecondaryStepCircuit { + _p: PhantomData, + arity_len: usize, +} + +impl SecondaryStepCircuit +where + F: PrimeField, +{ + /// new TrivialTestCircuit + pub fn new(arity_len: usize) -> Self { + SecondaryStepCircuit { + arity_len, + _p: PhantomData, + } + } +} + +impl StepCircuit for SecondaryStepCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + self.arity_len + } + + fn synthesize>( + &self, + _cs: &mut CS, + _pc_counter: &AllocatedNum, + z: &[AllocatedNum], + ) -> Result<(AllocatedNum, Vec>), SynthesisError> { + Ok((_pc_counter.clone(), z.to_vec())) + } +} diff --git a/nova/src/circuit_builder.rs b/nova/src/circuit_builder.rs new file mode 100644 index 0000000000..3c663314c7 --- /dev/null +++ b/nova/src/circuit_builder.rs @@ -0,0 +1,460 @@ +use core::panic; +use std::{ + collections::BTreeMap, + iter, + sync::{Arc, Mutex}, +}; + +use ast::{ + analyzed::Analyzed, + asm_analysis::{ + AssignmentStatement, FunctionStatement, InstructionStatement, LabelStatement, Machine, + }, + parsed::{ + asm::{FunctionCall, Param}, + Expression::Number, + UnaryOperator, + }, +}; +use ff::PrimeField; +use nova_snark::provider::bn256_grumpkin::{self}; +use nova_snark::{ + compute_digest, + supernova::{gen_commitmentkey_by_r1cs, PublicParams, RecursiveSNARK, RunningClaim}, + traits::Group, +}; +use num_bigint::BigUint; +use number::{BigInt, FieldElement}; + +use crate::{ + circuit::{NovaStepCircuit, SecondaryStepCircuit}, + nonnative::bignat::limbs_to_nat, + utils::WitnessGen, + FREE_INPUT_DUMMY_REG, FREE_INPUT_INSTR_NAME, FREE_INPUT_TY, LIMB_WIDTH, +}; + +// TODO support other cycling curve +type G1 = bn256_grumpkin::bn256::Point; +type G2 = bn256_grumpkin::grumpkin::Point; + +pub(crate) fn nova_prove( + analyzed: &Analyzed, + main_machine: &Machine, + _: Vec<(&str, Vec)>, + witness: Vec<(&str, Vec)>, + public_io: Vec, +) { + if polyexen::expr::get_field_p::<::Scalar>() != T::modulus().to_arbitrary_integer() + { + panic!("powdr modulus doesn't match nova modulus. Make sure you are using Bn254"); + } + + let public_io_len = public_io.len(); + + // TODO to avoid clone witness object, witness are wrapped by lock so it can be shared below + // need to refactor this part to support parallel folding + // redesign this part if possible + // println!("fixed {:?}", fixed); + // println!("witness {:?}", witness); + let witness = Arc::new(Mutex::new(WitnessGen::new( + witness + .iter() + .map(|(k, v)| (k.strip_prefix("main.").unwrap(), v)) + .collect::)>>(), + ))); + + // collect all instr from pil file + let mut instr_index_mapping: BTreeMap = main_machine + .instructions + .iter() + .flat_map(|k| Some(k.name.clone())) + .enumerate() + .map(|(i, v)| (v.clone(), i)) + .collect(); + + instr_index_mapping.insert(FREE_INPUT_INSTR_NAME.to_string(), instr_index_mapping.len()); // index start from 0 + + // collect all register + let regs_index_mapping: BTreeMap = main_machine + .registers + .iter() + .filter(|reg| { + reg.flag.is_none() // only collect writable register + }) + .enumerate() + .map(|(i, reg)| (reg.name.clone(), i)) + .collect(); + // instruction <-> input/output params mapping + let mut instr_io_mapping = main_machine + .instructions + .iter() + .map(|k| { + let input_index: Vec = k.params.inputs.params.clone(); + let output_index: Vec = if let Some(output) = &k.params.outputs { + output.params.iter().for_each(|output_param| { + assert!(output_param.ty.is_none()); // do not support signature other than register type + }); + output.params.clone() + } else { + vec![] + }; + (k.name.clone(), (input_index, output_index)) + }) + .collect::, Vec)>>(); + + // NOTE: create and attach free_input "fake" meta + instr_io_mapping.insert( + FREE_INPUT_INSTR_NAME.to_string(), + ( + vec![Param { + name: FREE_INPUT_DUMMY_REG.to_string(), + ty: Some(FREE_INPUT_TY.to_string()), + }], + vec![Param { + name: FREE_INPUT_DUMMY_REG.to_string(), + ty: None, + }], + ), + ); + + // firstly, compile pil ROM to simple memory commitment + // Idea to represent a instruction is by linear combination lc(instruction,input params.., output params..) + // params can be register or constant. For register, first we translate to register index + // decomposed(inst_encoded, 1 << LIMB_WIDTH) = [, operand1, operand2, operand3...] + // each operand should be fit into 1 << LIMB_WIDTH + // TODO1: move this part to setup stage. + // TODO2: replace this part with more efficient memory commitment strategy, e.g. folding KZG + + let mut pc = 0u64; + + // To support jumping to label behind, we will process rom first to collect all available + // TODO optimise to have just one pass. + let mut label_pc_mapping = main_machine + .rom + .as_ref() + .map(|rom| { + rom.statements + .iter() + .flat_map(|statement| match statement { + FunctionStatement::Label(LabelStatement { name, .. }) => Some((name, pc)), + FunctionStatement::Assignment(..) | FunctionStatement::Instruction(..) => { + pc += 1; + None + } + s => unimplemented!("unimplemented statement {:?}", s), + }) + .collect::>() + }) + .unwrap_or_default(); + + let rom = main_machine.rom.as_ref().map(|rom| { + rom.statements.iter().flat_map(|statement| { + + match statement { + FunctionStatement::Label(LabelStatement{ name, .. }) => { + label_pc_mapping.insert(name, pc); + None + }, + FunctionStatement::Assignment(AssignmentStatement { + lhs, + rhs, + .. // ignore start + }) => { + pc += 1; + let instr_name: String = match rhs { + box ast::parsed::Expression::FunctionCall(FunctionCall{id, ..}) => { + assert!(id != FREE_INPUT_INSTR_NAME, "{} is a reserved instruction name", FREE_INPUT_INSTR_NAME); + id.to_string() + }, + box ast::parsed::Expression::FreeInput(box ast::parsed::Expression::Tuple(vector)) => match &vector[..] { + [ast::parsed::Expression::String(name), ast::parsed::Expression::Number(_)] => { + assert!(name == "input"); + FREE_INPUT_INSTR_NAME.to_string() + }, + _ => unimplemented!() + }, + s => unimplemented!("{:?}", s), + }; + + let mut io_params: Vec<::Scalar> = match rhs { + box ast::parsed::Expression::FunctionCall(FunctionCall{arguments, ..}) => arguments.iter().map(|argument| match argument { + ast::parsed::Expression::PolynomialReference(ast::parsed::PolynomialReference{ namespace, name, index, next }) => { + assert!(!*next); + assert_eq!(*namespace, None); + assert_eq!(*index, None); + // label or register + if let Some(label_pc) = label_pc_mapping.get(name) { + ::Scalar::from(*label_pc) + } else { + ::Scalar::from(regs_index_mapping[name] as u64) + } + }, + Number(n) => { + ::Scalar::from_bytes(&n.to_bytes_le().try_into().unwrap()).unwrap() + }, + ast::parsed::Expression::UnaryOperation(ope,box Number(n)) => { + let value = ::Scalar::from_bytes(&n.to_bytes_le().try_into().unwrap()).unwrap(); + match ope { + UnaryOperator::Plus => value, + UnaryOperator::Minus => get_neg_value_within_limbsize(value, LIMB_WIDTH), + } + }, + x => unimplemented!("unsupported expression {}", x), + }).collect(), + box ast::parsed::Expression::FreeInput(box ast::parsed::Expression::Tuple(vector)) => match &vector[..] { + [ast::parsed::Expression::String(name), ast::parsed::Expression::Number(n)] => { + assert!(name == "input"); + assert!(n <= &T::from(public_io_len as u64)); + vec![::Scalar::from_bytes(&n.to_bytes_le().try_into().unwrap()).unwrap()] + }, + _ => unimplemented!() + }, + _ => unimplemented!(), + }; + + let output_params:Vec<::Scalar> = lhs.iter().map(|x| ::Scalar::from(regs_index_mapping[x] as u64)).collect(); + io_params.extend(output_params); // append output register to the back of input register + + // Now we can do linear combination + let rom_value = if let Some(instr_index) = instr_index_mapping.get(&instr_name) { + limbs_to_nat(iter::once(::Scalar::from(*instr_index as u64)).chain(io_params), LIMB_WIDTH).to_biguint().unwrap() + } else { + panic!("instr_name {:?} not found in instr_index_mapping {:?}", instr_name, instr_index_mapping); + }; + Some(rom_value) + } + FunctionStatement::Instruction(InstructionStatement{ instruction, inputs, ..}) => { + pc += 1; + let io_params: Vec<::Scalar> = inputs.iter().map(|argument| { + match argument { + ast::parsed::Expression::PolynomialReference(ast::parsed::PolynomialReference{ namespace, name, index, next }) => { + assert!(!*next); + assert_eq!(*namespace, None); + assert_eq!(*index, None); + // label or register + if let Some(label_pc) = label_pc_mapping.get(name) { + ::Scalar::from(*label_pc) + } else { + ::Scalar::from(regs_index_mapping[name] as u64) + } + }, + Number(n) => ::Scalar::from_bytes(&n.to_bytes_le().try_into().unwrap()).unwrap(), + ast::parsed::Expression::UnaryOperation(ope,box Number(n)) => { + let value = ::Scalar::from_bytes(&n.to_bytes_le().try_into().unwrap()).unwrap(); + match ope { + UnaryOperator::Plus => value, + UnaryOperator::Minus => get_neg_value_within_limbsize(value, LIMB_WIDTH), + } + }, + x => unimplemented!("unsupported expression {:?}", x), + }}).collect(); + + // Now we can do linear combination + let rom_value = if let Some(instr_index) = instr_index_mapping.get(instruction) { + limbs_to_nat(iter::once(::Scalar::from(*instr_index as u64)).chain(io_params), LIMB_WIDTH).to_biguint().unwrap() + } else { + panic!("instr_name {} not found in instr_index_mapping {:?}", instruction, instr_index_mapping); + }; + Some(rom_value) + } + s => unimplemented!("unimplemented statement {:?}", s), + }}).collect::>() + }).unwrap(); + + // rom.iter().for_each(|v| println!("rom value {:#32x}", v)); + + // build step circuit + // 2 cycles curve, secondary circuit responsible of folding first circuit running instances with new r1cs instance, no application logic + let circuit_secondary = + SecondaryStepCircuit::new(regs_index_mapping.len() + public_io_len + rom.len()); + let num_augmented_circuit = instr_index_mapping.len(); + + // allocated running claims is the list of running instances witness. Number match #instruction + let mut running_claims: Vec< + RunningClaim< + G1, + G2, + NovaStepCircuit<::Scalar, T>, + SecondaryStepCircuit<::Scalar>, + >, + > = instr_index_mapping + .iter() + .map(|(instr_name, index)| { + // Structuring running claims + let test_circuit = NovaStepCircuit::<::Scalar, T>::new( + public_io_len, + rom.len(), + *index, + instr_name.to_string(), + analyzed, + &instr_io_mapping[instr_name], + regs_index_mapping.len(), + witness.clone(), + ); + let running_claim = RunningClaim::< + G1, + G2, + NovaStepCircuit<::Scalar, T>, + SecondaryStepCircuit<::Scalar>, + >::new( + *index, + test_circuit, + circuit_secondary.clone(), + num_augmented_circuit, + ); + running_claim + }) + .collect(); + + // sort running claim by augmented_index, assure we can fetch them by augmented index later + running_claims.sort_by(|a, b| { + a.get_augmented_circuit_index() + .cmp(&b.get_augmented_circuit_index()) + }); + + // TODO detect the max circuit iterate and inspect all running claim R1CS shape, instead of assume first 1 is largest + // generate the commitkey based on max num of constraints and reused it for all other augmented circuit + // let (max_index_circuit, _) = running_claims + // .iter() + // .enumerate() + // .map(|(i, running_claim)| -> (usize, usize) { + // let (r1cs_shape_primary, _) = running_claim.get_r1cs_shape(); + // (i, r1cs_shape_primary.) + // }) + // .max_by(|(_, circuit_size1), (_, circuit_size2)| circuit_size1.cmp(circuit_size2)) + // .unwrap(); + + let ck_primary = gen_commitmentkey_by_r1cs(running_claims[0].get_r1cs_shape().0); + let ck_secondary = gen_commitmentkey_by_r1cs(running_claims[0].get_r1cs_shape().1); + + // set unified ck_primary, ck_secondary and update digest for all running claim + running_claims.iter_mut().for_each(|running_claim| { + running_claim.set_commitmentkey(ck_primary.clone(), ck_secondary.clone()) + }); + + let digest = + compute_digest::>(&[running_claims[0].get_publicparams()]); + let initial_program_counter = ::Scalar::from(0); + + // process register + let mut z0_primary: Vec<::Scalar> = iter::repeat(::Scalar::zero()) + .take(regs_index_mapping.len()) + .collect(); + z0_primary.extend(public_io.iter().map(|value| { + let mut value = value.to_bytes_le(); + value.resize(32, 0); + ::Scalar::from_bytes(&value.try_into().unwrap()).unwrap() + })); + // extend z0_primary/secondary with rom content + z0_primary.extend(rom.iter().map(|memory_value| { + let mut memory_value_bytes = memory_value.to_bytes_le(); + memory_value_bytes.resize(32, 0); + ::Scalar::from_bytes(&memory_value_bytes.try_into().unwrap()).unwrap() + })); + // secondary circuit just fill 0 on z0 + let z0_secondary: Vec<::Scalar> = iter::repeat(::Scalar::zero()) + .take(z0_primary.len() as usize) + .collect(); + + let mut recursive_snark_option: Option> = None; + + // Have estimate of iteration via length of witness + let num_steps = witness.lock().unwrap().num_of_iteration(); + for i in 0..num_steps { + println!("round i {}, total step {}", i, num_steps - 1); + if i > 0 { + // iterate through next witness row + witness.lock().unwrap().next(); + } + let program_counter = recursive_snark_option + .as_ref() + .map(|recursive_snark| recursive_snark.get_program_counter()) + .unwrap_or_else(|| initial_program_counter); + // decompose value_in_rom into (opcode_index, input_operand1, input_operand2, ... output_operand) as witness, + // then using opcode_index to invoke respective running instance + let value_in_rom = &rom[u32::from_le_bytes( + // convert program counter from field to usize (only took le 4 bytes) + program_counter.to_repr().as_ref()[0..4].try_into().unwrap(), + ) as usize]; + let mut opcode_index_bytes = + (value_in_rom % BigUint::from(1_u32 << LIMB_WIDTH)).to_bytes_le(); + opcode_index_bytes.resize(4, 0); + let opcode_index = u32::from_le_bytes(opcode_index_bytes.try_into().unwrap()); + + let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { + RecursiveSNARK::iter_base_step( + &running_claims[opcode_index as usize], + digest, + program_counter, + opcode_index as usize, + num_augmented_circuit, + &z0_primary, + &z0_secondary, + ) + .unwrap() + }); + + let res = recursive_snark.prove_step( + &running_claims[opcode_index as usize], + &z0_primary, + &z0_secondary, + ); + if let Err(e) = &res { + println!("res failed {:?}", e); + } + assert!(res.is_ok()); + let res = recursive_snark.verify( + &running_claims[opcode_index as usize], + &z0_primary, + &z0_secondary, + ); + if let Err(e) = &res { + println!("res failed {:?}", e); + } + assert!(res.is_ok()); + recursive_snark_option = Some(recursive_snark); + } + + assert!(recursive_snark_option.is_some()); + + // Now you can handle the Result using if let + // let RecursiveSNARK { + // zi_primary, + // zi_secondary, + // program_counter, + // .. + // } = &recursive_snark_option.unwrap(); + + // println!("zi_primary: {:?}", zi_primary); + // println!("zi_secondary: {:?}", zi_secondary); + // println!("final program_counter: {:?}", program_counter); +} + +/// get additive negative of value within limbsize +fn get_neg_value_within_limbsize( + value: ::Scalar, + nbit: usize, +) -> ::Scalar { + let value = value.to_bytes(); + let (lsb, msb) = value.split_at(nbit / 8); + assert_eq!( + msb.iter().map(|v| *v as usize).sum::(), + 0, + "value {:?} is overflow", + value + ); + let mut lsb = lsb.to_vec(); + lsb.resize(32, 0); + let value = ::Scalar::from_bytes(lsb[..].try_into().unwrap()).unwrap(); + + let mut max_limb_plus_one_bytes = vec![0u8; nbit / 8 + 1]; + max_limb_plus_one_bytes[nbit / 8] = 1u8; + max_limb_plus_one_bytes.resize(32, 0); + let max_limb_plus_one = + ::Scalar::from_bytes(max_limb_plus_one_bytes[..].try_into().unwrap()).unwrap(); + + let mut value_neg = (max_limb_plus_one - value).to_bytes()[0..nbit / 8].to_vec(); + value_neg.resize(32, 0); + + ::Scalar::from_bytes(&value_neg[..].try_into().unwrap()).unwrap() +} diff --git a/nova/src/lib.rs b/nova/src/lib.rs new file mode 100644 index 0000000000..77b4c1262a --- /dev/null +++ b/nova/src/lib.rs @@ -0,0 +1,15 @@ +#![feature(box_patterns)] +#![allow(dead_code)] +pub(crate) mod circuit; +pub(crate) mod circuit_builder; +pub(crate) mod nonnative; +pub(crate) mod prover; +pub(crate) mod utils; + +pub use prover::*; + +/// LIMB WITHD as base, represent 2^16 +pub(crate) const LIMB_WIDTH: usize = 16; +pub(crate) const FREE_INPUT_INSTR_NAME: &str = "free_input_instr"; +pub(crate) const FREE_INPUT_TY: &str = "free_input_ty"; +pub(crate) const FREE_INPUT_DUMMY_REG: &str = "free_input_dummy_reg"; diff --git a/nova/src/nonnative/bignat.rs b/nova/src/nonnative/bignat.rs new file mode 100644 index 0000000000..de87081ce1 --- /dev/null +++ b/nova/src/nonnative/bignat.rs @@ -0,0 +1,905 @@ +use super::{ + util::{ + Bitvector, Num, {f_to_nat, nat_to_f}, + }, + OptionExt, +}; +use bellpepper_core::{ConstraintSystem, LinearCombination, SynthesisError}; +use ff::PrimeField; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; +use std::borrow::Borrow; +use std::cmp::{max, min}; +use std::convert::From; + +/// Compute the natural number represented by an array of limbs. +/// The limbs are assumed to be based the `limb_width` power of 2. +pub fn limbs_to_nat, I: DoubleEndedIterator>( + limbs: I, + limb_width: usize, +) -> BigInt { + limbs.rev().fold(BigInt::from(0), |mut acc, limb| { + acc <<= limb_width as u32; + acc += f_to_nat(limb.borrow()); + acc + }) +} + +fn int_with_n_ones(n: usize) -> BigInt { + let mut m = BigInt::from(1); + m <<= n as u32; + m -= 1; + m +} + +/// Compute the limbs encoding a natural number. +/// The limbs are assumed to be based the `limb_width` power of 2. +pub fn nat_to_limbs( + nat: &BigInt, + limb_width: usize, + n_limbs: usize, +) -> Result, SynthesisError> { + let mask = int_with_n_ones(limb_width); + let mut nat = nat.clone(); + if nat.bits() as usize <= n_limbs * limb_width { + Ok((0..n_limbs) + .map(|_| { + let r = &nat & &mask; + nat >>= limb_width as u32; + nat_to_f(&r).unwrap() + }) + .collect()) + } else { + eprintln!("nat {nat} does not fit in {n_limbs} limbs of width {limb_width}"); + Err(SynthesisError::Unsatisfiable) + } +} + +#[derive(Clone, PartialEq, Eq)] +pub struct BigNatParams { + pub min_bits: usize, + pub max_word: BigInt, + pub limb_width: usize, + pub n_limbs: usize, +} + +impl BigNatParams { + pub fn new(limb_width: usize, n_limbs: usize) -> Self { + let mut max_word = BigInt::from(1) << limb_width as u32; + max_word -= 1; + BigNatParams { + max_word, + n_limbs, + limb_width, + min_bits: 0, + } + } +} + +/// A representation of a large natural number (a member of {0, 1, 2, ... }) +#[derive(Clone)] +pub struct BigNat { + /// The linear combinations which constrain the value of each limb of the number + pub limbs: Vec>, + /// The witness values for each limb (filled at witness-time) + pub limb_values: Option>, + /// The value of the whole number (filled at witness-time) + pub value: Option, + /// Parameters + pub params: BigNatParams, +} + +impl std::cmp::PartialEq for BigNat { + fn eq(&self, other: &Self) -> bool { + self.value == other.value && self.params == other.params + } +} +impl std::cmp::Eq for BigNat {} + +impl From> for Polynomial { + fn from(other: BigNat) -> Polynomial { + Polynomial { + coefficients: other.limbs, + values: other.limb_values, + } + } +} + +impl BigNat { + /// Allocates a `BigNat` in the circuit with `n_limbs` limbs of width `limb_width` each. + /// If `max_word` is missing, then it is assumed to be `(2 << limb_width) - 1`. + /// The value is provided by a closure returning limb values. + pub fn alloc_from_limbs( + mut cs: CS, + f: F, + max_word: Option, + limb_width: usize, + n_limbs: usize, + ) -> Result + where + CS: ConstraintSystem, + F: FnOnce() -> Result, SynthesisError>, + { + let values_cell = f(); + let mut value = None; + let mut limb_values = None; + let limbs = (0..n_limbs) + .map(|limb_i| { + cs.alloc( + || format!("limb {limb_i}"), + || match values_cell { + Ok(ref vs) => { + if vs.len() != n_limbs { + eprintln!("Values do not match stated limb count"); + return Err(SynthesisError::Unsatisfiable); + } + if value.is_none() { + value = Some(limbs_to_nat::(vs.iter(), limb_width)); + } + if limb_values.is_none() { + limb_values = Some(vs.clone()); + } + Ok(vs[limb_i]) + } + // Hack b/c SynthesisError and io::Error don't implement Clone + Err(ref e) => Err(SynthesisError::from(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{e}"), + ))), + }, + ) + .map(|v| LinearCombination::zero() + v) + }) + .collect::, _>>()?; + Ok(Self { + value, + limb_values, + limbs, + params: BigNatParams { + min_bits: 0, + n_limbs, + max_word: max_word.unwrap_or_else(|| int_with_n_ones(limb_width)), + limb_width, + }, + }) + } + + /// Allocates a `BigNat` in the circuit with `n_limbs` limbs of width `limb_width` each. + /// The `max_word` is gauranteed to be `(2 << limb_width) - 1`. + /// The value is provided by a closure returning a natural number. + pub fn alloc_from_nat( + mut cs: CS, + f: F, + limb_width: usize, + n_limbs: usize, + ) -> Result + where + CS: ConstraintSystem, + F: FnOnce() -> Result, + { + let all_values_cell = + f().and_then(|v| Ok((nat_to_limbs::(&v, limb_width, n_limbs)?, v))); + let mut value = None; + let mut limb_values = Vec::new(); + let limbs = (0..n_limbs) + .map(|limb_i| { + cs.alloc( + || format!("limb {limb_i}"), + || match all_values_cell { + Ok((ref vs, ref v)) => { + if value.is_none() { + value = Some(v.clone()); + } + limb_values.push(vs[limb_i]); + Ok(vs[limb_i]) + } + // Hack b/c SynthesisError and io::Error don't implement Clone + Err(ref e) => Err(SynthesisError::from(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{e}"), + ))), + }, + ) + .map(|v| LinearCombination::zero() + v) + }) + .collect::, _>>()?; + Ok(Self { + value, + limb_values: if !limb_values.is_empty() { + Some(limb_values) + } else { + None + }, + limbs, + params: BigNatParams::new(limb_width, n_limbs), + }) + } + + /// Allocates a `BigNat` in the circuit with `n_limbs` limbs of width `limb_width` each. + /// The `max_word` is gauranteed to be `(2 << limb_width) - 1`. + /// The value is provided by an allocated number + pub fn from_num>( + mut cs: CS, + n: &Num, + limb_width: usize, + n_limbs: usize, + ) -> Result { + let bignat = Self::alloc_from_nat( + cs.namespace(|| "bignat"), + || { + Ok({ + n.value + .as_ref() + .map(|n| f_to_nat(n)) + .ok_or(SynthesisError::AssignmentMissing)? + }) + }, + limb_width, + n_limbs, + )?; + + // check if bignat equals n + // (1) decompose `bignat` into a bitvector `bv` + let bv = bignat.decompose(cs.namespace(|| "bv"))?; + // (2) recompose bits and check if it equals n + n.is_equal(cs.namespace(|| "n"), &bv)?; + + Ok(bignat) + } + + pub fn as_limbs(&self) -> Vec> { + let mut limbs = Vec::new(); + for (i, lc) in self.limbs.iter().enumerate() { + limbs.push(Num::new( + self.limb_values.as_ref().map(|vs| vs[i]), + lc.clone(), + )); + } + limbs + } + + pub fn assert_well_formed>( + &self, + mut cs: CS, + ) -> Result<(), SynthesisError> { + // swap the option and iterator + let limb_values_split = + (0..self.limbs.len()).map(|i| self.limb_values.as_ref().map(|vs| vs[i])); + for (i, (limb, limb_value)) in self.limbs.iter().zip(limb_values_split).enumerate() { + Num::new(limb_value, limb.clone()) + .fits_in_bits(cs.namespace(|| format!("{i}")), self.params.limb_width)?; + } + Ok(()) + } + + /// Break `self` up into a bit-vector. + pub fn decompose>( + &self, + mut cs: CS, + ) -> Result, SynthesisError> { + let limb_values_split = + (0..self.limbs.len()).map(|i| self.limb_values.as_ref().map(|vs| vs[i])); + let bitvectors: Vec> = self + .limbs + .iter() + .zip(limb_values_split) + .enumerate() + .map(|(i, (limb, limb_value))| { + Num::new(limb_value, limb.clone()).decompose( + cs.namespace(|| format!("subdecmop {i}")), + self.params.limb_width, + ) + }) + .collect::, _>>()?; + let mut bits = Vec::new(); + let mut values = Vec::new(); + let mut allocations = Vec::new(); + for bv in bitvectors { + bits.extend(bv.bits); + if let Some(vs) = bv.values { + values.extend(vs) + }; + allocations.extend(bv.allocations); + } + let values = if !values.is_empty() { + Some(values) + } else { + None + }; + Ok(Bitvector { + bits, + values, + allocations, + }) + } + + pub fn enforce_limb_width_agreement( + &self, + other: &Self, + location: &str, + ) -> Result { + if self.params.limb_width == other.params.limb_width { + Ok(self.params.limb_width) + } else { + eprintln!( + "Limb widths {}, {}, do not agree at {}", + self.params.limb_width, other.params.limb_width, location + ); + Err(SynthesisError::Unsatisfiable) + } + } + + pub fn from_poly(poly: Polynomial, limb_width: usize, max_word: BigInt) -> Self { + Self { + params: BigNatParams { + min_bits: 0, + max_word, + n_limbs: poly.coefficients.len(), + limb_width, + }, + limbs: poly.coefficients, + value: poly + .values + .as_ref() + .map(|limb_values| limbs_to_nat::(limb_values.iter(), limb_width)), + limb_values: poly.values, + } + } + + /// Constrain `self` to be equal to `other`, after carrying both. + pub fn equal_when_carried>( + &self, + mut cs: CS, + other: &Self, + ) -> Result<(), SynthesisError> { + self.enforce_limb_width_agreement(other, "equal_when_carried")?; + + // We'll propegate carries over the first `n` limbs. + let n = min(self.limbs.len(), other.limbs.len()); + let target_base = BigInt::from(1u8) << self.params.limb_width as u32; + let mut accumulated_extra = BigInt::from(0usize); + let max_word = std::cmp::max(&self.params.max_word, &other.params.max_word); + let carry_bits = + (((max_word.to_f64().unwrap() * 2.0).log2() - self.params.limb_width as f64).ceil() + + 0.1) as usize; + let mut carry_in = Num::new(Some(Scalar::ZERO), LinearCombination::zero()); + + for i in 0..n { + let carry = Num::alloc(cs.namespace(|| format!("carry value {i}")), || { + Ok(nat_to_f( + &((f_to_nat(&self.limb_values.grab()?[i]) + + f_to_nat(&carry_in.value.unwrap()) + + max_word + - f_to_nat(&other.limb_values.grab()?[i])) + / &target_base), + ) + .unwrap()) + })?; + accumulated_extra += max_word; + + cs.enforce( + || format!("carry {i}"), + |lc| lc, + |lc| lc, + |lc| { + lc + &carry_in.num + &self.limbs[i] - &other.limbs[i] + + (nat_to_f(max_word).unwrap(), CS::one()) + - (nat_to_f(&target_base).unwrap(), &carry.num) + - ( + nat_to_f(&(&accumulated_extra % &target_base)).unwrap(), + CS::one(), + ) + }, + ); + + accumulated_extra /= &target_base; + + if i < n - 1 { + carry.fits_in_bits(cs.namespace(|| format!("carry {i} decomp")), carry_bits)?; + } else { + cs.enforce( + || format!("carry {i} is out"), + |lc| lc, + |lc| lc, + |lc| lc + &carry.num - (nat_to_f(&accumulated_extra).unwrap(), CS::one()), + ); + } + carry_in = carry; + } + + for (i, zero_limb) in self.limbs.iter().enumerate().skip(n) { + cs.enforce( + || format!("zero self {i}"), + |lc| lc, + |lc| lc, + |lc| lc + zero_limb, + ); + } + for (i, zero_limb) in other.limbs.iter().enumerate().skip(n) { + cs.enforce( + || format!("zero other {i}"), + |lc| lc, + |lc| lc, + |lc| lc + zero_limb, + ); + } + Ok(()) + } + + /// Constrain `self` to be equal to `other`, after carrying both. + /// Uses regrouping internally to take full advantage of the field size and reduce the amount + /// of carrying. + pub fn equal_when_carried_regroup>( + &self, + mut cs: CS, + other: &Self, + ) -> Result<(), SynthesisError> { + self.enforce_limb_width_agreement(other, "equal_when_carried_regroup")?; + let max_word = std::cmp::max(&self.params.max_word, &other.params.max_word); + let carry_bits = + (((max_word.to_f64().unwrap() * 2.0).log2() - self.params.limb_width as f64).ceil() + + 0.1) as usize; + let limbs_per_group = (Scalar::CAPACITY as usize - carry_bits) / self.params.limb_width; + let self_grouped = self.group_limbs(limbs_per_group); + let other_grouped = other.group_limbs(limbs_per_group); + self_grouped.equal_when_carried(cs.namespace(|| "grouped"), &other_grouped) + } + + pub fn add(&self, other: &Self) -> Result, SynthesisError> { + self.enforce_limb_width_agreement(other, "add")?; + let n_limbs = max(self.params.n_limbs, other.params.n_limbs); + let max_word = &self.params.max_word + &other.params.max_word; + let limbs: Vec> = (0..n_limbs) + .map(|i| match (self.limbs.get(i), other.limbs.get(i)) { + (Some(a), Some(b)) => a.clone() + b, + (Some(a), None) => a.clone(), + (None, Some(b)) => b.clone(), + (None, None) => unreachable!(), + }) + .collect(); + let limb_values: Option> = self.limb_values.as_ref().and_then(|x| { + other.limb_values.as_ref().map(|y| { + (0..n_limbs) + .map(|i| match (x.get(i), y.get(i)) { + (Some(a), Some(b)) => { + let mut t = *a; + t.add_assign(b); + t + } + (Some(a), None) => *a, + (None, Some(a)) => *a, + (None, None) => unreachable!(), + }) + .collect() + }) + }); + let value = self + .value + .as_ref() + .and_then(|x| other.value.as_ref().map(|y| x + y)); + Ok(Self { + limb_values, + value, + limbs, + params: BigNatParams { + min_bits: max(self.params.min_bits, other.params.min_bits), + n_limbs, + max_word, + limb_width: self.params.limb_width, + }, + }) + } + + /// Compute a `BigNat` contrained to be equal to `self * other % modulus`. + pub fn mult_mod>( + &self, + mut cs: CS, + other: &Self, + modulus: &Self, + ) -> Result<(BigNat, BigNat), SynthesisError> { + self.enforce_limb_width_agreement(other, "mult_mod")?; + let limb_width = self.params.limb_width; + let quotient_bits = + (self.n_bits() + other.n_bits()).saturating_sub(modulus.params.min_bits); + let quotient_limbs = quotient_bits.saturating_sub(1) / limb_width + 1; + let quotient = BigNat::alloc_from_nat( + cs.namespace(|| "quotient"), + || { + Ok({ + let mut x = self.value.grab()?.clone(); + x *= other.value.grab()?; + x /= modulus.value.grab()?; + x + }) + }, + self.params.limb_width, + quotient_limbs, + )?; + quotient.assert_well_formed(cs.namespace(|| "quotient rangecheck"))?; + let remainder = BigNat::alloc_from_nat( + cs.namespace(|| "remainder"), + || { + Ok({ + let mut x = self.value.grab()?.clone(); + x *= other.value.grab()?; + x %= modulus.value.grab()?; + x + }) + }, + self.params.limb_width, + modulus.limbs.len(), + )?; + remainder.assert_well_formed(cs.namespace(|| "remainder rangecheck"))?; + let a_poly = Polynomial::from(self.clone()); + let b_poly = Polynomial::from(other.clone()); + let mod_poly = Polynomial::from(modulus.clone()); + let q_poly = Polynomial::from(quotient.clone()); + let r_poly = Polynomial::from(remainder.clone()); + + // a * b + let left = a_poly.alloc_product(cs.namespace(|| "left"), &b_poly)?; + let right_product = q_poly.alloc_product(cs.namespace(|| "right_product"), &mod_poly)?; + // q * m + r + let right = right_product.sum(&r_poly); + + let left_max_word = { + let mut x = BigInt::from(min(self.limbs.len(), other.limbs.len())); + x *= &self.params.max_word; + x *= &other.params.max_word; + x + }; + let right_max_word = { + let mut x = BigInt::from(std::cmp::min(quotient.limbs.len(), modulus.limbs.len())); + x *= "ient.params.max_word; + x *= &modulus.params.max_word; + x += &remainder.params.max_word; + x + }; + + let left_int = BigNat::from_poly(left, limb_width, left_max_word); + let right_int = BigNat::from_poly(right, limb_width, right_max_word); + left_int.equal_when_carried_regroup(cs.namespace(|| "carry"), &right_int)?; + Ok((quotient, remainder)) + } + + /// Compute a `BigNat` contrained to be equal to `self * other % modulus`. + pub fn red_mod>( + &self, + mut cs: CS, + modulus: &Self, + ) -> Result, SynthesisError> { + self.enforce_limb_width_agreement(modulus, "red_mod")?; + let limb_width = self.params.limb_width; + let quotient_bits = self.n_bits().saturating_sub(modulus.params.min_bits); + let quotient_limbs = quotient_bits.saturating_sub(1) / limb_width + 1; + let quotient = BigNat::alloc_from_nat( + cs.namespace(|| "quotient"), + || Ok(self.value.grab()? / modulus.value.grab()?), + self.params.limb_width, + quotient_limbs, + )?; + quotient.assert_well_formed(cs.namespace(|| "quotient rangecheck"))?; + let remainder = BigNat::alloc_from_nat( + cs.namespace(|| "remainder"), + || Ok(self.value.grab()? % modulus.value.grab()?), + self.params.limb_width, + modulus.limbs.len(), + )?; + remainder.assert_well_formed(cs.namespace(|| "remainder rangecheck"))?; + let mod_poly = Polynomial::from(modulus.clone()); + let q_poly = Polynomial::from(quotient.clone()); + let r_poly = Polynomial::from(remainder.clone()); + + // q * m + r + let right_product = q_poly.alloc_product(cs.namespace(|| "right_product"), &mod_poly)?; + let right = right_product.sum(&r_poly); + + let right_max_word = { + let mut x = BigInt::from(std::cmp::min(quotient.limbs.len(), modulus.limbs.len())); + x *= "ient.params.max_word; + x *= &modulus.params.max_word; + x += &remainder.params.max_word; + x + }; + + let right_int = BigNat::from_poly(right, limb_width, right_max_word); + self.equal_when_carried_regroup(cs.namespace(|| "carry"), &right_int)?; + Ok(remainder) + } + + /// Combines limbs into groups. + pub fn group_limbs(&self, limbs_per_group: usize) -> BigNat { + let n_groups = (self.limbs.len() - 1) / limbs_per_group + 1; + let limb_values = self.limb_values.as_ref().map(|vs| { + let mut values: Vec = vec![Scalar::ZERO; n_groups]; + let mut shift = Scalar::ONE; + let limb_block = (0..self.params.limb_width).fold(Scalar::ONE, |mut l, _| { + l = l.double(); + l + }); + for (i, v) in vs.iter().enumerate() { + if i % limbs_per_group == 0 { + shift = Scalar::ONE; + } + let mut a = shift; + a *= v; + values[i / limbs_per_group].add_assign(&a); + shift.mul_assign(&limb_block); + } + values + }); + let limbs = { + let mut limbs: Vec> = + vec![LinearCombination::zero(); n_groups]; + let mut shift = Scalar::ONE; + let limb_block = (0..self.params.limb_width).fold(Scalar::ONE, |mut l, _| { + l = l.double(); + l + }); + for (i, limb) in self.limbs.iter().enumerate() { + if i % limbs_per_group == 0 { + shift = Scalar::ONE; + } + limbs[i / limbs_per_group] = + std::mem::replace(&mut limbs[i / limbs_per_group], LinearCombination::zero()) + + (shift, limb); + shift.mul_assign(&limb_block); + } + limbs + }; + let max_word = (0..limbs_per_group).fold(BigInt::from(0u8), |mut acc, i| { + acc.set_bit((i * self.params.limb_width) as u64, true); + acc + }) * &self.params.max_word; + BigNat { + params: BigNatParams { + min_bits: self.params.min_bits, + limb_width: self.params.limb_width * limbs_per_group, + n_limbs: limbs.len(), + max_word, + }, + limbs, + limb_values, + value: self.value.clone(), + } + } + + pub fn n_bits(&self) -> usize { + assert!(self.params.n_limbs > 0); + self.params.limb_width * (self.params.n_limbs - 1) + self.params.max_word.bits() as usize + } +} + +pub struct Polynomial { + pub coefficients: Vec>, + pub values: Option>, +} + +impl Polynomial { + pub fn alloc_product>( + &self, + mut cs: CS, + other: &Self, + ) -> Result, SynthesisError> { + let n_product_coeffs = self.coefficients.len() + other.coefficients.len() - 1; + let values = self.values.as_ref().and_then(|self_vs| { + other.values.as_ref().map(|other_vs| { + let mut values: Vec = std::iter::repeat_with(|| Scalar::ZERO) + .take(n_product_coeffs) + .collect(); + for (self_i, self_v) in self_vs.iter().enumerate() { + for (other_i, other_v) in other_vs.iter().enumerate() { + let mut v = *self_v; + v.mul_assign(other_v); + values[self_i + other_i].add_assign(&v); + } + } + values + }) + }); + let coefficients = (0..n_product_coeffs) + .map(|i| { + Ok(LinearCombination::zero() + + cs.alloc(|| format!("prod {i}"), || Ok(values.grab()?[i]))?) + }) + .collect::>, SynthesisError>>()?; + let product = Polynomial { + coefficients, + values, + }; + let one = Scalar::ONE; + let mut x = Scalar::ZERO; + for _ in 1..(n_product_coeffs + 1) { + x.add_assign(&one); + cs.enforce( + || format!("pointwise product @ {x:?}"), + |lc| { + let mut i = Scalar::ONE; + self.coefficients.iter().fold(lc, |lc, c| { + let r = lc + (i, c); + i.mul_assign(&x); + r + }) + }, + |lc| { + let mut i = Scalar::ONE; + other.coefficients.iter().fold(lc, |lc, c| { + let r = lc + (i, c); + i.mul_assign(&x); + r + }) + }, + |lc| { + let mut i = Scalar::ONE; + product.coefficients.iter().fold(lc, |lc, c| { + let r = lc + (i, c); + i.mul_assign(&x); + r + }) + }, + ) + } + Ok(product) + } + + pub fn sum(&self, other: &Self) -> Self { + let n_coeffs = max(self.coefficients.len(), other.coefficients.len()); + let values = self.values.as_ref().and_then(|self_vs| { + other.values.as_ref().map(|other_vs| { + (0..n_coeffs) + .map(|i| { + let mut s = Scalar::ZERO; + if i < self_vs.len() { + s.add_assign(&self_vs[i]); + } + if i < other_vs.len() { + s.add_assign(&other_vs[i]); + } + s + }) + .collect() + }) + }); + let coefficients = (0..n_coeffs) + .map(|i| { + let mut lc = LinearCombination::zero(); + if i < self.coefficients.len() { + lc = lc + &self.coefficients[i]; + } + if i < other.coefficients.len() { + lc = lc + &other.coefficients[i]; + } + lc + }) + .collect(); + Polynomial { + coefficients, + values, + } + } +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use bellperson::{gadgets::test::TestConstraintSystem, Circuit}; +// use pasta_curves::pallas::Scalar; +// use proptest::prelude::*; + +// pub struct PolynomialMultiplier { +// pub a: Vec, +// pub b: Vec, +// } + +// impl Circuit for PolynomialMultiplier { +// fn synthesize>( +// self, +// cs: &mut CS, +// ) -> Result<(), SynthesisError> { +// let a = Polynomial { +// coefficients: self +// .a +// .iter() +// .enumerate() +// .map(|(i, x)| { +// Ok(LinearCombination::zero() +// + cs.alloc(|| format!("coeff_a {i}"), || Ok(*x))?) +// }) +// .collect::>, SynthesisError>>()?, +// values: Some(self.a), +// }; +// let b = Polynomial { +// coefficients: self +// .b +// .iter() +// .enumerate() +// .map(|(i, x)| { +// Ok(LinearCombination::zero() +// + cs.alloc(|| format!("coeff_b {i}"), || Ok(*x))?) +// }) +// .collect::>, SynthesisError>>()?, +// values: Some(self.b), +// }; +// let _prod = a.alloc_product(cs.namespace(|| "product"), &b)?; +// Ok(()) +// } +// } + +// #[test] +// fn test_polynomial_multiplier_circuit() { +// let mut cs = TestConstraintSystem::::new(); + +// let circuit = PolynomialMultiplier { +// a: [1, 1, 1].iter().map(|i| Scalar::from_u128(*i)).collect(), +// b: [1, 1].iter().map(|i| Scalar::from_u128(*i)).collect(), +// }; + +// circuit.synthesize(&mut cs).expect("synthesis failed"); + +// if let Some(token) = cs.which_is_unsatisfied() { +// eprintln!("Error: {} is unsatisfied", token); +// } +// } + +// #[derive(Debug)] +// pub struct BigNatBitDecompInputs { +// pub n: BigInt, +// } + +// pub struct BigNatBitDecompParams { +// pub limb_width: usize, +// pub n_limbs: usize, +// } + +// pub struct BigNatBitDecomp { +// inputs: Option, +// params: BigNatBitDecompParams, +// } + +// impl Circuit for BigNatBitDecomp { +// fn synthesize>( +// self, +// cs: &mut CS, +// ) -> Result<(), SynthesisError> { +// let n = BigNat::alloc_from_nat( +// cs.namespace(|| "n"), +// || Ok(self.inputs.grab()?.n.clone()), +// self.params.limb_width, +// self.params.n_limbs, +// )?; +// n.decompose(cs.namespace(|| "decomp"))?; +// Ok(()) +// } +// } + +// proptest! { + +// #![proptest_config(ProptestConfig { +// cases: 10, // this test is costlier as max n gets larger +// .. ProptestConfig::default() +// })] +// #[test] +// fn test_big_nat_can_decompose(n in any::(), limb_width in 40u8..200) { +// let n = n as usize; + +// let n_limbs = if n == 0 { +// 1 +// } else { +// (n - 1) / limb_width as usize + 1 +// }; + +// let circuit = BigNatBitDecomp { +// inputs: Some(BigNatBitDecompInputs { +// n: BigInt::from(n), +// }), +// params: BigNatBitDecompParams { +// limb_width: limb_width as usize, +// n_limbs, +// }, +// }; +// let mut cs = TestConstraintSystem::::new(); +// circuit.synthesize(&mut cs).expect("synthesis failed"); +// prop_assert!(cs.is_satisfied()); +// } +// } +// } diff --git a/nova/src/nonnative/mod.rs b/nova/src/nonnative/mod.rs new file mode 100644 index 0000000000..2af0b6251c --- /dev/null +++ b/nova/src/nonnative/mod.rs @@ -0,0 +1,41 @@ +//! This module implements various low-level gadgets, which is copy from Nova non-public mod +//! https://github.com/microsoft/Nova/blob/main/src/gadgets/mod.rs#L3 +//! This module implements various gadgets necessary for doing non-native arithmetic +//! Code in this module is adapted from [bellman-bignat](https://github.com/alex-ozdemir/bellman-bignat), which is licenced under MIT + +use bellpepper_core::SynthesisError; +use ff::PrimeField; + +trait OptionExt { + fn grab(&self) -> Result<&T, SynthesisError>; + fn grab_mut(&mut self) -> Result<&mut T, SynthesisError>; +} + +impl OptionExt for Option { + fn grab(&self) -> Result<&T, SynthesisError> { + self.as_ref().ok_or(SynthesisError::AssignmentMissing) + } + fn grab_mut(&mut self) -> Result<&mut T, SynthesisError> { + self.as_mut().ok_or(SynthesisError::AssignmentMissing) + } +} + +trait BitAccess { + fn get_bit(&self, i: usize) -> Option; +} + +impl BitAccess for Scalar { + fn get_bit(&self, i: usize) -> Option { + if i as u32 >= Scalar::NUM_BITS { + return None; + } + + let (byte_pos, bit_pos) = (i / 8, i % 8); + let byte = self.to_repr().as_ref()[byte_pos]; + let bit = byte >> bit_pos & 1; + Some(bit == 1) + } +} + +pub mod bignat; +pub mod util; diff --git a/nova/src/nonnative/util.rs b/nova/src/nonnative/util.rs new file mode 100644 index 0000000000..54156fb630 --- /dev/null +++ b/nova/src/nonnative/util.rs @@ -0,0 +1,261 @@ +use super::{BitAccess, OptionExt}; +use bellpepper_core::{ + num::AllocatedNum, + {ConstraintSystem, LinearCombination, SynthesisError, Variable}, +}; +// use byteorder::WriteBytesExt; +use ff::{derive::byteorder::WriteBytesExt, PrimeField}; +use num_bigint::{BigInt, Sign}; +use std::convert::From; +use std::io::{self, Write}; + +#[derive(Clone)] +/// A representation of a bit +pub struct Bit { + /// The linear combination which constrain the value of the bit + pub bit: LinearCombination, + /// The value of the bit (filled at witness-time) + pub value: Option, +} + +#[derive(Clone)] +/// A representation of a bit-vector +pub struct Bitvector { + /// The linear combination which constrain the values of the bits + pub bits: Vec>, + /// The value of the bits (filled at witness-time) + pub values: Option>, + /// Allocated bit variables + pub allocations: Vec>, +} + +impl Bit { + /// Allocate a variable in the constraint system which can only be a + /// boolean value. + pub fn alloc(mut cs: CS, value: Option) -> Result + where + CS: ConstraintSystem, + { + let var = cs.alloc( + || "boolean", + || { + if *value.grab()? { + Ok(Scalar::ONE) + } else { + Ok(Scalar::ZERO) + } + }, + )?; + + // Constrain: (1 - a) * a = 0 + // This constrains a to be either 0 or 1. + cs.enforce( + || "boolean constraint", + |lc| lc + CS::one() - var, + |lc| lc + var, + |lc| lc, + ); + + Ok(Self { + bit: LinearCombination::zero() + var, + value, + }) + } +} + +pub struct Num { + pub num: LinearCombination, + pub value: Option, +} + +impl Num { + pub fn new(value: Option, num: LinearCombination) -> Self { + Self { value, num } + } + pub fn alloc(mut cs: CS, value: F) -> Result + where + CS: ConstraintSystem, + F: FnOnce() -> Result, + { + let mut new_value = None; + let var = cs.alloc( + || "num", + || { + let tmp = value()?; + + new_value = Some(tmp); + + Ok(tmp) + }, + )?; + + Ok(Num { + value: new_value, + num: LinearCombination::zero() + var, + }) + } + + pub fn fits_in_bits>( + &self, + mut cs: CS, + n_bits: usize, + ) -> Result<(), SynthesisError> { + let v = self.value; + + // Allocate all but the first bit. + let bits: Vec = (1..n_bits) + .map(|i| { + cs.alloc( + || format!("bit {i}"), + || { + let r = if *v.grab()?.get_bit(i).grab()? { + Scalar::ONE + } else { + Scalar::ZERO + }; + Ok(r) + }, + ) + }) + .collect::>()?; + + for (i, v) in bits.iter().enumerate() { + cs.enforce( + || format!("{i} is bit"), + |lc| lc + *v, + |lc| lc + CS::one() - *v, + |lc| lc, + ) + } + + // Last bit + cs.enforce( + || "last bit", + |mut lc| { + let mut f = Scalar::ONE; + lc = lc + &self.num; + for v in bits.iter() { + f = f.double(); + lc = lc - (f, *v); + } + lc + }, + |mut lc| { + lc = lc + CS::one(); + let mut f = Scalar::ONE; + lc = lc - &self.num; + for v in bits.iter() { + f = f.double(); + lc = lc + (f, *v); + } + lc + }, + |lc| lc, + ); + Ok(()) + } + + /// Computes the natural number represented by an array of bits. + /// Checks if the natural number equals `self` + pub fn is_equal>( + &self, + mut cs: CS, + other: &Bitvector, + ) -> Result<(), SynthesisError> { + let allocations = other.allocations.clone(); + let mut f = Scalar::ONE; + let sum = allocations + .iter() + .fold(LinearCombination::zero(), |lc, bit| { + let l = lc + (f, &bit.bit); + f = f.double(); + l + }); + let sum_lc = LinearCombination::zero() + &self.num - ∑ + cs.enforce(|| "sum", |lc| lc + &sum_lc, |lc| lc + CS::one(), |lc| lc); + Ok(()) + } + + /// Compute the natural number represented by an array of limbs. + /// The limbs are assumed to be based the `limb_width` power of 2. + /// Low-index bits are low-order + pub fn decompose>( + &self, + mut cs: CS, + n_bits: usize, + ) -> Result, SynthesisError> { + let values: Option> = self.value.as_ref().map(|v| { + let num = *v; + (0..n_bits).map(|i| num.get_bit(i).unwrap()).collect() + }); + let allocations: Vec> = (0..n_bits) + .map(|bit_i| { + Bit::alloc( + cs.namespace(|| format!("bit{bit_i}")), + values.as_ref().map(|vs| vs[bit_i]), + ) + }) + .collect::, _>>()?; + let mut f = Scalar::ONE; + let sum = allocations + .iter() + .fold(LinearCombination::zero(), |lc, bit| { + let l = lc + (f, &bit.bit); + f = f.double(); + l + }); + let sum_lc = LinearCombination::zero() + &self.num - ∑ + cs.enforce(|| "sum", |lc| lc + &sum_lc, |lc| lc + CS::one(), |lc| lc); + let bits: Vec> = allocations + .clone() + .into_iter() + .map(|a| LinearCombination::zero() + &a.bit) + .collect(); + Ok(Bitvector { + allocations, + values, + bits, + }) + } + + pub fn as_allocated_num>( + &self, + mut cs: CS, + ) -> Result, SynthesisError> { + let new = AllocatedNum::alloc(cs.namespace(|| "alloc"), || Ok(*self.value.grab()?))?; + cs.enforce( + || "eq", + |lc| lc, + |lc| lc, + |lc| lc + new.get_variable() - &self.num, + ); + Ok(new) + } +} + +impl From> for Num { + fn from(a: AllocatedNum) -> Self { + Self::new(a.get_value(), LinearCombination::zero() + a.get_variable()) + } +} + +fn write_be(f: &F, mut writer: W) -> io::Result<()> { + for digit in f.to_repr().as_ref().iter().rev() { + writer.write_u8(*digit)?; + } + + Ok(()) +} + +/// Convert a field element to a natural number +pub fn f_to_nat(f: &Scalar) -> BigInt { + let mut s = Vec::new(); + write_be(f, &mut s).unwrap(); // f.to_repr().write_be(&mut s).unwrap(); + BigInt::from_bytes_le(Sign::Plus, f.to_repr().as_ref()) +} + +/// Convert a natural number to a field element. +/// Returns `None` if the number is too big for the field. +pub fn nat_to_f(n: &BigInt) -> Option { + Scalar::from_str_vartime(&format!("{n}")) +} diff --git a/nova/src/prover.rs b/nova/src/prover.rs new file mode 100644 index 0000000000..529d2aadfe --- /dev/null +++ b/nova/src/prover.rs @@ -0,0 +1,24 @@ +use ast::{analyzed::Analyzed, asm_analysis::Machine}; +use number::FieldElement; + +use crate::circuit_builder::nova_prove; + +pub fn prove_ast_read_params( + pil: &Analyzed, + main_machine: &Machine, + fixed: Vec<(&str, Vec)>, + witness: Vec<(&str, Vec)>, + public_io: Vec, +) { + prove_ast(pil, main_machine, fixed, witness, public_io) +} +pub fn prove_ast( + pil: &Analyzed, + main_machine: &Machine, + fixed: Vec<(&str, Vec)>, + witness: Vec<(&str, Vec)>, + public_io: Vec, +) { + log::info!("Starting proof generation..."); + nova_prove(pil, main_machine, fixed, witness, public_io); +} diff --git a/nova/src/utils.rs b/nova/src/utils.rs new file mode 100644 index 0000000000..b263e4bbb6 --- /dev/null +++ b/nova/src/utils.rs @@ -0,0 +1,797 @@ +//! This module implements various util function, which is copy from Nova non-public mod +//! https://github.com/microsoft/Nova/blob/main/src/gadgets/mod.rs#L5 +use std::{ + collections::BTreeMap, + sync::{Arc, Mutex}, +}; + +use crate::nonnative::util::Num; + +use super::nonnative::bignat::{nat_to_limbs, BigNat}; +use ast::{ + analyzed::{Expression, PolynomialReference}, + parsed::BinaryOperator, +}; +use bellpepper::gadgets::Assignment; +use bellpepper_core::{ + boolean::{AllocatedBit, Boolean}, + num::AllocatedNum, + ConstraintSystem, LinearCombination, SynthesisError, +}; +use ff::{Field, PrimeField, PrimeFieldBits}; +use nova_snark::traits::{Group, PrimeFieldExt}; +use num_bigint::BigInt; +use number::FieldElement; + +/// Gets as input the little indian representation of a number and spits out the number +pub fn le_bits_to_num( + mut cs: CS, + bits: &[AllocatedBit], +) -> Result, SynthesisError> +where + Scalar: PrimeField + PrimeFieldBits, + CS: ConstraintSystem, +{ + // We loop over the input bits and construct the constraint + // and the field element that corresponds to the result + let mut lc = LinearCombination::zero(); + let mut coeff = Scalar::ONE; + let mut fe = Some(Scalar::ZERO); + for bit in bits.iter() { + lc = lc + (coeff, bit.get_variable()); + fe = bit.get_value().map(|val| { + if val { + fe.unwrap() + coeff + } else { + fe.unwrap() + } + }); + coeff = coeff.double(); + } + let num = AllocatedNum::alloc(cs.namespace(|| "Field element"), || { + fe.ok_or(SynthesisError::AssignmentMissing) + })?; + lc = lc - num.get_variable(); + cs.enforce(|| "compute number from bits", |lc| lc, |lc| lc, |_| lc); + Ok(num) +} + +/// Allocate a variable that is set to zero +pub fn alloc_zero>( + mut cs: CS, +) -> Result, SynthesisError> { + let zero = AllocatedNum::alloc(cs.namespace(|| "alloc"), || Ok(F::ZERO))?; + cs.enforce( + || "check zero is valid", + |lc| lc, + |lc| lc, + |lc| lc + zero.get_variable(), + ); + Ok(zero) +} + +/// Allocate a variable that is set to one +pub fn alloc_one>( + mut cs: CS, +) -> Result, SynthesisError> { + let one = AllocatedNum::alloc(cs.namespace(|| "alloc"), || Ok(F::ONE))?; + cs.enforce( + || "check one is valid", + |lc| lc + CS::one(), + |lc| lc + CS::one(), + |lc| lc + one.get_variable(), + ); + + Ok(one) +} + +/// alloc a field as a constant +/// implemented refer from https://github.com/lurk-lab/lurk-rs/blob/4335fbb3290ed1a1176e29428f7daacb47f8033d/src/circuit/gadgets/data.rs#L387-L402 +pub fn alloc_const>( + mut cs: CS, + val: F, +) -> Result, SynthesisError> { + let allocated = AllocatedNum::::alloc(cs.namespace(|| "allocate const"), || Ok(val))?; + + // allocated * 1 = val + cs.enforce( + || "enforce constant", + |lc| lc + allocated.get_variable(), + |lc| lc + CS::one(), + |_| Boolean::Constant(true).lc(CS::one(), val), + ); + + Ok(allocated) +} + +/// Allocate incremental integers within range [start, end) as vector of AllocatedNum +pub fn alloc_incremental_range_index>( + mut cs: CS, + start: usize, + len: usize, +) -> Result>, SynthesisError> { + if len == 0 { + return Ok(vec![]); + } + + let one = alloc_one(cs.namespace(|| "one"))?; + + let mut res_vec = if start == 0 { + vec![alloc_zero(cs.namespace(|| "zero"))?] + } else { + vec![alloc_const( + cs.namespace(|| format!("start {}", start)), + F::from(start as u64), + )?] + }; + + let _ = (start + 1..len).try_fold(&mut res_vec, |res_vec, i| { + let new_acc = add_allocated_num( + cs.namespace(|| format!("{}", i)), + res_vec.last().unwrap(), + &one, + )?; + res_vec.push(new_acc); + Ok::<&mut Vec>, SynthesisError>(res_vec) + })?; + Ok(res_vec) +} + +/// Allocate a scalar as a base. Only to be used is the scalar fits in base! +pub fn alloc_scalar_as_base( + mut cs: CS, + input: Option, +) -> Result, SynthesisError> +where + G: Group, + ::Scalar: PrimeFieldBits, + CS: ConstraintSystem<::Base>, +{ + AllocatedNum::alloc(cs.namespace(|| "allocate scalar as base"), || { + let input_bits = input.unwrap_or(G::Scalar::ZERO).clone().to_le_bits(); + let mut mult = G::Base::ONE; + let mut val = G::Base::ZERO; + for bit in input_bits { + if bit { + val += mult; + } + mult = mult + mult; + } + Ok(val) + }) +} + +/// interepret scalar as base +pub fn scalar_as_base(input: G::Scalar) -> G::Base { + let input_bits = input.to_le_bits(); + let mut mult = G::Base::ONE; + let mut val = G::Base::ZERO; + for bit in input_bits { + if bit { + val += mult; + } + mult = mult + mult; + } + val +} + +/// Allocate bignat a constant +pub fn alloc_bignat_constant>( + mut cs: CS, + val: &BigInt, + limb_width: usize, + n_limbs: usize, +) -> Result, SynthesisError> { + let limbs = nat_to_limbs(val, limb_width, n_limbs).unwrap(); + let bignat = BigNat::alloc_from_limbs( + cs.namespace(|| "alloc bignat"), + || Ok(limbs.clone()), + None, + limb_width, + n_limbs, + )?; + // Now enforce that the limbs are all equal to the constants + #[allow(clippy::needless_range_loop)] + for i in 0..n_limbs { + cs.enforce( + || format!("check limb {i}"), + |lc| lc + &bignat.limbs[i], + |lc| lc + CS::one(), + |lc| lc + (limbs[i], CS::one()), + ); + } + Ok(bignat) +} + +/// Check that two numbers are equal and return a bit +pub fn alloc_num_equals>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, +) -> Result { + // Allocate and constrain `r`: result boolean bit. + // It equals `true` if `a` equals `b`, `false` otherwise + let r_value = match (a.get_value(), b.get_value()) { + (Some(a), Some(b)) => Some(a == b), + _ => None, + }; + + let r = AllocatedBit::alloc(cs.namespace(|| "r"), r_value)?; + + // Allocate t s.t. t=1 if z1 == z2 else 1/(z1 - z2) + + let t = AllocatedNum::alloc(cs.namespace(|| "t"), || { + Ok(if *a.get_value().get()? == *b.get_value().get()? { + F::ONE + } else { + (*a.get_value().get()? - *b.get_value().get()?) + .invert() + .unwrap() + }) + })?; + + cs.enforce( + || "t*(a - b) = 1 - r", + |lc| lc + t.get_variable(), + |lc| lc + a.get_variable() - b.get_variable(), + |lc| lc + CS::one() - r.get_variable(), + ); + + cs.enforce( + || "r*(a - b) = 0", + |lc| lc + r.get_variable(), + |lc| lc + a.get_variable() - b.get_variable(), + |lc| lc, + ); + + Ok(r) +} + +/// If condition return a otherwise b +pub fn conditionally_select>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, + condition: &Boolean, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? { + Ok(*a.get_value().get()?) + } else { + Ok(*b.get_value().get()?) + } + })?; + + // a * condition + b*(1-condition) = c -> + // a * condition - b*condition = c - b + cs.enforce( + || "conditional select constraint", + |lc| lc + a.get_variable() - b.get_variable(), + |_| condition.lc(CS::one(), F::ONE), + |lc| lc + c.get_variable() - b.get_variable(), + ); + + Ok(c) +} + +/// If condition return a otherwise b +pub fn conditionally_select_vec>( + mut cs: CS, + a: &[AllocatedNum], + b: &[AllocatedNum], + condition: &Boolean, +) -> Result>, SynthesisError> { + a.iter() + .zip(b.iter()) + .enumerate() + .map(|(i, (a, b))| { + conditionally_select(cs.namespace(|| format!("select_{i}")), a, b, condition) + }) + .collect::>, SynthesisError>>() +} + +/// If condition return a otherwise b where a and b are BigNats +pub fn conditionally_select_bignat>( + mut cs: CS, + a: &BigNat, + b: &BigNat, + condition: &Boolean, +) -> Result, SynthesisError> { + assert!(a.limbs.len() == b.limbs.len()); + let c = BigNat::alloc_from_nat( + cs.namespace(|| "conditional select result"), + || { + if *condition.get_value().get()? { + Ok(a.value.get()?.clone()) + } else { + Ok(b.value.get()?.clone()) + } + }, + a.params.limb_width, + a.params.n_limbs, + )?; + + // a * condition + b*(1-condition) = c -> + // a * condition - b*condition = c - b + for i in 0..c.limbs.len() { + cs.enforce( + || format!("conditional select constraint {i}"), + |lc| lc + &a.limbs[i] - &b.limbs[i], + |_| condition.lc(CS::one(), F::ONE), + |lc| lc + &c.limbs[i] - &b.limbs[i], + ); + } + Ok(c) +} + +/// Same as the above but Condition is an AllocatedNum that needs to be +/// 0 or 1. 1 => True, 0 => False +pub fn conditionally_select2>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, + condition: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? == F::ONE { + Ok(*a.get_value().get()?) + } else { + Ok(*b.get_value().get()?) + } + })?; + + // a * condition + b*(1-condition) = c -> + // a * condition - b*condition = c - b + cs.enforce( + || "conditional select constraint", + |lc| lc + a.get_variable() - b.get_variable(), + |lc| lc + condition.get_variable(), + |lc| lc + c.get_variable() - b.get_variable(), + ); + + Ok(c) +} + +/// If condition set to 0 otherwise a. Condition is an allocated num +pub fn select_zero_or_num2>( + mut cs: CS, + a: &AllocatedNum, + condition: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? == F::ONE { + Ok(F::ZERO) + } else { + Ok(*a.get_value().get()?) + } + })?; + + // a * (1 - condition) = c + cs.enforce( + || "conditional select constraint", + |lc| lc + a.get_variable(), + |lc| lc + CS::one() - condition.get_variable(), + |lc| lc + c.get_variable(), + ); + + Ok(c) +} + +/// If condition set to a otherwise 0. Condition is an allocated num +pub fn select_num_or_zero2>( + mut cs: CS, + a: &AllocatedNum, + condition: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? == F::ONE { + Ok(*a.get_value().get()?) + } else { + Ok(F::ZERO) + } + })?; + + cs.enforce( + || "conditional select constraint", + |lc| lc + a.get_variable(), + |lc| lc + condition.get_variable(), + |lc| lc + c.get_variable(), + ); + + Ok(c) +} + +/// If condition set to a otherwise 0 +pub fn select_num_or_zero>( + mut cs: CS, + a: &AllocatedNum, + condition: &Boolean, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? { + Ok(*a.get_value().get()?) + } else { + Ok(F::ZERO) + } + })?; + + cs.enforce( + || "conditional select constraint", + |lc| lc + a.get_variable(), + |_| condition.lc(CS::one(), F::ONE), + |lc| lc + c.get_variable(), + ); + + Ok(c) +} + +/// If condition set to 1 otherwise a +pub fn select_one_or_num2>( + mut cs: CS, + a: &AllocatedNum, + condition: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? == F::ONE { + Ok(F::ONE) + } else { + Ok(*a.get_value().get()?) + } + })?; + + cs.enforce( + || "conditional select constraint", + |lc| lc + CS::one() - a.get_variable(), + |lc| lc + condition.get_variable(), + |lc| lc + c.get_variable() - a.get_variable(), + ); + Ok(c) +} + +/// If condition set to 1 otherwise a - b +pub fn select_one_or_diff2>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, + condition: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? == F::ONE { + Ok(F::ONE) + } else { + Ok(*a.get_value().get()? - *b.get_value().get()?) + } + })?; + + cs.enforce( + || "conditional select constraint", + |lc| lc + CS::one() - a.get_variable() + b.get_variable(), + |lc| lc + condition.get_variable(), + |lc| lc + c.get_variable() - a.get_variable() + b.get_variable(), + ); + Ok(c) +} + +/// If condition set to a otherwise 1 for boolean conditions +pub fn select_num_or_one>( + mut cs: CS, + a: &AllocatedNum, + condition: &Boolean, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || { + if *condition.get_value().get()? { + Ok(*a.get_value().get()?) + } else { + Ok(F::ONE) + } + })?; + + cs.enforce( + || "conditional select constraint", + |lc| lc + a.get_variable() - CS::one(), + |_| condition.lc(CS::one(), F::ONE), + |lc| lc + c.get_variable() - CS::one(), + ); + + Ok(c) +} + +/// c = a + b where a, b is AllocatedNum +pub fn add_allocated_num>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "c"), || { + Ok(*a.get_value().get()? + b.get_value().get()?) + })?; + cs.enforce( + || "c = a + b", + |lc| lc + a.get_variable() + b.get_variable(), + |lc| lc + CS::one(), + |lc| lc + c.get_variable(), + ); + Ok(c) +} + +/// c = a * b where a, b is AllocatedNum +pub fn mul_allocated_num>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "c"), || { + Ok(*a.get_value().get()? * b.get_value().get()?) + })?; + cs.enforce( + || "c = a * b", + |lc| lc + a.get_variable(), + |lc| lc + b.get_variable(), + |lc| lc + c.get_variable(), + ); + Ok(c) +} + +/// witness generation wrapper to support scan witness row by row +#[derive(Clone, Debug)] +pub struct WitnessGen<'a, T: FieldElement> { + data: Vec<(&'a str, &'a Vec)>, + cur_index: usize, + pub cur_witness: BTreeMap<&'a str, &'a T>, +} + +impl<'a, T: FieldElement> WitnessGen<'a, T> { + fn gen_current_witness( + index: usize, + data: &[(&'a str, &'a Vec)], + prev: &mut BTreeMap<&'a str, &'a T>, + ) { + data.iter().for_each(|(k, v)| { + if let Some(v) = v.get(index) { + prev.insert(*k, v); + } else { + panic!("out of bound: index {:?} but v is {:?}", index, v); + } + }); + } + + pub fn new(data: Vec<(&'a str, &'a Vec)>) -> Self { + let mut cur_witness = BTreeMap::new(); + Self::gen_current_witness(0, &data, &mut cur_witness); + Self { + data, + cur_index: 0, + cur_witness, + } + } + + pub fn get_wit_by_key(&self, k: &str) -> Option<&T> { + if let Some(v) = self.cur_witness.get(k) { + Some(v) + } else { + panic!("key {:?} not found in {:?}", k, self.cur_witness); + } + } + + pub fn next(&mut self) { + self.cur_index += 1; + Self::gen_current_witness(self.cur_index, &self.data, &mut self.cur_witness); + } + + pub fn num_of_iteration(&self) -> usize { + // just collect first witness column as len, since all column will be same length + self.data + .get(0) + .map(|first| first.1.len()) + .unwrap_or_default() + } +} + +pub fn get_num_at_index>( + mut cs: CS, + target_index: &AllocatedNum, + arr: &[AllocatedNum], +) -> Result, SynthesisError> { + let indexes_alloc = alloc_incremental_range_index( + cs.namespace(|| "augment circuit range index"), + 0, + arr.len(), + )?; + + // select target when index match or empty + let zero = alloc_zero(cs.namespace(|| "zero"))?; + let selected_num = indexes_alloc + .iter() + .zip(arr.iter()) + .enumerate() + .map(|(i, (index_alloc, z))| { + let equal_bit = Boolean::from(alloc_num_equals( + cs.namespace(|| format!("check selected_circuit_index {} equal bit", i)), + target_index, + index_alloc, + )?); + conditionally_select( + cs.namespace(|| format!("select on index namespace {}", i)), + z, + &zero, + &equal_bit, + ) + }) + .collect::>, SynthesisError>>()?; + + let selected_num = selected_num + .iter() + .enumerate() + .try_fold(zero, |agg, (i, _num)| { + add_allocated_num(cs.namespace(|| format!("selected_num {}", i)), _num, &agg) + })?; + Ok(selected_num) +} + +/// get negative field value from signed limb +pub fn signed_limb_to_neg>( + mut cs: CS, + limb: &Num, + max_limb_plus_one_const: &AllocatedNum, + nbit: usize, +) -> Result, SynthesisError> { + let limb_alloc = limb.as_allocated_num(cs.namespace(|| "rom decompose index"))?; + let bits = limb.decompose(cs.namespace(|| "index decompose bits"), nbit)?; + let signed_bit = &bits.allocations[nbit - 1]; + let twos_complement = AllocatedNum::alloc(cs.namespace(|| "alloc twos complement"), || { + max_limb_plus_one_const + .get_value() + .zip(limb_alloc.get_value()) + .map(|(a, b)| a - b) + .ok_or(SynthesisError::AssignmentMissing) + })?; + cs.enforce( + || "constraints 2's complement", + |lc| lc + twos_complement.get_variable() + limb_alloc.get_variable(), + |lc| lc + CS::one(), + |lc| lc + max_limb_plus_one_const.get_variable(), + ); + let twos_complement_neg = AllocatedNum::alloc(cs.namespace(|| " 2's complment neg"), || { + twos_complement + .get_value() + .map(|v| v.neg()) + .ok_or(SynthesisError::AssignmentMissing) + })?; + cs.enforce( + || "constraints 2's complement additive neg", + |lc| lc + twos_complement.get_variable() + twos_complement_neg.get_variable(), + |lc| lc + CS::one(), + |lc| lc, + ); + + let c = AllocatedNum::alloc(cs.namespace(|| "conditional select"), || { + signed_bit + .value + .map(|signed_bit_value| { + if signed_bit_value { + twos_complement_neg.get_value().unwrap() + } else { + limb_alloc.get_value().unwrap() + } + }) + .ok_or(SynthesisError::AssignmentMissing) + })?; + + // twos_complement_neg * condition + limb_alloc*(1-condition) = c -> + // twos_complement_neg * condition - limb_alloc*condition = c - limb_alloc + cs.enforce( + || "index conditional select", + |lc| lc + twos_complement_neg.get_variable() - limb_alloc.get_variable(), + |_| signed_bit.bit.clone(), + |lc| lc + c.get_variable() - limb_alloc.get_variable(), + ); + Ok(c) +} + +// TODO optmize constraints to leverage R1CS cost-free additive +// TODO combine FieldElement & PrimeField +pub fn evaluate_expr>( + mut cs: CS, + poly_map: &mut BTreeMap>, + expr: &Expression, + witgen: Arc>>, +) -> Result, SynthesisError> { + match expr { + Expression::Number(n) => { + let mut n_le = n.to_bytes_le(); + n_le.resize(64, 0); + alloc_const( + cs.namespace(|| format!("{:x?}", n.to_string())), + F::from_uniform(&n_le[..]), + ) + } + // this is refer to another polynomial, in other word, witness + Expression::PolynomialReference(PolynomialReference { + index, name, next, .. + }) => { + let name = name.strip_prefix("main.").unwrap(); // TODO FIXME: trim namespace should be happened in unified place + assert_eq!(*index, None); + assert!(!*next); + + Ok(poly_map + .entry(name.to_string()) + .or_insert_with(|| { + AllocatedNum::alloc(cs.namespace(|| format!("{}.{}", expr, name)), || { + let wit_value = witgen.lock().unwrap().get_wit_by_key(name).cloned(); + let mut n_le = wit_value.unwrap().to_bytes_le(); + n_le.resize(64, 0); + let f = F::from_uniform(&n_le[..]); + Ok(f) + }) + .unwrap() + }) + .clone()) + } + Expression::BinaryOperation(lhe, op, rhe) => { + let lhe = evaluate_expr(cs.namespace(|| "lhe"), poly_map, lhe, witgen.clone())?; + let rhe = evaluate_expr(cs.namespace(|| "rhe"), poly_map, rhe, witgen)?; + match op { + BinaryOperator::Add => add_allocated_num(cs, &lhe, &rhe), + BinaryOperator::Sub => { + let rhe_neg: AllocatedNum = + AllocatedNum::alloc(cs.namespace(|| "inv"), || { + rhe.get_value() + .map(|v| v.neg()) + .ok_or_else(|| SynthesisError::AssignmentMissing) + })?; + + // (a + a_neg) * 1 = 0 + cs.enforce( + || "(a + a_neg) * 1 = 0", + |lc| lc + rhe.get_variable() + rhe_neg.get_variable(), + |lc| lc + CS::one(), + |lc| lc, + ); + + add_allocated_num(cs, &lhe, &rhe_neg) + } + BinaryOperator::Mul => mul_allocated_num(cs, &lhe, &rhe), + _ => unimplemented!("{}", expr), + } + } + Expression::Constant(constant_name) => { + poly_map + .get(constant_name) + .cloned() + .ok_or_else(|| SynthesisError::AssignmentMissing) // constant must exist + } + _ => unimplemented!("{}", expr), + } +} + +// find rhs expression where the left hand side is the instr_name, e.g. instr_name * () +// this is a workaround and inefficient way, since we are working on PIL file. Need to optimise it. +pub fn find_pc_expression>( + expr: &Expression, + instr_name: &String, +) -> Option>> { + match expr { + Expression::Number(_) => None, + // this is refer to another polynomial, in other word, witness + Expression::PolynomialReference(PolynomialReference { .. }) => None, + Expression::BinaryOperation(lhe, operator, rhe) => { + let find_match_expr = match lhe { + // early pattern match on lhs to retrive the instr * () + box Expression::PolynomialReference(PolynomialReference { name, .. }) => { + if name == instr_name && *operator == BinaryOperator::Mul { + Some(rhe) + } else { + None + } + } + _ => None, + }; + find_match_expr + .cloned() + .or_else(|| find_pc_expression::(rhe, instr_name)) + .or_else(|| find_pc_expression::(lhe, instr_name)) + } + Expression::Constant(_) => None, + _ => unimplemented!("{}", expr), + } +} diff --git a/nova/zero.asm b/nova/zero.asm new file mode 100644 index 0000000000..91f5a16097 --- /dev/null +++ b/nova/zero.asm @@ -0,0 +1,96 @@ +machine NovaZero { + + degree 32; + + // this simple machine does not have submachines + + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg Z[<=]; + reg x0; + reg x1; + + constraints { + col witness XInv; + col witness XIsZero; + XIsZero = 1 - X * XInv; + XIsZero * X = 0; + XIsZero * (1 - XIsZero) = 0; + } + + constraints { + col witness x_b0; + col witness x_b1; + col witness x_b2; + col witness x_b3; + + // constraints bit + x_b0 * (1-x_b0) = 0; + x_b1 * (1-x_b1) = 0; + x_b2 * (1-x_b2) = 0; + x_b3 * (1-x_b3) = 0; + // ... + } + + instr incr X -> Y { + Y = X + 1, + X = x_b0 + x_b1 * 2 + x_b2 * 2**2 + x_b3 * 2**3, + pc' = pc + 1 + } + + instr add X, Y -> Z { + Z = X + Y, + pc' = pc + 1 + } + + instr sub X, Y -> Z { + Z = X - Y, + pc' = pc + 1 + } + + instr decr X -> Y { + Y = X - 1 + } + + instr addi X, Y:signed -> Z { + Z = X + Y, + pc' = pc + 1 + } + + // an instruction to loop forever, as we must fill the whole execution trace + instr loop { + pc' = pc + } + + // an instruction only proceed pc + 1 if X = 0 + instr bnz X, Y: label { + pc' = (1 - XIsZero) * (Y) + XIsZero * (pc + 1) + } + + instr assert_zero X { + X = 0 + } + + constraints { + } + + // the main function assigns the first prover input to A, increments it, decrements it, and loops forever + function main { + x0 <=X= ${ ("input", 0) }; + x1 <=Z= addi(x0, 1); // x1 = 1 + x0 <=Y= decr(x1); // x0 = 0 + x1 <=Y= incr(x0); // x1 = 1 + x0 <=Z= add(x1, x1); // x0 = 1 + 1 + x0 <=Z= addi(x0, 1); // x0 = 2 + 1 + x0 <=Z= addi(x0, -2); // x0 = 3 - 2 + x0 <=Z= sub(x0, x0); // x0 - x0 = 0 + assert_zero x0; // x0 == 0 + x1 <=X= ${ ("input", 1) }; + LOOP:: + x1 <=Z= addi(x1, -1); + bnz x1, LOOP; + assert_zero x1; // x1 == 0 + loop; + } +} diff --git a/powdr_cli/Cargo.toml b/powdr_cli/Cargo.toml index 8348476615..de7ab1052d 100644 --- a/powdr_cli/Cargo.toml +++ b/powdr_cli/Cargo.toml @@ -4,8 +4,9 @@ version = "0.1.0" edition = "2021" [features] -default = ["halo2"] +default = ["halo2", "nova"] halo2 = ["dep:halo2", "backend/halo2", "compiler/halo2"] +nova = ["backend/nova", "dep:nova"] [dependencies] clap = { version = "^4.3", features = ["derive"] } @@ -16,6 +17,7 @@ parser = { path = "../parser" } riscv = { path = "../riscv" } number = { path = "../number" } halo2 = { path = "../halo2", optional = true } +nova = { path = "../nova", optional = true } backend = { path = "../backend" } pilopt = { path = "../pilopt" } strum = { version = "0.24.1", features = ["derive"] } diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index 37e767fab9..7b1419cef8 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -3,10 +3,12 @@ mod util; use clap::{Parser, Subcommand}; -use compiler::{compile_pil_or_asm, Backend}; +use compiler::util::{read_fixed, read_witness}; +use compiler::{analyze, analyze_pil, compile_pil_or_asm, Backend}; use env_logger::{Builder, Target}; use log::LevelFilter; use number::{Bn254Field, FieldElement, GoldilocksField}; +use parser::parse_asm; use riscv::{compile_riscv_asm, compile_rust}; use std::io::BufWriter; use std::{borrow::Cow, collections::HashSet, fs, io::Write, path::Path}; @@ -184,6 +186,35 @@ enum Commands { params: Option, }, + ProveNova { + /// Input PIL file + file: String, + + /// TODO retrive asm info from PIL + asm_file: String, + + /// The field to use + #[arg(long)] + #[arg(default_value_t = FieldArgument::Bn254)] + #[arg(value_parser = clap_enum_variants!(FieldArgument))] + field: FieldArgument, + + /// Directory to find the committed and fixed values + #[arg(short, long)] + #[arg(default_value_t = String::from("."))] + dir: String, + + /// Comma-separated list of free inputs (numbers). Assumes queries to have the form + /// ("input", ). + #[arg(short, long)] + #[arg(default_value_t = String::new())] + inputs: String, + + /// File containing previously generated setup parameters. + #[arg(long)] + params: Option, + }, + Setup { /// Size of the parameters size: usize, @@ -355,6 +386,64 @@ fn run_command(command: Commands) { } => { setup(size, dir, field, backend); } + #[cfg(feature = "nova")] + Commands::ProveNova { + file, + asm_file, + field: _, + inputs, + dir, + params: _, + } => { + // Remove BN254 Hardcode + let pil = Path::new(&file); + let contents = fs::read_to_string(asm_file.clone()).unwrap(); + + let parsed = parse_asm::(Some(&asm_file[..]), &contents[..]) + .unwrap_or_else(|err| { + eprintln!("Error parsing .asm file:"); + err.output_to_stderr(); + panic!(); + }); + let analysed_asm = analyze::(parsed).unwrap(); + let pi_inputs = split_inputs(&inputs); + + // retrieve instance of the Main state machine + let main = match analysed_asm.machines.len() { + // if there is a single machine, treat it as main + 1 => analysed_asm.machines.values().next().unwrap().clone(), + // otherwise, find the machine called "Main" + _ => analysed_asm + .machines + .get("Main") + .expect("couldn't find a Main state machine") + .clone(), + }; + + let dir = Path::new(&dir); + + let pil = analyze_pil::(pil); + let fixed = read_fixed::(&pil, dir); + let witness = read_witness::(&pil, dir); + + // let params = fs::File::open(dir.join(params.unwrap())).unwrap(); + NovaBackend::prove::(&pil, &main, fixed.0, witness.0, pi_inputs); + + // TODO: this probably should be abstracted alway in a common backends API, + // // maybe a function "get_file_extension()". + // let proof_filename = if let Backend::Halo2Aggr = backend { + // "proof_aggr.bin" + // } else { + // "proof.bin" + // }; + // if let Some(proof) = proof { + // let mut proof_file = fs::File::create(dir.join(proof_filename)).unwrap(); + // let mut proof_writer = BufWriter::new(&mut proof_file); + // proof_writer.write_all(&proof).unwrap(); + // proof_writer.flush().unwrap(); + // log::info!("Wrote {proof_filename}."); + // } + } #[cfg(not(feature = "halo2"))] _ => unreachable!(), diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 292fe499e3..5d56faf9ae 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "stable" +channel = "nightly" diff --git a/zero_nova/init.json b/zero_nova/init.json deleted file mode 100644 index 029f0a0a7b..0000000000 --- a/zero_nova/init.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "pc": "0", - "A": "0" -} diff --git a/zero_nova/nova.params b/zero_nova/nova.params deleted file mode 100644 index 167f997138..0000000000 Binary files a/zero_nova/nova.params and /dev/null differ diff --git a/zero_nova/steps.json b/zero_nova/steps.json deleted file mode 100644 index 08fac6d1e3..0000000000 --- a/zero_nova/steps.json +++ /dev/null @@ -1,209 +0,0 @@ -[ - { - "X": "0", - "reg_write_X_A": "1", - "instr_ASSERT_IS_ZERO_A": "0", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "1", - "first_step": "1", - "first_step_next": "0" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - ,{ - "X": "0", - "reg_write_X_A": "0", - "instr_ASSERT_IS_ZERO_A": "1", - "X_const": "0", - "X_read_free": "0", - "read_X_A": "0", - "read_X_pc": "0", - "X_free_value": "0", - "first_step": "0", - "first_step_next": "1" - } - - - -] diff --git a/zero_nova/zero.asm b/zero_nova/zero.asm deleted file mode 100644 index e7b0f54880..0000000000 --- a/zero_nova/zero.asm +++ /dev/null @@ -1,15 +0,0 @@ -machine Zero { - degree 2; - - reg pc[@pc]; - reg X[<=]; - - reg A; - - instr ASSERT_IS_ZERO_A { A = 0 } - - function main { - A <=X= ${ ("input", 0) }; // pc = 0 => input[0] = X && X = A - ASSERT_IS_ZERO_A; // pc = 1 => A = 0 - } -} diff --git a/zero_nova/zero.zok b/zero_nova/zero.zok deleted file mode 100644 index 14b977c506..0000000000 --- a/zero_nova/zero.zok +++ /dev/null @@ -1,37 +0,0 @@ -struct State { - field pc; - field A; -} - -struct Witness { - field X; - field reg_write_X_A; - field instr_ASSERT_IS_ZERO_A; - field X_const; - field X_read_free; - field read_X_A; - field read_X_pc; - field X_free_value; - field first_step; - field first_step_next; -} - -def onestep(State s, private Witness w) -> State { - assert((w.first_step * s.A) == 0); - assert((w.instr_ASSERT_IS_ZERO_A * (s.A - 0)) == 0); - assert(w.X == ((((w.read_X_A * s.A) + (w.read_X_pc * s.pc)) + w.X_const) + (w.X_read_free * w.X_free_value))); - - field A_next = (((w.first_step_next * 0) + (w.reg_write_X_A * w.X)) + ((1 - (w.first_step_next + w.reg_write_X_A)) * s.A)); - field pc_next = ((1 - w.first_step_next) * (s.pc + 1)); - - return State { - pc: pc_next, - A: A_next, - }; -} - -def main(State s, private Witness[10000] w) -> State { - for i in 10000 { - s = onestep(s, w[i]); - } -}