diff --git a/Cargo.toml b/Cargo.toml index f1b6565b..17890512 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,61 +1,52 @@ -[package] -name = "plonk_verifier" -version = "0.1.0" -edition = "2021" - -[dependencies] -itertools = "0.10.3" -lazy_static = "1.4.0" -num-bigint = "0.4" -num-traits = "0.2" - -rand = "0.8" -rand_chacha = "0.3.1" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -hex = "0.4.3" -ark-std = { version = "0.3", features = ["print-trace"] } - -halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.2.1", package = "halo2curves" } - -# system_halo2 -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_09_10", optional = true } - -# loader_evm -ethereum_types = { package = "ethereum-types", version = "0.13.1", default-features = false, features = ["std"], optional = true } -sha3 = { version = "0.10.1", optional = true } -foundry_evm = { git = "https://github.com/jonathanpwang/foundry", package = "foundry-evm", branch = "fix/pin-revm-to-rev", optional = true } - -# loader_halo2 -halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_base", default-features = false, optional = true } -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_ecc", default-features = false, optional = true } -poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", branch = "padding", optional = true } - -[dev-dependencies] -paste = "1.0.7" - -# loader_evm -crossterm = { version = "0.22.1" } -tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } - -# zkevm -zkevm_circuit_benchmarks = {git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", package = "circuit-benchmarks", features = ["benches"] } -zkevm_circuits = {git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", package = "zkevm-circuits" } - -[features] -default = ["loader_evm", "loader_halo2", "system_halo2"] -loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:foundry_evm"] -loader_halo2 = ["dep:halo2_proofs", "dep:halo2_base", "halo2_ecc", "dep:poseidon"] -system_halo2 = ["dep:halo2_proofs"] -sanity_check = [] - -[patch."https://github.com/privacy-scaling-explorations/halo2"] -halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization", package = "halo2_proofs" } - -[[example]] -name = "evm-verifier" -required-features = ["loader_evm", "system_halo2"] - -[[example]] -name = "evm-verifier-with-accumulator" -required-features = ["loader_halo2", "loader_evm", "system_halo2"] +[workspace] +members = [ + "snark-verifier", + "snark-verifier-sdk", +] + +[profile.dev] +opt-level = 3 + +# Local "release" mode, more optimized than dev but faster to compile than release +[profile.local] +inherits = "dev" +opt-level = 3 +# Set this to 1 or 2 to get more useful backtraces +debug = true +debug-assertions = false +panic = 'unwind' +# better recompile times +incremental = true +lto = "thin" +codegen-units = 16 + +[profile.release] +opt-level = 3 +debug = false +debug-assertions = false +lto = "fat" +# codegen-units = 1 +panic = "abort" +incremental = false + +# For performance profiling +[profile.flamegraph] +inherits = "release" +debug = true + +[patch."ssh://github.com/axiom-crypto/axiom-core-working.git"] +halo2-base = { path = "../axiom-core-working/halo2-lib/halo2-base" } +halo2-ecc = { path = "../axiom-core-working/halo2-lib/halo2-ecc" } + +[patch."https://github.com/axiom-crypto/halo2.git"] +halo2_proofs = { path = "../halo2/halo2_proofs" } +halo2curves = { path = "../halo2/arithmetic/curves" } +poseidon = { path = "../halo2/primitives/poseidon" } + +# patch for now because PSE/halo2 has not yet updated halo2curves version, unnecessary if halo2_proofs is using latest halo2curves with Fq12 public +[patch."https://github.com/privacy-scaling-explorations/halo2curves.git"] +halo2curves = { path = "../halo2/arithmetic/curves" } + +# patch just because we cannot patch the same repo with different tag: serialization is already in latest PSE/halo2 but not in v2022_10_22 +[patch."https://github.com/privacy-scaling-explorations/halo2.git"] +halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization" } \ No newline at end of file diff --git a/README.md b/README.md index f9b83a4c..bcd16c74 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,3 @@ # PLONK Verifier Generic PLONK verifier. - -## SRS - -Note that if aggregating snarks with different `K` params size, you should generate the largest srs necessarily and then `downgrade` to the smaller param sizes so that the first two points are the same for all srs files. diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs deleted file mode 100644 index 77d58f26..00000000 --- a/examples/evm-verifier-with-accumulator.rs +++ /dev/null @@ -1,367 +0,0 @@ -use application::StandardPlonk; -use ark_std::{end_timer, start_timer}; -use ethereum_types::Address; -use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; -use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; -use halo2_proofs::{ - dev::MockProver, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::{ - commitment::{Params, ParamsProver}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverGWC, VerifierGWC}, - strategy::AccumulatorStrategy, - }, - VerificationStrategy, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, EncodedChallenge, TranscriptReadBuffer, - TranscriptWriterBuffer, - }, -}; -use itertools::Itertools; -use plonk_verifier::{ - loader::{ - evm::{encode_calldata, EvmLoader}, - native::NativeLoader, - }, - pcs::kzg::{Gwc19, Kzg, LimbsEncoding}, - system::halo2::{ - aggregation::{ - self, create_snark_shplonk, gen_pk, gen_srs, write_bytes, AggregationCircuit, Snark, - TargetCircuit, - }, - compile, - transcript::evm::EvmTranscript, - Config, - }, - verifier::{self, PlonkVerifier}, -}; -use rand::rngs::OsRng; -use std::{fs, io::Cursor, rc::Rc}; - -const LIMBS: usize = 3; -const BITS: usize = 88; - -type Pcs = Kzg; -// type As = KzgAs; -type Plonk = verifier::Plonk>; - -mod application { - use halo2_curves::bn256::Fr; - use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, - poly::Rotation, - }; - use rand::RngCore; - - #[derive(Clone, Copy)] - pub struct StandardPlonkConfig { - a: Column, - b: Column, - c: Column, - q_a: Column, - q_b: Column, - q_c: Column, - q_ab: Column, - constant: Column, - #[allow(dead_code)] - instance: Column, - } - - impl StandardPlonkConfig { - fn configure(meta: &mut ConstraintSystem) -> Self { - let [a, b, c] = [(); 3].map(|_| meta.advice_column()); - let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); - let instance = meta.instance_column(); - - [a, b, c].map(|column| meta.enable_equality(column)); - - meta.create_gate( - "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", - |meta| { - let [a, b, c] = - [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); - let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] - .map(|column| meta.query_fixed(column, Rotation::cur())); - let instance = meta.query_instance(instance, Rotation::cur()); - Some( - q_a * a.clone() - + q_b * b.clone() - + q_c * c - + q_ab * a * b - + constant - + instance, - ) - }, - ); - - StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } - } - } - - #[derive(Clone, Default)] - pub struct StandardPlonk(pub Fr); - - impl StandardPlonk { - pub fn rand(mut rng: R) -> Self { - Self(Fr::from(rng.next_u32() as u64)) - } - - pub fn num_instance() -> Vec { - vec![1] - } - - pub fn instances(&self) -> Vec> { - vec![vec![self.0]] - } - } - - impl Circuit for StandardPlonk { - type Config = StandardPlonkConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - meta.set_minimum_degree(4); - StandardPlonkConfig::configure(meta) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; - region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; - - region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5)))?; - for (idx, column) in (1..).zip([ - config.q_a, - config.q_b, - config.q_c, - config.q_ab, - config.constant, - ]) { - region.assign_fixed(|| "", column, 1, || Value::known(Fr::from(idx)))?; - } - - let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; - a.copy_advice(|| "", &mut region, config.b, 3)?; - a.copy_advice(|| "", &mut region, config.c, 4)?; - - Ok(()) - }, - ) - } - } -} - -fn gen_proof< - C: Circuit + Clone, - E: EncodedChallenge, - TR: TranscriptReadBuffer>, G1Affine, E>, - TW: TranscriptWriterBuffer, G1Affine, E>, ->( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: C, - instances: Vec>, -) -> Vec { - MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); - // For testing purposes: Native verify - // Uncomment to test if evm verifier fails silently - /*{ - let proof = { - let mut transcript = Blake2bWrite::init(Vec::new()); - create_proof::< - KZGCommitmentScheme, - ProverGWC<_>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >( - params, - pk, - &[circuit.clone()], - &[&[instances[0].as_slice()]], - OsRng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); - let protocol = compile( - params, - pk.get_vk(), - Config::kzg(aggregation::KZG_QUERY_INSTANCE) - .with_num_instance(vec![instances[0].len()]) - .with_accumulator_indices(aggregation::AggregationCircuit::accumulator_indices()), - ); - let mut transcript = Blake2bRead::<_, G1Affine, _>::init(proof.as_slice()); - let instances = &[instances[0].to_vec()]; - let proof = Plonk::read_proof(&svk, &protocol, instances, &mut transcript).unwrap(); - assert!(Plonk::verify(&svk, &dk, &protocol, instances, &proof).unwrap()); - }*/ - - let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); - let proof = { - let mut transcript = TW::init(Vec::new()); - create_proof::, ProverGWC<_>, _, _, TW, _>( - params, - pk, - &[circuit], - &[instances.as_slice()], - OsRng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - - let accept = { - let mut transcript = TR::init(Cursor::new(proof.clone())); - VerificationStrategy::<_, VerifierGWC<_>>::finalize( - verify_proof::<_, VerifierGWC<_>, _, TR, _>( - params.verifier_params(), - pk.get_vk(), - AccumulatorStrategy::new(params.verifier_params()), - &[instances.as_slice()], - &mut transcript, - ) - .unwrap(), - ) - }; - assert!(accept); - - proof -} - -fn gen_aggregation_evm_verifier( - params: &ParamsKZG, - vk: &VerifyingKey, - num_instance: Vec, - accumulator_indices: Vec<(usize, usize)>, -) -> Vec { - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); - let protocol = compile( - params, - vk, - Config::kzg(aggregation::KZG_QUERY_INSTANCE) - .with_num_instance(num_instance.clone()) - .with_accumulator_indices(accumulator_indices), - ); - - let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); - - let instances = transcript.load_instances(num_instance); - let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); - Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - - loader.deployment_code() -} - -fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { - let calldata = encode_calldata(&instances, &proof); - fs::write("./data/verifier_calldata.dat", hex::encode(&calldata)).unwrap(); - let success = { - let mut evm = ExecutorBuilder::default() - .with_gas_limit(u64::MAX.into()) - .build(Backend::new(MultiFork::new().0, None)); - - let caller = Address::from_low_u64_be(0xfe); - let verifier = evm.deploy(caller, deployment_code.into(), 0.into(), None).unwrap(); - dbg!(verifier.gas); - let verifier = verifier.address; - let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()).unwrap(); - - dbg!(result.gas); - - !result.reverted - }; - assert!(success); -} - -pub fn load_verify_circuit_degree() -> u32 { - let path = "./configs/verify_circuit.config"; - let params_str = - std::fs::read_to_string(path).expect(format!("{} file should exist", path).as_str()); - let params: plonk_verifier::system::halo2::Halo2VerifierCircuitConfigParams = - serde_json::from_str(params_str.as_str()).unwrap(); - params.degree -} - -impl TargetCircuit for StandardPlonk { - const N_PROOFS: usize = 1; - const NAME: &'static str = "standard_plonk"; - - type Circuit = Self; -} - -fn main() { - let k = load_verify_circuit_degree(); - let params = gen_srs(k); - - let params_app = { - let mut params = params.clone(); - params.downsize(8); - params - }; - let app_circuit = StandardPlonk::rand(OsRng); - let snark = create_snark_shplonk::( - ¶ms_app, - vec![app_circuit.clone()], - vec![vec![vec![app_circuit.0]]], - None, - ); - let snarks = vec![snark]; - - let agg_circuit = AggregationCircuit::new(¶ms, snarks, true); - let pk = gen_pk(¶ms, &agg_circuit, "standard_plonk_agg_circuit"); - - let deploy_time = start_timer!(|| "generate aggregation evm verifier code"); - let deployment_code = gen_aggregation_evm_verifier( - ¶ms, - pk.get_vk(), - agg_circuit.num_instance(), - AggregationCircuit::accumulator_indices(), - ); - end_timer!(deploy_time); - fs::write("./data/verifier_bytecode.dat", hex::encode(&deployment_code)).unwrap(); - - // use different input snarks to test instances etc - let app_circuit = StandardPlonk::rand(OsRng); - let snark = create_snark_shplonk::( - ¶ms_app, - vec![app_circuit.clone()], - vec![vec![vec![app_circuit.0]]], - None, - ); - let snarks = vec![snark]; - let agg_circuit = AggregationCircuit::new(¶ms, snarks, true); - let proof_time = start_timer!(|| "create agg_circuit proof"); - let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( - ¶ms, - &pk, - agg_circuit.clone(), - agg_circuit.instances(), - ); - end_timer!(proof_time); - - let verify_time = start_timer!(|| "on-chain verification"); - evm_verify(deployment_code, agg_circuit.instances(), proof); - end_timer!(verify_time); -} diff --git a/rust-toolchain b/rust-toolchain index 24b6d11f..51ab4759 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2022-08-23 +nightly-2022-10-28 \ No newline at end of file diff --git a/snark-verifier-sdk/Cargo.toml b/snark-verifier-sdk/Cargo.toml new file mode 100644 index 00000000..fd5c2344 --- /dev/null +++ b/snark-verifier-sdk/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "snark-verifier-sdk" +version = "0.0.1" +edition = "2021" + +[dependencies] +itertools = "0.10.3" +lazy_static = "1.4.0" +num-bigint = "0.4.3" +num-integer = "0.1.45" +num-traits = "0.2.15" +rand = "0.8" +rand_chacha = "0.3.1" +hex = "0.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +bincode = "1.3.3" +ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } + +halo2-base = { git = "ssh://github.com/axiom-crypto/axiom-core-working.git", branch = "experiment/optimizations", default-features = false } +snark-verifier = { path = "../snark-verifier", default-features = false } + +# loader_evm +ethereum-types = { version = "0.14", default-features = false, features = ["std"], optional = true } +# sha3 = { version = "0.10", optional = true } +# revm = { version = "2.3.1", optional = true } +# bytes = { version = "1.2", optional = true } +# rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } + +# zkevm benchmarks +zkevm-circuits = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", features = ["test"], optional = true } +bus-mapping = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } +eth-types = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } +mock = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } + +[dev-dependencies] +ark-std = { version = "0.3.0", features = ["print-trace"] } +paste = "1.0.7" +pprof = { version = "0.11", features = ["criterion", "flamegraph"] } +criterion = "0.4" +criterion-macro = "0.4" +# loader_evm +crossterm = { version = "0.25" } +tui = { version = "0.19", default-features = false, features = ["crossterm"] } + +[features] +default = ["loader_evm", "loader_halo2", "zkevm", "halo2-pse", "halo2-base/jemallocator"] +display = ["snark-verifier/display"] +loader_evm = ["snark-verifier/loader_evm", "dep:ethereum-types"] +loader_halo2 = ["snark-verifier/loader_halo2"] +parallel = ["snark-verifier/parallel"] +# EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo +halo2-pse = ["snark-verifier/halo2-pse"] +halo2-axiom = ["snark-verifier/halo2-axiom"] + +zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] + +[[bench]] +name = "standard_plonk" +required-features = ["loader_halo2"] +harness = false + +[[bench]] +name = "zkevm" +required-features = ["loader_halo2", "zkevm", "halo2-pse", "halo2-base/jemallocator"] +harness = false \ No newline at end of file diff --git a/snark-verifier-sdk/benches/standard_plonk.rs b/snark-verifier-sdk/benches/standard_plonk.rs new file mode 100644 index 00000000..c72072eb --- /dev/null +++ b/snark-verifier-sdk/benches/standard_plonk.rs @@ -0,0 +1,242 @@ +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; +use pprof::criterion::{Output, PProfProfiler}; + +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs; +use halo2_proofs::halo2curves as halo2_curves; +use halo2_proofs::{ + halo2curves::bn256::Bn256, + poly::{commitment::Params, kzg::commitment::ParamsKZG}, +}; +use rand::rngs::OsRng; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; +use snark_verifier::loader::native::NativeLoader; +use snark_verifier_sdk::{ + gen_pk, + halo2::{ + aggregation::AggregationCircuit, gen_proof_shplonk, gen_snark_shplonk, PoseidonTranscript, + POSEIDON_SPEC, + }, + Snark, +}; + +mod application { + use super::halo2_curves::bn256::Fr; + use super::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + use rand::RngCore; + use snark_verifier_sdk::CircuitExt; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } + } + } + + #[derive(Clone, Default)] + pub struct StandardPlonk(Fr); + + impl StandardPlonk { + pub fn rand(mut rng: R) -> Self { + Self(Fr::from(rng.next_u32() as u64)) + } + } + + impl CircuitExt for StandardPlonk { + fn num_instance() -> Vec { + vec![1] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + #[cfg(feature = "halo2-pse")] + { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice( + || "", + config.a, + 1, + || Value::known(-Fr::from(5u64)), + )?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed( + || "", + column, + 1, + || Value::known(Fr::from(idx as u64)), + )?; + } + + let a = + region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + } + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice( + config.a, + 0, + Value::known(Assigned::Trivial(self.0)), + )?; + region.assign_fixed(config.q_a, 0, -Fr::one()); + + region.assign_advice(config.a, 1, Value::known(-Fr::from(5u64)))?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Fr::from(idx as u64)); + } + + let a = region.assign_advice(config.a, 2, Value::known(Fr::one()))?; + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); + } + + Ok(()) + }, + ) + } + } +} + +fn gen_application_snark( + params: &ParamsKZG, + transcript: &mut PoseidonTranscript>, +) -> Snark { + let circuit = application::StandardPlonk::rand(OsRng); + + let pk = gen_pk(params, &circuit, None); + gen_snark_shplonk(params, &pk, circuit, transcript, &mut OsRng, None) +} + +fn bench(c: &mut Criterion) { + std::env::set_var("VERIFY_CONFIG", "./configs/example_evm_accumulator.config"); + let k = 21; + let params = halo2_base::utils::fs::gen_srs(k); + let params_app = { + let mut params = params.clone(); + params.downsize(8); + params + }; + + let mut transcript = + PoseidonTranscript::::from_spec(vec![], POSEIDON_SPEC.clone()); + let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app, &mut transcript)); + + let start1 = start_timer!(|| "Create aggregation circuit"); + let mut rng = ChaCha20Rng::from_entropy(); + let agg_circuit = AggregationCircuit::new(¶ms, snarks, &mut transcript, &mut rng); + end_timer!(start1); + + let pk = gen_pk(¶ms, &agg_circuit, None); + + let mut group = c.benchmark_group("plonk-prover"); + group.sample_size(10); + group.bench_with_input( + BenchmarkId::new("standard-plonk-agg", k), + &(¶ms, &pk, &agg_circuit), + |b, &(params, pk, agg_circuit)| { + b.iter(|| { + let instances = agg_circuit.instances(); + gen_proof_shplonk( + params, + pk, + agg_circuit.clone(), + instances, + &mut transcript, + &mut rng, + None, + ) + }) + }, + ); + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); + targets = bench +} +criterion_main!(benches); diff --git a/snark-verifier-sdk/benches/zkevm.rs b/snark-verifier-sdk/benches/zkevm.rs new file mode 100644 index 00000000..a35573b1 --- /dev/null +++ b/snark-verifier-sdk/benches/zkevm.rs @@ -0,0 +1,136 @@ +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs; +use halo2_base::utils::fs::gen_srs; +use halo2_proofs::halo2curves::bn256::Fr; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; +use snark_verifier::loader::native::NativeLoader; +use snark_verifier_sdk::{ + self, + evm::{ + evm_verify, gen_evm_proof_gwc, gen_evm_proof_shplonk, gen_evm_verifier_gwc, + gen_evm_verifier_shplonk, + }, + gen_pk, + halo2::{ + aggregation::load_verify_circuit_degree, aggregation::AggregationCircuit, gen_proof_gwc, + gen_proof_shplonk, gen_snark_shplonk, PoseidonTranscript, POSEIDON_SPEC, + }, +}; +use std::env::{set_var, var}; +use std::path::Path; + +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; +use pprof::criterion::{Output, PProfProfiler}; + +pub mod zkevm { + use super::Fr; + use bus_mapping::{circuit_input_builder::CircuitsParams, mock::BlockData}; + use eth_types::geth_types::GethData; + use mock::TestContext; + use zkevm_circuits::evm_circuit::{witness::block_convert, EvmCircuit}; + + pub fn test_circuit() -> EvmCircuit { + let empty_data: GethData = + TestContext::<0, 0>::new(None, |_| {}, |_, _| {}, |b, _| b).unwrap().into(); + + let mut builder = BlockData::new_from_geth_data_with_params( + empty_data.clone(), + CircuitsParams::default(), + ) + .new_circuit_input_builder(); + + builder.handle_block(&empty_data.eth_block, &empty_data.geth_traces).unwrap(); + + let block = block_convert(&builder.block, &builder.code_db).unwrap(); + + EvmCircuit::::new(block) + } +} + +fn bench(c: &mut Criterion) { + let mut rng = ChaCha20Rng::from_entropy(); + let mut transcript = + PoseidonTranscript::::from_spec(vec![], POSEIDON_SPEC.clone()); + + // === create zkevm evm circuit snark === + let k: u32 = var("DEGREE") + .unwrap_or_else(|_| { + set_var("DEGREE", "18"); + "18".to_owned() + }) + .parse() + .unwrap(); + let circuit = zkevm::test_circuit(); + let params_app = gen_srs(k); + let pk = gen_pk(¶ms_app, &circuit, Some(Path::new("data/zkevm_evm.pkey"))); + let snark = gen_snark_shplonk( + ¶ms_app, + &pk, + circuit, + &mut transcript, + &mut rng, + Some((Path::new("data/zkevm_evm.in"), Path::new("data/zkevm_evm.pf"))), + ); + let snarks = [snark]; + // === finished zkevm evm circuit === + + // === now to do aggregation === + set_var("VERIFY_CONFIG", "./configs/bench_zkevm.config"); + let k = load_verify_circuit_degree(); + let params = gen_srs(k); + + let start1 = start_timer!(|| "Create aggregation circuit"); + let agg_circuit = AggregationCircuit::new(¶ms, snarks, &mut transcript, &mut rng); + end_timer!(start1); + + let pk = gen_pk(¶ms, &agg_circuit, None); + + let mut group = c.benchmark_group("plonk-prover"); + group.sample_size(10); + group.bench_with_input( + BenchmarkId::new("zkevm-evm-agg", k), + &(¶ms, &pk, &agg_circuit), + |b, &(params, pk, agg_circuit)| { + b.iter(|| { + let instances = agg_circuit.instances(); + gen_proof_shplonk( + params, + pk, + agg_circuit.clone(), + instances, + &mut transcript, + &mut rng, + None, + ); + }) + }, + ); + group.finish(); + + #[cfg(feature = "loader_evm")] + { + let deployment_code = + gen_evm_verifier_shplonk::(¶ms, pk.get_vk(), None); + + let start2 = start_timer!(|| "Create EVM proof"); + let proof = gen_evm_proof_shplonk( + ¶ms, + &pk, + agg_circuit.clone(), + agg_circuit.instances(), + &mut rng, + ); + end_timer!(start2); + + evm_verify(deployment_code, agg_circuit.instances(), proof); + } +} + +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); + targets = bench +} +criterion_main!(benches); diff --git a/configs/verify_circuit.config b/snark-verifier-sdk/configs/bench_zkevm.config similarity index 50% rename from configs/verify_circuit.config rename to snark-verifier-sdk/configs/bench_zkevm.config index 6f70df6a..1cda14ab 100644 --- a/configs/verify_circuit.config +++ b/snark-verifier-sdk/configs/bench_zkevm.config @@ -1 +1 @@ -{"strategy":"Simple","degree":23,"num_advice":7,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3} \ No newline at end of file +{"strategy":"Simple","degree":23,"num_advice":5,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/snark-verifier-sdk/configs/example_evm_accumulator.config b/snark-verifier-sdk/configs/example_evm_accumulator.config new file mode 100644 index 00000000..fcda49a0 --- /dev/null +++ b/snark-verifier-sdk/configs/example_evm_accumulator.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":5,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/snark-verifier-sdk/configs/verify_circuit.config b/snark-verifier-sdk/configs/verify_circuit.config new file mode 100644 index 00000000..e65b2b52 --- /dev/null +++ b/snark-verifier-sdk/configs/verify_circuit.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":4,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/snark-verifier-sdk/src/evm.rs b/snark-verifier-sdk/src/evm.rs new file mode 100644 index 00000000..f9d6215e --- /dev/null +++ b/snark-verifier-sdk/src/evm.rs @@ -0,0 +1,187 @@ +use super::{CircuitExt, Plonk}; +use ethereum_types::Address; +use halo2_base::halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::{ + commitment::{Params, ParamsProver, Prover, Verifier}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + msm::DualMSM, + multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, + strategy::{AccumulatorStrategy, GuardKZG}, + }, + VerificationStrategy, + }, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use itertools::Itertools; +use rand::Rng; +use snark_verifier::{ + loader::evm::{compile_yul, encode_calldata, EvmLoader, ExecutorBuilder}, + pcs::{ + kzg::{Bdfg21, Gwc19, Kzg, KzgAccumulator, KzgDecidingKey, KzgSuccinctVerifyingKey}, + Decider, MultiOpenScheme, PolynomialCommitmentScheme, + }, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::PlonkVerifier, +}; +use std::{fs, io, path::Path, rc::Rc}; + +/// Generates a proof for evm verification using either SHPLONK or GWC proving method. Uses Keccak for Fiat-Shamir. +pub fn gen_evm_proof<'params, C, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + rng: &mut impl Rng, +) -> Vec +where + C: Circuit, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + #[cfg(debug_assertions)] + { + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + } + + let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); + let proof = { + let mut transcript = TranscriptWriterBuffer::<_, G1Affine, _>::init(Vec::new()); + create_proof::, P, _, _, EvmTranscript<_, _, _, _>, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + rng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TranscriptReadBuffer::<_, G1Affine, _>::init(proof.as_slice()); + VerificationStrategy::<_, V>::finalize( + verify_proof::<_, V, _, EvmTranscript<_, _, _, _>, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} + +pub fn gen_evm_proof_gwc<'params, C: Circuit>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + rng: &mut impl Rng, +) -> Vec { + gen_evm_proof::, VerifierGWC<_>>(params, pk, circuit, instances, rng) +} + +pub fn gen_evm_proof_shplonk<'params, C: Circuit>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + rng: &mut impl Rng, +) -> Vec { + gen_evm_proof::, VerifierSHPLONK<_>>(params, pk, circuit, instances, rng) +} + +pub fn gen_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + path: Option<&Path>, +) -> Vec +where + C: CircuitExt, + PCS: PolynomialCommitmentScheme< + G1Affine, + Rc, + Accumulator = KzgAccumulator>, + > + MultiOpenScheme< + G1Affine, + Rc, + SuccinctVerifyingKey = KzgSuccinctVerifyingKey, + > + Decider, DecidingKey = KzgDecidingKey>, +{ + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(C::num_instance()) + .with_accumulator_indices(C::accumulator_indices()), + ); + + let loader = EvmLoader::new::(); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); + + let instances = transcript.load_instances(C::num_instance()); + let proof = Plonk::::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + let yul_code = loader.yul_code(); + let byte_code = compile_yul(&yul_code); + if let Some(path) = path { + path.parent().and_then(|dir| fs::create_dir_all(dir).ok()).unwrap(); + fs::write(path, yul_code).unwrap(); + } + byte_code +} + +pub fn gen_evm_verifier_gwc>( + params: &ParamsKZG, + vk: &VerifyingKey, + path: Option<&Path>, +) -> Vec { + gen_evm_verifier::>(params, vk, path) +} + +pub fn gen_evm_verifier_shplonk>( + params: &ParamsKZG, + vk: &VerifyingKey, + path: Option<&Path>, +) -> Vec { + gen_evm_verifier::>(params, vk, path) +} + +pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default().with_gas_limit(u64::MAX.into()).build(); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm.deploy(caller, deployment_code.into(), 0.into()).address.unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +pub fn write_calldata(instances: &[Vec], proof: &[u8], path: &Path) -> io::Result<()> { + let calldata = encode_calldata(instances, proof); + fs::write(path, hex::encode(calldata)) +} diff --git a/snark-verifier-sdk/src/halo2.rs b/snark-verifier-sdk/src/halo2.rs new file mode 100644 index 00000000..b453b67e --- /dev/null +++ b/snark-verifier-sdk/src/halo2.rs @@ -0,0 +1,330 @@ +use super::{read_instances, write_instances, CircuitExt, Snark, SnarkWitness}; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_base::{halo2_proofs, poseidon::Spec}; +use halo2_proofs::{ + circuit::Layouter, + dev::MockProver, + halo2curves::{ + bn256::{Bn256, Fr, G1Affine}, + group::ff::Field, + }, + plonk::{ + create_proof, keygen_vk, verify_proof, Circuit, ConstraintSystem, Error, ProvingKey, + VerifyingKey, + }, + poly::{ + commitment::{Params, ParamsProver, Prover, Verifier}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + msm::DualMSM, + multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, + strategy::{AccumulatorStrategy, GuardKZG}, + }, + VerificationStrategy, + }, +}; +use itertools::Itertools; +use lazy_static::lazy_static; +use rand::Rng; +use snark_verifier::{ + cost::CostEstimation, + loader::native::NativeLoader, + pcs::{self, MultiOpenScheme}, + system::halo2::{compile, Config}, + util::transcript::TranscriptWrite, + verifier::PlonkProof, +}; +use std::{fs, iter, marker::PhantomData, path::Path}; + +pub mod aggregation; + +// Poseidon parameters +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +pub type PoseidonTranscript = + snark_verifier::system::halo2::transcript::halo2::PoseidonTranscript< + G1Affine, + L, + S, + T, + RATE, + R_F, + R_P, + >; + +lazy_static! { + pub static ref POSEIDON_SPEC: Spec = Spec::new(R_F, R_P); +} + +/// Generates a native proof using either SHPLONK or GWC proving method. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_proof<'params, C, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Vec +where + C: Circuit, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + #[cfg(debug_assertions)] + { + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + } + + let mut proof: Option> = None; + + if let Some((instance_path, proof_path)) = path { + let cached_instances = read_instances(instance_path); + if matches!(cached_instances, Ok(tmp) if tmp == instances) && proof_path.exists() { + #[cfg(feature = "display")] + let read_time = start_timer!(|| format!("Reading proof from {proof_path:?}")); + + proof = Some(fs::read(proof_path).unwrap()); + + #[cfg(feature = "display")] + end_timer!(read_time); + } + } + + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + + let proof = proof.unwrap_or_else(|| { + #[cfg(feature = "display")] + let proof_time = start_timer!(|| "Create proof"); + + transcript.clear(); + create_proof::<_, P, _, _, _, _>(params, pk, &[circuit], &[&instances], rng, transcript) + .unwrap(); + let proof = transcript.stream_mut().split_off(0); + + #[cfg(feature = "display")] + end_timer!(proof_time); + + if let Some((instance_path, proof_path)) = path { + write_instances(&instances, instance_path); + fs::write(proof_path, &proof).unwrap(); + } + proof + }); + + debug_assert!({ + let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + VerificationStrategy::<_, V>::finalize( + verify_proof::<_, V, _, _, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }); + + proof +} + +/// Generates a native proof using original Plonk (GWC '19) multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_proof_gwc>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Vec { + gen_proof::, VerifierGWC<_>>( + params, pk, circuit, instances, transcript, rng, path, + ) +} + +/// Generates a native proof using SHPLONK multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_proof_shplonk>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Vec { + gen_proof::, VerifierSHPLONK<_>>( + params, pk, circuit, instances, transcript, rng, path, + ) +} + +/// Generates a SNARK using either SHPLONK or GWC multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_snark<'params, ConcreteCircuit, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Snark +where + ConcreteCircuit: CircuitExt, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + let protocol = compile( + params, + pk.get_vk(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + + let instances = circuit.instances(); + let proof = gen_proof::( + params, + pk, + circuit, + instances.clone(), + transcript, + rng, + path, + ); + + Snark::new(protocol, instances, proof) +} + +/// Generates a SNARK using GWC multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_snark_gwc>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Snark { + gen_snark::, VerifierGWC<_>>( + params, pk, circuit, transcript, rng, path, + ) +} + +/// Generates a SNARK using SHPLONK multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_snark_shplonk>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Snark { + gen_snark::, VerifierSHPLONK<_>>( + params, pk, circuit, transcript, rng, path, + ) +} + +pub fn gen_dummy_snark( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, +) -> Snark +where + ConcreteCircuit: CircuitExt, + MOS: MultiOpenScheme + + CostEstimation>>, +{ + struct CsProxy(PhantomData<(F, C)>); + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row + // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) + layouter.assign_region( + || "", + |mut region| { + for q in C::selectors(&config).iter() { + q.enable(&mut region, 0)?; + } + Ok(()) + }, + )?; + Ok(()) + } + } + + let dummy_vk = vk + .is_none() + .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let protocol = compile( + params, + vk.or(dummy_vk.as_ref()).unwrap(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = ConcreteCircuit::num_instance() + .into_iter() + .map(|n| iter::repeat(Fr::default()).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript.write_ec_point(G1Affine::default()).unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::default()).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..MOS::estimate_cost(&queries).num_commitment { + transcript.write_ec_point(G1Affine::default()).unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) +} diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs new file mode 100644 index 00000000..81cfd12e --- /dev/null +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -0,0 +1,375 @@ +use crate::{Plonk, BITS, LIMBS}; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::{self, Circuit, Column, ConstraintSystem, Instance, Selector}, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, +}; +use halo2_base::{Context, ContextParams}; +use itertools::Itertools; +use rand::Rng; +use snark_verifier::{ + loader::{ + self, + halo2::halo2_ecc::{self, ecc::EccChip}, + native::NativeLoader, + }, + pcs::{ + kzg::{Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, + AccumulationScheme, AccumulationSchemeProver, MultiOpenScheme, PolynomialCommitmentScheme, + }, + util::arithmetic::fe_to_limbs, + verifier::PlonkVerifier, +}; +use std::{fs::File, rc::Rc}; + +use super::{CircuitExt, PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; + +type Svk = KzgSuccinctVerifyingKey; +type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; +type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +type Shplonk = Plonk>; + +pub fn load_verify_circuit_degree() -> u32 { + let path = std::env::var("VERIFY_CONFIG") + .unwrap_or_else(|_| "./configs/verify_circuit.config".to_string()); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).unwrap_or_else(|_| panic!("{path} does not exist")), + ) + .unwrap(); + params.degree +} + +/// Core function used in `synthesize` to aggregate multiple `snarks`. +/// +/// Returns the assigned instances of previous snarks (all concatenated together) and the new final pair that needs to be verified in a pairing check +pub fn aggregate<'a, PCS>( + svk: &PCS::SuccinctVerifyingKey, + loader: &Rc>, + snarks: &[SnarkWitness], + as_proof: Value<&'_ [u8]>, +) -> ( + Vec>, + KzgAccumulator>>, +) +where + PCS: PolynomialCommitmentScheme< + G1Affine, + Rc>, + Accumulator = KzgAccumulator>>, + > + MultiOpenScheme>>, +{ + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() + }) + .collect_vec() + }; + + // TODO pre-allocate capacity better + let mut previous_instances = vec![]; + let mut transcript = PoseidonTranscript::>, _>::from_spec( + loader, + Value::unknown(), + POSEIDON_SPEC.clone(), + ); + + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); + // TODO use 1d vector + let instances = assign_instances(&snark.instances); + + // read the transcript and perform Fiat-Shamir + // run through verification computation and produce the final pair `succinct` + transcript.new_stream(snark.proof()); + let proof = + Plonk::::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulator = + Plonk::::succinct_verify(svk, &protocol, &instances, &proof).unwrap(); + + previous_instances.extend(instances.into_iter().flatten()); + + accumulator + }) + .collect_vec(); + + let accumulator = if accumulators.len() > 1 { + transcript.new_stream(as_proof); + let proof = + KzgAs::::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + KzgAs::::verify(&Default::default(), &accumulators, &proof).unwrap() + } else { + accumulators.pop().unwrap() + }; + + (previous_instances, accumulator) +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct AggregationConfigParams { + pub strategy: halo2_ecc::fields::fp::FpStrategy, + pub degree: u32, + pub num_advice: usize, + pub num_lookup_advice: usize, + pub num_fixed: usize, + pub lookup_bits: usize, + pub limb_bits: usize, + pub num_limbs: usize, +} + +#[derive(Clone)] +pub struct AggregationConfig { + pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub instance: Column, +} + +impl AggregationConfig { + pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( + meta, + params.strategy, + &[params.num_advice], + &[params.num_lookup_advice], + params.num_fixed, + params.lookup_bits, + BITS, + LIMBS, + halo2_base::utils::modulus::(), + 0, + params.degree as usize, + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self { base_field_config, instance } + } + + pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { + &self.base_field_config.range + } + + pub fn gate(&self) -> &halo2_base::gates::flex_gate::FlexGateConfig { + &self.base_field_config.range.gate + } + + pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip { + EccChip::construct(self.base_field_config.clone()) + } +} + +/// Aggregation circuit that does not re-expose any public inputs from aggregated snarks +/// +/// This is mostly a reference implementation. In practice one will probably need to re-implement the circuit for one's particular use case with specific instance logic. +#[derive(Clone)] +pub struct AggregationCircuit { + svk: Svk, + snarks: Vec, + instances: Vec, + as_proof: Value>, +} + +impl AggregationCircuit { + pub fn new( + params: &ParamsKZG, + snarks: impl IntoIterator, + transcript_write: &mut PoseidonTranscript>, + rng: &mut impl Rng, + ) -> Self { + let svk = params.get_g()[0].into(); + let snarks = snarks.into_iter().collect_vec(); + + // TODO: this is all redundant calculation to get the public output + // Halo2 should just be able to expose public output to instance column directly + let mut transcript_read = + PoseidonTranscript::::from_spec(&[], POSEIDON_SPEC.clone()); + let accumulators = snarks + .iter() + .flat_map(|snark| { + transcript_read.new_stream(snark.proof.as_slice()); + let proof = Shplonk::read_proof( + &svk, + &snark.protocol, + &snark.instances, + &mut transcript_read, + ) + .unwrap(); + Shplonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }) + .collect_vec(); + + let (accumulator, as_proof) = { + transcript_write.clear(); + // We always use SHPLONK for accumulation scheme when aggregating proofs + let accumulator = KzgAs::>::create_proof( + &Default::default(), + &accumulators, + transcript_write, + rng, + ) + .unwrap(); + (accumulator, transcript_write.stream_mut().split_off(0)) + }; + + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, _, LIMBS, BITS>).concat(); + + Self { + svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_proof: Value::known(as_proof), + } + } + + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + + pub fn num_instance() -> Vec { + vec![4 * LIMBS] + } + + pub fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } +} + +impl CircuitExt for AggregationCircuit { + fn num_instance() -> Vec { + // [..lhs, ..rhs] + vec![4 * LIMBS] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + + fn selectors(config: &Self::Config) -> Vec { + config.gate().basic_gates[0].iter().map(|gate| gate.q_enable).collect() + } +} + +impl Circuit for AggregationCircuit { + type Config = AggregationConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self.snarks.iter().map(SnarkWitness::without_witnesses).collect(), + instances: Vec::new(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + let path = std::env::var("VERIFY_CONFIG") + .unwrap_or_else(|_| "configs/verify_circuit.config".to_owned()); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).unwrap_or_else(|_| panic!("{path:?} does not exist")), + ) + .unwrap(); + + AggregationConfig::configure(meta, params) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + config.range().load_lookup_table(&mut layouter)?; + + // assume using simple floor planner + let mut first_pass = halo2_base::SKIP_FIRST_PASS; + let mut assigned_instances = vec![]; + + layouter.assign_region( + || "", + |region| { + if first_pass { + first_pass = false; + return Ok(()); + } + #[cfg(feature = "display")] + let witness_time = start_timer!(|| "Witness Collection"); + let ctx = Context::new( + region, + ContextParams { + max_rows: config.gate().max_rows, + num_context_ids: 1, + fixed_columns: config.gate().constants.clone(), + }, + ); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let (_, KzgAccumulator { lhs, rhs }) = aggregate::>( + &self.svk, + &loader, + &self.snarks, + self.as_proof(), + ); + + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + + config.base_field_config.finalize(&mut loader.ctx_mut()); + #[cfg(feature = "display")] + println!("Total advice cells: {}", loader.ctx().total_advice); + #[cfg(feature = "display")] + println!("Advice columns used: {}", loader.ctx().advice_alloc[0][0].0 + 1); + + assigned_instances = lhs + .x + .truncation + .limbs + .iter() + .chain(lhs.y.truncation.limbs.iter()) + .chain(rhs.x.truncation.limbs.iter()) + .chain(rhs.y.truncation.limbs.iter()) + .map(|assigned| { + #[cfg(feature = "halo2-axiom")] + { + *assigned.cell() + } + #[cfg(feature = "halo2-pse")] + { + assigned.cell() + } + }) + .collect_vec(); + #[cfg(feature = "display")] + end_timer!(witness_time); + Ok(()) + }, + )?; + + // Expose instances + // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate + for (i, cell) in assigned_instances.into_iter().enumerate() { + layouter.constrain_instance(cell, config.instance, i); + } + Ok(()) + } +} diff --git a/snark-verifier-sdk/src/lib.rs b/snark-verifier-sdk/src/lib.rs new file mode 100644 index 00000000..e46b704d --- /dev/null +++ b/snark-verifier-sdk/src/lib.rs @@ -0,0 +1,198 @@ +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs; +use halo2_proofs::{ + circuit::Value, + halo2curves::{ + bn256::{Bn256, Fr, G1Affine}, + group::ff::Field, + }, + plonk::{keygen_pk, keygen_vk, Circuit, ProvingKey, Selector}, + poly::kzg::commitment::ParamsKZG, +}; +use itertools::Itertools; +use snark_verifier::{pcs::kzg::LimbsEncoding, verifier, Protocol}; +use std::{ + fs::{self, File}, + io::{BufReader, BufWriter}, + path::Path, +}; + +#[cfg(feature = "loader_evm")] +pub mod evm; +#[cfg(feature = "loader_halo2")] +pub mod halo2; + +const LIMBS: usize = 3; +const BITS: usize = 88; + +/// PCS be either `Kzg` or `Kzg` +pub type Plonk = verifier::Plonk>; + +pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, +} + +impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { protocol, instances, proof } + } +} + +impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } +} + +#[derive(Clone)] +pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, +} + +impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } +} + +pub trait CircuitExt: Circuit { + fn num_instance() -> Vec; + + fn instances(&self) -> Vec>; + + fn accumulator_indices() -> Option> { + None + } + + /// Output the simple selector columns (before selector compression) of the circuit + fn selectors(_: &Self::Config) -> Vec { + vec![] + } +} + +pub fn gen_pk>( + params: &ParamsKZG, + circuit: &C, + path: Option<&Path>, +) -> ProvingKey { + if let Some(path) = path { + match File::open(path) { + Ok(f) => { + #[cfg(feature = "display")] + let read_time = start_timer!(|| format!("Reading pkey from {path:?}")); + + // TODO: bench if BufReader is indeed faster than Read + let mut bufreader = BufReader::new(f); + let pk = ProvingKey::read::<_, C>(&mut bufreader, params) + .expect("Reading pkey should not fail"); + + #[cfg(feature = "display")] + end_timer!(read_time); + + pk + } + Err(_) => { + #[cfg(feature = "display")] + let pk_time = start_timer!(|| "Generating vkey & pkey"); + + let vk = keygen_vk(params, circuit).unwrap(); + let pk = keygen_pk(params, vk, circuit).unwrap(); + + #[cfg(feature = "display")] + end_timer!(pk_time); + + #[cfg(feature = "display")] + let write_time = start_timer!(|| format!("Writing pkey to {path:?}")); + + path.parent().and_then(|dir| fs::create_dir_all(dir).ok()).unwrap(); + let mut f = BufWriter::new(File::create(path).unwrap()); + pk.write(&mut f).unwrap(); + + #[cfg(feature = "display")] + end_timer!(write_time); + + pk + } + } + } else { + #[cfg(feature = "display")] + let pk_time = start_timer!(|| "Generating vkey & pkey"); + + let vk = keygen_vk(params, circuit).unwrap(); + let pk = keygen_pk(params, vk, circuit).unwrap(); + + #[cfg(feature = "display")] + end_timer!(pk_time); + + pk + } +} + +pub fn read_instances(path: impl AsRef) -> Result>, bincode::Error> { + let f = File::open(path)?; + let reader = BufReader::new(f); + let instances: Vec> = bincode::deserialize_from(reader)?; + instances + .into_iter() + .map(|instance_column| { + instance_column + .iter() + .map(|bytes| { + Option::from(Fr::from_bytes(bytes)).ok_or(Box::new(bincode::ErrorKind::Custom( + "Invalid finite field point".to_owned(), + ))) + }) + .collect::, _>>() + }) + .collect() +} + +pub fn write_instances(instances: &[&[Fr]], path: impl AsRef) { + let instances: Vec> = instances + .iter() + .map(|instance_column| instance_column.iter().map(|x| x.to_bytes()).collect_vec()) + .collect_vec(); + let f = BufWriter::new(File::create(path).unwrap()); + bincode::serialize_into(f, &instances).unwrap(); +} + +#[cfg(feature = "zkevm")] +mod zkevm { + use super::CircuitExt; + use eth_types::Field; + use zkevm_circuits::evm_circuit::EvmCircuit; + + impl CircuitExt for EvmCircuit { + fn instances(&self) -> Vec> { + vec![] + } + fn num_instance() -> Vec { + vec![] + } + } +} diff --git a/snark-verifier/Cargo.toml b/snark-verifier/Cargo.toml new file mode 100644 index 00000000..146c69fb --- /dev/null +++ b/snark-verifier/Cargo.toml @@ -0,0 +1,61 @@ +[package] +name = "snark-verifier" +version = "0.1.0" +edition = "2021" + +[dependencies] +itertools = "0.10.3" +lazy_static = "1.4.0" +num-bigint = "0.4.3" +num-integer = "0.1.45" +num-traits = "0.2.15" +hex = "0.4" +rand = "0.8" + +# Use halo2-base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" +halo2-base = { git = "ssh://github.com/axiom-crypto/axiom-core-working.git", branch = "experiment/optimizations", default-features = false } + +# parallel +rayon = { version = "1.5.3", optional = true } + +# loader_evm +ethereum-types = { version = "0.14", default-features = false, features = ["std"], optional = true } +sha3 = { version = "0.10", optional = true } +revm = { version = "2.3.1", optional = true } +bytes = { version = "1.2", optional = true } +rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } + +# loader_halo2 +halo2-ecc = { git = "ssh://github.com/axiom-crypto/axiom-core-working.git", branch = "experiment/optimizations", default-features = false, optional = true } + +[dev-dependencies] +ark-std = { version = "0.3.0", features = ["print-trace"] } +paste = "1.0.7" +rand_chacha = "0.3.1" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +# loader_evm +crossterm = { version = "0.25" } +tui = { version = "0.19", default-features = false, features = ["crossterm"] } + +[features] +default = ["loader_evm", "loader_halo2", "halo2-pse"] +display = ["halo2-base/display", "halo2-ecc?/display"] +loader_evm = ["dep:ethereum-types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] +loader_halo2 = ["halo2-ecc"] +parallel = ["dep:rayon"] +# EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo +halo2-pse = ["halo2-base/halo2-pse", "halo2-ecc?/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom", "halo2-ecc?/halo2-axiom"] + +[[example]] +name = "evm-verifier" +required-features = ["loader_evm"] + +[[example]] +name = "evm-verifier-with-accumulator" +required-features = ["loader_halo2", "loader_evm"] + +[[example]] +name = "recursion" +required-features = ["loader_halo2"] \ No newline at end of file diff --git a/snark-verifier/configs/example_evm_accumulator.config b/snark-verifier/configs/example_evm_accumulator.config new file mode 100644 index 00000000..fcda49a0 --- /dev/null +++ b/snark-verifier/configs/example_evm_accumulator.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":5,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/snark-verifier/configs/verify_circuit.config b/snark-verifier/configs/verify_circuit.config new file mode 100644 index 00000000..e65b2b52 --- /dev/null +++ b/snark-verifier/configs/verify_circuit.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":4,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/examples/README.md b/snark-verifier/examples/README.md similarity index 77% rename from examples/README.md rename to snark-verifier/examples/README.md index 443dc67e..67cd4267 100644 --- a/examples/README.md +++ b/snark-verifier/examples/README.md @@ -1,4 +1,4 @@ -In `plonk-verifier` root directory: +In `snark-verifier` root directory: 1. Create `./configs/verify_circuit.config` diff --git a/snark-verifier/examples/evm-verifier-with-accumulator.rs b/snark-verifier/examples/evm-verifier-with-accumulator.rs new file mode 100644 index 00000000..cdb5bb16 --- /dev/null +++ b/snark-verifier/examples/evm-verifier-with-accumulator.rs @@ -0,0 +1,675 @@ +use ethereum_types::Address; +use halo2_base::halo2_proofs; +use halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::{ + commitment::{Params, ParamsProver}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + VerificationStrategy, + }, + transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use itertools::Itertools; +use rand::rngs::OsRng; +use snark_verifier::{ + loader::{ + evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, + native::NativeLoader, + }, + pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::{self, PlonkVerifier}, +}; +use std::{io::Cursor, rc::Rc}; + +const LIMBS: usize = 3; +const BITS: usize = 88; + +type Pcs = Kzg; +type As = KzgAs; +type Plonk = verifier::Plonk>; + +mod application { + use super::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + use super::Fr; + use halo2_base::halo2_proofs::plonk::Assigned; + use rand::RngCore; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } + } + } + + #[derive(Clone, Default)] + pub struct StandardPlonk(Fr); + + impl StandardPlonk { + pub fn rand(mut rng: R) -> Self { + Self(Fr::from(rng.next_u32() as u64)) + } + + pub fn num_instance() -> Vec { + vec![1] + } + + pub fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + #[cfg(feature = "halo2-pse")] + { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice( + || "", + config.a, + 1, + || Value::known(-Fr::from(5u64)), + )?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed( + || "", + column, + 1, + || Value::known(Fr::from(idx as u64)), + )?; + } + + let a = + region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + } + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice( + config.a, + 0, + Value::known(Assigned::Trivial(self.0)), + )?; + region.assign_fixed(config.q_a, 0, Assigned::Trivial(-Fr::one())); + + region.assign_advice( + config.a, + 1, + Value::known(Assigned::Trivial(-Fr::from(5u64))), + )?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Assigned::Trivial(Fr::from(idx as u64))); + } + + let a = region.assign_advice( + config.a, + 2, + Value::known(Assigned::Trivial(Fr::one())), + )?; + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); + } + + Ok(()) + }, + ) + } + } +} + +mod aggregation { + use super::halo2_proofs::{ + circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, + plonk::{self, Circuit, Column, ConstraintSystem, Instance, Selector}, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, + }; + use super::{As, Plonk, BITS, LIMBS}; + use super::{Bn256, Fq, Fr, G1Affine}; + use ark_std::{end_timer, start_timer}; + use halo2_base::{Context, ContextParams}; + use halo2_ecc::ecc::EccChip; + use itertools::Itertools; + use rand::rngs::OsRng; + use snark_verifier::{ + loader::{self, native::NativeLoader}, + pcs::{ + kzg::{KzgAccumulator, KzgSuccinctVerifyingKey}, + AccumulationScheme, AccumulationSchemeProver, + }, + system, + util::arithmetic::fe_to_limbs, + verifier::PlonkVerifier, + Protocol, + }; + use std::{fs::File, rc::Rc}; + + const T: usize = 5; + const RATE: usize = 4; + const R_F: usize = 8; + const R_P: usize = 60; + + type Svk = KzgSuccinctVerifyingKey; + type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + pub type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + + pub struct Snark { + protocol: Protocol, + instances: Vec>, + proof: Vec, + } + + impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { protocol, instances, proof } + } + } + + impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } + } + + #[derive(Clone)] + pub struct SnarkWitness { + protocol: Protocol, + instances: Vec>>, + proof: Value>, + } + + impl SnarkWitness { + fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } + } + + pub fn aggregate<'a>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + as_proof: Value<&'_ [u8]>, + ) -> KzgAccumulator>> { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() + }) + .collect_vec() + }; + + let accumulators = snarks + .iter() + .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); + let instances = assign_instances(&snark.instances); + let mut transcript = + PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap() + }) + .collect_vec(); + + let acccumulator = { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = + As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() + }; + + acccumulator + } + + #[derive(serde::Serialize, serde::Deserialize)] + pub struct AggregationConfigParams { + pub strategy: halo2_ecc::fields::fp::FpStrategy, + pub degree: u32, + pub num_advice: usize, + pub num_lookup_advice: usize, + pub num_fixed: usize, + pub lookup_bits: usize, + pub limb_bits: usize, + pub num_limbs: usize, + } + + #[derive(Clone)] + pub struct AggregationConfig { + pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub instance: Column, + } + + impl AggregationConfig { + pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( + meta, + params.strategy, + &[params.num_advice], + &[params.num_lookup_advice], + params.num_fixed, + params.lookup_bits, + params.limb_bits, + params.num_limbs, + halo2_base::utils::modulus::(), + 0, + params.degree as usize, + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self { base_field_config, instance } + } + + pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { + &self.base_field_config.range + } + + pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip { + EccChip::construct(self.base_field_config.clone()) + } + } + + #[derive(Clone)] + pub struct AggregationCircuit { + svk: Svk, + snarks: Vec, + instances: Vec, + as_proof: Value>, + } + + impl AggregationCircuit { + pub fn new(params: &ParamsKZG, snarks: impl IntoIterator) -> Self { + let svk = params.get_g()[0].into(); + let snarks = snarks.into_iter().collect_vec(); + + let accumulators = snarks + .iter() + .flat_map(|snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = + [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, _, LIMBS, BITS>).concat(); + + Self { + svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_proof: Value::known(as_proof), + } + } + + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + + pub fn num_instance() -> Vec { + // [..lhs, ..rhs] + vec![4 * LIMBS] + } + + pub fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + } + + impl Circuit for AggregationCircuit { + type Config = AggregationConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self.snarks.iter().map(SnarkWitness::without_witnesses).collect(), + instances: Vec::new(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + let path = std::env::var("VERIFY_CONFIG").unwrap(); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()) + .unwrap_or_else(|err| panic!("Path {path} does not exist: {err:?}")), + ) + .unwrap(); + + AggregationConfig::configure(meta, params) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + config.range().load_lookup_table(&mut layouter)?; + let max_rows = config.range().gate.max_rows; + + let mut first_pass = halo2_base::SKIP_FIRST_PASS; // assume using simple floor planner + let mut assigned_instances: Option> = None; + layouter.assign_region( + || "", + |region| { + if first_pass { + first_pass = false; + return Ok(()); + } + let witness_time = start_timer!(|| "Witness Collection"); + let ctx = Context::new( + region, + ContextParams { + max_rows, + num_context_ids: 1, + fixed_columns: config.base_field_config.range.gate.constants.clone(), + }, + ); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let KzgAccumulator { lhs, rhs } = + aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); + + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + + config.base_field_config.finalize(&mut loader.ctx_mut()); + #[cfg(feature = "display")] + println!("Total advice cells: {}", loader.ctx().total_advice); + #[cfg(feature = "display")] + println!("Advice columns used: {}", loader.ctx().advice_alloc[0].0 + 1); + + let instances: Vec<_> = lhs + .x + .truncation + .limbs + .iter() + .chain(lhs.y.truncation.limbs.iter()) + .chain(rhs.x.truncation.limbs.iter()) + .chain(rhs.y.truncation.limbs.iter()) + .map(|assigned| assigned.cell().clone()) + .collect(); + assigned_instances = Some(instances); + end_timer!(witness_time); + Ok(()) + }, + )?; + + // Expose instances + // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate + let mut layouter = layouter.namespace(|| "expose"); + for (i, cell) in assigned_instances.unwrap().into_iter().enumerate() { + layouter.constrain_instance(cell, config.instance, i); + } + Ok(()) + } + } +} + +fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() +} + +fn gen_proof< + C: Circuit, + E: EncodedChallenge, + TR: TranscriptReadBuffer>, G1Affine, E>, + TW: TranscriptWriterBuffer, G1Affine, E>, +>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, +) -> Vec { + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + + let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); + let proof = { + let mut transcript = TW::init(Vec::new()); + create_proof::, ProverGWC<_>, _, _, TW, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TR::init(Cursor::new(proof.clone())); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, TR, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} + +fn gen_application_snark(params: &ParamsKZG) -> aggregation::Snark { + let circuit = application::StandardPlonk::rand(OsRng); + + let pk = gen_pk(params, &circuit); + let protocol = compile( + params, + pk.get_vk(), + Config::kzg().with_num_instance(application::StandardPlonk::num_instance()), + ); + + let proof = gen_proof::< + _, + _, + aggregation::PoseidonTranscript, + aggregation::PoseidonTranscript, + >(params, &pk, circuit.clone(), circuit.instances()); + aggregation::Snark::new(protocol, circuit.instances(), proof) +} + +fn gen_aggregation_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: Vec, + accumulator_indices: Vec<(usize, usize)>, +) -> Vec { + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(num_instance.clone()) + .with_accumulator_indices(Some(accumulator_indices)), + ); + + let loader = EvmLoader::new::(); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); + + let instances = transcript.load_instances(num_instance); + let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + evm::compile_yul(&loader.yul_code()) +} + +fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default().with_gas_limit(u64::MAX.into()).build(); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm.deploy(caller, deployment_code.into(), 0.into()).address.unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +fn main() { + std::env::set_var("VERIFY_CONFIG", "./configs/example_evm_accumulator.config"); + let params = halo2_base::utils::fs::gen_srs(21); + let params_app = { + let mut params = params.clone(); + params.downsize(8); + params + }; + + let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); + let agg_circuit = aggregation::AggregationCircuit::new(¶ms, snarks); + let pk = gen_pk(¶ms, &agg_circuit); + let deployment_code = gen_aggregation_evm_verifier( + ¶ms, + pk.get_vk(), + aggregation::AggregationCircuit::num_instance(), + aggregation::AggregationCircuit::accumulator_indices(), + ); + + let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( + ¶ms, + &pk, + agg_circuit.clone(), + agg_circuit.instances(), + ); + evm_verify(deployment_code, agg_circuit.instances(), proof); +} diff --git a/examples/evm-verifier.rs b/snark-verifier/examples/evm-verifier.rs similarity index 66% rename from examples/evm-verifier.rs rename to snark-verifier/examples/evm-verifier.rs index 613bd5ed..d7a1f0c8 100644 --- a/examples/evm-verifier.rs +++ b/snark-verifier/examples/evm-verifier.rs @@ -1,11 +1,11 @@ use ethereum_types::Address; -use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; -use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_base::halo2_proofs; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, dev::MockProver, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, plonk::{ - create_proof, keygen_pk, keygen_vk, verify_proof, Advice, Circuit, Column, + create_proof, keygen_pk, keygen_vk, verify_proof, Advice, Assigned, Circuit, Column, ConstraintSystem, Error, Fixed, Instance, ProvingKey, VerifyingKey, }, poly::{ @@ -17,20 +17,16 @@ use halo2_proofs::{ }, Rotation, VerificationStrategy, }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; use itertools::Itertools; -use plonk_verifier::{ - loader::evm::{encode_calldata, EvmLoader}, +use rand::{rngs::OsRng, RngCore}; +use snark_verifier::{ + loader::evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, pcs::kzg::{Gwc19, Kzg}, - system::halo2::{ - aggregation::KZG_QUERY_INSTANCE, compile, transcript::evm::EvmTranscript, Config, - }, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, }; -use rand::{rngs::OsRng, RngCore}; use std::rc::Rc; type Plonk = verifier::Plonk>; @@ -117,20 +113,59 @@ impl Circuit for StandardPlonk { layouter.assign_region( || "", |mut region| { - region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; - region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; - - region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5)))?; - for (idx, column) in - (1..).zip([config.q_a, config.q_b, config.q_c, config.q_ab, config.constant]) + #[cfg(feature = "halo2-pse")] { - region.assign_fixed(|| "", column, 1, || Value::known(Fr::from(idx)))?; + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5u64)))?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed( + || "", + column, + 1, + || Value::known(Fr::from(idx as u64)), + )?; + } + + let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + } + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice(config.a, 0, Value::known(Assigned::Trivial(self.0)))?; + region.assign_fixed(config.q_a, 0, Assigned::Trivial(-Fr::one())); + + region.assign_advice( + config.a, + 1, + Value::known(Assigned::Trivial(-Fr::from(5u64))), + )?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Assigned::Trivial(Fr::from(idx as u64))); + } + + let a = region.assign_advice( + config.a, + 2, + Value::known(Assigned::Trivial(Fr::one())), + )?; + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); } - - let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; - a.copy_advice(|| "", &mut region, config.b, 3)?; - a.copy_advice(|| "", &mut region, config.c, 4)?; - Ok(()) }, ) @@ -146,7 +181,7 @@ fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey< keygen_pk(params, vk, circuit).unwrap() } -fn gen_proof + Clone>( +fn gen_proof>( params: &ParamsKZG, pk: &ProvingKey, circuit: C, @@ -154,44 +189,6 @@ fn gen_proof + Clone>( ) -> Vec { MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); - // For testing purposes: - // Native verify - // Uncomment to test if evm verifier fails silently - /*{ - let proof = { - let mut transcript = Blake2bWrite::init(Vec::new()); - create_proof::< - KZGCommitmentScheme, - ProverGWC<_>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >( - params, - pk, - &[circuit.clone()], - &[&[instances[0].as_slice()]], - OsRng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); - let protocol = compile( - params, - pk.get_vk(), - Config::kzg(KZG_QUERY_INSTANCE).with_num_instance(vec![instances[0].len()]), - ); - let mut transcript = Blake2bRead::<_, G1Affine, _>::init(proof.as_slice()); - let instances = &[instances[0].to_vec()]; - let proof = Plonk::read_proof(&svk, &protocol, instances, &mut transcript).unwrap(); - assert!(Plonk::verify(&svk, &dk, &protocol, instances, &proof).unwrap()); - println!("==Native verify passed=="); - }*/ - let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); let proof = { let mut transcript = TranscriptWriterBuffer::<_, G1Affine, _>::init(Vec::new()); @@ -232,34 +229,29 @@ fn gen_evm_verifier( ) -> Vec { let svk = params.get_g()[0].into(); let dk = (params.g2(), params.s_g2()).into(); - let protocol = compile( - params, - vk, - Config::kzg(KZG_QUERY_INSTANCE).with_num_instance(num_instance.clone()), - ); + let protocol = compile(params, vk, Config::kzg().with_num_instance(num_instance.clone())); let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.deployment_code() + evm::compile_yul(&loader.yul_code()) } fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { let calldata = encode_calldata(&instances, &proof); let success = { - let mut evm = ExecutorBuilder::default() - .with_gas_limit(u64::MAX.into()) - .build(Backend::new(MultiFork::new().0, None)); + let mut evm = ExecutorBuilder::default().with_gas_limit(u64::MAX.into()).build(); let caller = Address::from_low_u64_be(0xfe); - let verifier = evm.deploy(caller, deployment_code.into(), 0.into(), None).unwrap().address; - let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()).unwrap(); + let verifier = evm.deploy(caller, deployment_code.into(), 0.into()).address.unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); - dbg!(result.gas); + dbg!(result.gas_used); !result.reverted }; @@ -271,7 +263,8 @@ fn main() { let circuit = StandardPlonk::rand(OsRng); let pk = gen_pk(¶ms, &circuit); - let proof = gen_proof(¶ms, &pk, circuit.clone(), circuit.instances()); let deployment_code = gen_evm_verifier(¶ms, pk.get_vk(), StandardPlonk::num_instance()); + + let proof = gen_proof(¶ms, &pk, circuit.clone(), circuit.instances()); evm_verify(deployment_code, circuit.instances(), proof); } diff --git a/snark-verifier/examples/recursion.rs b/snark-verifier/examples/recursion.rs new file mode 100644 index 00000000..569d6a12 --- /dev/null +++ b/snark-verifier/examples/recursion.rs @@ -0,0 +1,907 @@ +#![allow(clippy::type_complexity)] + +use ark_std::{end_timer, start_timer}; +use common::*; +use halo2_base::halo2_proofs; +use halo2_base::utils::fs::gen_srs; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + halo2curves::{ + bn256::{Bn256, Fq, Fr, G1Affine}, + group::ff::Field, + FieldExt, + }, + plonk::{ + self, create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, + Selector, VerifyingKey, + }, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::ParamsKZG, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + Rotation, VerificationStrategy, + }, +}; +use itertools::Itertools; +use rand_chacha::rand_core::OsRng; +use snark_verifier::{ + loader::{self, native::NativeLoader, Loader, ScalarLoader}, + pcs::{ + kzg::{Gwc19, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + AccumulationScheme, AccumulationSchemeProver, + }, + system::halo2::{self, compile, Config}, + util::{ + arithmetic::{fe_to_fe, fe_to_limbs}, + hash, + }, + verifier::{self, PlonkProof, PlonkVerifier}, + Protocol, +}; +use std::{fs, iter, marker::PhantomData, rc::Rc}; + +use crate::recursion::AggregationConfigParams; + +const LIMBS: usize = 3; +const BITS: usize = 88; +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +type Pcs = Kzg; +type Svk = KzgSuccinctVerifyingKey; +type As = KzgAs; +type Plonk = verifier::Plonk>; +type Poseidon = hash::Poseidon; +type PoseidonTranscript = + halo2::transcript::halo2::PoseidonTranscript; + +mod common { + use super::*; + use halo2_proofs::{plonk::verify_proof, poly::commitment::Params}; + use snark_verifier::{cost::CostEstimation, util::transcript::TranscriptWrite}; + + pub fn poseidon>( + loader: &L, + inputs: &[L::LoadedScalar], + ) -> L::LoadedScalar { + let mut hasher = Poseidon::new(loader, R_F, R_P); + hasher.update(inputs); + hasher.squeeze() + } + + pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, + } + + impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { protocol, instances, proof } + } + } + + impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } + } + + #[derive(Clone)] + pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, + } + + impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } + } + + pub trait CircuitExt: Circuit { + fn num_instance() -> Vec; + + fn instances(&self) -> Vec>; + + fn accumulator_indices() -> Option> { + None + } + + /// Output the simple selector columns (before selector compression) of the circuit + fn selectors(_: &Self::Config) -> Vec { + vec![] + } + } + + pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() + } + + pub fn gen_proof>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + ) -> Vec { + if params.k() > 3 { + let mock = start_timer!(|| "Mock prover"); + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + end_timer!(mock); + } + + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + create_proof::<_, ProverGWC<_>, _, _, _, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, _, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof + } + + pub fn gen_snark>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + ) -> Snark { + let protocol = compile( + params, + pk.get_vk(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + + let instances = circuit.instances(); + let proof = gen_proof(params, pk, circuit, instances.clone()); + + Snark::new(protocol, instances, proof) + } + + pub fn gen_dummy_snark>( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, + ) -> Snark { + struct CsProxy(PhantomData<(F, C)>); + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row + // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) + layouter.assign_region( + || "", + |mut region| { + for q in C::selectors(&config).iter() { + q.enable(&mut region, 0)?; + } + Ok(()) + }, + )?; + Ok(()) + } + } + + let dummy_vk = vk + .is_none() + .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let protocol = compile( + params, + vk.or(dummy_vk.as_ref()).unwrap(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = ConcreteCircuit::num_instance() + .into_iter() + .map(|n| iter::repeat_with(|| Fr::random(OsRng)).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::random(OsRng)).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..Pcs::estimate_cost(&queries).num_commitment { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) + } +} + +mod application { + use super::*; + + #[derive(Clone, Default)] + pub struct Square(Fr); + + impl Circuit for Square { + type Config = Selector; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let q = meta.selector(); + let i = meta.instance_column(); + meta.create_gate("square", |meta| { + let q = meta.query_selector(q); + let [i, i_w] = [0, 1].map(|rotation| meta.query_instance(i, Rotation(rotation))); + Some(q * (i.clone() * i - i_w)) + }); + q + } + + fn synthesize( + &self, + q: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region(|| "", |mut region| q.enable(&mut region, 0)) + } + } + + impl CircuitExt for Square { + fn num_instance() -> Vec { + vec![2] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0, self.0.square()]] + } + } + + impl recursion::StateTransition for Square { + type Input = (); + + fn new(state: Fr) -> Self { + Self(state) + } + + fn state_transition(&self, _: Self::Input) -> Fr { + self.0.square() + } + } +} + +mod recursion { + use std::fs::File; + + use halo2_base::{ + gates::GateInstructions, AssignedValue, Context, ContextParams, QuantumCell::Existing, + }; + use halo2_ecc::ecc::EccChip; + use halo2_proofs::plonk::{Column, Instance}; + use snark_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; + + use super::*; + + type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + + pub trait StateTransition { + type Input; + + fn new(state: Fr) -> Self; + + fn state_transition(&self, input: Self::Input) -> Fr; + } + + fn succinct_verify<'a>( + svk: &Svk, + loader: &Rc>, + snark: &SnarkWitness, + preprocessed_digest: Option>, + ) -> (Vec>>, Vec>>>) { + let protocol = if let Some(preprocessed_digest) = preprocessed_digest { + let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let inputs = protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let assigned = preprocessed.assigned(); + [assigned.x(), assigned.y()] + .map(|coordinate| loader.scalar_from_assigned(coordinate.native().clone())) + }) + .chain(protocol.transcript_initial_state.clone()) + .collect_vec(); + loader.assert_eq("", &poseidon(loader, &inputs), &preprocessed_digest).unwrap(); + protocol + } else { + snark.protocol.loaded(loader) + }; + + let instances = snark + .instances + .iter() + .map(|instances| { + instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() + }) + .collect_vec(); + let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulators = Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap(); + + ( + instances + .into_iter() + .map(|instance| { + instance.into_iter().map(|instance| instance.into_assigned()).collect() + }) + .collect(), + accumulators, + ) + } + + fn select_accumulator<'a>( + loader: &Rc>, + condition: &AssignedValue<'a, Fr>, + lhs: &KzgAccumulator>>, + rhs: &KzgAccumulator>>, + ) -> Result>>, Error> { + let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] + .iter() + .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) + .map(|(lhs, rhs)| loader.ecc_chip().select(&mut loader.ctx_mut(), lhs, rhs, condition)) + .collect::>() + .try_into() + .unwrap(); + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) + } + + fn accumulate<'a>( + loader: &Rc>, + accumulators: Vec>>>, + as_proof: Value<&'_ [u8]>, + ) -> KzgAccumulator>> { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() + } + + #[derive(serde::Serialize, serde::Deserialize)] + pub struct AggregationConfigParams { + pub strategy: halo2_ecc::fields::fp::FpStrategy, + pub degree: u32, + pub num_advice: usize, + pub num_lookup_advice: usize, + pub num_fixed: usize, + pub lookup_bits: usize, + pub limb_bits: usize, + pub num_limbs: usize, + } + + #[derive(Clone)] + pub struct RecursionConfig { + pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub instance: Column, + } + + impl RecursionConfig { + pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( + meta, + params.strategy, + &[params.num_advice], + &[params.num_lookup_advice], + params.num_fixed, + params.lookup_bits, + params.limb_bits, + params.num_limbs, + halo2_base::utils::modulus::(), + 0, + params.degree as usize, + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self { base_field_config, instance } + } + + pub fn gate(&self) -> &halo2_base::gates::flex_gate::FlexGateConfig { + &self.base_field_config.range.gate + } + + pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { + &self.base_field_config.range + } + + pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip { + EccChip::construct(self.base_field_config.clone()) + } + } + + #[derive(Clone)] + pub struct RecursionCircuit { + svk: Svk, + default_accumulator: KzgAccumulator, + app: SnarkWitness, + previous: SnarkWitness, + round: usize, + instances: Vec, + as_proof: Value>, + } + + impl RecursionCircuit { + const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS; + const INITIAL_STATE_ROW: usize = 4 * LIMBS + 1; + const STATE_ROW: usize = 4 * LIMBS + 2; + const ROUND_ROW: usize = 4 * LIMBS + 3; + + pub fn new( + params: &ParamsKZG, + app: Snark, + previous: Snark, + initial_state: Fr, + state: Fr, + round: usize, + ) -> Self { + let svk = params.get_g()[0].into(); + let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); + + let succinct_verify = |snark: &Snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + + let accumulators = iter::empty() + .chain(succinct_verify(&app)) + .chain((round > 0).then(|| succinct_verify(&previous)).unwrap_or_else(|| { + let num_accumulator = 1 + previous.protocol.accumulator_indices.len(); + vec![default_accumulator.clone(); num_accumulator] + })) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let preprocessed_digest = { + let inputs = previous + .protocol + .preprocessed + .iter() + .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) + .map(fe_to_fe) + .chain(previous.protocol.transcript_initial_state) + .collect_vec(); + poseidon(&NativeLoader, &inputs) + }; + let instances = + [accumulator.lhs.x, accumulator.lhs.y, accumulator.rhs.x, accumulator.rhs.y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([preprocessed_digest, initial_state, state, Fr::from(round as u64)]) + .collect(); + + Self { + svk, + default_accumulator, + app: app.into(), + previous: previous.into(), + round, + instances, + as_proof: Value::known(as_proof), + } + } + + fn initial_snark(params: &ParamsKZG, vk: Option<&VerifyingKey>) -> Snark { + let mut snark = gen_dummy_snark::(params, vk); + let g = params.get_g(); + snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([Fr::zero(); 4]) + .collect_vec()]; + snark + } + + fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + + fn load_default_accumulator<'a>( + &self, + loader: &Rc>, + ) -> Result>>, Error> { + let [lhs, rhs] = + [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { + let assigned = + loader.ecc_chip().assign_constant(&mut loader.ctx_mut(), default).unwrap(); + loader.ec_point_from_assigned(assigned) + }); + Ok(KzgAccumulator::new(lhs, rhs)) + } + } + + impl Circuit for RecursionCircuit { + type Config = RecursionConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + default_accumulator: self.default_accumulator.clone(), + app: self.app.without_witnesses(), + previous: self.previous.without_witnesses(), + round: self.round, + instances: self.instances.clone(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + let path = std::env::var("VERIFY_CONFIG") + .unwrap_or_else(|_| "configs/verify_circuit.config".to_owned()); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}")), + ) + .unwrap(); + + RecursionConfig::configure(meta, params) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + config.range().load_lookup_table(&mut layouter)?; + let max_rows = config.range().gate.max_rows; + let main_gate = config.gate(); + + let mut first_pass = halo2_base::SKIP_FIRST_PASS; // assume using simple floor planner + let mut assigned_instances = Vec::new(); + layouter.assign_region( + || "", + |region| { + if first_pass { + first_pass = false; + return Ok(()); + } + let mut ctx = Context::new( + region, + ContextParams { + max_rows, + num_context_ids: 1, + fixed_columns: config.base_field_config.range.gate.constants.clone(), + }, + ); + + let [preprocessed_digest, initial_state, state, round] = [ + self.instances[Self::PREPROCESSED_DIGEST_ROW], + self.instances[Self::INITIAL_STATE_ROW], + self.instances[Self::STATE_ROW], + self.instances[Self::ROUND_ROW], + ] + .map(|instance| { + main_gate.assign_integer(&mut ctx, Value::known(instance)).unwrap() + }); + let first_round = main_gate.is_zero(&mut ctx, &round); + let not_first_round = main_gate.not(&mut ctx, Existing(&first_round)); + + let loader = Halo2Loader::new(config.ecc_chip(), ctx); + let (mut app_instances, app_accumulators) = + succinct_verify(&self.svk, &loader, &self.app, None); + let (mut previous_instances, previous_accumulators) = succinct_verify( + &self.svk, + &loader, + &self.previous, + Some(preprocessed_digest.clone()), + ); + + let default_accmulator = self.load_default_accumulator(&loader)?; + let previous_accumulators = previous_accumulators + .iter() + .map(|previous_accumulator| { + select_accumulator( + &loader, + &first_round, + &default_accmulator, + previous_accumulator, + ) + }) + .collect::, Error>>()?; + + let KzgAccumulator { lhs, rhs } = accumulate( + &loader, + [app_accumulators, previous_accumulators].concat(), + self.as_proof(), + ); + + let lhs = lhs.into_assigned(); + let rhs = rhs.into_assigned(); + let app_instances = app_instances.pop().unwrap(); + let previous_instances = previous_instances.pop().unwrap(); + + let mut ctx = loader.ctx_mut(); + for (lhs, rhs) in [ + // Propagate preprocessed_digest + ( + &main_gate.mul( + &mut ctx, + Existing(&preprocessed_digest), + Existing(¬_first_round), + ), + &previous_instances[Self::PREPROCESSED_DIGEST_ROW], + ), + // Propagate initial_state + ( + &main_gate.mul( + &mut ctx, + Existing(&initial_state), + Existing(¬_first_round), + ), + &previous_instances[Self::INITIAL_STATE_ROW], + ), + // Verify initial_state is same as the first application snark + ( + &main_gate.mul( + &mut ctx, + Existing(&initial_state), + Existing(&first_round), + ), + &main_gate.mul( + &mut ctx, + Existing(&app_instances[0]), + Existing(&first_round), + ), + ), + // Verify current state is same as the current application snark + (&state, &app_instances[1]), + // Verify previous state is same as the current application snark + ( + &main_gate.mul( + &mut ctx, + Existing(&app_instances[0]), + Existing(¬_first_round), + ), + &previous_instances[Self::STATE_ROW], + ), + // Verify round is increased by 1 when not at first round + ( + &round, + &main_gate.add( + &mut ctx, + Existing(¬_first_round), + Existing(&previous_instances[Self::ROUND_ROW]), + ), + ), + ] { + ctx.region.constrain_equal(lhs.cell(), rhs.cell()); + } + + // IMPORTANT: + config.base_field_config.finalize(&mut ctx); + #[cfg(feature = "display")] + dbg!(ctx.total_advice); + #[cfg(feature = "display")] + println!("Advice columns used: {}", ctx.advice_alloc[0][0].0 + 1); + + assigned_instances.extend( + [lhs.x(), lhs.y(), rhs.x(), rhs.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs()) + .chain([preprocessed_digest, initial_state, state, round].iter()) + .map(|assigned| assigned.cell()), + ); + Ok(()) + }, + )?; + + assert_eq!(assigned_instances.len(), 4 * LIMBS + 4); + for (row, limb) in assigned_instances.into_iter().enumerate() { + layouter.constrain_instance(limb, config.instance, row); + } + + Ok(()) + } + } + + impl CircuitExt for RecursionCircuit { + fn num_instance() -> Vec { + // [..lhs, ..rhs, preprocessed_digest, initial_state, state, round] + vec![4 * LIMBS + 4] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + + fn selectors(config: &Self::Config) -> Vec { + config.base_field_config.range.gate.basic_gates[0] + .iter() + .map(|gate| gate.q_enable) + .collect() + } + } + + pub fn gen_recursion_pk>( + recursion_params: &ParamsKZG, + app_params: &ParamsKZG, + app_vk: &VerifyingKey, + ) -> ProvingKey { + let recursion = RecursionCircuit::new( + recursion_params, + gen_dummy_snark::(app_params, Some(app_vk)), + RecursionCircuit::initial_snark(recursion_params, None), + Fr::zero(), + Fr::zero(), + 0, + ); + gen_pk(recursion_params, &recursion) + } + + pub fn gen_recursion_snark + StateTransition>( + app_params: &ParamsKZG, + recursion_params: &ParamsKZG, + app_pk: &ProvingKey, + recursion_pk: &ProvingKey, + initial_state: Fr, + inputs: Vec, + ) -> (Fr, Snark) { + let mut state = initial_state; + let mut app = ConcreteCircuit::new(state); + let mut previous = + RecursionCircuit::initial_snark(recursion_params, Some(recursion_pk.get_vk())); + for (round, input) in inputs.into_iter().enumerate() { + state = app.state_transition(input); + println!("Generate app snark"); + let app_snark = gen_snark(app_params, app_pk, app); + let recursion = RecursionCircuit::new( + recursion_params, + app_snark, + previous, + initial_state, + state, + round, + ); + println!("Generate recursion snark"); + previous = gen_snark(recursion_params, recursion_pk, recursion); + app = ConcreteCircuit::new(state); + } + (state, previous) + } +} + +fn main() { + let app_params = gen_srs(3); + let recursion_config: AggregationConfigParams = + serde_json::from_reader(fs::File::open("configs/verify_circuit.config").unwrap()).unwrap(); + let k = recursion_config.degree; + let recursion_params = gen_srs(k); + + let app_pk = gen_pk(&app_params, &application::Square::default()); + + let pk_time = start_timer!(|| "Generate recursion pk"); + let recursion_pk = recursion::gen_recursion_pk::( + &recursion_params, + &app_params, + app_pk.get_vk(), + ); + end_timer!(pk_time); + + let num_round = 1; + let pf_time = start_timer!(|| "Generate full recursive snark"); + let (final_state, snark) = recursion::gen_recursion_snark::( + &app_params, + &recursion_params, + &app_pk, + &recursion_pk, + Fr::from(2u64), + vec![(); num_round], + ); + end_timer!(pf_time); + assert_eq!(final_state, Fr::from(2u64).pow(&[1 << num_round, 0, 0, 0])); + + let accept = { + let svk = recursion_params.get_g()[0].into(); + let dk = (recursion_params.g2(), recursion_params.s_g2()).into(); + let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + assert!(accept) +} diff --git a/src/cost.rs b/snark-verifier/src/cost.rs similarity index 100% rename from src/cost.rs rename to snark-verifier/src/cost.rs diff --git a/src/lib.rs b/snark-verifier/src/lib.rs similarity index 73% rename from src/lib.rs rename to snark-verifier/src/lib.rs index 4e740638..a6e4b3e6 100644 --- a/src/lib.rs +++ b/snark-verifier/src/lib.rs @@ -1,6 +1,3 @@ -#![feature(int_log)] -#![feature(int_roundings)] -#![feature(assert_matches)] #![allow(clippy::type_complexity)] #![allow(clippy::too_many_arguments)] #![allow(clippy::upper_case_acronyms)] @@ -12,6 +9,10 @@ pub mod system; pub mod util; pub mod verifier; +pub(crate) use halo2_base::halo2_proofs; +pub(crate) use halo2_base::poseidon; +pub(crate) use halo2_proofs::halo2curves as halo2_curves; + #[derive(Clone, Debug)] pub enum Error { InvalidInstances, @@ -23,10 +24,14 @@ pub enum Error { } #[derive(Clone, Debug)] -pub struct Protocol { +pub struct Protocol +where + C: util::arithmetic::CurveAffine, + L: loader::Loader, +{ // Common description pub domain: util::arithmetic::Domain, - pub preprocessed: Vec, + pub preprocessed: Vec, pub num_instance: Vec, pub num_witness: Vec, pub num_challenge: Vec, @@ -34,7 +39,7 @@ pub struct Protocol { pub queries: Vec, pub quotient: util::protocol::QuotientPolynomial, // Minor customization - pub transcript_initial_state: Option, + pub transcript_initial_state: Option, pub instance_committing_key: Option>, pub linearization: Option, pub accumulator_indices: Vec>, diff --git a/src/loader.rs b/snark-verifier/src/loader.rs similarity index 75% rename from src/loader.rs rename to snark-verifier/src/loader.rs index 2308b08b..297390d0 100644 --- a/src/loader.rs +++ b/snark-verifier/src/loader.rs @@ -5,7 +5,7 @@ use crate::{ }, Error, }; -use std::{fmt::Debug, iter}; +use std::{borrow::Cow, fmt::Debug, iter, ops::Deref}; pub mod native; @@ -19,10 +19,6 @@ pub trait LoadedEcPoint: Clone + Debug + PartialEq { type Loader: Loader; fn loader(&self) -> &Self::Loader; - - fn multi_scalar_multiplication( - pairs: impl IntoIterator>::LoadedScalar, Self)>, - ) -> Self; } pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { @@ -34,23 +30,10 @@ pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { self.clone() * self } - fn mul_add(a: &Self, b: &Self, c: &Self) -> Self; - - fn mul_add_constant(a: &Self, b: &Self, c: &F) -> Self; - fn invert(&self) -> Option { FieldOps::invert(self) } - fn batch_invert<'a>(values: impl IntoIterator) - where - Self: 'a, - { - values - .into_iter() - .for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone())) - } - fn pow_const(&self, mut exp: u64) -> Self { assert!(exp > 0); @@ -101,6 +84,12 @@ pub trait EcPointLoader { lhs: &Self::LoadedEcPoint, rhs: &Self::LoadedEcPoint, ) -> Result<(), Error>; + + fn multi_scalar_multiplication( + pairs: &[(&Self::LoadedScalar, &Self::LoadedEcPoint)], + ) -> Self::LoadedEcPoint + where + Self: ScalarLoader; } pub trait ScalarLoader { @@ -123,7 +112,7 @@ pub trait ScalarLoader { rhs: &Self::LoadedScalar, ) -> Result<(), Error>; - fn sum_with_coeff_and_constant( + fn sum_with_coeff_and_const( &self, values: &[(F, &Self::LoadedScalar)], constant: F, @@ -134,19 +123,24 @@ pub trait ScalarLoader { let loader = values.first().unwrap().1.loader(); iter::empty() - .chain(if constant == F::zero() { None } else { Some(loader.load_const(&constant)) }) + .chain(if constant == F::zero() { + None + } else { + Some(Cow::Owned(loader.load_const(&constant))) + }) .chain(values.iter().map(|&(coeff, value)| { if coeff == F::one() { - value.clone() + Cow::Borrowed(value) } else { - loader.load_const(&coeff) * value + Cow::Owned(loader.load_const(&coeff) * value) } })) - .reduce(|acc, term| acc + term) + .reduce(|acc, term| Cow::Owned(acc.into_owned() + term.deref())) .unwrap() + .into_owned() } - fn sum_products_with_coeff_and_constant( + fn sum_products_with_coeff_and_const( &self, values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], constant: F, @@ -157,7 +151,11 @@ pub trait ScalarLoader { let loader = values.first().unwrap().1.loader(); iter::empty() - .chain(if constant == F::zero() { None } else { Some(loader.load_const(&constant)) }) + .chain(if constant == F::zero() { + None + } else { + Some(loader.load_const(&constant)) + }) .chain(values.iter().map(|&(coeff, lhs, rhs)| { if coeff == F::one() { lhs.clone() * rhs @@ -170,39 +168,61 @@ pub trait ScalarLoader { } fn sum_with_coeff(&self, values: &[(F, &Self::LoadedScalar)]) -> Self::LoadedScalar { - self.sum_with_coeff_and_constant(values, F::zero()) + self.sum_with_coeff_and_const(values, F::zero()) + } + + fn sum_with_const(&self, values: &[&Self::LoadedScalar], constant: F) -> Self::LoadedScalar { + self.sum_with_coeff_and_const( + &values.iter().map(|&value| (F::one(), value)).collect_vec(), + constant, + ) + } + + fn sum(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { + self.sum_with_const(values, F::zero()) } fn sum_products_with_coeff( &self, values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], ) -> Self::LoadedScalar { - self.sum_products_with_coeff_and_constant(values, F::zero()) + self.sum_products_with_coeff_and_const(values, F::zero()) } - fn sum_products( + fn sum_products_with_const( &self, values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], + constant: F, ) -> Self::LoadedScalar { - self.sum_products_with_coeff_and_constant( - &values.iter().map(|&(lhs, rhs)| (F::one(), lhs, rhs)).collect_vec(), - F::zero(), - ) - } - - fn sum_with_const(&self, values: &[&Self::LoadedScalar], constant: F) -> Self::LoadedScalar { - self.sum_with_coeff_and_constant( - &values.iter().map(|&value| (F::one(), value)).collect_vec(), + self.sum_products_with_coeff_and_const( + &values + .iter() + .map(|&(lhs, rhs)| (F::one(), lhs, rhs)) + .collect_vec(), constant, ) } - fn sum(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { - self.sum_with_const(values, F::zero()) + fn sum_products( + &self, + values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], + ) -> Self::LoadedScalar { + self.sum_products_with_const(values, F::zero()) } fn product(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { - values.iter().fold(self.load_one(), |acc, value| acc * *value) + values + .iter() + .fold(self.load_one(), |acc, value| acc * *value) + } + + fn batch_invert<'a>(values: impl IntoIterator) + where + Self::LoadedScalar: 'a, + { + values + .into_iter() + .for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone())) } } diff --git a/src/loader/evm.rs b/snark-verifier/src/loader/evm.rs similarity index 57% rename from src/loader/evm.rs rename to snark-verifier/src/loader/evm.rs index 7a07670c..263da0e2 100644 --- a/src/loader/evm.rs +++ b/snark-verifier/src/loader/evm.rs @@ -6,7 +6,10 @@ mod util; mod test; pub use loader::{EcPoint, EvmLoader, Scalar}; -pub use util::{encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, MemoryChunk}; +pub use util::{ + compile_yul, encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, + MemoryChunk, +}; pub use ethereum_types::U256; diff --git a/snark-verifier/src/loader/evm/code.rs b/snark-verifier/src/loader/evm/code.rs new file mode 100644 index 00000000..840d1e67 --- /dev/null +++ b/snark-verifier/src/loader/evm/code.rs @@ -0,0 +1,75 @@ +pub enum Precompiled { + BigModExp = 0x05, + Bn254Add = 0x6, + Bn254ScalarMul = 0x7, + Bn254Pairing = 0x8, +} + +#[derive(Clone, Debug)] +pub struct YulCode { + // runtime code area + runtime: String, +} + +impl YulCode { + pub fn new() -> Self { + YulCode { + runtime: String::new(), + } + } + + pub fn code(&self, base_modulus: String, scalar_modulus: String) -> String { + format!( + " + object \"plonk_verifier\" {{ + code {{ + function allocate(size) -> ptr {{ + ptr := mload(0x40) + if eq(ptr, 0) {{ ptr := 0x60 }} + mstore(0x40, add(ptr, size)) + }} + let size := datasize(\"Runtime\") + let offset := allocate(size) + datacopy(offset, dataoffset(\"Runtime\"), size) + return(offset, size) + }} + object \"Runtime\" {{ + code {{ + let success:bool := true + let f_p := {base_modulus} + let f_q := {scalar_modulus} + function validate_ec_point(x, y) -> valid:bool {{ + {{ + let x_lt_p:bool := lt(x, {base_modulus}) + let y_lt_p:bool := lt(y, {base_modulus}) + valid := and(x_lt_p, y_lt_p) + }} + {{ + let x_is_zero:bool := eq(x, 0) + let y_is_zero:bool := eq(y, 0) + let x_or_y_is_zero:bool := or(x_is_zero, y_is_zero) + let x_and_y_is_not_zero:bool := not(x_or_y_is_zero) + valid := and(x_and_y_is_not_zero, valid) + }} + {{ + let y_square := mulmod(y, y, {base_modulus}) + let x_square := mulmod(x, x, {base_modulus}) + let x_cube := mulmod(x_square, x, {base_modulus}) + let x_cube_plus_3 := addmod(x_cube, 3, {base_modulus}) + let y_square_eq_x_cube_plus_3:bool := eq(x_cube_plus_3, y_square) + valid := and(y_square_eq_x_cube_plus_3, valid) + }} + }} + {} + }} + }} + }}", + self.runtime + ) + } + + pub fn runtime_append(&mut self, mut code: String) { + code.push('\n'); + self.runtime.push_str(&code); + } +} diff --git a/snark-verifier/src/loader/evm/loader.rs b/snark-verifier/src/loader/evm/loader.rs new file mode 100644 index 00000000..db15c8d7 --- /dev/null +++ b/snark-verifier/src/loader/evm/loader.rs @@ -0,0 +1,910 @@ +use crate::{ + loader::{ + evm::{ + code::{Precompiled, YulCode}, + fe_to_u256, modulus, u256_to_fe, + }, + EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader, + }, + util::{ + arithmetic::{CurveAffine, FieldOps, PrimeField}, + Itertools, + }, + Error, +}; +use ethereum_types::{U256, U512}; +use hex; +use std::{ + cell::RefCell, + collections::HashMap, + fmt::{self, Debug}, + iter, + ops::{Add, AddAssign, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign}, + rc::Rc, +}; + +#[derive(Clone, Debug)] +pub enum Value { + Constant(T), + Memory(usize), + Negated(Box>), + Sum(Box>, Box>), + Product(Box>, Box>), +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.identifier() == other.identifier() + } +} + +impl Value { + fn identifier(&self) -> String { + match &self { + Value::Constant(_) | Value::Memory(_) => format!("{self:?}"), + Value::Negated(value) => format!("-({value:?})"), + Value::Sum(lhs, rhs) => format!("({lhs:?} + {rhs:?})"), + Value::Product(lhs, rhs) => format!("({lhs:?} * {rhs:?})"), + } + } +} + +#[derive(Clone, Debug)] +pub struct EvmLoader { + base_modulus: U256, + scalar_modulus: U256, + code: RefCell, + ptr: RefCell, + cache: RefCell>, + #[cfg(test)] + gas_metering_ids: RefCell>, +} + +fn hex_encode_u256(value: &U256) -> String { + let mut bytes = [0; 32]; + value.to_big_endian(&mut bytes); + format!("0x{}", hex::encode(bytes)) +} + +impl EvmLoader { + pub fn new() -> Rc + where + Base: PrimeField, + Scalar: PrimeField, + { + let base_modulus = modulus::(); + let scalar_modulus = modulus::(); + let code = YulCode::new(); + + Rc::new(Self { + base_modulus, + scalar_modulus, + code: RefCell::new(code), + ptr: Default::default(), + cache: Default::default(), + #[cfg(test)] + gas_metering_ids: RefCell::new(Vec::new()), + }) + } + + pub fn yul_code(self: &Rc) -> String { + let code = " + if not(success) { revert(0, 0) } + return(0, 0)" + .to_string(); + self.code.borrow_mut().runtime_append(code); + self.code.borrow().code( + hex_encode_u256(&self.base_modulus), + hex_encode_u256(&self.scalar_modulus), + ) + } + + pub fn allocate(self: &Rc, size: usize) -> usize { + let ptr = *self.ptr.borrow(); + *self.ptr.borrow_mut() += size; + ptr + } + + pub(crate) fn ptr(&self) -> usize { + *self.ptr.borrow() + } + + pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { + self.code.borrow_mut() + } + + fn push(self: &Rc, scalar: &Scalar) -> String { + match scalar.value.clone() { + Value::Constant(constant) => { + format!("{constant}") + } + Value::Memory(ptr) => { + format!("mload({ptr:#x})") + } + Value::Negated(value) => { + let v = self.push(&self.scalar(*value)); + format!("sub(f_q, {v})") + } + Value::Sum(lhs, rhs) => { + let lhs = self.push(&self.scalar(*lhs)); + let rhs = self.push(&self.scalar(*rhs)); + format!("addmod({lhs}, {rhs}, f_q)") + } + Value::Product(lhs, rhs) => { + let lhs = self.push(&self.scalar(*lhs)); + let rhs = self.push(&self.scalar(*rhs)); + format!("mulmod({lhs}, {rhs}, f_q)") + } + } + } + + pub fn calldataload_scalar(self: &Rc, offset: usize) -> Scalar { + let ptr = self.allocate(0x20); + let code = format!("mstore({ptr:#x}, mod(calldataload({offset:#x}), f_q))"); + self.code.borrow_mut().runtime_append(code); + self.scalar(Value::Memory(ptr)) + } + + pub fn calldataload_ec_point(self: &Rc, offset: usize) -> EcPoint { + let x_ptr = self.allocate(0x40); + let y_ptr = x_ptr + 0x20; + let x_cd_ptr = offset; + let y_cd_ptr = offset + 0x20; + let validate_code = self.validate_ec_point(); + let code = format!( + " + {{ + let x := calldataload({x_cd_ptr:#x}) + mstore({x_ptr:#x}, x) + let y := calldataload({y_cd_ptr:#x}) + mstore({y_ptr:#x}, y) + {validate_code} + }}" + ); + self.code.borrow_mut().runtime_append(code); + self.ec_point(Value::Memory(x_ptr)) + } + + pub fn ec_point_from_limbs( + self: &Rc, + x_limbs: [&Scalar; LIMBS], + y_limbs: [&Scalar; LIMBS], + ) -> EcPoint { + let ptr = self.allocate(0x40); + let mut code = String::new(); + for (idx, limb) in x_limbs.iter().enumerate() { + let limb_i = self.push(limb); + let shift = idx * BITS; + if idx == 0 { + code.push_str(format!("let x := {limb_i}\n").as_str()); + } else { + code.push_str(format!("x := add(x, shl({shift}, {limb_i}))\n").as_str()); + } + } + let x_ptr = ptr; + code.push_str(format!("mstore({x_ptr}, x)\n").as_str()); + for (idx, limb) in y_limbs.iter().enumerate() { + let limb_i = self.push(limb); + let shift = idx * BITS; + if idx == 0 { + code.push_str(format!("let y := {limb_i}\n").as_str()); + } else { + code.push_str(format!("y := add(y, shl({shift}, {limb_i}))\n").as_str()); + } + } + let y_ptr = ptr + 0x20; + code.push_str(format!("mstore({y_ptr}, y)\n").as_str()); + let validate_code = self.validate_ec_point(); + let code = format!( + "{{ + {code} + {validate_code} + }}" + ); + self.code.borrow_mut().runtime_append(code); + self.ec_point(Value::Memory(ptr)) + } + + fn validate_ec_point(self: &Rc) -> String { + "success := and(validate_ec_point(x, y), success)".to_string() + } + + pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { + let value = if matches!( + value, + Value::Constant(_) | Value::Memory(_) | Value::Negated(_) + ) { + value + } else { + let identifier = value.identifier(); + let some_ptr = self.cache.borrow().get(&identifier).cloned(); + let ptr = if let Some(ptr) = some_ptr { + ptr + } else { + let v = self.push(&Scalar { + loader: self.clone(), + value, + }); + let ptr = self.allocate(0x20); + self.code + .borrow_mut() + .runtime_append(format!("mstore({ptr:#x}, {v})")); + self.cache.borrow_mut().insert(identifier, ptr); + ptr + }; + Value::Memory(ptr) + }; + Scalar { + loader: self.clone(), + value, + } + } + + fn ec_point(self: &Rc, value: Value<(U256, U256)>) -> EcPoint { + EcPoint { + loader: self.clone(), + value, + } + } + + pub fn keccak256(self: &Rc, ptr: usize, len: usize) -> usize { + let hash_ptr = self.allocate(0x20); + let code = format!("mstore({hash_ptr:#x}, keccak256({ptr:#x}, {len}))"); + self.code.borrow_mut().runtime_append(code); + hash_ptr + } + + pub fn copy_scalar(self: &Rc, scalar: &Scalar, ptr: usize) { + let scalar = self.push(scalar); + self.code + .borrow_mut() + .runtime_append(format!("mstore({ptr:#x}, {scalar})")); + } + + pub fn dup_scalar(self: &Rc, scalar: &Scalar) -> Scalar { + let ptr = self.allocate(0x20); + self.copy_scalar(scalar, ptr); + self.scalar(Value::Memory(ptr)) + } + + pub fn dup_ec_point(self: &Rc, value: &EcPoint) -> EcPoint { + let ptr = self.allocate(0x40); + match value.value { + Value::Constant((x, y)) => { + let x_ptr = ptr; + let y_ptr = ptr + 0x20; + let x = hex_encode_u256(&x); + let y = hex_encode_u256(&y); + let code = format!( + "mstore({x_ptr:#x}, {x}) + mstore({y_ptr:#x}, {y})" + ); + self.code.borrow_mut().runtime_append(code); + } + Value::Memory(src_ptr) => { + let x_ptr = ptr; + let y_ptr = ptr + 0x20; + let src_x = src_ptr; + let src_y = src_ptr + 0x20; + let code = format!( + "mstore({x_ptr:#x}, mload({src_x:#x})) + mstore({y_ptr:#x}, mload({src_y:#x}))" + ); + self.code.borrow_mut().runtime_append(code); + } + Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => { + unreachable!() + } + } + self.ec_point(Value::Memory(ptr)) + } + + fn staticcall(self: &Rc, precompile: Precompiled, cd_ptr: usize, rd_ptr: usize) { + let (cd_len, rd_len) = match precompile { + Precompiled::BigModExp => (0xc0, 0x20), + Precompiled::Bn254Add => (0x80, 0x40), + Precompiled::Bn254ScalarMul => (0x60, 0x40), + Precompiled::Bn254Pairing => (0x180, 0x20), + }; + let a = precompile as usize; + let code = format!("success := and(eq(staticcall(gas(), {a:#x}, {cd_ptr:#x}, {cd_len:#x}, {rd_ptr:#x}, {rd_len:#x}), 1), success)"); + self.code.borrow_mut().runtime_append(code); + } + + fn invert(self: &Rc, scalar: &Scalar) -> Scalar { + let rd_ptr = self.allocate(0x20); + let [cd_ptr, ..] = [ + &self.scalar(Value::Constant(0x20.into())), + &self.scalar(Value::Constant(0x20.into())), + &self.scalar(Value::Constant(0x20.into())), + scalar, + &self.scalar(Value::Constant(self.scalar_modulus - 2)), + &self.scalar(Value::Constant(self.scalar_modulus)), + ] + .map(|value| self.dup_scalar(value).ptr()); + self.staticcall(Precompiled::BigModExp, cd_ptr, rd_ptr); + self.scalar(Value::Memory(rd_ptr)) + } + + fn ec_point_add(self: &Rc, lhs: &EcPoint, rhs: &EcPoint) -> EcPoint { + let rd_ptr = self.dup_ec_point(lhs).ptr(); + self.dup_ec_point(rhs); + self.staticcall(Precompiled::Bn254Add, rd_ptr, rd_ptr); + self.ec_point(Value::Memory(rd_ptr)) + } + + fn ec_point_scalar_mul(self: &Rc, ec_point: &EcPoint, scalar: &Scalar) -> EcPoint { + let rd_ptr = self.dup_ec_point(ec_point).ptr(); + self.dup_scalar(scalar); + self.staticcall(Precompiled::Bn254ScalarMul, rd_ptr, rd_ptr); + self.ec_point(Value::Memory(rd_ptr)) + } + + pub fn pairing( + self: &Rc, + lhs: &EcPoint, + g2: (U256, U256, U256, U256), + rhs: &EcPoint, + minus_s_g2: (U256, U256, U256, U256), + ) { + let rd_ptr = self.dup_ec_point(lhs).ptr(); + self.allocate(0x80); + let g2_0 = hex_encode_u256(&g2.0); + let g2_0_ptr = rd_ptr + 0x40; + let g2_1 = hex_encode_u256(&g2.1); + let g2_1_ptr = rd_ptr + 0x60; + let g2_2 = hex_encode_u256(&g2.2); + let g2_2_ptr = rd_ptr + 0x80; + let g2_3 = hex_encode_u256(&g2.3); + let g2_3_ptr = rd_ptr + 0xa0; + let code = format!( + "mstore({g2_0_ptr:#x}, {g2_0}) + mstore({g2_1_ptr:#x}, {g2_1}) + mstore({g2_2_ptr:#x}, {g2_2}) + mstore({g2_3_ptr:#x}, {g2_3})" + ); + self.code.borrow_mut().runtime_append(code); + self.dup_ec_point(rhs); + self.allocate(0x80); + let minus_s_g2_0 = hex_encode_u256(&minus_s_g2.0); + let minus_s_g2_0_ptr = rd_ptr + 0x100; + let minus_s_g2_1 = hex_encode_u256(&minus_s_g2.1); + let minus_s_g2_1_ptr = rd_ptr + 0x120; + let minus_s_g2_2 = hex_encode_u256(&minus_s_g2.2); + let minus_s_g2_2_ptr = rd_ptr + 0x140; + let minus_s_g2_3 = hex_encode_u256(&minus_s_g2.3); + let minus_s_g2_3_ptr = rd_ptr + 0x160; + let code = format!( + "mstore({minus_s_g2_0_ptr:#x}, {minus_s_g2_0}) + mstore({minus_s_g2_1_ptr:#x}, {minus_s_g2_1}) + mstore({minus_s_g2_2_ptr:#x}, {minus_s_g2_2}) + mstore({minus_s_g2_3_ptr:#x}, {minus_s_g2_3})" + ); + self.code.borrow_mut().runtime_append(code); + self.staticcall(Precompiled::Bn254Pairing, rd_ptr, rd_ptr); + let code = format!("success := and(eq(mload({rd_ptr:#x}), 1), success)"); + self.code.borrow_mut().runtime_append(code); + } + + fn add(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { + if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { + let out = (U512::from(lhs) + U512::from(rhs)) % U512::from(self.scalar_modulus); + return self.scalar(Value::Constant(out.try_into().unwrap())); + } + + self.scalar(Value::Sum( + Box::new(lhs.value.clone()), + Box::new(rhs.value.clone()), + )) + } + + fn sub(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { + if rhs.is_const() { + return self.add(lhs, &self.neg(rhs)); + } + + self.scalar(Value::Sum( + Box::new(lhs.value.clone()), + Box::new(Value::Negated(Box::new(rhs.value.clone()))), + )) + } + + fn mul(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { + if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { + let out = (U512::from(lhs) * U512::from(rhs)) % U512::from(self.scalar_modulus); + return self.scalar(Value::Constant(out.try_into().unwrap())); + } + + self.scalar(Value::Product( + Box::new(lhs.value.clone()), + Box::new(rhs.value.clone()), + )) + } + + fn neg(self: &Rc, scalar: &Scalar) -> Scalar { + if let Value::Constant(constant) = scalar.value { + return self.scalar(Value::Constant(self.scalar_modulus - constant)); + } + + self.scalar(Value::Negated(Box::new(scalar.value.clone()))) + } +} + +#[cfg(test)] +impl EvmLoader { + fn start_gas_metering(self: &Rc, identifier: &str) { + self.gas_metering_ids + .borrow_mut() + .push(identifier.to_string()); + let code = format!("let {identifier} := gas()"); + self.code.borrow_mut().runtime_append(code); + } + + fn end_gas_metering(self: &Rc) { + let code = format!( + "log1(0, 0, sub({}, gas()))", + self.gas_metering_ids.borrow().last().unwrap() + ); + self.code.borrow_mut().runtime_append(code); + } + + pub fn print_gas_metering(self: &Rc, costs: Vec) { + for (identifier, cost) in self.gas_metering_ids.borrow().iter().zip(costs) { + println!("{}: {}", identifier, cost); + } + } +} + +#[derive(Clone)] +pub struct EcPoint { + loader: Rc, + value: Value<(U256, U256)>, +} + +impl EcPoint { + pub(crate) fn loader(&self) -> &Rc { + &self.loader + } + + pub(crate) fn value(&self) -> Value<(U256, U256)> { + self.value.clone() + } + + pub(crate) fn ptr(&self) -> usize { + match self.value { + Value::Memory(ptr) => ptr, + _ => unreachable!(), + } + } +} + +impl Debug for EcPoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EcPoint") + .field("value", &self.value) + .finish() + } +} + +impl PartialEq for EcPoint { + fn eq(&self, other: &Self) -> bool { + self.value == other.value + } +} + +impl LoadedEcPoint for EcPoint +where + C: CurveAffine, + C::ScalarExt: PrimeField, +{ + type Loader = Rc; + + fn loader(&self) -> &Rc { + &self.loader + } +} + +#[derive(Clone)] +pub struct Scalar { + loader: Rc, + value: Value, +} + +impl Scalar { + pub(crate) fn loader(&self) -> &Rc { + &self.loader + } + + pub(crate) fn value(&self) -> Value { + self.value.clone() + } + + pub(crate) fn is_const(&self) -> bool { + matches!(self.value, Value::Constant(_)) + } + + pub(crate) fn ptr(&self) -> usize { + match self.value { + Value::Memory(ptr) => ptr, + _ => *self + .loader + .cache + .borrow() + .get(&self.value.identifier()) + .unwrap(), + } + } +} + +impl Debug for Scalar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scalar") + .field("value", &self.value) + .finish() + } +} + +impl Add for Scalar { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + self.loader.add(&self, &rhs) + } +} + +impl Sub for Scalar { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + self.loader.sub(&self, &rhs) + } +} + +impl Mul for Scalar { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + self.loader.mul(&self, &rhs) + } +} + +impl Neg for Scalar { + type Output = Self; + + fn neg(self) -> Self { + self.loader.neg(&self) + } +} + +impl<'a> Add<&'a Self> for Scalar { + type Output = Self; + + fn add(self, rhs: &'a Self) -> Self { + self.loader.add(&self, rhs) + } +} + +impl<'a> Sub<&'a Self> for Scalar { + type Output = Self; + + fn sub(self, rhs: &'a Self) -> Self { + self.loader.sub(&self, rhs) + } +} + +impl<'a> Mul<&'a Self> for Scalar { + type Output = Self; + + fn mul(self, rhs: &'a Self) -> Self { + self.loader.mul(&self, rhs) + } +} + +impl AddAssign for Scalar { + fn add_assign(&mut self, rhs: Self) { + *self = self.loader.add(self, &rhs); + } +} + +impl SubAssign for Scalar { + fn sub_assign(&mut self, rhs: Self) { + *self = self.loader.sub(self, &rhs); + } +} + +impl MulAssign for Scalar { + fn mul_assign(&mut self, rhs: Self) { + *self = self.loader.mul(self, &rhs); + } +} + +impl<'a> AddAssign<&'a Self> for Scalar { + fn add_assign(&mut self, rhs: &'a Self) { + *self = self.loader.add(self, rhs); + } +} + +impl<'a> SubAssign<&'a Self> for Scalar { + fn sub_assign(&mut self, rhs: &'a Self) { + *self = self.loader.sub(self, rhs); + } +} + +impl<'a> MulAssign<&'a Self> for Scalar { + fn mul_assign(&mut self, rhs: &'a Self) { + *self = self.loader.mul(self, rhs); + } +} + +impl FieldOps for Scalar { + fn invert(&self) -> Option { + Some(self.loader.invert(self)) + } +} + +impl PartialEq for Scalar { + fn eq(&self, other: &Self) -> bool { + self.value == other.value + } +} + +impl> LoadedScalar for Scalar { + type Loader = Rc; + + fn loader(&self) -> &Self::Loader { + &self.loader + } +} + +impl EcPointLoader for Rc +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + type LoadedEcPoint = EcPoint; + + fn ec_point_load_const(&self, value: &C) -> EcPoint { + let coordinates = value.coordinates().unwrap(); + let [x, y] = [coordinates.x(), coordinates.y()] + .map(|coordinate| U256::from_little_endian(coordinate.to_repr().as_ref())); + self.ec_point(Value::Constant((x, y))) + } + + fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> { + unimplemented!() + } + + fn multi_scalar_multiplication( + pairs: &[(&>::LoadedScalar, &EcPoint)], + ) -> EcPoint { + pairs + .iter() + .cloned() + .map(|(scalar, ec_point)| match scalar.value { + Value::Constant(constant) if U256::one() == constant => ec_point.clone(), + _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar), + }) + .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) + .unwrap() + } +} + +impl> ScalarLoader for Rc { + type LoadedScalar = Scalar; + + fn load_const(&self, value: &F) -> Scalar { + self.scalar(Value::Constant(fe_to_u256(*value))) + } + + fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) -> Result<(), Error> { + unimplemented!() + } + + fn sum_with_coeff_and_const(&self, values: &[(F, &Scalar)], constant: F) -> Scalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let push_addend = |(coeff, value): &(F, &Scalar)| { + assert_ne!(*coeff, F::zero()); + match (*coeff == F::one(), &value.value) { + (true, _) => self.push(value), + (false, Value::Constant(value)) => self.push(&self.scalar(Value::Constant( + fe_to_u256(*coeff * u256_to_fe::(*value)), + ))), + (false, _) => { + let value = self.push(value); + let coeff = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + format!("mulmod({value}, {coeff}, f_q)") + } + } + }; + + let mut values = values.iter(); + let initial_value = if constant == F::zero() { + push_addend(values.next().unwrap()) + } else { + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) + }; + + let mut code = format!("let result := {initial_value}\n"); + for value in values { + let v = push_addend(value); + let addend = format!("result := addmod({v}, result, f_q)\n"); + code.push_str(addend.as_str()); + } + + let ptr = self.allocate(0x20); + code.push_str(format!("mstore({ptr}, result)").as_str()); + self.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); + + self.scalar(Value::Memory(ptr)) + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(F, &Scalar, &Scalar)], + constant: F, + ) -> Scalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let push_addend = |(coeff, lhs, rhs): &(F, &Scalar, &Scalar)| { + assert_ne!(*coeff, F::zero()); + match (*coeff == F::one(), &lhs.value, &rhs.value) { + (_, Value::Constant(lhs), Value::Constant(rhs)) => { + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*lhs) * u256_to_fe::(*rhs), + )))) + } + (_, value @ Value::Memory(_), Value::Constant(constant)) + | (_, Value::Constant(constant), value @ Value::Memory(_)) => { + let v1 = self.push(&self.scalar(value.clone())); + let v2 = self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*constant), + )))); + format!("mulmod({v1}, {v2}, f_q)") + } + (true, _, _) => { + let rhs = self.push(rhs); + let lhs = self.push(lhs); + format!("mulmod({rhs}, {lhs}, f_q)") + } + (false, _, _) => { + let rhs = self.push(rhs); + let lhs = self.push(lhs); + let value = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + format!("mulmod({rhs}, mulmod({lhs}, {value}, f_q), f_q)") + } + } + }; + + let mut values = values.iter(); + let initial_value = if constant == F::zero() { + push_addend(values.next().unwrap()) + } else { + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) + }; + + let mut code = format!("let result := {initial_value}\n"); + for value in values { + let v = push_addend(value); + let addend = format!("result := addmod({v}, result, f_q)\n"); + code.push_str(addend.as_str()); + } + + let ptr = self.allocate(0x20); + code.push_str(format!("mstore({ptr}, result)").as_str()); + self.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); + + self.scalar(Value::Memory(ptr)) + } + + // batch_invert algorithm + // n := values.len() - 1 + // input : values[0], ..., values[n] + // output : values[0]^{-1}, ..., values[n]^{-1} + // 1. products[i] <- values[0] * ... * values[i], i = 1, ..., n + // 2. inv <- (products[n])^{-1} + // 3. v_n <- values[n] + // 4. values[n] <- products[n - 1] * inv (values[n]^{-1}) + // 5. inv <- v_n * inv + fn batch_invert<'a>(values: impl IntoIterator) { + let values = values.into_iter().collect_vec(); + let loader = &values.first().unwrap().loader; + let products = iter::once(values[0].clone()) + .chain( + iter::repeat_with(|| loader.allocate(0x20)) + .map(|ptr| loader.scalar(Value::Memory(ptr))) + .take(values.len() - 1), + ) + .collect_vec(); + + let initial_value = loader.push(products.first().unwrap()); + let mut code = format!("let prod := {initial_value}\n"); + for (_, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { + let v = loader.push(value); + let ptr = product.ptr(); + code.push_str( + format!( + " + prod := mulmod({v}, prod, f_q) + mstore({ptr:#x}, prod) + " + ) + .as_str(), + ); + } + loader.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); + + let inv = loader.push(&loader.invert(products.last().unwrap())); + + let mut code = format!( + " + let inv := {inv} + let v + " + ); + for (value, product) in values.iter().rev().zip( + products + .iter() + .rev() + .skip(1) + .map(Some) + .chain(iter::once(None)), + ) { + if let Some(product) = product { + let val_ptr = value.ptr(); + let prod_ptr = product.ptr(); + let v = loader.push(value); + code.push_str( + format!( + " + v := {v} + mstore({val_ptr}, mulmod(mload({prod_ptr:#x}), inv, f_q)) + inv := mulmod(v, inv, f_q) + " + ) + .as_str(), + ); + } else { + let ptr = value.ptr(); + code.push_str(format!("mstore({ptr:#x}, inv)\n").as_str()); + } + } + loader.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); + } +} + +impl Loader for Rc +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + #[cfg(test)] + fn start_cost_metering(&self, identifier: &str) { + self.start_gas_metering(identifier) + } + + #[cfg(test)] + fn end_cost_metering(&self) { + self.end_gas_metering() + } +} diff --git a/snark-verifier/src/loader/evm/test.rs b/snark-verifier/src/loader/evm/test.rs new file mode 100644 index 00000000..e6f3703e --- /dev/null +++ b/snark-verifier/src/loader/evm/test.rs @@ -0,0 +1,49 @@ +use crate::{ + loader::evm::{test::tui::Tui, util::ExecutorBuilder}, + util::Itertools, +}; +use ethereum_types::{Address, U256}; +use std::env::var_os; + +mod tui; + +fn debug() -> bool { + matches!( + var_os("DEBUG"), + Some(value) if value.to_str() == Some("1") + ) +} + +pub fn execute(deployment_code: Vec, calldata: Vec) -> (bool, u64, Vec) { + assert!( + deployment_code.len() <= 0x6000, + "Contract size {} exceeds the limit 24576", + deployment_code.len() + ); + + let debug = debug(); + let caller = Address::from_low_u64_be(0xfe); + + let mut evm = ExecutorBuilder::default() + .with_gas_limit(u64::MAX.into()) + .set_debugger(debug) + .build(); + + let contract = evm + .deploy(caller, deployment_code.into(), 0.into()) + .address + .unwrap(); + let result = evm.call_raw(caller, contract, calldata.into(), 0.into()); + + let costs = result + .logs + .into_iter() + .map(|log| U256::from_big_endian(log.topics[0].as_bytes()).as_u64()) + .collect_vec(); + + if debug { + Tui::new(result.debug.unwrap().flatten(0), 0).start(); + } + + (!result.reverted, result.gas_used, costs) +} diff --git a/src/loader/evm/test/tui.rs b/snark-verifier/src/loader/evm/test/tui.rs similarity index 95% rename from src/loader/evm/test/tui.rs rename to snark-verifier/src/loader/evm/test/tui.rs index fcaef36c..c0c4d7f8 100644 --- a/src/loader/evm/test/tui.rs +++ b/snark-verifier/src/loader/evm/test/tui.rs @@ -1,5 +1,6 @@ //! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/ui/src/lib.rs +use crate::loader::evm::util::executor::{CallKind, DebugStep}; use crossterm::{ event::{ self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode, KeyEvent, KeyModifiers, @@ -8,11 +9,8 @@ use crossterm::{ execute, terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; -use foundry_evm::{ - debug::{DebugStep, Instruction}, - revm::opcode, - Address, CallKind, -}; +use ethereum_types::Address; +use revm::opcode; use std::{ cmp::{max, min}, io, @@ -90,7 +88,7 @@ impl Tui { self.terminal.clear().unwrap(); let mut draw_memory: DrawMemory = DrawMemory::default(); - let debug_call: Vec<(Address, Vec, CallKind)> = self.debug_arena.clone(); + let debug_call = &self.debug_arena; let mut opcode_list: Vec = debug_call[0] .1 .iter() @@ -207,7 +205,7 @@ impl Tui { } KeyCode::Char('s') => { for _ in 0..Tui::buffer_as_number(&self.key_buffer, 1) { - let remaining_ops = opcode_list[self.current_step..].to_vec().clone(); + let remaining_ops = &opcode_list[self.current_step..]; self.current_step += remaining_ops .iter() .enumerate() @@ -233,7 +231,7 @@ impl Tui { } KeyCode::Char('a') => { for _ in 0..Tui::buffer_as_number(&self.key_buffer, 1) { - let prev_ops = opcode_list[..self.current_step].to_vec().clone(); + let prev_ops = &opcode_list[..self.current_step]; self.current_step = prev_ops .iter() .enumerate() @@ -618,12 +616,7 @@ impl Tui { .borders(Borders::ALL); let min_len = usize::max(format!("{}", stack.len()).len(), 2); - let indices_affected = - if let Instruction::OpCode(op) = debug_steps[current_step].instruction { - stack_indices_affected(op) - } else { - vec![] - }; + let indices_affected = stack_indices_affected(debug_steps[current_step].instruction.0); let text: Vec = stack .iter() @@ -699,33 +692,29 @@ impl Tui { let mut word = None; let mut color = None; - if let Instruction::OpCode(op) = debug_steps[current_step].instruction { - let stack_len = debug_steps[current_step].stack.len(); - if stack_len > 0 { - let w = debug_steps[current_step].stack[stack_len - 1]; - match op { - opcode::MLOAD => { - word = Some(w.as_usize() / 32); - color = Some(Color::Cyan); - } - opcode::MSTORE => { - word = Some(w.as_usize() / 32); - color = Some(Color::Red); - } - _ => {} + let stack_len = debug_steps[current_step].stack.len(); + if stack_len > 0 { + let w = debug_steps[current_step].stack[stack_len - 1]; + match debug_steps[current_step].instruction.0 { + opcode::MLOAD => { + word = Some(w.as_usize() / 32); + color = Some(Color::Cyan); } + opcode::MSTORE => { + word = Some(w.as_usize() / 32); + color = Some(Color::Red); + } + _ => {} } } if current_step > 0 { let prev_step = current_step - 1; let stack_len = debug_steps[prev_step].stack.len(); - if let Instruction::OpCode(op) = debug_steps[prev_step].instruction { - if op == opcode::MSTORE { - let prev_top = debug_steps[prev_step].stack[stack_len - 1]; - word = Some(prev_top.as_usize() / 32); - color = Some(Color::Green); - } + if debug_steps[prev_step].instruction.0 == opcode::MSTORE { + let prev_top = debug_steps[prev_step].stack[stack_len - 1]; + word = Some(prev_top.as_usize() / 32); + color = Some(Color::Green); } } diff --git a/src/loader/evm/util.rs b/snark-verifier/src/loader/evm/util.rs similarity index 65% rename from src/loader/evm/util.rs rename to snark-verifier/src/loader/evm/util.rs index 0d9698bd..a7df5209 100644 --- a/src/loader/evm/util.rs +++ b/snark-verifier/src/loader/evm/util.rs @@ -3,7 +3,15 @@ use crate::{ util::{arithmetic::PrimeField, Itertools}, }; use ethereum_types::U256; -use std::iter; +use std::{ + io::Write, + iter, + process::{Command, Stdio}, +}; + +pub(crate) mod executor; + +pub use executor::ExecutorBuilder; pub struct MemoryChunk { ptr: usize, @@ -90,3 +98,37 @@ pub fn estimate_gas(cost: Cost) -> usize { intrinsic_cost + calldata_cost + ec_operation_cost } + +pub fn compile_yul(code: &str) -> Vec { + let mut cmd = Command::new("solc") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .arg("--bin") + .arg("--yul") + .arg("-") + .spawn() + .unwrap(); + cmd.stdin + .take() + .unwrap() + .write_all(code.as_bytes()) + .unwrap(); + let output = cmd.wait_with_output().unwrap().stdout; + let binary = *split_by_ascii_whitespace(&output).last().unwrap(); + hex::decode(binary).unwrap() +} + +fn split_by_ascii_whitespace(bytes: &[u8]) -> Vec<&[u8]> { + let mut split = Vec::new(); + let mut start = None; + for (idx, byte) in bytes.iter().enumerate() { + if byte.is_ascii_whitespace() { + if let Some(start) = start.take() { + split.push(&bytes[start..idx]); + } + } else if start.is_none() { + start = Some(idx); + } + } + split +} diff --git a/snark-verifier/src/loader/evm/util/executor.rs b/snark-verifier/src/loader/evm/util/executor.rs new file mode 100644 index 00000000..ec9695e0 --- /dev/null +++ b/snark-verifier/src/loader/evm/util/executor.rs @@ -0,0 +1,868 @@ +//! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/evm/src/executor/mod.rs + +use bytes::Bytes; +use ethereum_types::{Address, H256, U256, U64}; +use revm::{ + evm_inner, opcode, spec_opcode_gas, Account, BlockEnv, CallInputs, CallScheme, CreateInputs, + CreateScheme, Database, DatabaseCommit, EVMData, Env, ExecutionResult, Gas, GasInspector, + InMemoryDB, Inspector, Interpreter, Memory, OpCode, Return, TransactOut, TransactTo, TxEnv, +}; +use sha3::{Digest, Keccak256}; +use std::{cell::RefCell, collections::HashMap, fmt::Display, rc::Rc}; + +macro_rules! return_ok { + () => { + Return::Continue | Return::Stop | Return::Return | Return::SelfDestruct + }; +} + +fn keccak256(data: impl AsRef<[u8]>) -> [u8; 32] { + Keccak256::digest(data.as_ref()).into() +} + +fn get_contract_address(sender: impl Into
, nonce: impl Into) -> Address { + let mut stream = rlp::RlpStream::new(); + stream.begin_list(2); + stream.append(&sender.into()); + stream.append(&nonce.into()); + + let hash = keccak256(&stream.out()); + + let mut bytes = [0u8; 20]; + bytes.copy_from_slice(&hash[12..]); + Address::from(bytes) +} + +fn get_create2_address( + from: impl Into
, + salt: [u8; 32], + init_code: impl Into, +) -> Address { + get_create2_address_from_hash(from, salt, keccak256(init_code.into().as_ref()).to_vec()) +} + +fn get_create2_address_from_hash( + from: impl Into
, + salt: [u8; 32], + init_code_hash: impl Into, +) -> Address { + let bytes = [ + &[0xff], + from.into().as_bytes(), + salt.as_slice(), + init_code_hash.into().as_ref(), + ] + .concat(); + + let hash = keccak256(&bytes); + + let mut bytes = [0u8; 20]; + bytes.copy_from_slice(&hash[12..]); + Address::from(bytes) +} + +fn get_create_address(call: &CreateInputs, nonce: u64) -> Address { + match call.scheme { + CreateScheme::Create => get_contract_address(call.caller, nonce), + CreateScheme::Create2 { salt } => { + let mut buffer: [u8; 4 * 8] = [0; 4 * 8]; + salt.to_big_endian(&mut buffer); + get_create2_address(call.caller, buffer, call.init_code.clone()) + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Log { + pub address: Address, + pub topics: Vec, + pub data: Bytes, + pub block_hash: Option, + pub block_number: Option, + pub transaction_hash: Option, + pub transaction_index: Option, + pub log_index: Option, + pub transaction_log_index: Option, + pub log_type: Option, + pub removed: Option, +} + +#[derive(Clone, Debug, Default)] +struct LogCollector { + logs: Vec, +} + +impl Inspector for LogCollector { + fn log(&mut self, _: &mut EVMData<'_, DB>, address: &Address, topics: &[H256], data: &Bytes) { + self.logs.push(Log { + address: *address, + topics: topics.to_vec(), + data: data.clone(), + ..Default::default() + }); + } + + fn call( + &mut self, + _: &mut EVMData<'_, DB>, + call: &mut CallInputs, + _: bool, + ) -> (Return, Gas, Bytes) { + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } +} + +#[derive(Clone, Debug, Copy)] +pub enum CallKind { + Call, + StaticCall, + CallCode, + DelegateCall, + Create, + Create2, +} + +impl Default for CallKind { + fn default() -> Self { + CallKind::Call + } +} + +impl From for CallKind { + fn from(scheme: CallScheme) -> Self { + match scheme { + CallScheme::Call => CallKind::Call, + CallScheme::StaticCall => CallKind::StaticCall, + CallScheme::CallCode => CallKind::CallCode, + CallScheme::DelegateCall => CallKind::DelegateCall, + } + } +} + +impl From for CallKind { + fn from(create: CreateScheme) -> Self { + match create { + CreateScheme::Create => CallKind::Create, + CreateScheme::Create2 { .. } => CallKind::Create2, + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct DebugArena { + pub arena: Vec, +} + +impl DebugArena { + fn push_node(&mut self, mut new_node: DebugNode) -> usize { + fn recursively_push( + arena: &mut Vec, + entry: usize, + mut new_node: DebugNode, + ) -> usize { + match new_node.depth { + _ if arena[entry].depth == new_node.depth - 1 => { + let id = arena.len(); + new_node.location = arena[entry].children.len(); + new_node.parent = Some(entry); + arena[entry].children.push(id); + arena.push(new_node); + id + } + _ => { + let child = *arena[entry].children.last().unwrap(); + recursively_push(arena, child, new_node) + } + } + } + + if self.arena.is_empty() { + self.arena.push(new_node); + 0 + } else if new_node.depth == 0 { + let id = self.arena.len(); + new_node.location = self.arena[0].children.len(); + new_node.parent = Some(0); + self.arena[0].children.push(id); + self.arena.push(new_node); + id + } else { + recursively_push(&mut self.arena, 0, new_node) + } + } + + #[cfg(test)] + pub fn flatten(&self, entry: usize) -> Vec<(Address, Vec, CallKind)> { + let node = &self.arena[entry]; + + let mut flattened = vec![]; + if !node.steps.is_empty() { + flattened.push((node.address, node.steps.clone(), node.kind)); + } + flattened.extend(node.children.iter().flat_map(|child| self.flatten(*child))); + + flattened + } +} + +#[derive(Clone, Debug, Default)] +pub struct DebugNode { + pub parent: Option, + pub children: Vec, + pub location: usize, + pub address: Address, + pub kind: CallKind, + pub depth: usize, + pub steps: Vec, +} + +#[derive(Clone, Debug)] +pub struct DebugStep { + pub stack: Vec, + pub memory: Memory, + pub instruction: Instruction, + pub push_bytes: Option>, + pub pc: usize, + pub total_gas_used: u64, +} + +impl Default for DebugStep { + fn default() -> Self { + Self { + stack: vec![], + memory: Memory::new(), + instruction: Instruction(revm::opcode::INVALID), + push_bytes: None, + pc: 0, + total_gas_used: 0, + } + } +} + +impl DebugStep { + #[cfg(test)] + pub fn pretty_opcode(&self) -> String { + if let Some(push_bytes) = &self.push_bytes { + format!("{}(0x{})", self.instruction, hex::encode(push_bytes)) + } else { + self.instruction.to_string() + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct Instruction(pub u8); + +impl From for Instruction { + fn from(op: u8) -> Instruction { + Instruction(op) + } +} + +impl Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + OpCode::try_from_u8(self.0).map_or_else( + || format!("UNDEFINED(0x{:02x})", self.0), + |opcode| opcode.as_str().to_string(), + ) + ) + } +} + +#[derive(Clone, Debug)] +struct Debugger { + arena: DebugArena, + head: usize, + context: Address, + gas_inspector: Rc>, +} + +impl Debugger { + fn new(gas_inspector: Rc>) -> Self { + Self { + arena: Default::default(), + head: Default::default(), + context: Default::default(), + gas_inspector, + } + } + + fn enter(&mut self, depth: usize, address: Address, kind: CallKind) { + self.context = address; + self.head = self.arena.push_node(DebugNode { + depth, + address, + kind, + ..Default::default() + }); + } + + fn exit(&mut self) { + if let Some(parent_id) = self.arena.arena[self.head].parent { + let DebugNode { + depth, + address, + kind, + .. + } = self.arena.arena[parent_id]; + self.context = address; + self.head = self.arena.push_node(DebugNode { + depth, + address, + kind, + ..Default::default() + }); + } + } +} + +impl Inspector for Debugger { + fn step( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + _is_static: bool, + ) -> Return { + let pc = interpreter.program_counter(); + let op = interpreter.contract.bytecode.bytecode()[pc]; + + let opcode_infos = spec_opcode_gas(data.env.cfg.spec_id); + let opcode_info = &opcode_infos[op as usize]; + + let push_size = if opcode_info.is_push() { + (op - opcode::PUSH1 + 1) as usize + } else { + 0 + }; + let push_bytes = match push_size { + 0 => None, + n => { + let start = pc + 1; + let end = start + n; + Some(interpreter.contract.bytecode.bytecode()[start..end].to_vec()) + } + }; + + let spent = interpreter.gas.limit() - self.gas_inspector.borrow().gas_remaining(); + let total_gas_used = spent - (interpreter.gas.refunded() as u64).min(spent / 5); + + self.arena.arena[self.head].steps.push(DebugStep { + pc, + stack: interpreter.stack().data().clone(), + memory: interpreter.memory.clone(), + instruction: Instruction(op), + push_bytes, + total_gas_used, + }); + + Return::Continue + } + + fn call( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CallInputs, + _: bool, + ) -> (Return, Gas, Bytes) { + self.enter( + data.journaled_state.depth() as usize, + call.context.code_address, + call.context.scheme.into(), + ); + + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } + + fn call_end( + &mut self, + _: &mut EVMData<'_, DB>, + _: &CallInputs, + gas: Gas, + status: Return, + retdata: Bytes, + _: bool, + ) -> (Return, Gas, Bytes) { + self.exit(); + + (status, gas, retdata) + } + + fn create( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CreateInputs, + ) -> (Return, Option
, Gas, Bytes) { + let nonce = data.journaled_state.account(call.caller).info.nonce; + self.enter( + data.journaled_state.depth() as usize, + get_create_address(call, nonce), + CallKind::Create, + ); + + ( + Return::Continue, + None, + Gas::new(call.gas_limit), + Bytes::new(), + ) + } + + fn create_end( + &mut self, + _: &mut EVMData<'_, DB>, + _: &CreateInputs, + status: Return, + address: Option
, + gas: Gas, + retdata: Bytes, + ) -> (Return, Option
, Gas, Bytes) { + self.exit(); + + (status, address, gas, retdata) + } +} + +#[macro_export] +macro_rules! call_inspectors { + ($id:ident, [ $($inspector:expr),+ ], $call:block) => { + $({ + if let Some($id) = $inspector { + $call; + } + })+ + } +} + +struct InspectorData { + logs: Vec, + debug: Option, +} + +#[derive(Default)] +struct InspectorStack { + gas: Option>>, + logs: Option, + debugger: Option, +} + +impl InspectorStack { + fn collect_inspector_states(self) -> InspectorData { + InspectorData { + logs: self.logs.map(|logs| logs.logs).unwrap_or_default(), + debug: self.debugger.map(|debugger| debugger.arena), + } + } +} + +impl Inspector for InspectorStack { + fn initialize_interp( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.initialize_interp(interpreter, data, is_static); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn step( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.step(interpreter, data, is_static); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn log( + &mut self, + evm_data: &mut EVMData<'_, DB>, + address: &Address, + topics: &[H256], + data: &Bytes, + ) { + call_inspectors!(inspector, [&mut self.logs], { + inspector.log(evm_data, address, topics, data); + }); + } + + fn step_end( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + status: Return, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.step_end(interpreter, data, is_static, status); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn call( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CallInputs, + is_static: bool, + ) -> (Return, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (status, gas, retdata) = inspector.call(data, call, is_static); + + if status != Return::Continue { + return (status, gas, retdata); + } + } + ); + + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } + + fn call_end( + &mut self, + data: &mut EVMData<'_, DB>, + call: &CallInputs, + remaining_gas: Gas, + status: Return, + retdata: Bytes, + is_static: bool, + ) -> (Return, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (new_status, new_gas, new_retdata) = inspector.call_end( + data, + call, + remaining_gas, + status, + retdata.clone(), + is_static, + ); + + if new_status != status || (new_status == Return::Revert && new_retdata != retdata) + { + return (new_status, new_gas, new_retdata); + } + } + ); + + (status, remaining_gas, retdata) + } + + fn create( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CreateInputs, + ) -> (Return, Option
, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (status, addr, gas, retdata) = inspector.create(data, call); + + if status != Return::Continue { + return (status, addr, gas, retdata); + } + } + ); + + ( + Return::Continue, + None, + Gas::new(call.gas_limit), + Bytes::new(), + ) + } + + fn create_end( + &mut self, + data: &mut EVMData<'_, DB>, + call: &CreateInputs, + status: Return, + address: Option
, + remaining_gas: Gas, + retdata: Bytes, + ) -> (Return, Option
, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (new_status, new_address, new_gas, new_retdata) = inspector.create_end( + data, + call, + status, + address, + remaining_gas, + retdata.clone(), + ); + + if new_status != status { + return (new_status, new_address, new_gas, new_retdata); + } + } + ); + + (status, address, remaining_gas, retdata) + } + + fn selfdestruct(&mut self) { + call_inspectors!(inspector, [&mut self.logs, &mut self.debugger], { + Inspector::::selfdestruct(inspector); + }); + } +} + +pub struct RawCallResult { + pub exit_reason: Return, + pub reverted: bool, + pub result: Bytes, + pub gas_used: u64, + pub gas_refunded: u64, + pub logs: Vec, + pub debug: Option, + pub state_changeset: Option>, + pub env: Env, + pub out: TransactOut, +} + +#[derive(Clone, Debug)] +pub struct DeployResult { + pub exit_reason: Return, + pub reverted: bool, + pub address: Option
, + pub gas_used: u64, + pub gas_refunded: u64, + pub logs: Vec, + pub debug: Option, + pub env: Env, +} + +#[derive(Debug, Default)] +pub struct ExecutorBuilder { + debugger: bool, + gas_limit: Option, +} + +impl ExecutorBuilder { + pub fn set_debugger(mut self, enable: bool) -> Self { + self.debugger = enable; + self + } + + pub fn with_gas_limit(mut self, gas_limit: U256) -> Self { + self.gas_limit = Some(gas_limit); + self + } + + pub fn build(self) -> Executor { + Executor::new(self.debugger, self.gas_limit.unwrap_or(U256::MAX)) + } +} + +#[derive(Clone, Debug)] +pub struct Executor { + db: InMemoryDB, + debugger: bool, + gas_limit: U256, +} + +impl Executor { + fn new(debugger: bool, gas_limit: U256) -> Self { + Executor { + db: InMemoryDB::default(), + debugger, + gas_limit, + } + } + + pub fn db_mut(&mut self) -> &mut InMemoryDB { + &mut self.db + } + + pub fn deploy(&mut self, from: Address, code: Bytes, value: U256) -> DeployResult { + let env = self.build_test_env(from, TransactTo::Create(CreateScheme::Create), code, value); + let result = self.call_raw_with_env(env); + self.commit(&result); + + let RawCallResult { + exit_reason, + out, + gas_used, + gas_refunded, + logs, + debug, + env, + .. + } = result; + + let address = match (exit_reason, out) { + (return_ok!(), TransactOut::Create(_, Some(address))) => Some(address), + _ => None, + }; + + DeployResult { + exit_reason, + reverted: !matches!(exit_reason, return_ok!()), + address, + gas_used, + gas_refunded, + logs, + debug, + env, + } + } + + pub fn call_raw( + &self, + from: Address, + to: Address, + calldata: Bytes, + value: U256, + ) -> RawCallResult { + let env = self.build_test_env(from, TransactTo::Call(to), calldata, value); + self.call_raw_with_env(env) + } + + fn call_raw_with_env(&self, mut env: Env) -> RawCallResult { + let mut inspector = self.inspector(); + let result = + evm_inner::<_, true>(&mut env, &mut self.db.clone(), &mut inspector).transact(); + let (exec_result, state_changeset) = result; + let ExecutionResult { + exit_reason, + gas_refunded, + gas_used, + out, + .. + } = exec_result; + + let result = match out { + TransactOut::Call(ref data) => data.to_owned(), + _ => Bytes::default(), + }; + let InspectorData { logs, debug } = inspector.collect_inspector_states(); + + RawCallResult { + exit_reason, + reverted: !matches!(exit_reason, return_ok!()), + result, + gas_used, + gas_refunded, + logs: logs.to_vec(), + debug, + state_changeset: Some(state_changeset.into_iter().collect()), + env, + out, + } + } + + fn commit(&mut self, result: &RawCallResult) { + if let Some(state_changeset) = result.state_changeset.as_ref() { + self.db + .commit(state_changeset.clone().into_iter().collect()); + } + } + + fn inspector(&self) -> InspectorStack { + let mut stack = InspectorStack { + logs: Some(LogCollector::default()), + ..Default::default() + }; + if self.debugger { + let gas_inspector = Rc::new(RefCell::new(GasInspector::default())); + stack.gas = Some(gas_inspector.clone()); + stack.debugger = Some(Debugger::new(gas_inspector)); + } + stack + } + + fn build_test_env( + &self, + caller: Address, + transact_to: TransactTo, + data: Bytes, + value: U256, + ) -> Env { + Env { + block: BlockEnv { + gas_limit: self.gas_limit, + ..BlockEnv::default() + }, + tx: TxEnv { + caller, + transact_to, + data, + value, + gas_limit: self.gas_limit.as_u64(), + ..TxEnv::default() + }, + ..Env::default() + } + } +} diff --git a/snark-verifier/src/loader/halo2.rs b/snark-verifier/src/loader/halo2.rs new file mode 100644 index 00000000..0e84d506 --- /dev/null +++ b/snark-verifier/src/loader/halo2.rs @@ -0,0 +1,67 @@ +use crate::halo2_proofs::circuit; +use crate::{util::arithmetic::CurveAffine, Protocol}; +use std::rc::Rc; + +pub(crate) mod loader; +mod shim; + +#[cfg(test)] +pub(crate) mod test; + +pub use loader::{EcPoint, Halo2Loader, Scalar}; +pub use shim::{Context, EccInstructions, IntegerInstructions}; +pub use util::Valuetools; + +pub use halo2_ecc; + +mod util { + use crate::halo2_proofs::circuit::Value; + + pub trait Valuetools: Iterator> { + fn fold_zipped(self, init: B, mut f: F) -> Value + where + Self: Sized, + F: FnMut(B, V) -> B, + { + self.fold(Value::known(init), |acc, value| { + acc.zip(value).map(|(acc, value)| f(acc, value)) + }) + } + } + + impl>> Valuetools for I {} +} + +impl Protocol +where + C: CurveAffine, +{ + pub fn loaded_preprocessed_as_witness<'a, EccChip: EccInstructions<'a, C>>( + &self, + loader: &Rc>, + ) -> Protocol>> { + let preprocessed = self + .preprocessed + .iter() + .map(|preprocessed| loader.assign_ec_point(circuit::Value::known(*preprocessed))) + .collect(); + let transcript_initial_state = + self.transcript_initial_state.as_ref().map(|transcript_initial_state| { + loader.assign_scalar(circuit::Value::known(*transcript_initial_state)) + }); + Protocol { + domain: self.domain.clone(), + preprocessed, + num_instance: self.num_instance.clone(), + num_witness: self.num_witness.clone(), + num_challenge: self.num_challenge.clone(), + evaluations: self.evaluations.clone(), + queries: self.queries.clone(), + quotient: self.quotient.clone(), + transcript_initial_state, + instance_committing_key: self.instance_committing_key.clone(), + linearization: self.linearization, + accumulator_indices: self.accumulator_indices.clone(), + } + } +} diff --git a/snark-verifier/src/loader/halo2/loader.rs b/snark-verifier/src/loader/halo2/loader.rs new file mode 100644 index 00000000..845c6c92 --- /dev/null +++ b/snark-verifier/src/loader/halo2/loader.rs @@ -0,0 +1,678 @@ +use crate::halo2_proofs::circuit; +use crate::{ + loader::{ + halo2::shim::{EccInstructions, IntegerInstructions}, + EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader, + }, + util::{ + arithmetic::{CurveAffine, Field, FieldOps}, + Itertools, + }, +}; +use std::{ + cell::{Ref, RefCell, RefMut}, + fmt::{self, Debug}, + marker::PhantomData, + ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign}, + rc::Rc, +}; + +#[derive(Debug)] +pub struct Halo2Loader<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + ecc_chip: RefCell, + ctx: RefCell, + num_scalar: RefCell, + num_ec_point: RefCell, + _marker: PhantomData, + #[cfg(test)] + row_meterings: RefCell>, +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { + pub fn new(ecc_chip: EccChip, ctx: EccChip::Context) -> Rc { + Rc::new(Self { + ecc_chip: RefCell::new(ecc_chip), + ctx: RefCell::new(ctx), + num_scalar: RefCell::default(), + num_ec_point: RefCell::default(), + #[cfg(test)] + row_meterings: RefCell::default(), + _marker: PhantomData, + }) + } + + pub fn into_ctx(self) -> EccChip::Context { + self.ctx.into_inner() + } + + pub fn ecc_chip(&self) -> Ref { + self.ecc_chip.borrow() + } + + pub fn scalar_chip(&self) -> Ref { + Ref::map(self.ecc_chip(), |ecc_chip| ecc_chip.scalar_chip()) + } + + pub fn ctx(&self) -> Ref { + self.ctx.borrow() + } + + pub fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { + self.ctx.borrow_mut() + } + + fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> EccChip::AssignedScalar { + self.scalar_chip().assign_constant(&mut self.ctx_mut(), constant).unwrap() + } + + pub fn assign_scalar( + self: &Rc, + scalar: circuit::Value, + ) -> Scalar<'a, C, EccChip> { + let assigned = self.scalar_chip().assign_integer(&mut self.ctx_mut(), scalar).unwrap(); + self.scalar_from_assigned(assigned) + } + + pub fn scalar_from_assigned( + self: &Rc, + assigned: EccChip::AssignedScalar, + ) -> Scalar<'a, C, EccChip> { + self.scalar(Value::Assigned(assigned)) + } + + fn scalar( + self: &Rc, + value: Value, + ) -> Scalar<'a, C, EccChip> { + let index = *self.num_scalar.borrow(); + *self.num_scalar.borrow_mut() += 1; + Scalar { loader: self.clone(), index, value: value.into() } + } + + fn assign_const_ec_point(self: &Rc, constant: C) -> EccChip::AssignedEcPoint { + self.ecc_chip().assign_constant(&mut self.ctx_mut(), constant).unwrap() + } + + pub fn assign_ec_point( + self: &Rc, + ec_point: circuit::Value, + ) -> EcPoint<'a, C, EccChip> { + let assigned = self.ecc_chip().assign_point(&mut self.ctx_mut(), ec_point).unwrap(); + self.ec_point_from_assigned(assigned) + } + + pub fn ec_point_from_assigned( + self: &Rc, + assigned: EccChip::AssignedEcPoint, + ) -> EcPoint<'a, C, EccChip> { + self.ec_point(Value::Assigned(assigned)) + } + + fn ec_point( + self: &Rc, + value: Value, + ) -> EcPoint<'a, C, EccChip> { + let index = *self.num_ec_point.borrow(); + *self.num_ec_point.borrow_mut() += 1; + EcPoint { loader: self.clone(), index, value: value.into() } + } + + fn add( + self: &Rc, + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { + let output = match (lhs.value().deref(), rhs.value().deref()) { + (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs + rhs), + (Value::Assigned(assigned), Value::Constant(constant)) + | (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( + &mut self.ctx_mut(), + &[(C::Scalar::one(), assigned)], + *constant, + ) + .map(Value::Assigned) + .unwrap(), + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .scalar_chip() + .sum_with_coeff_and_const( + &mut self.ctx_mut(), + &[(C::Scalar::one(), lhs), (C::Scalar::one(), rhs)], + C::Scalar::zero(), + ) + .map(Value::Assigned) + .unwrap(), + }; + self.scalar(output) + } + + fn sub( + self: &Rc, + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { + let output = match (lhs.value().deref(), rhs.value().deref()) { + (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs - rhs), + (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( + &mut self.ctx_mut(), + &[(-C::Scalar::one(), assigned)], + *constant, + ) + .map(Value::Assigned) + .unwrap(), + (Value::Assigned(assigned), Value::Constant(constant)) => self + .scalar_chip() + .sum_with_coeff_and_const( + &mut self.ctx_mut(), + &[(C::Scalar::one(), assigned)], + -*constant, + ) + .map(Value::Assigned) + .unwrap(), + (Value::Assigned(lhs), Value::Assigned(rhs)) => { + IntegerInstructions::sub(self.scalar_chip().deref(), &mut self.ctx_mut(), lhs, rhs) + .map(Value::Assigned) + .unwrap() + } + }; + self.scalar(output) + } + + fn mul( + self: &Rc, + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { + let output = match (lhs.value().deref(), rhs.value().deref()) { + (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs * rhs), + (Value::Assigned(assigned), Value::Constant(constant)) + | (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( + &mut self.ctx_mut(), + &[(*constant, assigned)], + C::Scalar::zero(), + ) + .map(Value::Assigned) + .unwrap(), + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .scalar_chip() + .sum_products_with_coeff_and_const( + &mut self.ctx_mut(), + &[(C::Scalar::one(), lhs, rhs)], + C::Scalar::zero(), + ) + .map(Value::Assigned) + .unwrap(), + }; + self.scalar(output) + } + + fn neg(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { + let output = match scalar.value().deref() { + Value::Constant(constant) => Value::Constant(constant.neg()), + Value::Assigned(assigned) => { + IntegerInstructions::neg(self.scalar_chip().deref(), &mut self.ctx_mut(), assigned) + .map(Value::Assigned) + .unwrap() + } + }; + self.scalar(output) + } + + fn invert(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { + let output = match scalar.value().deref() { + Value::Constant(constant) => Value::Constant(Field::invert(constant).unwrap()), + Value::Assigned(assigned) => Value::Assigned( + IntegerInstructions::invert( + self.scalar_chip().deref(), + &mut self.ctx_mut(), + assigned, + ) + .unwrap(), + ), + }; + self.scalar(output) + } +} + +#[cfg(test)] +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { + fn start_row_metering(self: &Rc, identifier: &str) { + use crate::loader::halo2::shim::Context; + + self.row_meterings.borrow_mut().push((identifier.to_string(), self.ctx().offset())) + } + + fn end_row_metering(self: &Rc) { + use crate::loader::halo2::shim::Context; + + let mut row_meterings = self.row_meterings.borrow_mut(); + let (_, row) = row_meterings.last_mut().unwrap(); + *row = self.ctx().offset() - *row; + } + + pub fn print_row_metering(self: &Rc) { + for (identifier, cost) in self.row_meterings.borrow().iter() { + println!("{}: {}", identifier, cost); + } + } +} + +#[derive(Clone, Debug)] +pub enum Value { + Constant(T), + Assigned(L), +} + +impl Value { + fn maybe_const(&self) -> Option + where + T: Copy, + { + match self { + Value::Constant(constant) => Some(*constant), + _ => None, + } + } + + fn assigned(&self) -> &L { + match self { + Value::Assigned(assigned) => assigned, + _ => unreachable!(), + } + } +} + +#[derive(Clone)] +pub struct Scalar<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + loader: Rc>, + index: usize, + value: RefCell>, +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> { + pub fn loader(&self) -> &Rc> { + &self.loader + } + + pub fn into_assigned(self) -> EccChip::AssignedScalar { + match self.value.into_inner() { + Value::Constant(constant) => self.loader.assign_const_scalar(constant), + Value::Assigned(assigned) => assigned, + } + } + + pub fn assigned(&self) -> Ref { + if let Some(constant) = self.maybe_const() { + *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_scalar(constant)) + } + Ref::map(self.value.borrow(), Value::assigned) + } + + fn value(&self) -> Ref> { + self.value.borrow() + } + + fn maybe_const(&self) -> Option { + self.value().deref().maybe_const() + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for Scalar<'a, C, EccChip> { + fn eq(&self, other: &Self) -> bool { + self.index == other.index + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedScalar + for Scalar<'a, C, EccChip> +{ + type Loader = Rc>; + + fn loader(&self) -> &Self::Loader { + &self.loader + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for Scalar<'a, C, EccChip> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scalar").field("value", &self.value).finish() + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> FieldOps for Scalar<'a, C, EccChip> { + fn invert(&self) -> Option { + Some(self.loader.invert(self)) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add for Scalar<'a, C, EccChip> { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Halo2Loader::add(&self.loader, &self, &rhs) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub for Scalar<'a, C, EccChip> { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Halo2Loader::sub(&self.loader, &self, &rhs) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul for Scalar<'a, C, EccChip> { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Halo2Loader::mul(&self.loader, &self, &rhs) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Neg for Scalar<'a, C, EccChip> { + type Output = Self; + + fn neg(self) -> Self::Output { + Halo2Loader::neg(&self.loader, &self) + } +} + +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add<&'b Self> + for Scalar<'a, C, EccChip> +{ + type Output = Self; + + fn add(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::add(&self.loader, &self, rhs) + } +} + +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub<&'b Self> + for Scalar<'a, C, EccChip> +{ + type Output = Self; + + fn sub(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::sub(&self.loader, &self, rhs) + } +} + +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul<&'b Self> + for Scalar<'a, C, EccChip> +{ + type Output = Self; + + fn mul(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::mul(&self.loader, &self, rhs) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign for Scalar<'a, C, EccChip> { + fn add_assign(&mut self, rhs: Self) { + *self = Halo2Loader::add(&self.loader, self, &rhs) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign for Scalar<'a, C, EccChip> { + fn sub_assign(&mut self, rhs: Self) { + *self = Halo2Loader::sub(&self.loader, self, &rhs) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign for Scalar<'a, C, EccChip> { + fn mul_assign(&mut self, rhs: Self) { + *self = Halo2Loader::mul(&self.loader, self, &rhs) + } +} + +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign<&'b Self> + for Scalar<'a, C, EccChip> +{ + fn add_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::add(&self.loader, self, rhs) + } +} + +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign<&'b Self> + for Scalar<'a, C, EccChip> +{ + fn sub_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::sub(&self.loader, self, rhs) + } +} + +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self> + for Scalar<'a, C, EccChip> +{ + fn mul_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::mul(&self.loader, self, rhs) + } +} + +#[derive(Clone)] +pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + loader: Rc>, + index: usize, + value: RefCell>, +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { + pub fn into_assigned(self) -> EccChip::AssignedEcPoint { + match self.value.into_inner() { + Value::Constant(constant) => self.loader.assign_const_ec_point(constant), + Value::Assigned(assigned) => assigned, + } + } + + pub fn assigned(&self) -> Ref { + if let Some(constant) = self.maybe_const() { + *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_ec_point(constant)) + } + Ref::map(self.value.borrow(), Value::assigned) + } + + fn value(&self) -> Ref> { + self.value.borrow() + } + + fn maybe_const(&self) -> Option { + self.value().deref().maybe_const() + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for EcPoint<'a, C, EccChip> { + fn eq(&self, other: &Self) -> bool { + self.index == other.index + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedEcPoint + for EcPoint<'a, C, EccChip> +{ + type Loader = Rc>; + + fn loader(&self) -> &Self::Loader { + &self.loader + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for EcPoint<'a, C, EccChip> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EcPoint").field("index", &self.index).field("value", &self.value).finish() + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> ScalarLoader + for Rc> +{ + type LoadedScalar = Scalar<'a, C, EccChip>; + + fn load_const(&self, value: &C::Scalar) -> Scalar<'a, C, EccChip> { + self.scalar(Value::Constant(*value)) + } + + fn assert_eq( + &self, + annotation: &str, + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Result<(), crate::Error> { + self.scalar_chip() + .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + } + + fn sum_with_coeff_and_const( + &self, + values: &[(C::Scalar, &Scalar<'a, C, EccChip>)], + constant: C::Scalar, + ) -> Scalar<'a, C, EccChip> { + let values = values.iter().map(|(coeff, value)| (*coeff, value.assigned())).collect_vec(); + self.scalar(Value::Assigned( + self.scalar_chip() + .sum_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) + .unwrap(), + )) + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(C::Scalar, &Scalar<'a, C, EccChip>, &Scalar<'a, C, EccChip>)], + constant: C::Scalar, + ) -> Scalar<'a, C, EccChip> { + let values = values + .iter() + .map(|(coeff, lhs, rhs)| (*coeff, lhs.assigned(), rhs.assigned())) + .collect_vec(); + self.scalar(Value::Assigned( + self.scalar_chip() + .sum_products_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) + .unwrap(), + )) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader + for Rc> +{ + type LoadedEcPoint = EcPoint<'a, C, EccChip>; + + fn ec_point_load_const(&self, ec_point: &C) -> EcPoint<'a, C, EccChip> { + self.ec_point(Value::Constant(*ec_point)) + } + + fn ec_point_assert_eq( + &self, + annotation: &str, + lhs: &EcPoint<'a, C, EccChip>, + rhs: &EcPoint<'a, C, EccChip>, + ) -> Result<(), crate::Error> { + if let (Value::Constant(lhs), Value::Constant(rhs)) = + (lhs.value().deref(), rhs.value().deref()) + { + assert_eq!(lhs, rhs); + Ok(()) + } else { + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + self.ecc_chip() + .assert_equal(&mut self.ctx_mut(), lhs.deref(), rhs.deref()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + } + } + + fn multi_scalar_multiplication( + pairs: &[(&>::LoadedScalar, &EcPoint<'a, C, EccChip>)], + ) -> EcPoint<'a, C, EccChip> { + let loader = &pairs[0].0.loader; + + let (constant, fixed_base, variable_base_non_scaled, variable_base_scaled) = + pairs.iter().cloned().fold( + (C::identity(), Vec::new(), Vec::new(), Vec::new()), + |( + mut constant, + mut fixed_base, + mut variable_base_non_scaled, + mut variable_base_scaled, + ), + (scalar, base)| { + match (scalar.value().deref(), base.value().deref()) { + (Value::Constant(scalar), Value::Constant(base)) => { + constant = (*base * scalar + constant).into() + } + (Value::Assigned(_), Value::Constant(base)) => { + fixed_base.push((scalar, *base)) + } + (Value::Constant(scalar), Value::Assigned(_)) + if scalar.eq(&C::Scalar::one()) => + { + variable_base_non_scaled.push(base); + } + _ => variable_base_scaled.push((scalar, base)), + }; + (constant, fixed_base, variable_base_non_scaled, variable_base_scaled) + }, + ); + + let fixed_base_msm = (!fixed_base.is_empty()) + .then(|| { + let fixed_base = fixed_base + .into_iter() + .map(|(scalar, base)| (scalar.assigned(), base)) + .collect_vec(); + loader + .ecc_chip + .borrow_mut() + .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) + .unwrap() + }) + .map(RefCell::new); + let variable_base_msm = (!variable_base_scaled.is_empty()) + .then(|| { + let variable_base_scaled = variable_base_scaled + .into_iter() + .map(|(scalar, base)| (scalar.assigned(), base.assigned())) + .collect_vec(); + loader + .ecc_chip + .borrow_mut() + .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) + .unwrap() + }) + .map(RefCell::new); + let output = loader + .ecc_chip() + .sum_with_const( + &mut loader.ctx_mut(), + &variable_base_non_scaled + .into_iter() + .map(EcPoint::assigned) + .chain(fixed_base_msm.as_ref().map(RefCell::borrow)) + .chain(variable_base_msm.as_ref().map(RefCell::borrow)) + .collect_vec(), + constant, + ) + .unwrap(); + + loader.ec_point_from_assigned(output) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Loader + for Rc> +{ + #[cfg(test)] + fn start_cost_metering(&self, identifier: &str) { + self.start_row_metering(identifier) + } + + #[cfg(test)] + fn end_cost_metering(&self) { + self.end_row_metering() + } +} diff --git a/snark-verifier/src/loader/halo2/shim.rs b/snark-verifier/src/loader/halo2/shim.rs new file mode 100644 index 00000000..588e9482 --- /dev/null +++ b/snark-verifier/src/loader/halo2/shim.rs @@ -0,0 +1,703 @@ +use crate::halo2_proofs::{ + circuit::{Cell, Value}, + plonk::Error, +}; +use crate::util::arithmetic::{CurveAffine, FieldExt}; +use std::{fmt::Debug, ops::Deref}; + +pub trait Context: Debug { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error>; + + fn offset(&self) -> usize; +} + +pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { + type Context: Context; + type AssignedCell: Clone + Debug; + type AssignedInteger: Clone + Debug; + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result; + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result; + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F::Scalar, impl Deref)], + constant: F::Scalar, + ) -> Result; + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[( + F::Scalar, + impl Deref, + impl Deref, + )], + constant: F::Scalar, + ) -> Result; + + fn sub( + &self, + ctx: &mut Self::Context, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, + ) -> Result; + + fn neg( + &self, + ctx: &mut Self::Context, + value: &Self::AssignedInteger, + ) -> Result; + + fn invert( + &self, + ctx: &mut Self::Context, + value: &Self::AssignedInteger, + ) -> Result; + + fn assert_equal( + &self, + ctx: &mut Self::Context, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, + ) -> Result<(), Error>; +} + +pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { + type Context: Context; + type ScalarChip: IntegerInstructions< + 'a, + C::Scalar, + Context = Self::Context, + AssignedCell = Self::AssignedCell, + AssignedInteger = Self::AssignedScalar, + >; + type AssignedCell: Clone + Debug; + type AssignedScalar: Clone + Debug; + type AssignedEcPoint: Clone + Debug; + + fn scalar_chip(&self) -> &Self::ScalarChip; + + fn assign_constant( + &self, + ctx: &mut Self::Context, + ec_point: C, + ) -> Result; + + fn assign_point( + &self, + ctx: &mut Self::Context, + ec_point: Value, + ) -> Result; + + fn sum_with_const( + &self, + ctx: &mut Self::Context, + values: &[impl Deref], + constant: C, + ) -> Result; + + fn fixed_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[(impl Deref, C)], + ) -> Result; + + fn variable_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[( + impl Deref, + impl Deref, + )], + ) -> Result; + + fn assert_equal( + &self, + ctx: &mut Self::Context, + lhs: &Self::AssignedEcPoint, + rhs: &Self::AssignedEcPoint, + ) -> Result<(), Error>; +} + +mod halo2_lib { + use crate::halo2_proofs::{ + circuit::{Cell, Value}, + halo2curves::CurveAffineExt, + plonk::Error, + }; + use crate::{ + loader::halo2::{Context, EccInstructions, IntegerInstructions}, + util::arithmetic::{CurveAffine, Field}, + }; + use halo2_base::{ + self, + gates::{flex_gate::FlexGateConfig, GateInstructions, RangeInstructions}, + utils::PrimeField, + AssignedValue, + QuantumCell::{Constant, Existing, Witness}, + }; + use halo2_ecc::{ + bigint::CRTInteger, + ecc::{fixed_base::FixedEcPoint, BaseFieldEccChip, EcPoint}, + fields::FieldChip, + }; + use std::ops::Deref; + + type AssignedInteger<'v, C> = CRTInteger<'v, ::ScalarExt>; + type AssignedEcPoint<'v, C> = EcPoint<::ScalarExt, AssignedInteger<'v, C>>; + + impl<'a, F: PrimeField> Context for halo2_base::Context<'a, F> { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { + #[cfg(feature = "halo2-axiom")] + self.region.constrain_equal(&lhs, &rhs); + #[cfg(feature = "halo2-pse")] + self.region.constrain_equal(lhs, rhs); + Ok(()) + } + + fn offset(&self) -> usize { + unreachable!() + } + } + + impl<'a, F: PrimeField> IntegerInstructions<'a, F> for FlexGateConfig { + type Context = halo2_base::Context<'a, F>; + type AssignedCell = AssignedValue<'a, F>; + type AssignedInteger = AssignedValue<'a, F>; + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result { + Ok(self.assign_region_last(ctx, vec![Witness(integer)], vec![])) + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result { + Ok(self.assign_region_last(ctx, vec![Constant(integer)], vec![])) + } + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F::Scalar, impl Deref)], + constant: F, + ) -> Result { + let mut a = Vec::with_capacity(values.len() + 1); + let mut b = Vec::with_capacity(values.len() + 1); + if constant != F::zero() { + a.push(Constant(constant)); + b.push(Constant(F::one())); + } + a.extend(values.iter().map(|(_, a)| Existing(a))); + b.extend(values.iter().map(|(c, _)| Constant(*c))); + Ok(self.inner_product(ctx, a, b)) + } + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[( + F::Scalar, + impl Deref, + impl Deref, + )], + constant: F, + ) -> Result { + match values.len() { + 0 => self.assign_constant(ctx, constant), + _ => Ok(self.sum_products_with_coeff_and_var( + ctx, + values.iter().map(|(c, a, b)| (*c, Existing(a), Existing(b))), + Constant(constant), + )), + } + } + + fn sub( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result { + Ok(GateInstructions::sub(self, ctx, Existing(a), Existing(b))) + } + + fn neg( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + Ok(GateInstructions::neg(self, ctx, Existing(a))) + } + + fn invert( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + // make sure scalar != 0 + let is_zero = self.is_zero(ctx, a); + self.assert_is_const(ctx, &is_zero, F::zero()); + Ok(GateInstructions::div_unsafe(self, ctx, Constant(F::one()), Existing(a))) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result<(), Error> { + ctx.region.constrain_equal(a.cell(), b.cell()); + Ok(()) + } + } + + impl<'a, C: CurveAffineExt> EccInstructions<'a, C> for BaseFieldEccChip + where + C::ScalarExt: PrimeField, + C::Base: PrimeField, + { + type Context = halo2_base::Context<'a, C::Scalar>; + type ScalarChip = FlexGateConfig; + type AssignedCell = AssignedValue<'a, C::Scalar>; + type AssignedScalar = AssignedValue<'a, C::Scalar>; + type AssignedEcPoint = AssignedEcPoint<'a, C>; + + fn scalar_chip(&self) -> &Self::ScalarChip { + self.field_chip.range().gate() + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + point: C, + ) -> Result { + let fixed = FixedEcPoint::::from_curve( + point, + self.field_chip.num_limbs, + self.field_chip.limb_bits, + ); + Ok(FixedEcPoint::assign( + fixed, + self.field_chip(), + ctx, + self.field_chip().native_modulus(), + )) + } + + fn assign_point( + &self, + ctx: &mut Self::Context, + point: Value, + ) -> Result { + let assigned = self.assign_point(ctx, point); + let is_valid = self.is_on_curve_or_infinity::(ctx, &assigned); + self.field_chip.range.gate.assert_is_const(ctx, &is_valid, C::Scalar::one()); + Ok(assigned) + } + + fn sum_with_const( + &self, + ctx: &mut Self::Context, + values: &[impl Deref], + constant: C, + ) -> Result { + let constant = if bool::from(constant.is_identity()) { + None + } else { + let constant = EccInstructions::::assign_constant(self, ctx, constant).unwrap(); + Some(constant) + }; + Ok(self.sum::(ctx, constant.iter().chain(values.iter().map(Deref::deref)))) + } + + fn variable_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[( + impl Deref, + impl Deref, + )], + ) -> Result { + let (scalars, points): (Vec<_>, Vec<_>) = pairs + .iter() + .map(|(scalar, point)| (vec![scalar.deref().clone()], point.deref().clone())) + .unzip(); + + Ok(BaseFieldEccChip::::variable_base_msm::( + self, + ctx, + &points, + &scalars, + C::Scalar::NUM_BITS as usize, + 4, // empirically clump factor of 4 seems to be best + )) + } + + fn fixed_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[(impl Deref, C)], + ) -> Result { + let (scalars, points): (Vec<_>, Vec<_>) = pairs + .iter() + .filter_map(|(scalar, point)| { + if point.is_identity().into() { + None + } else { + Some((vec![scalar.deref().clone()], *point)) + } + }) + .unzip(); + + Ok(BaseFieldEccChip::::fixed_base_msm::( + self, + ctx, + &points, + &scalars, + C::Scalar::NUM_BITS as usize, + 0, + 4, + )) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedEcPoint, + b: &Self::AssignedEcPoint, + ) -> Result<(), Error> { + self.assert_equal(ctx, a, b); + Ok(()) + } + } +} + +/* +mod halo2_wrong { + use crate::{ + loader::halo2::{Context, EccInstructions, IntegerInstructions}, + util::{ + arithmetic::{CurveAffine, FieldExt, Group}, + Itertools, + }, + }; + use halo2_proofs::{ + circuit::{AssignedCell, Cell, Value}, + plonk::Error, + }; + use halo2_wrong_ecc::{ + integer::rns::Common, + maingate::{ + CombinationOption, CombinationOptionCommon, MainGate, MainGateInstructions, RegionCtx, + Term, + }, + AssignedPoint, BaseFieldEccChip, + }; + use rand::rngs::OsRng; + use std::{iter, ops::Deref}; + + impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { + self.constrain_equal(lhs, rhs) + } + + fn offset(&self) -> usize { + self.offset() + } + } + + impl<'a, F: FieldExt> IntegerInstructions<'a, F> for MainGate { + type Context = RegionCtx<'a, F>; + type AssignedCell = AssignedCell; + type AssignedInteger = AssignedCell; + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result { + self.assign_value(ctx, integer) + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result { + MainGateInstructions::assign_constant(self, ctx, integer) + } + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F, impl Deref)], + constant: F, + ) -> Result { + self.compose( + ctx, + &values + .iter() + .map(|(coeff, assigned)| Term::Assigned(assigned, *coeff)) + .collect_vec(), + constant, + ) + } + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[( + F, + impl Deref, + impl Deref, + )], + constant: F, + ) -> Result { + match values.len() { + 0 => MainGateInstructions::assign_constant(self, ctx, constant), + 1 => { + let (scalar, lhs, rhs) = &values[0]; + let output = lhs + .value() + .zip(rhs.value()) + .map(|(lhs, rhs)| *scalar * lhs * rhs + constant); + + Ok(self + .apply( + ctx, + [ + Term::Zero, + Term::Zero, + Term::assigned_to_mul(lhs), + Term::assigned_to_mul(rhs), + Term::unassigned_to_sub(output), + ], + constant, + CombinationOption::OneLinerDoubleMul(*scalar), + )? + .swap_remove(4)) + } + _ => { + let (scalar, lhs, rhs) = &values[0]; + self.apply( + ctx, + [Term::assigned_to_mul(lhs), Term::assigned_to_mul(rhs)], + constant, + CombinationOptionCommon::CombineToNextScaleMul(-F::one(), *scalar).into(), + )?; + let acc = + Value::known(*scalar) * lhs.value() * rhs.value() + Value::known(constant); + let output = values.iter().skip(1).fold( + Ok::<_, Error>(acc), + |acc, (scalar, lhs, rhs)| { + acc.and_then(|acc| { + self.apply( + ctx, + [ + Term::assigned_to_mul(lhs), + Term::assigned_to_mul(rhs), + Term::Zero, + Term::Zero, + Term::Unassigned(acc, F::one()), + ], + F::zero(), + CombinationOptionCommon::CombineToNextScaleMul( + -F::one(), + *scalar, + ) + .into(), + )?; + Ok(acc + Value::known(*scalar) * lhs.value() * rhs.value()) + }) + }, + )?; + self.apply( + ctx, + [ + Term::Zero, + Term::Zero, + Term::Zero, + Term::Zero, + Term::Unassigned(output, F::zero()), + ], + F::zero(), + CombinationOptionCommon::OneLinerAdd.into(), + ) + .map(|mut outputs| outputs.swap_remove(4)) + } + } + } + + fn sub( + &self, + ctx: &mut Self::Context, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::sub(self, ctx, lhs, rhs) + } + + fn neg( + &self, + ctx: &mut Self::Context, + value: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::neg_with_constant(self, ctx, value, F::zero()) + } + + fn invert( + &self, + ctx: &mut Self::Context, + value: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::invert_unsafe(self, ctx, value) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, + ) -> Result<(), Error> { + let mut eq = true; + lhs.value().zip(rhs.value()).map(|(lhs, rhs)| { + eq &= lhs == rhs; + }); + MainGateInstructions::assert_equal(self, ctx, lhs, rhs) + .and(eq.then_some(()).ok_or(Error::Synthesis)) + } + } + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EccInstructions<'a, C> + for BaseFieldEccChip + { + type Context = RegionCtx<'a, C::Scalar>; + type ScalarChip = MainGate; + type AssignedCell = AssignedCell; + type AssignedScalar = AssignedCell; + type AssignedEcPoint = AssignedPoint; + + fn scalar_chip(&self) -> &Self::ScalarChip { + self.main_gate() + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + ec_point: C, + ) -> Result { + self.assign_constant(ctx, ec_point) + } + + fn assign_point( + &self, + ctx: &mut Self::Context, + ec_point: Value, + ) -> Result { + self.assign_point(ctx, ec_point) + } + + fn sum_with_const( + &self, + ctx: &mut Self::Context, + values: &[impl Deref], + constant: C, + ) -> Result { + if values.is_empty() { + return self.assign_constant(ctx, constant); + } + + let constant = (!bool::from(constant.is_identity())) + .then(|| self.assign_constant(ctx, constant)) + .transpose()?; + let output = iter::empty() + .chain(constant) + .chain(values.iter().map(|value| value.deref().clone())) + .map(Ok) + .reduce(|acc, ec_point| self.add(ctx, &acc?, &ec_point?)) + .unwrap()?; + self.normalize(ctx, &output) + } + + fn fixed_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[(impl Deref, C)], + ) -> Result { + assert!(!pairs.is_empty()); + + // FIXME: Implement fixed base MSM in halo2_wrong + let pairs = pairs + .iter() + .filter(|(_, base)| !bool::from(base.is_identity())) + .map(|(scalar, base)| { + Ok::<_, Error>((scalar.deref().clone(), self.assign_constant(ctx, *base)?)) + }) + .collect::, _>>()?; + let pairs = pairs.iter().map(|(scalar, base)| (scalar, base)).collect_vec(); + self.variable_base_msm(ctx, &pairs) + } + + fn variable_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[( + impl Deref, + impl Deref, + )], + ) -> Result { + assert!(!pairs.is_empty()); + + const WINDOW_SIZE: usize = 3; + let pairs = pairs + .iter() + .map(|(scalar, base)| (base.deref().clone(), scalar.deref().clone())) + .collect_vec(); + let output = match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { + Err(_) => { + if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { + let aux_generator = Value::known(C::Curve::random(OsRng).into()); + self.assign_aux_generator(ctx, aux_generator)?; + self.assign_aux(ctx, WINDOW_SIZE, pairs.len())?; + } + self.mul_batch_1d_horizontal(ctx, pairs, WINDOW_SIZE) + } + result => result, + }?; + self.normalize(ctx, &output) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + lhs: &Self::AssignedEcPoint, + rhs: &Self::AssignedEcPoint, + ) -> Result<(), Error> { + let mut eq = true; + [(lhs.x(), rhs.x()), (lhs.y(), rhs.y())].map(|(lhs, rhs)| { + lhs.integer().zip(rhs.integer()).map(|(lhs, rhs)| { + eq &= lhs.value() == rhs.value(); + }); + }); + self.assert_equal(ctx, lhs, rhs).and(eq.then_some(()).ok_or(Error::Synthesis)) + } + } +} +*/ diff --git a/src/loader/halo2/test.rs b/snark-verifier/src/loader/halo2/test.rs similarity index 73% rename from src/loader/halo2/test.rs rename to snark-verifier/src/loader/halo2/test.rs index 09fc0102..96de6747 100644 --- a/src/loader/halo2/test.rs +++ b/snark-verifier/src/loader/halo2/test.rs @@ -1,12 +1,10 @@ +use crate::halo2_proofs::circuit::Value; use crate::{ util::{arithmetic::CurveAffine, Itertools}, Protocol, }; -use halo2_proofs::circuit::Value; -mod circuit; - -pub use circuit::standard::StandardPlonk; +#[derive(Clone, Debug)] pub struct Snark { pub protocol: Protocol, pub instances: Vec>, @@ -23,6 +21,7 @@ impl Snark { } } +#[derive(Clone, Debug)] pub struct SnarkWitness { pub protocol: Protocol, pub instances: Vec>>, @@ -44,16 +43,17 @@ impl From> for SnarkWitness { } impl SnarkWitness { + pub fn new_without_witness(protocol: Protocol) -> Self { + let instances = protocol + .num_instance + .iter() + .map(|num_instance| vec![Value::unknown(); *num_instance]) + .collect(); + SnarkWitness { protocol, instances, proof: Value::unknown() } + } + pub fn without_witnesses(&self) -> Self { - SnarkWitness { - protocol: self.protocol.clone(), - instances: self - .instances - .iter() - .map(|instances| vec![Value::unknown(); instances.len()]) - .collect(), - proof: Value::unknown(), - } + SnarkWitness::new_without_witness(self.protocol.clone()) } pub fn proof(&self) -> Value<&[u8]> { diff --git a/src/loader/native.rs b/snark-verifier/src/loader/native.rs similarity index 75% rename from src/loader/native.rs rename to snark-verifier/src/loader/native.rs index e1915663..6fce383a 100644 --- a/src/loader/native.rs +++ b/snark-verifier/src/loader/native.rs @@ -19,15 +19,6 @@ impl LoadedEcPoint for C { fn loader(&self) -> &NativeLoader { &LOADER } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, base)| base * scalar) - .reduce(|acc, value| acc + value) - .unwrap() - .to_affine() - } } impl FieldOps for F { @@ -42,14 +33,6 @@ impl LoadedScalar for F { fn loader(&self) -> &NativeLoader { &LOADER } - - fn mul_add(a: &F, b: &F, c: &F) -> Self { - *a * *b + *c - } - - fn mul_add_constant(a: &F, b: &F, c: &F) -> Self { - *a * *b + *c - } } impl EcPointLoader for NativeLoader { @@ -65,7 +48,21 @@ impl EcPointLoader for NativeLoader { lhs: &Self::LoadedEcPoint, rhs: &Self::LoadedEcPoint, ) -> Result<(), Error> { - lhs.eq(rhs).then_some(()).ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + lhs.eq(rhs) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + } + + fn multi_scalar_multiplication( + pairs: &[(&>::LoadedScalar, &C)], + ) -> C { + pairs + .iter() + .cloned() + .map(|(scalar, base)| *base * scalar) + .reduce(|acc, value| acc + value) + .unwrap() + .to_affine() } } @@ -82,7 +79,9 @@ impl ScalarLoader for NativeLoader { lhs: &Self::LoadedScalar, rhs: &Self::LoadedScalar, ) -> Result<(), Error> { - lhs.eq(rhs).then_some(()).ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + lhs.eq(rhs) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) } } diff --git a/src/pcs.rs b/snark-verifier/src/pcs.rs similarity index 91% rename from src/pcs.rs rename to snark-verifier/src/pcs.rs index 65804895..2726e4f3 100644 --- a/src/pcs.rs +++ b/snark-verifier/src/pcs.rs @@ -10,6 +10,7 @@ use crate::{ use rand::Rng; use std::fmt::Debug; +// pub mod ipa; pub mod kzg; pub trait PolynomialCommitmentScheme: Clone + Debug @@ -29,11 +30,7 @@ pub struct Query { impl Query { pub fn with_evaluation(self, eval: T) -> Query { - Query { - poly: self.poly, - shift: self.shift, - eval, - } + Query { poly: self.poly, shift: self.shift, eval } } } @@ -123,7 +120,7 @@ where L: Loader, PCS: PolynomialCommitmentScheme, { - fn from_repr(repr: Vec) -> Result; + fn from_repr(repr: &[&L::LoadedScalar]) -> Result; } impl AccumulatorEncoding for () @@ -132,7 +129,7 @@ where L: Loader, PCS: PolynomialCommitmentScheme, { - fn from_repr(_: Vec) -> Result { + fn from_repr(_: &[&L::LoadedScalar]) -> Result { unimplemented!() } } diff --git a/snark-verifier/src/pcs/ipa.rs b/snark-verifier/src/pcs/ipa.rs new file mode 100644 index 00000000..a2b34824 --- /dev/null +++ b/snark-verifier/src/pcs/ipa.rs @@ -0,0 +1,447 @@ +use crate::{ + loader::{native::NativeLoader, LoadedScalar, Loader, ScalarLoader}, + pcs::PolynomialCommitmentScheme, + util::{ + arithmetic::{ + inner_product, powers, Curve, CurveAffine, Domain, Field, Fraction, PrimeField, + }, + msm::{multi_scalar_multiplication, Msm}, + parallelize, + poly::Polynomial, + transcript::{TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use rand::Rng; +use std::{fmt::Debug, iter, marker::PhantomData}; + +mod accumulation; +mod accumulator; +mod decider; +mod multiopen; + +pub use accumulation::{IpaAs, IpaAsProof}; +pub use accumulator::IpaAccumulator; +pub use decider::IpaDecidingKey; +pub use multiopen::{Bgh19, Bgh19Proof, Bgh19SuccinctVerifyingKey}; + +#[derive(Clone, Debug)] +pub struct Ipa(PhantomData<(C, MOS)>); + +impl PolynomialCommitmentScheme for Ipa +where + C: CurveAffine, + L: Loader, + MOS: Clone + Debug, +{ + type Accumulator = IpaAccumulator; +} + +impl Ipa +where + C: CurveAffine, +{ + pub fn create_proof( + pk: &IpaProvingKey, + p: &[C::Scalar], + z: &C::Scalar, + omega: Option<&C::Scalar>, + transcript: &mut T, + mut rng: R, + ) -> Result, Error> + where + T: TranscriptWrite, + R: Rng, + { + let mut p_prime = Polynomial::new(p.to_vec()); + if pk.zk() { + let p_bar = { + let mut p_bar = Polynomial::rand(p.len(), &mut rng); + let p_bar_at_z = p_bar.evaluate(*z); + p_bar[0] -= p_bar_at_z; + p_bar + }; + let omega_bar = C::Scalar::random(&mut rng); + let c_bar = pk.commit(&p_bar, Some(omega_bar)); + transcript.write_ec_point(c_bar)?; + + let alpha = transcript.squeeze_challenge(); + let omega_prime = *omega.unwrap() + alpha * omega_bar; + transcript.write_scalar(omega_prime)?; + + p_prime = p_prime + &(p_bar * alpha); + }; + + let xi_0 = transcript.squeeze_challenge(); + let h_prime = pk.h * xi_0; + let mut bases = pk.g.clone(); + let mut coeffs = p_prime.to_vec(); + let mut zs = powers(*z).take(coeffs.len()).collect_vec(); + + let k = pk.domain.k; + let mut xi = Vec::with_capacity(k); + for i in 0..k { + let half = 1 << (k - i - 1); + + let l_i = multi_scalar_multiplication(&coeffs[half..], &bases[..half]) + + h_prime * inner_product(&coeffs[half..], &zs[..half]); + let r_i = multi_scalar_multiplication(&coeffs[..half], &bases[half..]) + + h_prime * inner_product(&coeffs[..half], &zs[half..]); + transcript.write_ec_point(l_i.to_affine())?; + transcript.write_ec_point(r_i.to_affine())?; + + let xi_i = transcript.squeeze_challenge(); + let xi_i_inv = Field::invert(&xi_i).unwrap(); + + let (bases_l, bases_r) = bases.split_at_mut(half); + let (coeffs_l, coeffs_r) = coeffs.split_at_mut(half); + let (zs_l, zs_r) = zs.split_at_mut(half); + parallelize(bases_l, |(bases_l, start)| { + let mut tmp = Vec::with_capacity(bases_l.len()); + for (lhs, rhs) in bases_l.iter().zip(bases_r[start..].iter()) { + tmp.push(lhs.to_curve() + *rhs * xi_i); + } + C::Curve::batch_normalize(&tmp, bases_l); + }); + parallelize(coeffs_l, |(coeffs_l, start)| { + for (lhs, rhs) in coeffs_l.iter_mut().zip(coeffs_r[start..].iter()) { + *lhs += xi_i_inv * rhs; + } + }); + parallelize(zs_l, |(zs_l, start)| { + for (lhs, rhs) in zs_l.iter_mut().zip(zs_r[start..].iter()) { + *lhs += xi_i * rhs; + } + }); + bases = bases_l.to_vec(); + coeffs = coeffs_l.to_vec(); + zs = zs_l.to_vec(); + + xi.push(xi_i); + } + + transcript.write_ec_point(bases[0])?; + transcript.write_scalar(coeffs[0])?; + + Ok(IpaAccumulator::new(xi, bases[0])) + } + + pub fn read_proof>( + svk: &IpaSuccinctVerifyingKey, + transcript: &mut T, + ) -> Result, Error> + where + T: TranscriptRead, + { + IpaProof::read(svk, transcript) + } + + pub fn succinct_verify>( + svk: &IpaSuccinctVerifyingKey, + commitment: &Msm, + z: &L::LoadedScalar, + eval: &L::LoadedScalar, + proof: &IpaProof, + ) -> Result, Error> { + let loader = z.loader(); + let h = loader.ec_point_load_const(&svk.h); + let s = svk.s.as_ref().map(|s| loader.ec_point_load_const(s)); + let h = Msm::::base(&h); + + let h_prime = h * &proof.xi_0; + let lhs = { + let c_prime = match ( + s.as_ref(), + proof.c_bar_alpha.as_ref(), + proof.omega_prime.as_ref(), + ) { + (Some(s), Some((c_bar, alpha)), Some(omega_prime)) => { + let s = Msm::::base(s); + commitment.clone() + Msm::base(c_bar) * alpha - s * omega_prime + } + (None, None, None) => commitment.clone(), + _ => unreachable!(), + }; + let c_0 = c_prime + h_prime.clone() * eval; + let c_k = c_0 + + proof + .rounds + .iter() + .zip(proof.xi_inv().iter()) + .flat_map(|(Round { l, r, xi }, xi_inv)| [(l, xi_inv), (r, xi)]) + .map(|(base, scalar)| Msm::::base(base) * scalar) + .sum::>(); + c_k.evaluate(None) + }; + let rhs = { + let u = Msm::::base(&proof.u); + let v_prime = h_eval(&proof.xi(), z) * &proof.c; + (u * &proof.c + h_prime * &v_prime).evaluate(None) + }; + + loader.ec_point_assert_eq("C_k == c[U] + v'[H']", &lhs, &rhs)?; + + Ok(IpaAccumulator::new(proof.xi(), proof.u.clone())) + } +} + +#[derive(Clone, Debug)] +pub struct IpaProvingKey { + pub domain: Domain, + pub g: Vec, + pub h: C, + pub s: Option, +} + +impl IpaProvingKey { + pub fn new(domain: Domain, g: Vec, h: C, s: Option) -> Self { + Self { domain, g, h, s } + } + + pub fn zk(&self) -> bool { + self.s.is_some() + } + + pub fn svk(&self) -> IpaSuccinctVerifyingKey { + IpaSuccinctVerifyingKey::new(self.domain.clone(), self.h, self.s) + } + + pub fn dk(&self) -> IpaDecidingKey { + IpaDecidingKey::new(self.g.clone()) + } + + pub fn commit(&self, poly: &Polynomial, omega: Option) -> C { + let mut c = multi_scalar_multiplication(&poly[..], &self.g); + match (self.s, omega) { + (Some(s), Some(omega)) => c += s * omega, + (None, None) => {} + _ => unreachable!(), + }; + c.to_affine() + } +} + +impl IpaProvingKey { + #[cfg(test)] + pub fn rand(k: usize, zk: bool, mut rng: R) -> Self { + use crate::util::arithmetic::{root_of_unity, Group}; + + let domain = Domain::new(k, root_of_unity(k)); + let mut g = vec![C::default(); 1 << k]; + C::Curve::batch_normalize( + &iter::repeat_with(|| C::Curve::random(&mut rng)) + .take(1 << k) + .collect_vec(), + &mut g, + ); + let h = C::Curve::random(&mut rng).to_affine(); + let s = zk.then(|| C::Curve::random(&mut rng).to_affine()); + Self { domain, g, h, s } + } +} + +#[derive(Clone, Debug)] +pub struct IpaSuccinctVerifyingKey { + pub domain: Domain, + pub h: C, + pub s: Option, +} + +impl IpaSuccinctVerifyingKey { + pub fn new(domain: Domain, h: C, s: Option) -> Self { + Self { domain, h, s } + } + + pub fn zk(&self) -> bool { + self.s.is_some() + } +} + +#[derive(Clone, Debug)] +pub struct IpaProof +where + C: CurveAffine, + L: Loader, +{ + c_bar_alpha: Option<(L::LoadedEcPoint, L::LoadedScalar)>, + omega_prime: Option, + xi_0: L::LoadedScalar, + rounds: Vec>, + u: L::LoadedEcPoint, + c: L::LoadedScalar, +} + +impl IpaProof +where + C: CurveAffine, + L: Loader, +{ + pub fn new( + c_bar_alpha: Option<(L::LoadedEcPoint, L::LoadedScalar)>, + omega_prime: Option, + xi_0: L::LoadedScalar, + rounds: Vec>, + u: L::LoadedEcPoint, + c: L::LoadedScalar, + ) -> Self { + Self { + c_bar_alpha, + omega_prime, + xi_0, + rounds, + u, + c, + } + } + + pub fn read(svk: &IpaSuccinctVerifyingKey, transcript: &mut T) -> Result + where + T: TranscriptRead, + { + let c_bar_alpha = svk + .zk() + .then(|| { + let c_bar = transcript.read_ec_point()?; + let alpha = transcript.squeeze_challenge(); + Ok((c_bar, alpha)) + }) + .transpose()?; + let omega_prime = svk.zk().then(|| transcript.read_scalar()).transpose()?; + let xi_0 = transcript.squeeze_challenge(); + let rounds = iter::repeat_with(|| { + Ok(Round::new( + transcript.read_ec_point()?, + transcript.read_ec_point()?, + transcript.squeeze_challenge(), + )) + }) + .take(svk.domain.k) + .collect::, _>>()?; + let u = transcript.read_ec_point()?; + let c = transcript.read_scalar()?; + Ok(Self { + c_bar_alpha, + omega_prime, + xi_0, + rounds, + u, + c, + }) + } + + pub fn xi(&self) -> Vec { + self.rounds.iter().map(|round| round.xi.clone()).collect() + } + + pub fn xi_inv(&self) -> Vec { + let mut xi_inv = self.xi().into_iter().map(Fraction::one_over).collect_vec(); + L::batch_invert(xi_inv.iter_mut().filter_map(Fraction::denom_mut)); + xi_inv.iter_mut().for_each(Fraction::evaluate); + xi_inv + .into_iter() + .map(|xi_inv| xi_inv.evaluated().clone()) + .collect() + } +} + +#[derive(Clone, Debug)] +pub struct Round +where + C: CurveAffine, + L: Loader, +{ + l: L::LoadedEcPoint, + r: L::LoadedEcPoint, + xi: L::LoadedScalar, +} + +impl Round +where + C: CurveAffine, + L: Loader, +{ + pub fn new(l: L::LoadedEcPoint, r: L::LoadedEcPoint, xi: L::LoadedScalar) -> Self { + Self { l, r, xi } + } +} + +pub fn h_eval>(xi: &[T], z: &T) -> T { + let loader = z.loader(); + let one = loader.load_one(); + loader.product( + &iter::successors(Some(z.clone()), |z| Some(z.square())) + .zip(xi.iter().rev()) + .map(|(z, xi)| z * xi + &one) + .collect_vec() + .iter() + .collect_vec(), + ) +} + +pub fn h_coeffs(xi: &[F], scalar: F) -> Vec { + assert!(!xi.is_empty()); + + let mut coeffs = vec![F::zero(); 1 << xi.len()]; + coeffs[0] = scalar; + + for (len, xi) in xi.iter().rev().enumerate().map(|(i, xi)| (1 << i, xi)) { + let (left, right) = coeffs.split_at_mut(len); + let right = &mut right[0..len]; + right.copy_from_slice(left); + for coeffs in right { + *coeffs *= xi; + } + } + + coeffs +} + +#[cfg(all(test, feature = "system_halo2"))] +mod test { + use crate::{ + pcs::{ + ipa::{self, IpaProvingKey}, + Decider, + }, + util::{arithmetic::Field, msm::Msm, poly::Polynomial}, + }; + use halo2_curves::pasta::pallas; + use halo2_proofs::transcript::{ + Blake2bRead, Blake2bWrite, TranscriptReadBuffer, TranscriptWriterBuffer, + }; + use rand::rngs::OsRng; + + #[test] + fn test_ipa() { + type Ipa = ipa::Ipa; + + let k = 10; + let mut rng = OsRng; + + for zk in [false, true] { + let pk = IpaProvingKey::::rand(k, zk, &mut rng); + let (c, z, v, proof) = { + let p = Polynomial::::rand(pk.domain.n, &mut rng); + let omega = pk.zk().then(|| pallas::Scalar::random(&mut rng)); + let c = pk.commit(&p, omega); + let z = pallas::Scalar::random(&mut rng); + let v = p.evaluate(z); + let mut transcript = Blake2bWrite::init(Vec::new()); + Ipa::create_proof(&pk, &p[..], &z, omega.as_ref(), &mut transcript, &mut rng) + .unwrap(); + (c, z, v, transcript.finalize()) + }; + + let svk = pk.svk(); + let accumulator = { + let mut transcript = Blake2bRead::init(proof.as_slice()); + let proof = Ipa::read_proof(&svk, &mut transcript).unwrap(); + Ipa::succinct_verify(&svk, &Msm::base(&c), &z, &v, &proof).unwrap() + }; + + let dk = pk.dk(); + assert!(Ipa::decide(&dk, accumulator)); + } + } +} diff --git a/snark-verifier/src/pcs/ipa/accumulation.rs b/snark-verifier/src/pcs/ipa/accumulation.rs new file mode 100644 index 00000000..07f294de --- /dev/null +++ b/snark-verifier/src/pcs/ipa/accumulation.rs @@ -0,0 +1,279 @@ +use crate::{ + loader::{native::NativeLoader, LoadedScalar, Loader}, + pcs::{ + ipa::{ + h_coeffs, h_eval, Ipa, IpaAccumulator, IpaProof, IpaProvingKey, IpaSuccinctVerifyingKey, + }, + AccumulationScheme, AccumulationSchemeProver, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{Curve, CurveAffine, Field}, + msm::Msm, + poly::Polynomial, + transcript::{TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use rand::Rng; +use std::{array, iter, marker::PhantomData}; + +#[derive(Clone, Debug)] +pub struct IpaAs(PhantomData); + +impl AccumulationScheme for IpaAs +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + type VerifyingKey = IpaSuccinctVerifyingKey; + type Proof = IpaAsProof; + + fn read_proof( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + IpaAsProof::read(vk, instances, transcript) + } + + fn verify( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + proof: &Self::Proof, + ) -> Result { + let loader = proof.z.loader(); + let s = vk.s.as_ref().map(|s| loader.ec_point_load_const(s)); + + let (u, h) = instances + .iter() + .map(|IpaAccumulator { u, xi }| (u.clone(), h_eval(xi, &proof.z))) + .chain(proof.a_b_u.as_ref().map(|(a, b, u)| (u.clone(), a.clone() * &proof.z + b))) + .unzip::<_, _, Vec<_>, Vec<_>>(); + let powers_of_alpha = proof.alpha.powers(u.len()); + + let mut c = powers_of_alpha + .iter() + .zip(u.iter()) + .map(|(power_of_alpha, u)| Msm::::base(u) * power_of_alpha) + .sum::>(); + if let Some(omega) = proof.omega.as_ref() { + c += Msm::base(s.as_ref().unwrap()) * omega; + } + let v = loader.sum_products(&powers_of_alpha.iter().zip(h.iter()).collect_vec()); + + Ipa::::succinct_verify(vk, &c, &proof.z, &v, &proof.ipa) + } +} + +#[derive(Clone, Debug)] +pub struct IpaAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + a_b_u: Option<(L::LoadedScalar, L::LoadedScalar, L::LoadedEcPoint)>, + omega: Option, + alpha: L::LoadedScalar, + z: L::LoadedScalar, + ipa: IpaProof, + _marker: PhantomData, +} + +impl IpaAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + fn read( + vk: &IpaSuccinctVerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + assert!(instances.len() > 1); + + let a_b_u = vk + .zk() + .then(|| { + let a = transcript.read_scalar()?; + let b = transcript.read_scalar()?; + let u = transcript.read_ec_point()?; + Ok((a, b, u)) + }) + .transpose()?; + let omega = vk + .zk() + .then(|| { + let omega = transcript.read_scalar()?; + Ok(omega) + }) + .transpose()?; + + for accumulator in instances { + for xi in accumulator.xi.iter() { + transcript.common_scalar(xi)?; + } + transcript.common_ec_point(&accumulator.u)?; + } + + let alpha = transcript.squeeze_challenge(); + let z = transcript.squeeze_challenge(); + + let ipa = IpaProof::read(vk, transcript)?; + + Ok(Self { a_b_u, omega, alpha, z, ipa, _marker: PhantomData }) + } +} + +impl AccumulationSchemeProver for IpaAs +where + C: CurveAffine, + PCS: PolynomialCommitmentScheme>, +{ + type ProvingKey = IpaProvingKey; + + fn create_proof( + pk: &Self::ProvingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + mut rng: R, + ) -> Result + where + T: TranscriptWrite, + R: Rng, + { + assert!(instances.len() > 1); + + let a_b_u = pk + .zk() + .then(|| { + let [a, b] = array::from_fn(|_| C::Scalar::random(&mut rng)); + let u = (pk.g[1] * a + pk.g[0] * b).to_affine(); + transcript.write_scalar(a)?; + transcript.write_scalar(b)?; + transcript.write_ec_point(u)?; + Ok((a, b, u)) + }) + .transpose()?; + let omega = pk + .zk() + .then(|| { + let omega = C::Scalar::random(&mut rng); + transcript.write_scalar(omega)?; + Ok(omega) + }) + .transpose()?; + + for accumulator in instances { + for xi in accumulator.xi.iter() { + transcript.common_scalar(xi)?; + } + transcript.common_ec_point(&accumulator.u)?; + } + + let alpha = transcript.squeeze_challenge(); + let z = transcript.squeeze_challenge(); + + let (u, h) = instances + .iter() + .map(|IpaAccumulator { u, xi }| (*u, h_coeffs(xi, C::Scalar::one()))) + .chain(a_b_u.map(|(a, b, u)| { + ( + u, + iter::empty() + .chain([b, a]) + .chain(iter::repeat_with(C::Scalar::zero).take(pk.domain.n - 2)) + .collect(), + ) + })) + .unzip::<_, _, Vec<_>, Vec<_>>(); + let powers_of_alpha = alpha.powers(u.len()); + + let h = powers_of_alpha + .into_iter() + .zip(h.into_iter().map(Polynomial::new)) + .map(|(power_of_alpha, h)| h * power_of_alpha) + .sum::>(); + + Ipa::::create_proof(pk, &h.to_vec(), &z, omega.as_ref(), transcript, &mut rng) + } +} + +#[cfg(test)] +mod test { + use crate::halo2_curves::pasta::pallas; + use crate::halo2_proofs::transcript::{ + Blake2bRead, Blake2bWrite, TranscriptReadBuffer, TranscriptWriterBuffer, + }; + use crate::{ + pcs::{ + ipa::{self, IpaProvingKey}, + AccumulationScheme, AccumulationSchemeProver, Decider, + }, + util::{arithmetic::Field, msm::Msm, poly::Polynomial, Itertools}, + }; + use rand::rngs::OsRng; + use std::iter; + + #[test] + fn test_ipa_as() { + type Ipa = ipa::Ipa; + type IpaAs = ipa::IpaAs; + + let k = 10; + let zk = true; + let mut rng = OsRng; + + let pk = IpaProvingKey::::rand(k, zk, &mut rng); + let accumulators = iter::repeat_with(|| { + let (c, z, v, proof) = { + let p = Polynomial::::rand(pk.domain.n, &mut rng); + let omega = pk.zk().then(|| pallas::Scalar::random(&mut rng)); + let c = pk.commit(&p, omega); + let z = pallas::Scalar::random(&mut rng); + let v = p.evaluate(z); + let mut transcript = Blake2bWrite::init(Vec::new()); + Ipa::create_proof(&pk, &p[..], &z, omega.as_ref(), &mut transcript, &mut rng) + .unwrap(); + (c, z, v, transcript.finalize()) + }; + + let svk = pk.svk(); + let accumulator = { + let mut transcript = Blake2bRead::init(proof.as_slice()); + let proof = Ipa::read_proof(&svk, &mut transcript).unwrap(); + Ipa::succinct_verify(&svk, &Msm::base(&c), &z, &v, &proof).unwrap() + }; + + accumulator + }) + .take(10) + .collect_vec(); + + let proof = { + let apk = pk.clone(); + let mut transcript = Blake2bWrite::init(Vec::new()); + IpaAs::create_proof(&apk, &accumulators, &mut transcript, &mut rng).unwrap(); + transcript.finalize() + }; + + let accumulator = { + let avk = pk.svk(); + let mut transcript = Blake2bRead::init(proof.as_slice()); + let proof = IpaAs::read_proof(&avk, &accumulators, &mut transcript).unwrap(); + IpaAs::verify(&avk, &accumulators, &proof).unwrap() + }; + + let dk = pk.dk(); + assert!(Ipa::decide(&dk, accumulator)); + } +} diff --git a/snark-verifier/src/pcs/ipa/accumulator.rs b/snark-verifier/src/pcs/ipa/accumulator.rs new file mode 100644 index 00000000..27d9d5c7 --- /dev/null +++ b/snark-verifier/src/pcs/ipa/accumulator.rs @@ -0,0 +1,21 @@ +use crate::{loader::Loader, util::arithmetic::CurveAffine}; + +#[derive(Clone, Debug)] +pub struct IpaAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub xi: Vec, + pub u: L::LoadedEcPoint, +} + +impl IpaAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub fn new(xi: Vec, u: L::LoadedEcPoint) -> Self { + Self { xi, u } + } +} diff --git a/snark-verifier/src/pcs/ipa/decider.rs b/snark-verifier/src/pcs/ipa/decider.rs new file mode 100644 index 00000000..2cf8c6cc --- /dev/null +++ b/snark-verifier/src/pcs/ipa/decider.rs @@ -0,0 +1,57 @@ +#[derive(Clone, Debug)] +pub struct IpaDecidingKey { + pub g: Vec, +} + +impl IpaDecidingKey { + pub fn new(g: Vec) -> Self { + Self { g } + } +} + +impl From> for IpaDecidingKey { + fn from(g: Vec) -> IpaDecidingKey { + IpaDecidingKey::new(g) + } +} + +mod native { + use crate::{ + loader::native::NativeLoader, + pcs::{ + ipa::{h_coeffs, Ipa, IpaAccumulator, IpaDecidingKey}, + Decider, + }, + util::{ + arithmetic::{Curve, CurveAffine, Field}, + msm::multi_scalar_multiplication, + }, + }; + use std::fmt::Debug; + + impl Decider for Ipa + where + C: CurveAffine, + MOS: Clone + Debug, + { + type DecidingKey = IpaDecidingKey; + type Output = bool; + + fn decide( + dk: &Self::DecidingKey, + IpaAccumulator { u, xi }: IpaAccumulator, + ) -> bool { + let h = h_coeffs(&xi, C::Scalar::one()); + u == multi_scalar_multiplication(&h, &dk.g).to_affine() + } + + fn decide_all( + dk: &Self::DecidingKey, + accumulators: Vec>, + ) -> bool { + !accumulators + .into_iter() + .any(|accumulator| !Self::decide(dk, accumulator)) + } + } +} diff --git a/snark-verifier/src/pcs/ipa/multiopen.rs b/snark-verifier/src/pcs/ipa/multiopen.rs new file mode 100644 index 00000000..9f685e76 --- /dev/null +++ b/snark-verifier/src/pcs/ipa/multiopen.rs @@ -0,0 +1,3 @@ +mod bgh19; + +pub use bgh19::{Bgh19, Bgh19Proof, Bgh19SuccinctVerifyingKey}; diff --git a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs new file mode 100644 index 00000000..29d291ad --- /dev/null +++ b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs @@ -0,0 +1,417 @@ +use crate::{ + loader::{LoadedScalar, Loader, ScalarLoader}, + pcs::{ + ipa::{Ipa, IpaProof, IpaSuccinctVerifyingKey, Round}, + MultiOpenScheme, Query, + }, + util::{ + arithmetic::{ilog2, CurveAffine, Domain, FieldExt, Fraction}, + msm::Msm, + transcript::TranscriptRead, + Itertools, + }, + Error, +}; +use std::{ + collections::{BTreeMap, BTreeSet}, + iter, + marker::PhantomData, +}; + +#[derive(Clone, Debug)] +pub struct Bgh19; + +impl MultiOpenScheme for Ipa +where + C: CurveAffine, + L: Loader, +{ + type SuccinctVerifyingKey = Bgh19SuccinctVerifyingKey; + type Proof = Bgh19Proof; + + fn read_proof( + svk: &Self::SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + Bgh19Proof::read(svk, queries, transcript) + } + + fn succinct_verify( + svk: &Self::SuccinctVerifyingKey, + commitments: &[Msm], + x: &L::LoadedScalar, + queries: &[Query], + proof: &Self::Proof, + ) -> Result { + let loader = x.loader(); + let g = loader.ec_point_load_const(&svk.g); + + // Multiopen + let sets = query_sets(queries); + let p = { + let coeffs = query_set_coeffs(&sets, x, &proof.x_3); + + let powers_of_x_1 = proof + .x_1 + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let f_eval = { + let powers_of_x_2 = proof.x_2.powers(sets.len()); + let f_evals = sets + .iter() + .zip(coeffs.iter()) + .zip(proof.q_evals.iter()) + .map(|((set, coeff), q_eval)| set.f_eval(coeff, q_eval, &powers_of_x_1)) + .collect_vec(); + x.loader() + .sum_products(&powers_of_x_2.iter().zip(f_evals.iter().rev()).collect_vec()) + }; + let msms = sets + .iter() + .zip(proof.q_evals.iter()) + .map(|(set, q_eval)| set.msm(commitments, q_eval, &powers_of_x_1)); + + let (mut msm, constant) = iter::once(Msm::base(&proof.f) - Msm::constant(f_eval)) + .chain(msms) + .zip(proof.x_4.powers(sets.len() + 1).into_iter().rev()) + .map(|(msm, power_of_x_4)| msm * &power_of_x_4) + .sum::>() + .split(); + if let Some(constant) = constant { + msm += Msm::base(&g) * &constant; + } + msm + }; + + // IPA + Ipa::::succinct_verify(&svk.ipa, &p, &proof.x_3, &loader.load_zero(), &proof.ipa) + } +} + +#[derive(Clone, Debug)] +pub struct Bgh19SuccinctVerifyingKey { + g: C, + ipa: IpaSuccinctVerifyingKey, +} + +impl Bgh19SuccinctVerifyingKey { + pub fn new(domain: Domain, g: C, w: C, u: C) -> Self { + Self { + g, + ipa: IpaSuccinctVerifyingKey::new(domain, u, Some(w)), + } + } +} + +#[derive(Clone, Debug)] +pub struct Bgh19Proof +where + C: CurveAffine, + L: Loader, +{ + // Multiopen + x_1: L::LoadedScalar, + x_2: L::LoadedScalar, + f: L::LoadedEcPoint, + x_3: L::LoadedScalar, + q_evals: Vec, + x_4: L::LoadedScalar, + // IPA + ipa: IpaProof, +} + +impl Bgh19Proof +where + C: CurveAffine, + L: Loader, +{ + fn read>( + svk: &Bgh19SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result { + // Multiopen + let x_1 = transcript.squeeze_challenge(); + let x_2 = transcript.squeeze_challenge(); + let f = transcript.read_ec_point()?; + let x_3 = transcript.squeeze_challenge(); + let q_evals = transcript.read_n_scalars(query_sets(queries).len())?; + let x_4 = transcript.squeeze_challenge(); + // IPA + let s = transcript.read_ec_point()?; + let xi = transcript.squeeze_challenge(); + let z = transcript.squeeze_challenge(); + let rounds = iter::repeat_with(|| { + Ok(Round::new( + transcript.read_ec_point()?, + transcript.read_ec_point()?, + transcript.squeeze_challenge(), + )) + }) + .take(svk.ipa.domain.k) + .collect::, _>>()?; + let c = transcript.read_scalar()?; + let blind = transcript.read_scalar()?; + let g = transcript.read_ec_point()?; + Ok(Bgh19Proof { + x_1, + x_2, + f, + x_3, + q_evals, + x_4, + ipa: IpaProof::new(Some((s, xi)), Some(blind), z, rounds, g, c), + }) + } +} + +fn query_sets(queries: &[Query]) -> Vec> +where + F: FieldExt, + T: Clone, +{ + let poly_shifts = queries.iter().fold( + Vec::<(usize, Vec, Vec<&T>)>::new(), + |mut poly_shifts, query| { + if let Some(pos) = poly_shifts + .iter() + .position(|(poly, _, _)| *poly == query.poly) + { + let (_, shifts, evals) = &mut poly_shifts[pos]; + if !shifts.contains(&query.shift) { + shifts.push(query.shift); + evals.push(&query.eval); + } + } else { + poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + } + poly_shifts + }, + ); + + poly_shifts.into_iter().fold( + Vec::>::new(), + |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx] + }) + .collect(), + ); + } + } else { + let set = QuerySet { + shifts, + polys: vec![poly], + evals: vec![evals], + }; + sets.push(set); + } + sets + }, + ) +} + +fn query_set_coeffs(sets: &[QuerySet], x: &T, x_3: &T) -> Vec> +where + F: FieldExt, + T: LoadedScalar, +{ + let loader = x.loader(); + let superset = sets + .iter() + .flat_map(|set| set.shifts.clone()) + .sorted() + .dedup(); + + let size = 2.max( + ilog2((sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two()) + 1, + ); + let powers_of_x = x.powers(size); + let x_3_minus_x_shift_i = BTreeMap::from_iter( + superset.map(|shift| (shift, x_3.clone() - x.clone() * loader.load_const(&shift))), + ); + + let mut coeffs = sets + .iter() + .map(|set| QuerySetCoeff::new(&set.shifts, &powers_of_x, x_3, &x_3_minus_x_shift_i)) + .collect_vec(); + + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + coeffs.iter_mut().for_each(QuerySetCoeff::evaluate); + + coeffs +} + +#[derive(Clone, Debug)] +struct QuerySet<'a, F, T> { + shifts: Vec, + polys: Vec, + evals: Vec>, +} + +impl<'a, F, T> QuerySet<'a, F, T> +where + F: FieldExt, + T: LoadedScalar, +{ + fn msm>( + &self, + commitments: &[Msm<'a, C, L>], + q_eval: &T, + powers_of_x_1: &[T], + ) -> Msm { + self.polys + .iter() + .rev() + .zip(powers_of_x_1) + .map(|(poly, power_of_x_1)| commitments[*poly].clone() * power_of_x_1) + .sum::>() + - Msm::constant(q_eval.clone()) + } + + fn f_eval(&self, coeff: &QuerySetCoeff, q_eval: &T, powers_of_x_1: &[T]) -> T { + let loader = q_eval.loader(); + let r_eval = { + let r_evals = self + .evals + .iter() + .map(|evals| { + loader.sum_products( + &coeff + .eval_coeffs + .iter() + .zip(evals.iter()) + .map(|(coeff, eval)| (coeff.evaluated(), *eval)) + .collect_vec(), + ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated() + }) + .collect_vec(); + loader.sum_products(&r_evals.iter().rev().zip(powers_of_x_1).collect_vec()) + }; + + (q_eval.clone() - r_eval) * coeff.f_eval_coeff.evaluated() + } +} + +#[derive(Clone, Debug)] +struct QuerySetCoeff { + eval_coeffs: Vec>, + r_eval_coeff: Option>, + f_eval_coeff: Fraction, + _marker: PhantomData, +} + +impl QuerySetCoeff +where + F: FieldExt, + T: LoadedScalar, +{ + fn new(shifts: &[F], powers_of_x: &[T], x_3: &T, x_3_minus_x_shift_i: &BTreeMap) -> Self { + let loader = x_3.loader(); + let normalized_ell_primes = shifts + .iter() + .enumerate() + .map(|(j, shift_j)| { + shifts + .iter() + .enumerate() + .filter(|&(i, _)| i != j) + .map(|(_, shift_i)| (*shift_j - shift_i)) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| F::one()) + }) + .collect_vec(); + + let x = &powers_of_x[1].clone(); + let x_pow_k_minus_one = { + let k_minus_one = shifts.len() - 1; + powers_of_x + .iter() + .enumerate() + .skip(1) + .filter_map(|(i, power_of_x)| { + (k_minus_one & (1 << i) == 1).then(|| power_of_x.clone()) + }) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| loader.load_one()) + }; + + let barycentric_weights = shifts + .iter() + .zip(normalized_ell_primes.iter()) + .map(|(shift, normalized_ell_prime)| { + loader.sum_products_with_coeff(&[ + (*normalized_ell_prime, &x_pow_k_minus_one, x_3), + (-(*normalized_ell_prime * shift), &x_pow_k_minus_one, x), + ]) + }) + .map(Fraction::one_over) + .collect_vec(); + + let f_eval_coeff = Fraction::one_over( + loader.product( + &shifts + .iter() + .map(|shift| x_3_minus_x_shift_i.get(shift).unwrap()) + .collect_vec(), + ), + ); + + Self { + eval_coeffs: barycentric_weights, + r_eval_coeff: None, + f_eval_coeff, + _marker: PhantomData, + } + } + + fn denoms(&mut self) -> impl IntoIterator { + if self.eval_coeffs.first().unwrap().denom().is_some() { + return self + .eval_coeffs + .iter_mut() + .chain(Some(&mut self.f_eval_coeff)) + .filter_map(Fraction::denom_mut) + .collect_vec(); + } + + if self.r_eval_coeff.is_none() { + self.eval_coeffs + .iter_mut() + .chain(Some(&mut self.f_eval_coeff)) + .for_each(Fraction::evaluate); + + let loader = self.f_eval_coeff.evaluated().loader(); + let barycentric_weights_sum = loader.sum( + &self + .eval_coeffs + .iter() + .map(Fraction::evaluated) + .collect_vec(), + ); + self.r_eval_coeff = Some(Fraction::one_over(barycentric_weights_sum)); + + return vec![self.r_eval_coeff.as_mut().unwrap().denom_mut().unwrap()]; + } + + unreachable!() + } + + fn evaluate(&mut self) { + self.r_eval_coeff.as_mut().unwrap().evaluate(); + } +} diff --git a/src/pcs/kzg.rs b/snark-verifier/src/pcs/kzg.rs similarity index 93% rename from src/pcs/kzg.rs rename to snark-verifier/src/pcs/kzg.rs index 9f10bd44..056589a8 100644 --- a/src/pcs/kzg.rs +++ b/snark-verifier/src/pcs/kzg.rs @@ -15,6 +15,9 @@ pub use accumulator::{KzgAccumulator, LimbsEncoding}; pub use decider::KzgDecidingKey; pub use multiopen::{Bdfg21, Bdfg21Proof, Gwc19, Gwc19Proof}; +#[cfg(feature = "loader_halo2")] +pub use accumulator::LimbsEncodingInstructions; + #[derive(Clone, Debug)] pub struct Kzg(PhantomData<(M, MOS)>); diff --git a/src/pcs/kzg/accumulation.rs b/snark-verifier/src/pcs/kzg/accumulation.rs similarity index 92% rename from src/pcs/kzg/accumulation.rs rename to snark-verifier/src/pcs/kzg/accumulation.rs index 3dea8031..4273ce9a 100644 --- a/src/pcs/kzg/accumulation.rs +++ b/snark-verifier/src/pcs/kzg/accumulation.rs @@ -44,16 +44,16 @@ where ) -> Result { let (lhs, rhs) = instances .iter() - .cloned() - .map(|accumulator| (accumulator.lhs, accumulator.rhs)) - .chain(proof.blind.clone()) + .map(|accumulator| (&accumulator.lhs, &accumulator.rhs)) + .chain(proof.blind.as_ref().map(|(lhs, rhs)| (lhs, rhs))) .unzip::<_, _, Vec<_>, Vec<_>>(); let powers_of_r = proof.r.powers(lhs.len()); - let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + let [lhs, rhs] = [lhs, rhs].map(|bases| { + bases + .into_iter() .zip(powers_of_r.iter()) - .map(|(msm, r)| Msm::::base(msm) * r) + .map(|(base, r)| Msm::::base(base) * r) .sum::>() .evaluate(None) }); @@ -114,7 +114,7 @@ where where T: TranscriptRead, { - assert!(instances.len() > 1); + assert!(!instances.is_empty()); for accumulator in instances { transcript.common_ec_point(&accumulator.lhs)?; @@ -153,7 +153,7 @@ where T: TranscriptWrite, R: Rng, { - assert!(instances.len() > 1); + assert!(!instances.is_empty()); for accumulator in instances { transcript.common_ec_point(&accumulator.lhs)?; @@ -184,7 +184,7 @@ where let powers_of_r = r.powers(lhs.len()); let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + msms.iter() .zip(powers_of_r.iter()) .map(|(msm, power_of_r)| Msm::::base(msm) * power_of_r) .sum::>() diff --git a/snark-verifier/src/pcs/kzg/accumulator.rs b/snark-verifier/src/pcs/kzg/accumulator.rs new file mode 100644 index 00000000..efc28cd8 --- /dev/null +++ b/snark-verifier/src/pcs/kzg/accumulator.rs @@ -0,0 +1,312 @@ +use crate::{loader::Loader, util::arithmetic::CurveAffine}; +use std::fmt::Debug; + +#[derive(Clone, Debug)] +pub struct KzgAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub lhs: L::LoadedEcPoint, + pub rhs: L::LoadedEcPoint, +} + +impl KzgAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub fn new(lhs: L::LoadedEcPoint, rhs: L::LoadedEcPoint) -> Self { + Self { lhs, rhs } + } +} + +/// `AccumulatorEncoding` that encodes `Accumulator` into limbs. +/// +/// Since in circuit everything are in scalar field, but `Accumulator` might contain base field elements, so we split them into limbs. +/// The const generic `LIMBS` and `BITS` respectively represents how many limbs +/// a base field element are split into and how many bits each limbs could have. +#[derive(Clone, Debug)] +pub struct LimbsEncoding; + +mod native { + use crate::{ + loader::native::NativeLoader, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{fe_from_limbs, CurveAffine}, + Itertools, + }, + Error, + }; + + impl AccumulatorEncoding + for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + NativeLoader, + Accumulator = KzgAccumulator, + >, + { + fn from_repr(limbs: &[&C::Scalar]) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs + .chunks(LIMBS) + .into_iter() + .map(|limbs| { + fe_from_limbs::<_, _, LIMBS, BITS>( + limbs.iter().map(|limb| **limb).collect_vec().try_into().unwrap(), + ) + }) + .collect_vec() + .try_into() + .unwrap(); + let accumulator = KzgAccumulator::new( + C::from_xy(lhs_x, lhs_y).unwrap(), + C::from_xy(rhs_x, rhs_y).unwrap(), + ); + + Ok(accumulator) + } + } +} + +#[cfg(feature = "loader_evm")] +mod evm { + use crate::{ + loader::evm::{EvmLoader, Scalar}, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{CurveAffine, PrimeField}, + Itertools, + }, + Error, + }; + use std::rc::Rc; + + impl AccumulatorEncoding, PCS> + for LimbsEncoding + where + C: CurveAffine, + C::Scalar: PrimeField, + PCS: PolynomialCommitmentScheme< + C, + Rc, + Accumulator = KzgAccumulator>, + >, + { + fn from_repr(limbs: &[&Scalar]) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let loader = limbs[0].loader(); + + let [lhs_x, lhs_y, rhs_x, rhs_y]: [[_; LIMBS]; 4] = limbs + .chunks(LIMBS) + .into_iter() + .map(|limbs| limbs.to_vec().try_into().unwrap()) + .collect_vec() + .try_into() + .unwrap(); + let accumulator = KzgAccumulator::new( + loader.ec_point_from_limbs::(lhs_x, lhs_y), + loader.ec_point_from_limbs::(rhs_x, rhs_y), + ); + + Ok(accumulator) + } + } +} + +#[cfg(feature = "loader_halo2")] +pub use halo2::LimbsEncodingInstructions; + +#[cfg(feature = "loader_halo2")] +mod halo2 { + use crate::halo2_proofs::{circuit::Value, plonk}; + use crate::{ + loader::halo2::{EccInstructions, Halo2Loader, Scalar, Valuetools}, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{fe_from_limbs, CurveAffine}, + Itertools, + }, + Error, + }; + use std::{iter, ops::Deref, rc::Rc}; + + fn ec_point_from_limbs( + limbs: &[Value<&C::Scalar>], + ) -> Value { + assert_eq!(limbs.len(), 2 * LIMBS); + + let [x, y] = [&limbs[..LIMBS], &limbs[LIMBS..]].map(|limbs| { + limbs + .iter() + .cloned() + .fold_zipped(Vec::new(), |mut acc, limb| { + acc.push(*limb); + acc + }) + .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + }); + + x.zip(y).map(|(x, y)| C::from_xy(x, y).unwrap()) + } + + pub trait LimbsEncodingInstructions<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize>: + EccInstructions<'a, C> + { + fn assign_ec_point_from_limbs( + &self, + ctx: &mut Self::Context, + limbs: &[impl Deref], + ) -> Result; + + fn assign_ec_point_to_limbs( + &self, + ctx: &mut Self::Context, + ec_point: impl Deref, + ) -> Result, plonk::Error>; + } + + impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> + AccumulatorEncoding>, PCS> for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + Rc>, + Accumulator = KzgAccumulator>>, + >, + EccChip: LimbsEncodingInstructions<'a, C, LIMBS, BITS>, + { + fn from_repr(limbs: &[&Scalar<'a, C, EccChip>]) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let loader = limbs[0].loader(); + + let [lhs, rhs] = [&limbs[..2 * LIMBS], &limbs[2 * LIMBS..]].map(|limbs| { + let assigned = loader + .ecc_chip() + .assign_ec_point_from_limbs( + &mut loader.ctx_mut(), + &limbs.iter().map(|limb| limb.assigned()).collect_vec(), + ) + .unwrap(); + loader.ec_point_from_assigned(assigned) + }); + + Ok(KzgAccumulator::new(lhs, rhs)) + } + } + + mod halo2_lib { + use super::*; + use halo2_base::{halo2_proofs::halo2curves::CurveAffineExt, utils::PrimeField}; + use halo2_ecc::ecc::BaseFieldEccChip; + + impl<'a, C, const LIMBS: usize, const BITS: usize> + LimbsEncodingInstructions<'a, C, LIMBS, BITS> for BaseFieldEccChip + where + C: CurveAffineExt, + C::ScalarExt: PrimeField, + C::Base: PrimeField, + { + fn assign_ec_point_from_limbs( + &self, + ctx: &mut Self::Context, + limbs: &[impl Deref], + ) -> Result { + assert_eq!(limbs.len(), 2 * LIMBS); + + let ec_point = self.assign_point::( + ctx, + ec_point_from_limbs::<_, LIMBS, BITS>( + &limbs.iter().map(|limb| limb.value()).collect_vec(), + ), + ); + + for (src, dst) in limbs + .iter() + .zip_eq(iter::empty().chain(ec_point.x().limbs()).chain(ec_point.y().limbs())) + { + ctx.region.constrain_equal(src.cell(), dst.cell()); + } + + Ok(ec_point) + } + + fn assign_ec_point_to_limbs( + &self, + _: &mut Self::Context, + ec_point: impl Deref, + ) -> Result, plonk::Error> { + Ok(iter::empty() + .chain(ec_point.x().limbs()) + .chain(ec_point.y().limbs()) + .cloned() + .collect()) + } + } + } + + /* + mod halo2_wrong { + use super::*; + use halo2_wrong_ecc::BaseFieldEccChip; + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> + LimbsEncodingInstructions<'a, C, LIMBS, BITS> for BaseFieldEccChip + { + fn assign_ec_point_from_limbs( + &self, + ctx: &mut Self::Context, + limbs: &[impl Deref], + ) -> Result { + assert_eq!(limbs.len(), 2 * LIMBS); + + let ec_point = self.assign_point( + ctx, + ec_point_from_limbs::<_, LIMBS, BITS>( + &limbs.iter().map(|limb| limb.value()).collect_vec(), + ), + )?; + + for (src, dst) in limbs + .iter() + .zip_eq(iter::empty().chain(ec_point.x().limbs()).chain(ec_point.y().limbs())) + { + ctx.constrain_equal(src.cell(), dst.as_ref().cell())?; + } + + Ok(ec_point) + } + + fn assign_ec_point_to_limbs( + &self, + _: &mut Self::Context, + ec_point: impl Deref, + ) -> Result, plonk::Error> { + Ok(iter::empty() + .chain(ec_point.x().limbs()) + .chain(ec_point.y().limbs()) + .map(|limb| limb.as_ref()) + .cloned() + .collect()) + } + } + } + */ +} diff --git a/src/pcs/kzg/decider.rs b/snark-verifier/src/pcs/kzg/decider.rs similarity index 90% rename from src/pcs/kzg/decider.rs rename to snark-verifier/src/pcs/kzg/decider.rs index b6957883..baabda6c 100644 --- a/src/pcs/kzg/decider.rs +++ b/snark-verifier/src/pcs/kzg/decider.rs @@ -60,7 +60,15 @@ mod native { ) -> bool { !accumulators .into_iter() - .any(|accumulator| !Self::decide(dk, accumulator)) + //.enumerate() + .any(|accumulator| { + /*let decide = Self::decide(dk, accumulator); + if !decide { + panic!("{i}"); + } + !decide*/ + !Self::decide(dk, accumulator) + }) } } } @@ -132,19 +140,13 @@ mod evm { let hash_ptr = loader.keccak256(lhs[0].ptr(), lhs.len() * 0x80); let challenge_ptr = loader.allocate(0x20); - loader - .code_mut() - .push(loader.scalar_modulus()) - .push(hash_ptr) - .mload() - .r#mod() - .push(challenge_ptr) - .mstore(); + let code = format!("mstore({challenge_ptr}, mod(mload({hash_ptr}), f_q))"); + loader.code_mut().runtime_append(code); let challenge = loader.scalar(Value::Memory(challenge_ptr)); let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + msms.iter() .zip(powers_of_challenge.iter()) .map(|(msm, power_of_challenge)| { Msm::>::base(msm) * power_of_challenge diff --git a/src/pcs/kzg/multiopen.rs b/snark-verifier/src/pcs/kzg/multiopen.rs similarity index 100% rename from src/pcs/kzg/multiopen.rs rename to snark-verifier/src/pcs/kzg/multiopen.rs diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs similarity index 89% rename from src/pcs/kzg/multiopen/bdfg21.rs rename to snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs index fb012dea..f542f750 100644 --- a/src/pcs/kzg/multiopen/bdfg21.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs @@ -6,7 +6,7 @@ use crate::{ MultiOpenScheme, Query, }, util::{ - arithmetic::{CurveAffine, FieldExt, Fraction, MultiMillerLoop}, + arithmetic::{ilog2, CurveAffine, FieldExt, Fraction, MultiMillerLoop}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -47,8 +47,8 @@ where queries: &[Query], proof: &Bdfg21Proof, ) -> Result { + let sets = query_sets(queries); let f = { - let sets = query_sets(queries); let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); let powers_of_mu = proof @@ -62,10 +62,10 @@ where msms.zip(proof.gamma.powers(sets.len()).into_iter()) .map(|(msm, power_of_gamma)| msm * &power_of_gamma) .sum::>() - - Msm::base(proof.w.clone()) * &coeffs[0].z_s + - Msm::base(&proof.w) * &coeffs[0].z_s }; - let rhs = Msm::base(proof.w_prime.clone()); + let rhs = Msm::base(&proof.w_prime); let lhs = f + rhs.clone() * &proof.z_prime; Ok(KzgAccumulator::new( @@ -143,7 +143,7 @@ fn query_sets(queries: &[Query]) -> Vec(queries: &[Query]) -> Vec(queries: &[Query]) -> Vec>( - sets: &[QuerySet], +fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( + sets: &[QuerySet<'a, F, T>], z: &T, z_prime: &T, ) -> Vec> { @@ -175,10 +175,7 @@ fn query_set_coeffs>( .dedup(); let size = 2.max( - (sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1) - .next_power_of_two() - .ilog2() as usize - + 1, + ilog2((sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two()) + 1, ); let powers_of_z = z.powers(size); let z_prime_minus_z_shift_i = BTreeMap::from_iter(superset.map(|shift| { @@ -206,25 +203,25 @@ fn query_set_coeffs>( }) .collect_vec(); - T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); - T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); coeffs.iter_mut().for_each(QuerySetCoeff::evaluate); coeffs } #[derive(Clone, Debug)] -struct QuerySet { +struct QuerySet<'a, F, T> { shifts: Vec, polys: Vec, - evals: Vec>, + evals: Vec>, } -impl> QuerySet { +impl<'a, F: FieldExt, T: LoadedScalar> QuerySet<'a, F, T> { fn msm>( &self, coeff: &QuerySetCoeff, - commitments: &[Msm], + commitments: &[Msm<'a, C, L>], powers_of_mu: &[T], ) -> Msm { self.polys @@ -244,7 +241,7 @@ impl> QuerySet { &coeff .eval_coeffs .iter() - .zip(evals.iter()) + .zip(evals.iter().cloned()) .map(|(coeff, eval)| (coeff.evaluated(), eval)) .collect_vec(), ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated(); @@ -291,18 +288,15 @@ where }) .collect_vec(); - let z = &powers_of_z[1].clone(); + let z = &powers_of_z[1]; let z_pow_k_minus_one = { let k_minus_one = shifts.len() - 1; powers_of_z .iter() .enumerate() .skip(1) - .filter_map(|(i, power_of_z)| { - (k_minus_one & (1 << i) == 1).then(|| power_of_z.clone()) - }) - .reduce(|acc, value| acc * value) - .unwrap_or_else(|| loader.load_one()) + .filter_map(|(i, power_of_z)| (k_minus_one & (1 << i) == 1).then(|| power_of_z)) + .fold(loader.load_one(), |acc, value| acc * value) }; let barycentric_weights = shifts @@ -357,7 +351,7 @@ where .map(Fraction::evaluated) .collect_vec(), ); - self.r_eval_coeff = Some(match self.commitment_coeff.clone() { + self.r_eval_coeff = Some(match self.commitment_coeff.as_ref() { Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), None => Fraction::one_over(barycentric_weights_sum), }); diff --git a/src/pcs/kzg/multiopen/gwc19.rs b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs similarity index 90% rename from src/pcs/kzg/multiopen/gwc19.rs rename to snark-verifier/src/pcs/kzg/multiopen/gwc19.rs index 121fce8a..6e3f579f 100644 --- a/src/pcs/kzg/multiopen/gwc19.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs @@ -55,15 +55,13 @@ where .map(|(msm, power_of_u)| msm * power_of_u) .sum::>() }; - let z_omegas = sets - .iter() - .map(|set| z.clone() * &z.loader().load_const(&set.shift)); + let z_omegas = sets.iter().map(|set| z.loader().load_const(&set.shift) * z); let rhs = proof .ws .iter() .zip(powers_of_u.iter()) - .map(|(w, power_of_u)| Msm::base(w.clone()) * power_of_u) + .map(|(w, power_of_u)| Msm::base(w) * power_of_u) .collect_vec(); let lhs = f + rhs .iter() @@ -105,25 +103,25 @@ where } } -struct QuerySet { +struct QuerySet<'a, F, T> { shift: F, polys: Vec, - evals: Vec, + evals: Vec<&'a T>, } -impl QuerySet +impl<'a, F, T> QuerySet<'a, F, T> where F: PrimeField, T: Clone, { fn msm>( &self, - commitments: &[Msm], + commitments: &[Msm<'a, C, L>], powers_of_v: &[L::LoadedScalar], ) -> Msm { self.polys .iter() - .zip(self.evals.iter()) + .zip(self.evals.iter().cloned()) .map(|(poly, eval)| { let commitment = commitments[*poly].clone(); commitment - Msm::constant(eval.clone()) @@ -142,12 +140,12 @@ where queries.iter().fold(Vec::new(), |mut sets, query| { if let Some(pos) = sets.iter().position(|set| set.shift == query.shift) { sets[pos].polys.push(query.poly); - sets[pos].evals.push(query.eval.clone()); + sets[pos].evals.push(&query.eval); } else { sets.push(QuerySet { shift: query.shift, polys: vec![query.poly], - evals: vec![query.eval.clone()], + evals: vec![&query.eval], }); } sets diff --git a/snark-verifier/src/system.rs b/snark-verifier/src/system.rs new file mode 100644 index 00000000..edf79228 --- /dev/null +++ b/snark-verifier/src/system.rs @@ -0,0 +1 @@ +pub mod halo2; diff --git a/src/system/halo2.rs b/snark-verifier/src/system/halo2.rs similarity index 80% rename from src/system/halo2.rs rename to snark-verifier/src/system/halo2.rs index 743b90a6..1ba7c1cc 100644 --- a/src/system/halo2.rs +++ b/snark-verifier/src/system/halo2.rs @@ -1,3 +1,8 @@ +use crate::halo2_proofs::{ + plonk::{self, Any, ConstraintSystem, FirstPhase, SecondPhase, ThirdPhase, VerifyingKey}, + poly::{self, commitment::Params}, + transcript::{EncodedChallenge, Transcript}, +}; use crate::{ util::{ arithmetic::{root_of_unity, CurveAffine, Domain, FieldExt, Rotation}, @@ -8,34 +13,15 @@ use crate::{ }, Protocol, }; -use halo2_curves::bn256::{Fq, Fr}; -use halo2_proofs::{ - plonk::{ - self, Any, Column, ConstraintSystem, FirstPhase, Instance, SecondPhase, ThirdPhase, - VerifyingKey, - }, - poly::{ - self, - commitment::{Params, ParamsProver}, - }, - transcript::{EncodedChallenge, Transcript}, -}; -use serde::{Deserialize, Serialize}; -use std::{ - fs::{self, File}, - io::{self, BufReader, BufWriter}, - iter, - mem::size_of, -}; +use num_integer::Integer; +use std::{io, iter, mem::size_of}; -pub mod aggregation; +// pub mod strategy; pub mod transcript; -pub const LIMBS: usize = 3; -pub const BITS: usize = 88; - #[cfg(test)] -mod test; +#[cfg(feature = "loader_halo2")] +pub(crate) mod test; #[derive(Clone, Debug, Default)] pub struct Config { @@ -47,8 +33,8 @@ pub struct Config { } impl Config { - pub fn kzg(query_instance: bool) -> Self { - Self { zk: true, query_instance, num_proof: 1, ..Default::default() } + pub fn kzg() -> Self { + Self { zk: true, query_instance: false, num_proof: 1, ..Default::default() } } pub fn ipa() -> Self { @@ -60,6 +46,11 @@ impl Config { self } + pub fn set_query_instance(mut self, query_instance: bool) -> Self { + self.query_instance = query_instance; + self + } + pub fn with_num_proof(mut self, num_proof: usize) -> Self { assert!(num_proof > 0); self.num_proof = num_proof; @@ -71,8 +62,11 @@ impl Config { self } - pub fn with_accumulator_indices(mut self, accumulator_indices: Vec<(usize, usize)>) -> Self { - self.accumulator_indices = Some(accumulator_indices); + pub fn with_accumulator_indices( + mut self, + accumulator_indices: Option>, + ) -> Self { + self.accumulator_indices = accumulator_indices; self } } @@ -82,11 +76,13 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( vk: &VerifyingKey, config: Config, ) -> Protocol { + assert_eq!(vk.get_domain().k(), params.k()); + let cs = vk.cs(); let Config { zk, query_instance, num_proof, num_instance, accumulator_indices } = config; - let k = vk.get_domain().empty_lagrange().len().ilog2(); - let domain = Domain::new(k as usize, root_of_unity(k as usize)); + let k = params.k() as usize; + let domain = Domain::new(k, root_of_unity(k)); let preprocessed = vk .fixed_commitments() @@ -127,7 +123,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( let instance_committing_key = query_instance.then(|| { instance_committing_key( params, - polynomials.num_instance().into_iter().max().unwrap_or_default(), + Iterator::max(polynomials.num_instance().into_iter()).unwrap_or_default(), ) }); @@ -191,7 +187,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { degree - 1 }; - let num_phase = *cs.advice_column_phase().iter().max().unwrap() as usize + 1; + let num_phase = *cs.advice_column_phase().iter().max().unwrap_or(&0) as usize + 1; let remapping = |phase: Vec| { let num = phase.iter().fold(vec![0; num_phase], |mut num, phase| { num[*phase as usize] += 1; @@ -227,11 +223,10 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { challenge_index, num_lookup_permuted: 2 * cs.lookups().len(), permutation_chunk_size, - num_permutation_z: cs - .permutation() - .get_columns() - .len() - .div_ceil(permutation_chunk_size), + num_permutation_z: Integer::div_ceil( + &cs.permutation().get_columns().len(), + &permutation_chunk_size, + ), num_lookup_z: cs.lookups().len(), } } @@ -545,27 +540,30 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .zip(polys.chunks(self.permutation_chunk_size)) .zip(permutation_fixeds.chunks(self.permutation_chunk_size)) .enumerate() - .map(|(i, ((((z, z_w, _), (_, z_next_w, _)), polys), permutation_fixeds))| { - let left = if self.zk || zs.len() == 1 { - z_w.clone() - } else { - z_w + l_last * (z_next_w - z_w) - } * polys - .iter() - .zip(permutation_fixeds.iter()) - .map(|(poly, permutation_fixed)| { - poly + beta * permutation_fixed + gamma - }) - .reduce(|acc, expr| acc * expr) - .unwrap(); - let right = - z * polys + .map( + |( + i, + ((((z, z_omega, _), (_, z_next_omega, _)), polys), permutation_fixeds), + )| { + let left = if self.zk || zs.len() == 1 { + z_omega.clone() + } else { + z_omega + l_last * (z_next_omega - z_omega) + } * polys + .iter() + .zip(permutation_fixeds.iter()) + .map(|(poly, permutation_fixed)| { + poly + beta * permutation_fixed + gamma + }) + .reduce(|acc, expr| acc * expr) + .unwrap(); + let right = z * polys .iter() .zip( iter::successors( - Some(F::DELTA.pow_vartime(&[(i - * self.permutation_chunk_size) - as u64])), + Some(F::DELTA.pow_vartime([ + (i * self.permutation_chunk_size) as u64, + ])), |delta| Some(F::DELTA * delta), ) .map(Expression::Constant), @@ -573,12 +571,13 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .map(|(poly, delta)| poly + beta * delta * identity + gamma) .reduce(|acc, expr| acc * expr) .unwrap(); - if self.zk { - l_active * (left - right) - } else { - left - right - } - }), + if self.zk { + l_active * (left - right) + } else { + left - right + } + }, + ), ) .collect_vec() } @@ -615,29 +614,35 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .lookups() .iter() .zip(polys.iter()) - .flat_map(|(lookup, (z, z_w, permuted_input, permuted_input_w_inv, permuted_table))| { - let input = compress(lookup.input_expressions()); - let table = compress(lookup.table_expressions()); - iter::empty() - .chain(Some(l_0 * (one - z))) - .chain(self.zk.then(|| l_last * (z * z - z))) - .chain(Some(if self.zk { - l_active - * (z_w * (permuted_input + beta) * (permuted_table + gamma) - - z * (input + beta) * (table + gamma)) - } else { - z_w * (permuted_input + beta) * (permuted_table + gamma) - - z * (input + beta) * (table + gamma) - })) - .chain(self.zk.then(|| l_0 * (permuted_input - permuted_table))) - .chain(Some(if self.zk { - l_active - * (permuted_input - permuted_table) - * (permuted_input - permuted_input_w_inv) - } else { - (permuted_input - permuted_table) * (permuted_input - permuted_input_w_inv) - })) - }) + .flat_map( + |( + lookup, + (z, z_omega, permuted_input, permuted_input_omega_inv, permuted_table), + )| { + let input = compress(lookup.input_expressions()); + let table = compress(lookup.table_expressions()); + iter::empty() + .chain(Some(l_0 * (one - z))) + .chain(self.zk.then(|| l_last * (z * z - z))) + .chain(Some(if self.zk { + l_active + * (z_omega * (permuted_input + beta) * (permuted_table + gamma) + - z * (input + beta) * (table + gamma)) + } else { + z_omega * (permuted_input + beta) * (permuted_table + gamma) + - z * (input + beta) * (table + gamma) + })) + .chain(self.zk.then(|| l_0 * (permuted_input - permuted_table))) + .chain(Some(if self.zk { + l_active + * (permuted_input - permuted_table) + * (permuted_input - permuted_input_omega_inv) + } else { + (permuted_input - permuted_table) + * (permuted_input - permuted_input_omega_inv) + })) + }, + ) .collect_vec() } @@ -741,74 +746,3 @@ fn instance_committing_key<'a, C: CurveAffine, P: Params<'a, C>>( InstanceCommittingKey { bases, constant: Some(w) } } - -// for tuning the circuit -#[derive(Serialize, Deserialize)] -pub struct Halo2VerifierCircuitConfigParams { - pub strategy: halo2_ecc::fields::fp::FpStrategy, - pub degree: u32, - pub num_advice: usize, - pub num_lookup_advice: usize, - pub num_fixed: usize, - pub lookup_bits: usize, - pub limb_bits: usize, - pub num_limbs: usize, -} - -#[derive(Clone)] -pub struct Halo2VerifierCircuitConfig { - pub base_field_config: halo2_ecc::fields::fp::FpConfig, - pub instance: Column, -} - -impl Halo2VerifierCircuitConfig { - pub fn configure( - meta: &mut ConstraintSystem, - params: Halo2VerifierCircuitConfigParams, - ) -> Self { - assert!( - params.limb_bits == BITS && params.num_limbs == LIMBS, - "For now we fix limb_bits = {}, otherwise change code", - BITS - ); - let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - halo2_base::utils::modulus::(), - "verifier".to_string(), - ); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - Self { base_field_config, instance } - } -} - -pub fn read_or_create_srs<'a, C: CurveAffine, P: ParamsProver<'a, C>>( - k: u32, - setup: impl Fn(u32) -> P, -) -> P { - let dir = "./params"; - let path = format!("{}/kzg_bn254_{}.srs", dir, k); - match fs::File::open(path.as_str()) { - Ok(f) => { - println!("read params from {}", path); - let mut reader = BufReader::new(f); - P::read(&mut reader).unwrap() - } - Err(_) => { - println!("creating params for {}", k); - fs::create_dir_all(dir).unwrap(); - let params = setup(k); - params.write(&mut BufWriter::new(File::create(path).unwrap())).unwrap(); - params - } - } -} diff --git a/src/system/halo2/aggregation.rs b/snark-verifier/src/system/halo2/aggregation.rs similarity index 76% rename from src/system/halo2/aggregation.rs rename to snark-verifier/src/system/halo2/aggregation.rs index 67ffc809..a3b09c15 100644 --- a/src/system/halo2/aggregation.rs +++ b/snark-verifier/src/system/halo2/aggregation.rs @@ -1,6 +1,6 @@ use super::{BITS, LIMBS}; use crate::{ - loader::{self, native::NativeLoader}, + loader::{self, native::NativeLoader, Loader}, pcs::{ kzg::{ Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, @@ -47,16 +47,16 @@ use num_bigint::BigUint; use num_traits::Num; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; use std::{ - fs::{self, File}, + fs::File, io::{BufReader, BufWriter, Cursor, Read, Write}, path::Path, rc::Rc, }; -pub const T: usize = 5; -pub const RATE: usize = 4; +pub const T: usize = 3; +pub const RATE: usize = 2; pub const R_F: usize = 8; -pub const R_P: usize = 60; +pub const R_P: usize = 57; pub type Halo2Loader<'a, 'b> = loader::halo2::Halo2Loader<'a, 'b, G1Affine>; pub type PoseidonTranscript = @@ -124,6 +124,14 @@ impl SnarkWitness { } } + pub fn protocol(&self) -> &Protocol { + &self.protocol + } + + pub fn instances(&self) -> &[Vec>] { + &self.instances + } + pub fn proof(&self) -> Value<&[u8]> { self.proof.as_ref().map(Vec::as_slice) } @@ -189,6 +197,98 @@ pub fn aggregate<'a, 'b>( .collect_vec() } +pub fn recursive_aggregate<'a, 'b>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + recursive_snark: &SnarkWitness, + as_vk: &AsVk, + as_proof: Value<&'_ [u8]>, + use_dummy: AssignedValue, +) -> (Vec>, Vec>>) { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() + }) + .collect_vec() + }; + + let mut assigned_instances = vec![]; + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let instances = assign_instances(&snark.instances); + assigned_instances.push( + instances + .iter() + .flat_map(|instance| instance.iter().map(|scalar| scalar.assigned())) + .collect_vec(), + ); + let mut transcript = + PoseidonTranscript::, _, _>::new(loader, snark.proof()); + let proof = + Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + }) + .collect_vec(); + + let use_dummy = loader.scalar_from_assigned(use_dummy); + + let prev_instances = assign_instances(&recursive_snark.instances); + let mut accs = { + let mut transcript = + PoseidonTranscript::, _, _>::new(loader, recursive_snark.proof()); + let proof = + Plonk::read_proof(svk, &recursive_snark.protocol, &prev_instances, &mut transcript) + .unwrap(); + let mut accs = Plonk::succinct_verify_or_dummy( + svk, + &recursive_snark.protocol, + &prev_instances, + &proof, + &use_dummy, + ) + .unwrap(); + for acc in accs.iter_mut() { + (*acc).lhs = + loader.ec_point_select(&accumulators[0].lhs, &acc.lhs, &use_dummy).unwrap(); + (*acc).rhs = + loader.ec_point_select(&accumulators[0].rhs, &acc.rhs, &use_dummy).unwrap(); + } + accs + }; + accumulators.append(&mut accs); + + let KzgAccumulator { lhs, rhs } = { + let mut transcript = PoseidonTranscript::, _, _>::new(loader, as_proof); + let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); + As::verify(as_vk, &accumulators, &proof).unwrap() + }; + + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + + let mut new_instances = prev_instances + .iter() + .flat_map(|instance| instance.iter().map(|scalar| scalar.assigned())) + .collect_vec(); + for (i, acc_limb) in lhs + .x + .truncation + .limbs + .iter() + .chain(lhs.y.truncation.limbs.iter()) + .chain(rhs.x.truncation.limbs.iter()) + .chain(rhs.y.truncation.limbs.iter()) + .enumerate() + { + new_instances[i] = acc_limb.clone(); + } + (new_instances, assigned_instances) +} + #[derive(Clone)] pub struct AggregationCircuit { svk: Svk, @@ -288,15 +388,7 @@ impl AggregationCircuit { first_pass = false; return Ok(()); } - let ctx = Context::new( - region, - ContextParams { - num_advice: vec![( - config.base_field_config.range.context_id.clone(), - config.base_field_config.range.gate.num_advice, - )], - }, - ); + let ctx = config.base_field_config.new_context(region); let loader = Halo2Loader::new(&config.base_field_config, ctx); let instances = aggregate( @@ -340,10 +432,11 @@ impl Circuit for AggregationCircuit { } fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = "./configs/verify_circuit.config"; - let params_str = fs::read_to_string(path).expect(format!("{} should exist", path).as_str()); - let params: Halo2VerifierCircuitConfigParams = - serde_json::from_str(params_str.as_str()).unwrap(); + let path = std::env::var("VERIFY_CONFIG").expect("export VERIFY_CONFIG with config path"); + let params: Halo2VerifierCircuitConfigParams = serde_json::from_reader( + File::open(path.as_str()).expect(format!("{} file should exist", path).as_str()), + ) + .unwrap(); Halo2VerifierCircuitConfig::configure(meta, params) } @@ -381,12 +474,14 @@ pub fn gen_vk>( name: &str, ) -> VerifyingKey { let path = format!("./data/{}_{}.vkey", name, params.k()); + #[cfg(feature = "serialize")] match File::open(path.as_str()) { Ok(f) => { - println!("Reading vkey from {}", path); + let read_time = start_timer!(|| format!("Reading vkey from {}", path)); let mut bufreader = BufReader::new(f); let vk = VerifyingKey::read::<_, ConcreteCircuit>(&mut bufreader, params) .expect("Reading vkey should not fail"); + end_timer!(read_time); vk } Err(_) => { @@ -399,6 +494,13 @@ pub fn gen_vk>( vk } } + #[cfg(not(feature = "serialize"))] + { + let vk_time = start_timer!(|| "vkey"); + let vk = keygen_vk(params, circuit).unwrap(); + end_timer!(vk_time); + vk + } } pub fn gen_pk>( @@ -407,12 +509,14 @@ pub fn gen_pk>( name: &str, ) -> ProvingKey { let path = format!("./data/{}_{}.pkey", name, params.k()); + #[cfg(feature = "serialize")] match File::open(path.as_str()) { Ok(f) => { - println!("Reading pkey from {}", path); + let read_time = start_timer!(|| format!("Reading pkey from {}", path)); let mut bufreader = BufReader::new(f); let pk = ProvingKey::read::<_, ConcreteCircuit>(&mut bufreader, params) .expect("Reading pkey should not fail"); + end_timer!(read_time); pk } Err(_) => { @@ -426,6 +530,14 @@ pub fn gen_pk>( pk } } + #[cfg(not(feature = "serialize"))] + { + let vk = gen_vk::(params, circuit, name); + let pk_time = start_timer!(|| "pkey"); + let pk = keygen_pk(params, vk, circuit).unwrap(); + end_timer!(pk_time); + pk + } } pub fn read_bytes(path: &str) -> Vec { @@ -486,15 +598,16 @@ pub fn write_instances(instances: &Vec>>, path: &str) { pub trait TargetCircuit { const N_PROOFS: usize; - const NAME: &'static str; type Circuit: Circuit; + + fn name() -> String; } // this is a toggle that should match the fork of halo2_proofs you are using -// the current default in PSE/main is `true`, while there is a PR to make it `false`: +// the current default in PSE/main is `false`, before 2022_10_22 it was `true`: // see https://github.com/privacy-scaling-explorations/halo2/pull/96/files -pub const KZG_QUERY_INSTANCE: bool = true; +pub const KZG_QUERY_INSTANCE: bool = false; pub fn create_snark_shplonk( params: &ParamsKZG, @@ -502,7 +615,7 @@ pub fn create_snark_shplonk( instances: Vec>>, // instances[i][j][..] is the i-th circuit's j-th instance column accumulator_indices: Option>, ) -> Snark { - println!("CREATING SNARK FOR: {}", T::NAME); + println!("CREATING SNARK FOR: {}", T::name()); let config = if let Some(accumulator_indices) = accumulator_indices { Config::kzg(KZG_QUERY_INSTANCE) .set_zk(true) @@ -512,7 +625,7 @@ pub fn create_snark_shplonk( Config::kzg(KZG_QUERY_INSTANCE).set_zk(true).with_num_proof(T::N_PROOFS) }; - let pk = gen_pk(params, &circuits[0], T::NAME); + let pk = gen_pk(params, &circuits[0], T::name().as_str()); // num_instance[i] is length of the i-th instance columns in circuit 0 (all circuits should have same shape of instances) let num_instance = instances[0].iter().map(|instance_column| instance_column.len()).collect(); let protocol = compile(params, pk.get_vk(), config.with_num_instance(num_instance)); @@ -526,16 +639,19 @@ pub fn create_snark_shplonk( // TODO: need to cache the instances as well! let proof = { - let path = format!("./data/proof_{}_{}.dat", T::NAME, params.k()); - let instance_path = format!("./data/instances_{}_{}.dat", T::NAME, params.k()); + let path = format!("./data/proof_{}_{}.dat", T::name(), params.k()); + let instance_path = format!("./data/instances_{}_{}.dat", T::name(), params.k()); let cached_instances = read_instances::(instance_path.as_str()); + #[cfg(feature = "serialize")] if cached_instances.is_some() && Path::new(path.as_str()).exists() && cached_instances.unwrap() == instances { + let proof_time = start_timer!(|| "read proof"); let mut file = File::open(path.as_str()).unwrap(); let mut buf = vec![]; file.read_to_end(&mut buf).unwrap(); + end_timer!(proof_time); buf } else { let proof_time = start_timer!(|| "create proof"); @@ -556,6 +672,23 @@ pub fn create_snark_shplonk( end_timer!(proof_time); proof } + #[cfg(not(feature = "serialize"))] + { + let proof_time = start_timer!(|| "create proof"); + let mut transcript = PoseidonTranscript::, _>::init(Vec::new()); + create_proof::, ProverSHPLONK<_>, ChallengeScalar<_>, _, _, _>( + params, + &pk, + &circuits, + instances2.as_slice(), + &mut ChaCha20Rng::from_seed(Default::default()), + &mut transcript, + ) + .unwrap(); + let proof = transcript.finalize(); + end_timer!(proof_time); + proof + } }; let verify_time = start_timer!(|| "verify proof"); diff --git a/snark-verifier/src/system/halo2/strategy.rs b/snark-verifier/src/system/halo2/strategy.rs new file mode 100644 index 00000000..de66f8e3 --- /dev/null +++ b/snark-verifier/src/system/halo2/strategy.rs @@ -0,0 +1,53 @@ +pub mod ipa { + use crate::util::arithmetic::CurveAffine; + use halo2_proofs::{ + plonk::Error, + poly::{ + commitment::MSM, + ipa::{ + commitment::{IPACommitmentScheme, ParamsIPA}, + msm::MSMIPA, + multiopen::VerifierIPA, + strategy::GuardIPA, + }, + VerificationStrategy, + }, + }; + + #[derive(Clone, Debug)] + pub struct SingleStrategy<'a, C: CurveAffine> { + msm: MSMIPA<'a, C>, + } + + impl<'a, C: CurveAffine> VerificationStrategy<'a, IPACommitmentScheme, VerifierIPA<'a, C>> + for SingleStrategy<'a, C> + { + type Output = C; + + fn new(params: &'a ParamsIPA) -> Self { + SingleStrategy { + msm: MSMIPA::new(params), + } + } + + fn process( + self, + f: impl FnOnce(MSMIPA<'a, C>) -> Result, Error>, + ) -> Result { + let guard = f(self.msm)?; + + let g = guard.compute_g(); + let (msm, _) = guard.use_g(g); + + if msm.check() { + Ok(g) + } else { + Err(Error::ConstraintSystemFailure) + } + } + + fn finalize(self) -> bool { + unreachable!() + } + } +} diff --git a/src/system/halo2/test.rs b/snark-verifier/src/system/halo2/test.rs similarity index 72% rename from src/system/halo2/test.rs rename to snark-verifier/src/system/halo2/test.rs index de803d13..840a3fcb 100644 --- a/src/system/halo2/test.rs +++ b/snark-verifier/src/system/halo2/test.rs @@ -1,25 +1,37 @@ -use super::{read_or_create_srs, Halo2VerifierCircuitConfigParams}; -use ark_std::{end_timer, start_timer}; -use halo2_proofs::{ +use crate::halo2_proofs::{ + dev::MockProver, plonk::{create_proof, verify_proof, Circuit, ProvingKey}, poly::{ - commitment::{CommitmentScheme, ParamsProver, Prover, Verifier}, + commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier}, VerificationStrategy, }, transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, }; +use crate::util::arithmetic::CurveAffine; use rand_chacha::rand_core::RngCore; -use std::io::Cursor; +use std::{fs, io::Cursor}; +mod circuit; +// mod ipa; mod kzg; -pub fn load_verify_circuit_degree() -> u32 { - let path = "./configs/verify_circuit.config"; - let params_str = - std::fs::read_to_string(path).expect(format!("{} file should exist", path).as_str()); - let params: Halo2VerifierCircuitConfigParams = - serde_json::from_str(params_str.as_str()).unwrap(); - params.degree +pub use circuit::standard::StandardPlonk; + +pub fn read_or_create_srs<'a, C: CurveAffine, P: ParamsProver<'a, C>>( + dir: &str, + k: u32, + setup: impl Fn(u32) -> P, +) -> P { + let path = format!("{}/k-{}.srs", dir, k); + match fs::File::open(path.as_str()) { + Ok(mut file) => P::read(&mut file).unwrap(), + Err(_) => { + fs::create_dir_all(dir).unwrap(); + let params = setup(k); + params.write(&mut fs::File::create(path).unwrap()).unwrap(); + params + } + } } pub fn create_proof_checked<'a, S, C, P, V, VS, TW, TR, EC, R>( @@ -42,7 +54,16 @@ where EC: EncodedChallenge, R: RngCore, { - let proof_time = start_timer!(|| "create proof"); + for (circuit, instances) in circuits.iter().zip(instances.iter()) { + MockProver::run( + params.k(), + circuit, + instances.iter().map(|instance| instance.to_vec()).collect(), + ) + .unwrap() + .assert_satisfied(); + } + let proof = { let mut transcript = TW::init(Vec::new()); create_proof::( @@ -56,58 +77,34 @@ where .unwrap(); transcript.finalize() }; - end_timer!(proof_time); - - let verify_time = start_timer!(|| "verify proof"); let output = { let params = params.verifier_params(); let strategy = VS::new(params); let mut transcript = TR::init(Cursor::new(proof.clone())); verify_proof(params, pk.get_vk(), strategy, instances, &mut transcript).unwrap() }; - end_timer!(verify_time); finalize(proof, output) } macro_rules! halo2_prepare { ($dir:expr, $k:expr, $setup:expr, $config:expr, $create_circuit:expr) => {{ - use halo2_proofs::{plonk::{keygen_pk, keygen_vk}}; + use $crate::halo2_proofs::plonk::{keygen_pk, keygen_vk}; + use std::iter; use $crate::{ system::halo2::{compile, test::read_or_create_srs}, - util::{Itertools}, + util::{arithmetic::GroupEncoding, Itertools}, }; - use ark_std::{start_timer, end_timer}; - let circuits = (0..$config.num_proof).map(|_| $create_circuit).collect_vec(); + let params = read_or_create_srs($dir, $k, $setup); - /* - let mock_time = start_timer!(|| "mock prover"); - let instances = circuits.iter().map(|circuit| circuit.instances()).collect_vec(); - - for (circuit, instance) in circuits.iter().zip(instances.iter()) { - MockProver::run( - $k, - circuit, - instance.clone(), - ) - .unwrap() - .assert_satisfied(); - } - end_timer!(mock_time); - */ - - let params = read_or_create_srs($k, $setup); + let circuits = iter::repeat_with(|| $create_circuit) + .take($config.num_proof) + .collect_vec(); let pk = if $config.zk { - let vk_time = start_timer!(|| "vkey"); let vk = keygen_vk(¶ms, &circuits[0]).unwrap(); - end_timer!(vk_time); - - let pk_time = start_timer!(|| "pkey"); let pk = keygen_pk(¶ms, vk, &circuits[0]).unwrap(); - end_timer!(pk_time); - pk } else { // TODO: Re-enable optional-zk when it's merged in pse/halo2. @@ -158,9 +155,10 @@ macro_rules! halo2_create_snark { $protocol:expr, $circuits:expr ) => {{ - use itertools::Itertools; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - use $crate::{loader::halo2::test::Snark, system::halo2::test::create_proof_checked}; + use $crate::{ + loader::halo2::test::Snark, system::halo2::test::create_proof_checked, util::Itertools, + }; let instances = $circuits.iter().map(|circuit| circuit.instances()).collect_vec(); let proof = { @@ -204,7 +202,7 @@ macro_rules! halo2_native_verify { $svk:expr, $dk:expr ) => {{ - use halo2_proofs::poly::commitment::ParamsProver; + use $crate::halo2_proofs::poly::commitment::ParamsProver; use $crate::verifier::PlonkVerifier; let proof = diff --git a/snark-verifier/src/system/halo2/test/circuit.rs b/snark-verifier/src/system/halo2/test/circuit.rs new file mode 100644 index 00000000..ab713995 --- /dev/null +++ b/snark-verifier/src/system/halo2/test/circuit.rs @@ -0,0 +1,2 @@ +// pub mod maingate; +pub mod standard; diff --git a/snark-verifier/src/system/halo2/test/circuit/maingate.rs b/snark-verifier/src/system/halo2/test/circuit/maingate.rs new file mode 100644 index 00000000..82d63b5e --- /dev/null +++ b/snark-verifier/src/system/halo2/test/circuit/maingate.rs @@ -0,0 +1,111 @@ +use crate::util::arithmetic::{CurveAffine, FieldExt}; +use halo2_proofs::{ + circuit::{floor_planner::V1, Layouter, Value}, + plonk::{Circuit, ConstraintSystem, Error}, +}; +use halo2_wrong_ecc::{ + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, RangeInstructions, + RegionCtx, + }, + BaseFieldEccChip, EccConfig, +}; +use rand::RngCore; + +#[derive(Clone)] +pub struct MainGateWithRangeConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, +} + +impl MainGateWithRangeConfig { + pub fn configure( + meta: &mut ConstraintSystem, + composition_bits: Vec, + overflow_bits: Vec, + ) -> Self { + let main_gate_config = MainGate::::configure(meta); + let range_config = + RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); + MainGateWithRangeConfig { + main_gate_config, + range_config, + } + } + + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip( + &self, + ) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } +} + +#[derive(Clone, Default)] +pub struct MainGateWithRange(Vec); + +impl MainGateWithRange { + pub fn new(inner: Vec) -> Self { + Self(inner) + } + + pub fn rand(mut rng: R) -> Self { + Self::new(vec![F::from(rng.next_u32() as u64)]) + } + + pub fn instances(&self) -> Vec> { + vec![self.0.clone()] + } +} + +impl Circuit for MainGateWithRange { + type Config = MainGateWithRangeConfig; + type FloorPlanner = V1; + + fn without_witnesses(&self) -> Self { + Self(vec![F::zero()]) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + MainGateWithRangeConfig::configure(meta, vec![8], vec![4, 7]) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + range_chip.load_table(&mut layouter)?; + + let a = layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; + range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; + let a = range_chip.assign(&mut ctx, Value::known(self.0[0]), 8, 68)?; + let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; + let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; + main_gate.select(&mut ctx, &a, &b, &cond)?; + + Ok(a) + }, + )?; + + main_gate.expose_public(layouter, a, 0)?; + + Ok(()) + } +} diff --git a/snark-verifier/src/system/halo2/test/circuit/standard.rs b/snark-verifier/src/system/halo2/test/circuit/standard.rs new file mode 100644 index 00000000..bfa94df4 --- /dev/null +++ b/snark-verifier/src/system/halo2/test/circuit/standard.rs @@ -0,0 +1,146 @@ +use crate::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, +}; +use crate::util::arithmetic::FieldExt; +use halo2_base::halo2_proofs::plonk::Assigned; +use rand::RngCore; + +#[allow(dead_code)] +#[derive(Clone)] +pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + instance: Column, +} + +impl StandardPlonkConfig { + pub fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } + } +} + +#[derive(Clone, Default)] +pub struct StandardPlonk(F); + +impl StandardPlonk { + pub fn rand(mut rng: R) -> Self { + Self(F::from(rng.next_u32() as u64)) + } + + pub fn instances(&self) -> Vec> { + vec![vec![self.0]] + } +} + +impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + #[cfg(feature = "halo2-pse")] + { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-F::one()))?; + + region.assign_advice(|| "", config.a, 1, || Value::known(-F::from(5u64)))?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed( + || "", + column, + 1, + || Value::known(F::from(idx as u64)), + )?; + } + + let a = region.assign_advice(|| "", config.a, 2, || Value::known(F::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + } + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice(config.a, 0, Value::known(Assigned::Trivial(self.0)))?; + region.assign_fixed(config.q_a, 0, Assigned::Trivial(-F::one())); + + region.assign_advice( + config.a, + 1, + Value::known(Assigned::Trivial(-F::from(5u64))), + )?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Assigned::Trivial(F::from(idx as u64))); + } + + let a = region.assign_advice( + config.a, + 2, + Value::known(Assigned::Trivial(F::one())), + )?; + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); + } + + Ok(()) + }, + ) + } +} diff --git a/snark-verifier/src/system/halo2/test/ipa.rs b/snark-verifier/src/system/halo2/test/ipa.rs new file mode 100644 index 00000000..07fd6efd --- /dev/null +++ b/snark-verifier/src/system/halo2/test/ipa.rs @@ -0,0 +1,143 @@ +use crate::util::arithmetic::CurveAffine; +use halo2_proofs::poly::{ + commitment::{Params, ParamsProver}, + ipa::commitment::ParamsIPA, +}; +use std::mem::size_of; + +mod native; + +pub const TESTDATA_DIR: &str = "./src/system/halo2/test/ipa/testdata"; + +pub fn setup(k: u32) -> ParamsIPA { + ParamsIPA::new(k) +} + +pub fn w_u() -> (C, C) { + let mut buf = Vec::new(); + setup::(1).write(&mut buf).unwrap(); + + let repr = C::Repr::default(); + let repr_len = repr.as_ref().len(); + let offset = size_of::() + 4 * repr_len; + + let [w, u] = [offset, offset + repr_len].map(|offset| { + let mut repr = C::Repr::default(); + repr.as_mut() + .copy_from_slice(&buf[offset..offset + repr_len]); + C::from_bytes(&repr).unwrap() + }); + + (w, u) +} + +macro_rules! halo2_ipa_config { + ($zk:expr, $num_proof:expr) => { + $crate::system::halo2::Config::ipa() + .set_zk($zk) + .with_num_proof($num_proof) + }; + ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { + $crate::system::halo2::Config::ipa() + .set_zk($zk) + .with_num_proof($num_proof) + .with_accumulator_indices($accumulator_indices) + }; +} + +macro_rules! halo2_ipa_prepare { + ($dir:expr, $curve:path, $k:expr, $config:expr, $create_circuit:expr) => {{ + use $crate::system::halo2::test::{halo2_prepare, ipa::setup}; + + halo2_prepare!($dir, $k, setup::<$curve>, $config, $create_circuit) + }}; + (pallas::Affine, $k:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_curves::pasta::pallas; + use $crate::system::halo2::test::ipa::TESTDATA_DIR; + + halo2_ipa_prepare!( + &format!("{TESTDATA_DIR}/pallas"), + pallas::Affine, + $k, + $config, + $create_circuit + ) + }}; + (vesta::Affine, $k:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_curves::pasta::vesta; + use $crate::system::halo2::test::ipa::TESTDATA_DIR; + + halo2_ipa_prepare!( + &format!("{TESTDATA_DIR}/vesta"), + vesta::Affine, + $k, + $config, + $create_circuit + ) + }}; +} + +macro_rules! halo2_ipa_create_snark { + ( + $prover:ty, + $verifier:ty, + $transcript_read:ty, + $transcript_write:ty, + $encoded_challenge:ty, + $params:expr, + $pk:expr, + $protocol:expr, + $circuits:expr + ) => {{ + use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme; + use $crate::{ + system::halo2::{strategy::ipa::SingleStrategy, test::halo2_create_snark}, + util::arithmetic::GroupEncoding, + }; + + halo2_create_snark!( + IPACommitmentScheme<_>, + $prover, + $verifier, + SingleStrategy<_>, + $transcript_read, + $transcript_write, + $encoded_challenge, + |proof, g| { [proof, g.to_bytes().as_ref().to_vec()].concat() }, + $params, + $pk, + $protocol, + $circuits + ) + }}; +} + +macro_rules! halo2_ipa_native_verify { + ( + $plonk_verifier:ty, + $params:expr, + $protocol:expr, + $instances:expr, + $transcript:expr + ) => {{ + use $crate::{ + pcs::ipa::{Bgh19SuccinctVerifyingKey, IpaDecidingKey}, + system::halo2::test::{halo2_native_verify, ipa::w_u}, + }; + + let (w, u) = w_u(); + halo2_native_verify!( + $plonk_verifier, + $params, + $protocol, + $instances, + $transcript, + &Bgh19SuccinctVerifyingKey::new($protocol.domain.clone(), $params.get_g()[0], w, u), + &IpaDecidingKey::new($params.get_g().to_vec()) + ) + }}; +} + +pub(crate) use { + halo2_ipa_config, halo2_ipa_create_snark, halo2_ipa_native_verify, halo2_ipa_prepare, +}; diff --git a/snark-verifier/src/system/halo2/test/ipa/native.rs b/snark-verifier/src/system/halo2/test/ipa/native.rs new file mode 100644 index 00000000..7d9e09bb --- /dev/null +++ b/snark-verifier/src/system/halo2/test/ipa/native.rs @@ -0,0 +1,59 @@ +use crate::{ + pcs::ipa::{Bgh19, Ipa}, + system::halo2::test::ipa::{ + halo2_ipa_config, halo2_ipa_create_snark, halo2_ipa_native_verify, halo2_ipa_prepare, + }, + system::halo2::test::StandardPlonk, + verifier::Plonk, +}; +use halo2_curves::pasta::pallas; +use halo2_proofs::{ + poly::ipa::multiopen::{ProverIPA, VerifierIPA}, + transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, +}; +use paste::paste; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + +macro_rules! test { + (@ $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { + paste! { + #[test] + fn []() { + let (params, pk, protocol, circuits) = halo2_ipa_prepare!( + pallas::Affine, + $k, + $config, + $create_cirucit + ); + let snark = halo2_ipa_create_snark!( + $prover, + $verifier, + Blake2bWrite<_, _, _>, + Blake2bRead<_, _, _>, + Challenge255<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + halo2_ipa_native_verify!( + $plonk_verifier, + params, + &snark.protocol, + &snark.instances, + &mut Blake2bRead::<_, pallas::Affine, _>::init(snark.proof.as_slice()) + ); + } + } + }; + ($name:ident, $k:expr, $config:expr, $create_cirucit:expr) => { + test!(@ $name, $k, $config, $create_cirucit, ProverIPA, VerifierIPA, Plonk::>); + } +} + +test!( + zk_standard_plonk_rand, + 9, + halo2_ipa_config!(true, 1), + StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) +); diff --git a/src/system/halo2/test/kzg.rs b/snark-verifier/src/system/halo2/test/kzg.rs similarity index 84% rename from src/system/halo2/test/kzg.rs rename to snark-verifier/src/system/halo2/test/kzg.rs index 5713af10..6cf145db 100644 --- a/src/system/halo2/test/kzg.rs +++ b/snark-verifier/src/system/halo2/test/kzg.rs @@ -1,5 +1,5 @@ +use crate::halo2_proofs::poly::kzg::commitment::ParamsKZG; use crate::util::arithmetic::MultiMillerLoop; -use halo2_proofs::poly::kzg::commitment::ParamsKZG; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; mod native; @@ -22,12 +22,10 @@ pub fn setup(k: u32) -> ParamsKZG { macro_rules! halo2_kzg_config { ($zk:expr, $num_proof:expr) => { - $crate::system::halo2::Config::kzg(crate::system::halo2::aggregation::KZG_QUERY_INSTANCE) - .set_zk($zk) - .with_num_proof($num_proof) + $crate::system::halo2::Config::kzg().set_zk($zk).with_num_proof($num_proof) }; ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { - $crate::system::halo2::Config::kzg(crate::system::halo2::aggregation::KZG_QUERY_INSTANCE) + $crate::system::halo2::Config::kzg() .set_zk($zk) .with_num_proof($num_proof) .with_accumulator_indices($accumulator_indices) @@ -36,7 +34,7 @@ macro_rules! halo2_kzg_config { macro_rules! halo2_kzg_prepare { ($k:expr, $config:expr, $create_circuit:expr) => {{ - use halo2_curves::bn256::Bn256; + use $crate::halo2_curves::bn256::Bn256; #[allow(unused_imports)] use $crate::system::halo2::test::{ halo2_prepare, @@ -59,7 +57,9 @@ macro_rules! halo2_kzg_create_snark { $protocol:expr, $circuits:expr ) => {{ - use halo2_proofs::poly::kzg::{commitment::KZGCommitmentScheme, strategy::SingleStrategy}; + use $crate::halo2_proofs::poly::kzg::{ + commitment::KZGCommitmentScheme, strategy::SingleStrategy, + }; use $crate::system::halo2::test::halo2_create_snark; halo2_create_snark!( diff --git a/src/system/halo2/test/kzg/evm.rs b/snark-verifier/src/system/halo2/test/kzg/evm.rs similarity index 80% rename from src/system/halo2/test/kzg/evm.rs rename to snark-verifier/src/system/halo2/test/kzg/evm.rs index febaf167..28357415 100644 --- a/src/system/halo2/test/kzg/evm.rs +++ b/snark-verifier/src/system/halo2/test/kzg/evm.rs @@ -1,10 +1,14 @@ +use crate::{halo2_curves, halo2_proofs}; use crate::{ - loader::{halo2::test::StandardPlonk, native::NativeLoader}, + loader::native::NativeLoader, pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, system::halo2::{ - test::kzg::{ - self, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, - halo2_kzg_prepare, BITS, LIMBS, + test::{ + kzg::{ + self, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, + halo2_kzg_prepare, BITS, LIMBS, + }, + StandardPlonk, }, transcript::evm::{ChallengeEvm, EvmTranscript}, }, @@ -21,7 +25,7 @@ macro_rules! halo2_kzg_evm_verify { use halo2_proofs::poly::commitment::ParamsProver; use std::rc::Rc; use $crate::{ - loader::evm::{encode_calldata, execute, EvmLoader}, + loader::evm::{compile_yul, encode_calldata, execute, EvmLoader}, system::halo2::{ test::kzg::{BITS, LIMBS}, transcript::evm::EvmTranscript, @@ -31,21 +35,22 @@ macro_rules! halo2_kzg_evm_verify { }; let loader = EvmLoader::new::(); - let runtime_code = { + let deployment_code = { let svk = $params.get_g()[0].into(); let dk = ($params.g2(), $params.s_g2()).into(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = $protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript .load_instances($instances.iter().map(|instances| instances.len()).collect_vec()); - let proof = <$plonk_verifier>::read_proof(&svk, $protocol, &instances, &mut transcript) + let proof = <$plonk_verifier>::read_proof(&svk, &protocol, &instances, &mut transcript) .unwrap(); - <$plonk_verifier>::verify(&svk, &dk, $protocol, &instances, &proof).unwrap(); + <$plonk_verifier>::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.runtime_code() + compile_yul(&loader.yul_code()) }; let (accept, total_cost, costs) = - execute(runtime_code, encode_calldata($instances, &$proof)); + execute(deployment_code, encode_calldata($instances, &$proof)); loader.print_gas_metering(costs); println!("Total gas cost: {}", total_cost); @@ -107,12 +112,20 @@ test!( halo2_kzg_config!(true, 1), StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) ); +/* +test!( + zk_main_gate_with_range_with_mock_kzg_accumulator, + 9, + halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + main_gate_with_range_with_mock_kzg_accumulator::() +); +*/ test!( #[cfg(feature = "loader_halo2")], #[ignore = "cause it requires 32GB memory to run"], zk_accumulation_two_snark, 22, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + halo2_kzg_config!(true, 1, Some((0..4 * LIMBS).map(|idx| (0, idx)).collect())), kzg::halo2::Accumulation::two_snark() ); test!( @@ -120,6 +133,6 @@ test!( #[ignore = "cause it requires 32GB memory to run"], zk_accumulation_two_snark_with_accumulator, 22, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + halo2_kzg_config!(true, 1, Some((0..4 * LIMBS).map(|idx| (0, idx)).collect())), kzg::halo2::Accumulation::two_snark_with_accumulator() ); diff --git a/src/system/halo2/test/kzg/halo2.rs b/snark-verifier/src/system/halo2/test/kzg/halo2.rs similarity index 74% rename from src/system/halo2/test/kzg/halo2.rs rename to snark-verifier/src/system/halo2/test/kzg/halo2.rs index e332b036..bd52426d 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/snark-verifier/src/system/halo2/test/kzg/halo2.rs @@ -1,7 +1,24 @@ +use crate::halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use crate::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{self, create_proof, verify_proof, Circuit, Column, ConstraintSystem, Instance}, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, EncodedChallenge, TranscriptReadBuffer, + TranscriptWriterBuffer, + }, +}; use crate::{ loader::{ self, - halo2::test::{Snark, SnarkWitness, StandardPlonk}, + halo2::test::{Snark, SnarkWitness}, native::NativeLoader, }, pcs::{ @@ -17,35 +34,21 @@ use crate::{ halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, BITS, LIMBS, }, - load_verify_circuit_degree, + StandardPlonk, }, transcript::halo2::{ChallengeScalar, PoseidonTranscript as GenericPoseidonTranscript}, - Halo2VerifierCircuitConfig, Halo2VerifierCircuitConfigParams, }, util::{arithmetic::fe_to_limbs, Itertools}, verifier::{self, PlonkVerifier}, }; use ark_std::{end_timer, start_timer}; use halo2_base::{Context, ContextParams}; -use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_ecc::ecc::EccChip; use halo2_ecc::fields::fp::FpConfig; -use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{self, create_proof, verify_proof, Circuit}, - poly::{ - commitment::ParamsProver, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; use paste::paste; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; +use serde::{Deserialize, Serialize}; +use std::fs::File; use std::{ io::{Cursor, Read, Write}, rc::Rc, @@ -56,8 +59,9 @@ const RATE: usize = 4; const R_F: usize = 8; const R_P: usize = 60; -type Halo2Loader<'a, 'b> = loader::halo2::Halo2Loader<'a, 'b, G1Affine>; -type PoseidonTranscript = GenericPoseidonTranscript; +type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; +type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +type PoseidonTranscript = GenericPoseidonTranscript; type Pcs = Kzg; type Svk = KzgSuccinctVerifyingKey; @@ -66,13 +70,70 @@ type AsPk = KzgAsProvingKey; type AsVk = KzgAsVerifyingKey; type Plonk = verifier::Plonk>; -pub fn accumulate<'a, 'b>( +// for tuning the circuit +#[derive(Serialize, Deserialize)] +pub struct Halo2VerifierCircuitConfigParams { + pub strategy: halo2_ecc::fields::fp::FpStrategy, + pub degree: u32, + pub num_advice: usize, + pub num_lookup_advice: usize, + pub num_fixed: usize, + pub lookup_bits: usize, + pub limb_bits: usize, + pub num_limbs: usize, +} + +pub fn load_verify_circuit_degree() -> u32 { + let path = "./configs/verify_circuit.config"; + let params: Halo2VerifierCircuitConfigParams = + serde_json::from_reader(File::open(path).unwrap_or_else(|err| panic!("{err:?}"))).unwrap(); + params.degree +} + +#[derive(Clone)] +pub struct Halo2VerifierCircuitConfig { + pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub instance: Column, +} + +impl Halo2VerifierCircuitConfig { + pub fn configure( + meta: &mut ConstraintSystem, + params: Halo2VerifierCircuitConfigParams, + ) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( + meta, + params.strategy, + &[params.num_advice], + &[params.num_lookup_advice], + params.num_fixed, + params.lookup_bits, + params.limb_bits, + params.num_limbs, + halo2_base::utils::modulus::(), + 0, + params.degree as usize, + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self { base_field_config, instance } + } +} + +pub fn accumulate<'a>( svk: &Svk, - loader: &Rc>, + loader: &Rc>, snarks: &[SnarkWitness], as_vk: &AsVk, as_proof: Value<&'_ [u8]>, -) -> KzgAccumulator>> { +) -> KzgAccumulator>> { let assign_instances = |instances: &[Vec>]| { instances .iter() @@ -85,17 +146,17 @@ pub fn accumulate<'a, 'b>( let mut accumulators = snarks .iter() .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); let instances = assign_instances(&snark.instances); let mut transcript = - PoseidonTranscript::, _, _>::new(loader, snark.proof()); - let proof = - Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); - Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap() }) .collect_vec(); let acccumulator = if accumulators.len() > 1 { - let mut transcript = PoseidonTranscript::, _, _>::new(loader, as_proof); + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); As::verify(as_vk, &accumulators, &proof).unwrap() } else { @@ -129,7 +190,7 @@ impl Accumulation { .iter() .flat_map(|snark| { let mut transcript = - PoseidonTranscript::::new(snark.proof.as_slice()); + PoseidonTranscript::::new(snark.proof.as_slice()); let proof = Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) .unwrap(); @@ -139,7 +200,7 @@ impl Accumulation { let as_pk = AsPk::new(Some((params.get_g()[0], params.get_g()[1]))); let (accumulator, as_proof) = if accumulators.len() > 1 { - let mut transcript = PoseidonTranscript::::new(Vec::new()); + let mut transcript = PoseidonTranscript::::new(Vec::new()); let accumulator = As::create_proof( &as_pk, &accumulators, @@ -175,8 +236,8 @@ impl Accumulation { let snark = halo2_kzg_create_snark!( ProverSHPLONK<_>, VerifierSHPLONK<_>, - PoseidonTranscript<_, _, _>, - PoseidonTranscript<_, _, _>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, ChallengeScalar<_>, ¶ms, &pk, @@ -195,8 +256,8 @@ impl Accumulation { halo2_kzg_create_snark!( ProverSHPLONK<_>, VerifierSHPLONK<_>, - PoseidonTranscript<_, _, _>, - PoseidonTranscript<_, _, _>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, ChallengeScalar<_>, ¶ms, &pk, @@ -212,15 +273,15 @@ impl Accumulation { const K: u32 = 22; halo2_kzg_prepare!( K, - halo2_kzg_config!(true, 2, Self::accumulator_indices()), + halo2_kzg_config!(true, 2, Some(Self::accumulator_indices())), Self::two_snark() ) }; let snark = halo2_kzg_create_snark!( ProverSHPLONK<_>, VerifierSHPLONK<_>, - PoseidonTranscript<_, _, _>, - PoseidonTranscript<_, _, _>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, ChallengeScalar<_>, ¶ms, &pk, @@ -275,7 +336,8 @@ impl Circuit for Accumulation { params.limb_bits, params.num_limbs, halo2_base::utils::modulus::(), - "verify".to_string(), + 0, + params.degree as usize, ); let instance = meta.instance_column(); @@ -293,58 +355,55 @@ impl Circuit for Accumulation { config.base_field_config.load_lookup_table(&mut layouter)?; // Need to trick layouter to skip first pass in get shape mode - let using_simple_floor_planner = true; - let mut first_pass = true; - let mut final_pair = None; + let mut first_pass = halo2_base::SKIP_FIRST_PASS; + let mut assigned_instances = None; layouter.assign_region( || "", |region| { - if using_simple_floor_planner && first_pass { + if first_pass { first_pass = false; return Ok(()); } let ctx = Context::new( region, ContextParams { - num_advice: vec![( - config.base_field_config.range.context_id.clone(), - config.base_field_config.range.gate.num_advice, - )], + max_rows: config.base_field_config.range.gate.max_rows, + num_context_ids: 1, + fixed_columns: config.base_field_config.range.gate.constants.clone(), }, ); - let loader = Halo2Loader::new(&config.base_field_config, ctx); + let loader = + Halo2Loader::new(EccChip::construct(config.base_field_config.clone()), ctx); let KzgAccumulator { lhs, rhs } = accumulate(&self.svk, &loader, &self.snarks, &self.as_vk, self.as_proof()); + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); // REQUIRED STEP - loader.finalize(); - final_pair = Some((lhs.assigned(), rhs.assigned())); + config.base_field_config.finalize(&mut loader.ctx_mut()); + + let instances: Vec<_> = lhs + .x + .truncation + .limbs + .iter() + .chain(lhs.y.truncation.limbs.iter()) + .chain(rhs.x.truncation.limbs.iter()) + .chain(rhs.y.truncation.limbs.iter()) + .map(|assigned| assigned.cell().clone()) + .collect(); + assigned_instances = Some(instances); Ok(()) }, )?; - let (lhs, rhs) = final_pair.unwrap(); - Ok({ - // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate - let mut layouter = layouter.namespace(|| "expose"); - for (i, assigned_instance) in lhs - .x - .truncation - .limbs - .iter() - .chain(lhs.y.truncation.limbs.iter()) - .chain(rhs.x.truncation.limbs.iter()) - .chain(rhs.y.truncation.limbs.iter()) - .enumerate() - { - layouter.constrain_instance( - assigned_instance.cell().clone(), - config.instance, - i, - )?; - } - }) + // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate + let mut layouter = layouter.namespace(|| "expose"); + for (i, cell) in assigned_instances.unwrap().into_iter().enumerate() { + layouter.constrain_instance(cell, config.instance, i); + } + Ok(()) } } @@ -391,14 +450,14 @@ test!( // create aggregation circuit A that aggregates two simple snarks {B,C}, then verify proof of this aggregation circuit A zk_aggregate_two_snarks, 21, - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), + halo2_kzg_config!(true, 1, Some(Accumulation::accumulator_indices())), Accumulation::two_snark() ); test!( // create aggregation circuit A that aggregates two copies of same aggregation circuit B that aggregates two simple snarks {C, D}, then verifies proof of this aggregation circuit A zk_aggregate_two_snarks_with_accumulator, 22, // 22 = 21 + 1 since there are two copies of circuit B - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), + halo2_kzg_config!(true, 1, Some(Accumulation::accumulator_indices())), Accumulation::two_snark_with_accumulator() ); @@ -431,7 +490,7 @@ pub fn create_snark() -> (ParamsKZG, Snark) { // TODO: need to cache the instances as well! let proof = { - let path = format!("./src/system/halo2/test/data/proof_{}.data", T::NAME); + let path = format!("./data/proof_{}.data", T::NAME); match std::fs::File::open(path.as_str()) { Ok(mut file) => { let mut buf = vec![]; @@ -439,8 +498,7 @@ pub fn create_snark() -> (ParamsKZG, Snark) { buf } Err(_) => { - let mut transcript = - PoseidonTranscript::, _>::init(Vec::new()); + let mut transcript = PoseidonTranscript::>::init(Vec::new()); create_proof::, ProverSHPLONK<_>, _, _, _, _>( ¶ms, &pk, @@ -451,7 +509,8 @@ pub fn create_snark() -> (ParamsKZG, Snark) { ) .unwrap(); let proof = transcript.finalize(); - let mut file = std::fs::File::create(path.as_str()).unwrap(); + let mut file = std::fs::File::create(path.as_str()) + .expect(format!("{:?} should exist", path).as_str()); file.write_all(&proof).unwrap(); proof } @@ -464,7 +523,7 @@ pub fn create_snark() -> (ParamsKZG, Snark) { let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = - >, _> as TranscriptReadBuffer< + >> as TranscriptReadBuffer< _, _, _, @@ -483,6 +542,7 @@ pub fn create_snark() -> (ParamsKZG, Snark) { (params, Snark::new(protocol.clone(), instances0.into_iter().flatten().collect_vec(), proof)) } +/* pub mod zkevm { use super::*; use zkevm_circuit_benchmarks::evm_circuit::TestCircuit as EvmCircuit; @@ -557,3 +617,4 @@ pub mod zkevm { evm_and_state_aggregation_circuit() ); } +*/ diff --git a/src/system/halo2/test/kzg/native.rs b/snark-verifier/src/system/halo2/test/kzg/native.rs similarity index 79% rename from src/system/halo2/test/kzg/native.rs rename to snark-verifier/src/system/halo2/test/kzg/native.rs index 0f2849a3..0801a317 100644 --- a/src/system/halo2/test/kzg/native.rs +++ b/snark-verifier/src/system/halo2/test/kzg/native.rs @@ -1,17 +1,19 @@ +use crate::halo2_curves::bn256::{Bn256, G1Affine}; +use crate::halo2_proofs::{ + poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, + transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, +}; use crate::{ - loader::halo2::test::StandardPlonk, pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, - system::halo2::test::kzg::{ - halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, BITS, - LIMBS, + system::halo2::test::{ + kzg::{ + halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, + BITS, LIMBS, + }, + StandardPlonk, }, verifier::Plonk, }; -use halo2_curves::bn256::{Bn256, G1Affine}; -use halo2_proofs::{ - poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, - transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, -}; use paste::paste; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; @@ -58,3 +60,11 @@ test!( halo2_kzg_config!(true, 2), StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) ); +/* +test!( + zk_main_gate_with_range_with_mock_kzg_accumulator, + 9, + halo2_kzg_config!(true, 2, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + main_gate_with_range_with_mock_kzg_accumulator::() +); +*/ diff --git a/src/system/halo2/transcript.rs b/snark-verifier/src/system/halo2/transcript.rs similarity index 99% rename from src/system/halo2/transcript.rs rename to snark-verifier/src/system/halo2/transcript.rs index 2200bbf4..8b97f968 100644 --- a/src/system/halo2/transcript.rs +++ b/snark-verifier/src/system/halo2/transcript.rs @@ -1,3 +1,4 @@ +use crate::halo2_proofs; use crate::{ loader::native::{self, NativeLoader}, util::{ diff --git a/src/system/halo2/transcript/evm.rs b/snark-verifier/src/system/halo2/transcript/evm.rs similarity index 85% rename from src/system/halo2/transcript/evm.rs rename to snark-verifier/src/system/halo2/transcript/evm.rs index 6b98e0ac..909bb71d 100644 --- a/src/system/halo2/transcript/evm.rs +++ b/snark-verifier/src/system/halo2/transcript/evm.rs @@ -1,12 +1,13 @@ +use crate::halo2_proofs; use crate::{ loader::{ evm::{loader::Value, u256_to_fe, EcPoint, EvmLoader, MemoryChunk, Scalar}, native::{self, NativeLoader}, Loader, }, - system::halo2::aggregation::KZG_QUERY_INSTANCE, util::{ arithmetic::{Coordinates, CurveAffine, PrimeField}, + hash::{Digest, Keccak256}, transcript::{Transcript, TranscriptRead}, Itertools, }, @@ -14,7 +15,6 @@ use crate::{ }; use ethereum_types::U256; use halo2_proofs::transcript::EncodedChallenge; -use sha3::{Digest, Keccak256}; use std::{ io::{self, Read, Write}, iter, @@ -25,7 +25,6 @@ pub struct EvmTranscript, S, B> { loader: L, stream: S, buf: B, - query_instance_reset: bool, _marker: PhantomData, } @@ -34,14 +33,12 @@ where C: CurveAffine, C::Scalar: PrimeField, { - pub fn new(loader: Rc) -> Self { - let ptr = if KZG_QUERY_INSTANCE { 0 } else { loader.allocate(0x20) }; + pub fn new(loader: &Rc) -> Self { + let ptr = loader.allocate(0x20); assert_eq!(ptr, 0); let mut buf = MemoryChunk::new(ptr); - if !KZG_QUERY_INSTANCE { - buf.extend(0x20); - } - Self { loader, stream: 0, buf, query_instance_reset: false, _marker: PhantomData } + buf.extend(0x20); + Self { loader: loader.clone(), stream: 0, buf, _marker: PhantomData } } pub fn load_instances(&mut self, num_instance: Vec) -> Vec> { @@ -72,7 +69,9 @@ where fn squeeze_challenge(&mut self) -> Scalar { let len = if self.buf.len() == 0x20 { assert_eq!(self.loader.ptr(), self.buf.end()); - self.loader.code_mut().push(1).push(self.buf.end()).mstore8(); + let buf_end = self.buf.end(); + let code = format!("mstore8({buf_end}, 1)"); + self.loader.code_mut().runtime_append(code); 0x21 } else { self.buf.len() @@ -81,17 +80,14 @@ where let challenge_ptr = self.loader.allocate(0x20); let dup_hash_ptr = self.loader.allocate(0x20); - self.loader - .code_mut() - .push(hash_ptr) - .mload() - .push(self.loader.scalar_modulus()) - .dup(1) - .r#mod() - .push(challenge_ptr) - .mstore() - .push(dup_hash_ptr) - .mstore(); + let code = format!( + "{{ + let hash := mload({hash_ptr:#x}) + mstore({challenge_ptr:#x}, mod(hash, f_q)) + mstore({dup_hash_ptr:#x}, hash) + }}" + ); + self.loader.code_mut().runtime_append(code); self.buf.reset(dup_hash_ptr); self.buf.extend(0x20); @@ -101,18 +97,8 @@ where fn common_ec_point(&mut self, ec_point: &EcPoint) -> Result<(), Error> { if let Value::Memory(ptr) = ec_point.value() { - // this should never to reached if a vk is first hashed into transcript - if KZG_QUERY_INSTANCE && !self.query_instance_reset && self.buf.end() != ptr { - self.buf.reset(self.loader.ptr()); - self.query_instance_reset = true; - } - if self.buf.end() != ptr { - assert!(self.buf.end() > ptr && KZG_QUERY_INSTANCE); - self.loader.dup_ec_point(ec_point); - self.buf.extend(0x40); - } else { - self.buf.extend(0x40); - } + assert_eq!(self.buf.end(), ptr); + self.buf.extend(0x40); } else { unreachable!() } @@ -122,12 +108,6 @@ where fn common_scalar(&mut self, scalar: &Scalar) -> Result<(), Error> { match scalar.value() { Value::Constant(_) if self.buf.ptr() == 0 => { - if KZG_QUERY_INSTANCE && !self.query_instance_reset { - self.buf.reset(self.loader.ptr()); - self.buf.extend(0x20); - self.loader.allocate(0x20); - self.query_instance_reset = true; - } self.loader.copy_scalar(scalar, self.buf.ptr()); } Value::Memory(ptr) => { @@ -165,13 +145,7 @@ where C: CurveAffine, { pub fn new(stream: S) -> Self { - Self { - loader: NativeLoader, - stream, - buf: Vec::new(), - query_instance_reset: false, - _marker: PhantomData, - } + Self { loader: NativeLoader, stream, buf: Vec::new(), _marker: PhantomData } } } diff --git a/snark-verifier/src/system/halo2/transcript/halo2.rs b/snark-verifier/src/system/halo2/transcript/halo2.rs new file mode 100644 index 00000000..5e343740 --- /dev/null +++ b/snark-verifier/src/system/halo2/transcript/halo2.rs @@ -0,0 +1,447 @@ +use crate::halo2_proofs; +use crate::{ + loader::{ + halo2::{EcPoint, EccInstructions, Halo2Loader, Scalar}, + native::{self, NativeLoader}, + Loader, ScalarLoader, + }, + util::{ + arithmetic::{fe_to_fe, CurveAffine, PrimeField}, + hash::Poseidon, + transcript::{Transcript, TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use halo2_proofs::{circuit::Value, transcript::EncodedChallenge}; +use std::{ + io::{self, Read, Write}, + rc::Rc, +}; + +/// Encoding that encodes elliptic curve point into native field elements. +pub trait NativeEncoding<'a, C>: EccInstructions<'a, C> +where + C: CurveAffine, +{ + fn encode( + &self, + ctx: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result, Error>; +} + +pub struct PoseidonTranscript< + C, + L, + S, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +> where + C: CurveAffine, + L: Loader, +{ + loader: L, + stream: S, + buf: Poseidon>::LoadedScalar, T, RATE>, +} + +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, +{ + pub fn new(loader: &Rc>, stream: Value) -> Self { + let buf = Poseidon::new(loader, R_F, R_P); + Self { loader: loader.clone(), stream, buf } + } + + pub fn from_spec( + loader: &Rc>, + stream: Value, + spec: crate::poseidon::Spec, + ) -> Self { + let buf = Poseidon::from_spec(loader, spec); + Self { loader: loader.clone(), stream, buf } + } + + pub fn clear(&mut self) { + self.buf.clear(); + } + + pub fn new_stream(&mut self, stream: Value) { + self.buf.clear(); + self.stream = stream; + } +} + +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + Transcript>> + for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, +{ + fn loader(&self) -> &Rc> { + &self.loader + } + + fn squeeze_challenge(&mut self) -> Scalar<'a, C, EccChip> { + self.buf.squeeze() + } + + fn common_scalar(&mut self, scalar: &Scalar<'a, C, EccChip>) -> Result<(), Error> { + self.buf.update(&[scalar.clone()]); + Ok(()) + } + + fn common_ec_point(&mut self, ec_point: &EcPoint<'a, C, EccChip>) -> Result<(), Error> { + let encoded = self + .loader + .ecc_chip() + .encode(&mut self.loader.ctx_mut(), &ec_point.assigned()) + .map(|encoded| { + encoded + .into_iter() + .map(|encoded| self.loader.scalar_from_assigned(encoded)) + .collect_vec() + }) + .map_err(|_| { + Error::Transcript( + io::ErrorKind::Other, + "Failed to encode elliptic curve point into native field elements".to_string(), + ) + })?; + self.buf.update(&encoded); + Ok(()) + } +} + +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + TranscriptRead>> + for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, +{ + fn read_scalar(&mut self) -> Result, Error> { + let scalar = self.stream.as_mut().and_then(|stream| { + let mut data = ::Repr::default(); + if stream.read_exact(data.as_mut()).is_err() { + return Value::unknown(); + } + Option::::from(C::Scalar::from_repr(data)) + .map(Value::known) + .unwrap_or_else(Value::unknown) + }); + let scalar = self.loader.assign_scalar(scalar); + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result, Error> { + let ec_point = self.stream.as_mut().and_then(|stream| { + let mut compressed = C::Repr::default(); + if stream.read_exact(compressed.as_mut()).is_err() { + return Value::unknown(); + } + Option::::from(C::from_bytes(&compressed)) + .map(Value::known) + .unwrap_or_else(Value::unknown) + }); + let ec_point = self.loader.assign_ec_point(ec_point); + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl + PoseidonTranscript +{ + pub fn new(stream: S) -> Self { + Self { loader: NativeLoader, stream, buf: Poseidon::new(&NativeLoader, R_F, R_P) } + } + + pub fn from_spec(stream: S, spec: crate::poseidon::Spec) -> Self { + Self { loader: NativeLoader, stream, buf: Poseidon::from_spec(&NativeLoader, spec) } + } + + pub fn clear(&mut self) { + self.buf.clear(); + } + + pub fn new_stream(&mut self, stream: S) { + self.buf.clear(); + self.stream = stream; + } +} + +impl + Transcript for PoseidonTranscript +{ + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + self.buf.squeeze() + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + self.buf.update(&[*scalar]); + Ok(()) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + let encoded: Vec<_> = Option::from(ec_point.coordinates().map(|coordinates| { + [coordinates.x(), coordinates.y()].into_iter().cloned().map(fe_to_fe).collect_vec() + })) + .ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.buf.update(&encoded); + Ok(()) + } +} + +impl + TranscriptRead for PoseidonTranscript +where + C: CurveAffine, + R: Read, +{ + fn read_scalar(&mut self) -> Result { + let mut data = ::Repr::default(); + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { + Error::Transcript(io::ErrorKind::Other, "Invalid scalar encoding in proof".to_string()) + })?; + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result { + let mut data = C::Repr::default(); + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + let ec_point = Option::::from(C::from_bytes(&data)).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl + PoseidonTranscript +where + C: CurveAffine, + W: Write, +{ + pub fn stream_mut(&mut self) -> &mut W { + &mut self.stream + } + + pub fn finalize(self) -> W { + self.stream + } +} + +impl TranscriptWrite + for PoseidonTranscript +where + C: CurveAffine, + W: Write, +{ + fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { + self.common_scalar(&scalar)?; + let data = scalar.to_repr(); + self.stream_mut().write_all(data.as_ref()).map_err(|err| { + Error::Transcript(err.kind(), "Failed to write scalar to transcript".to_string()) + }) + } + + fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error> { + self.common_ec_point(&ec_point)?; + let data = ec_point.to_bytes(); + self.stream_mut().write_all(data.as_ref()).map_err(|err| { + Error::Transcript( + err.kind(), + "Failed to write elliptic curve to transcript".to_string(), + ) + }) + } +} + +pub struct ChallengeScalar(C::Scalar); + +impl EncodedChallenge for ChallengeScalar { + type Input = C::Scalar; + + fn new(challenge_input: &C::Scalar) -> Self { + ChallengeScalar(*challenge_input) + } + + fn get_scalar(&self) -> C::Scalar { + self.0 + } +} + +impl + halo2_proofs::transcript::Transcript> + for PoseidonTranscript +{ + fn squeeze_challenge(&mut self) -> ChallengeScalar { + ChallengeScalar::new(&Transcript::squeeze_challenge(self)) + } + + fn common_point(&mut self, ec_point: C) -> io::Result<()> { + match Transcript::common_ec_point(self, &ec_point) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + match Transcript::common_scalar(self, &scalar) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } +} + +impl + halo2_proofs::transcript::TranscriptRead> + for PoseidonTranscript +where + C: CurveAffine, + R: Read, +{ + fn read_point(&mut self) -> io::Result { + match TranscriptRead::read_ec_point(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } + + fn read_scalar(&mut self) -> io::Result { + match TranscriptRead::read_scalar(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } +} + +impl + halo2_proofs::transcript::TranscriptReadBuffer> + for PoseidonTranscript +where + C: CurveAffine, + R: Read, +{ + fn init(reader: R) -> Self { + Self::new(reader) + } +} + +impl + halo2_proofs::transcript::TranscriptWrite> + for PoseidonTranscript +where + C: CurveAffine, + W: Write, +{ + fn write_point(&mut self, ec_point: C) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_point( + self, ec_point, + )?; + let data = ec_point.to_bytes(); + self.stream_mut().write_all(data.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_scalar(self, scalar)?; + let data = scalar.to_repr(); + self.stream_mut().write_all(data.as_ref()) + } +} + +impl + halo2_proofs::transcript::TranscriptWriterBuffer> + for PoseidonTranscript +where + C: CurveAffine, + W: Write, +{ + fn init(writer: W) -> Self { + Self::new(writer) + } + + fn finalize(self) -> W { + self.finalize() + } +} + +mod halo2_lib { + use crate::halo2_curves::CurveAffineExt; + use crate::system::halo2::transcript::halo2::NativeEncoding; + use halo2_base::utils::PrimeField; + use halo2_ecc::ecc::BaseFieldEccChip; + + impl<'a, C: CurveAffineExt> NativeEncoding<'a, C> for BaseFieldEccChip + where + C::Scalar: PrimeField, + C::Base: PrimeField, + { + fn encode( + &self, + _: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result, crate::Error> { + Ok(vec![ec_point.x().native().clone(), ec_point.y().native().clone()]) + } + } +} + +/* +mod halo2_wrong { + use crate::system::halo2::transcript::halo2::NativeEncoding; + use halo2_curves::CurveAffine; + use halo2_proofs::circuit::AssignedCell; + use halo2_wrong_ecc::BaseFieldEccChip; + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> NativeEncoding<'a, C> + for BaseFieldEccChip + { + fn encode( + &self, + _: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result>, crate::Error> { + Ok(vec![ + ec_point.x().native().clone(), + ec_point.y().native().clone(), + ]) + } + } +} +*/ diff --git a/snark-verifier/src/util.rs b/snark-verifier/src/util.rs new file mode 100644 index 00000000..b42db61c --- /dev/null +++ b/snark-verifier/src/util.rs @@ -0,0 +1,47 @@ +pub mod arithmetic; +pub mod hash; +pub mod msm; +pub mod poly; +pub mod protocol; +pub mod transcript; + +pub(crate) use itertools::Itertools; + +#[cfg(feature = "parallel")] +pub(crate) use rayon::current_num_threads; + +pub fn parallelize_iter(iter: I, f: F) +where + I: Send + Iterator, + T: Send, + F: Fn(T) + Send + Sync + Clone, +{ + #[cfg(feature = "parallel")] + rayon::scope(|scope| { + for item in iter { + let f = f.clone(); + scope.spawn(move |_| f(item)); + } + }); + #[cfg(not(feature = "parallel"))] + iter.for_each(f); +} + +pub fn parallelize(v: &mut [T], f: F) +where + T: Send, + F: Fn((&mut [T], usize)) + Send + Sync + Clone, +{ + #[cfg(feature = "parallel")] + { + let num_threads = current_num_threads(); + let chunk_size = v.len() / num_threads; + if chunk_size < num_threads { + f((v, 0)); + } else { + parallelize_iter(v.chunks_mut(chunk_size).zip((0..).step_by(chunk_size)), f); + } + } + #[cfg(not(feature = "parallel"))] + f((v, 0)); +} diff --git a/src/util/arithmetic.rs b/snark-verifier/src/util/arithmetic.rs similarity index 87% rename from src/util/arithmetic.rs rename to snark-verifier/src/util/arithmetic.rs index 36d42cf5..7a795339 100644 --- a/src/util/arithmetic.rs +++ b/snark-verifier/src/util/arithmetic.rs @@ -8,6 +8,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use crate::halo2_curves; pub use halo2_curves::{ group::{ ff::{BatchInvert, Field, PrimeField}, @@ -122,8 +123,8 @@ impl Domain { pub fn rotate_scalar(&self, scalar: F, rotation: Rotation) -> F { match rotation.0.cmp(&0) { Ordering::Equal => scalar, - Ordering::Greater => scalar * self.gen.pow_vartime(&[rotation.0 as u64]), - Ordering::Less => scalar * self.gen_inv.pow_vartime(&[(-rotation.0) as u64]), + Ordering::Greater => scalar * self.gen.pow_vartime([rotation.0 as u64]), + Ordering::Less => scalar * self.gen_inv.pow_vartime([(-rotation.0) as u64]), } } } @@ -170,19 +171,23 @@ impl Fraction { self.eval = Some( self.numer - .as_ref() - .map(|numer| numer.clone() * &self.denom) + .take() + .map(|numer| numer * &self.denom) .unwrap_or_else(|| self.denom.clone()), ); } pub fn evaluated(&self) -> &T { - assert!(self.inv); + assert!(self.eval.is_some()); self.eval.as_ref().unwrap() } } +pub fn ilog2(value: usize) -> usize { + (usize::BITS - value.leading_zeros() - 1) as usize +} + pub fn modulus() -> BigUint { fe_to_big(-F::one()) + 1usize } @@ -221,19 +226,24 @@ pub fn fe_to_limbs [F2; LIMBS] { let big = BigUint::from_bytes_le(fe.to_repr().as_ref()); - let mask = (BigUint::one() << BITS) - 1usize; + let mask = &((BigUint::one() << BITS) - 1usize); (0usize..) .step_by(BITS) .take(LIMBS) - .map(move |shift| fe_from_big((&big >> shift) & &mask)) + .map(|shift| fe_from_big((&big >> shift) & mask)) .collect_vec() .try_into() .unwrap() } -pub fn powers(scalar: F) -> impl Iterator -where - for<'a> F: Mul<&'a F, Output = F> + One + Clone, -{ - iter::successors(Some(F::one()), move |power| Some(scalar.clone() * power)) +pub fn powers(scalar: F) -> impl Iterator { + iter::successors(Some(F::one()), move |power| Some(scalar * power)) +} + +pub fn inner_product(lhs: &[F], rhs: &[F]) -> F { + lhs.iter() + .zip_eq(rhs.iter()) + .map(|(lhs, rhs)| *lhs * rhs) + .reduce(|acc, product| acc + product) + .unwrap_or_default() } diff --git a/snark-verifier/src/util/hash.rs b/snark-verifier/src/util/hash.rs new file mode 100644 index 00000000..17ede0b3 --- /dev/null +++ b/snark-verifier/src/util/hash.rs @@ -0,0 +1,6 @@ +mod poseidon; + +pub use crate::util::hash::poseidon::Poseidon; + +#[cfg(feature = "loader_evm")] +pub use sha3::{Digest, Keccak256}; diff --git a/snark-verifier/src/util/hash/poseidon.rs b/snark-verifier/src/util/hash/poseidon.rs new file mode 100644 index 00000000..fa7442f4 --- /dev/null +++ b/snark-verifier/src/util/hash/poseidon.rs @@ -0,0 +1,165 @@ +use crate::poseidon::{self, SparseMDSMatrix, Spec}; +use crate::{ + loader::{LoadedScalar, ScalarLoader}, + util::{arithmetic::FieldExt, Itertools}, +}; +use std::{iter, marker::PhantomData, mem}; + +#[derive(Clone)] +struct State { + inner: [L; T], + _marker: PhantomData, +} + +impl, const T: usize, const RATE: usize> State { + fn new(inner: [L; T]) -> Self { + Self { inner, _marker: PhantomData } + } + + fn loader(&self) -> &L::Loader { + self.inner[0].loader() + } + + fn power5_with_constant(value: &L, constant: &F) -> L { + value.loader().sum_products_with_const(&[(value, &value.square().square())], *constant) + } + + fn sbox_full(&mut self, constants: &[F; T]) { + for (state, constant) in self.inner.iter_mut().zip(constants.iter()) { + *state = Self::power5_with_constant(state, constant); + } + } + + fn sbox_part(&mut self, constant: &F) { + self.inner[0] = Self::power5_with_constant(&self.inner[0], constant); + } + + fn absorb_with_pre_constants(&mut self, inputs: &[L], pre_constants: &[F; T]) { + assert!(inputs.len() < T); + + self.inner[0] = self.loader().sum_with_const(&[&self.inner[0]], pre_constants[0]); + self.inner.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs).for_each( + |((state, constant), input)| { + *state = state.loader().sum_with_const(&[state, input], *constant); + }, + ); + self.inner + .iter_mut() + .zip(pre_constants.iter()) + .skip(1 + inputs.len()) + .enumerate() + .for_each(|(idx, (state, constant))| { + *state = state.loader().sum_with_const( + &[state], + if idx == 0 { F::one() + constant } else { *constant }, + ); + }); + } + + fn apply_mds(&mut self, mds: &[[F; T]; T]) { + self.inner = mds + .iter() + .map(|row| { + self.loader() + .sum_with_coeff(&row.iter().cloned().zip(self.inner.iter()).collect_vec()) + }) + .collect_vec() + .try_into() + .unwrap(); + } + + fn apply_sparse_mds(&mut self, mds: &SparseMDSMatrix) { + self.inner = iter::once( + self.loader() + .sum_with_coeff(&mds.row().iter().cloned().zip(self.inner.iter()).collect_vec()), + ) + .chain(mds.col_hat().iter().zip(self.inner.iter().skip(1)).map(|(coeff, state)| { + self.loader().sum_with_coeff(&[(*coeff, &self.inner[0]), (F::one(), state)]) + })) + .collect_vec() + .try_into() + .unwrap(); + } +} + +pub struct Poseidon { + spec: Spec, + default_state: State, + state: State, + buf: Vec, +} + +impl, const T: usize, const RATE: usize> Poseidon { + pub fn new(loader: &L::Loader, r_f: usize, r_p: usize) -> Self { + let default_state = + State::new(poseidon::State::default().words().map(|state| loader.load_const(&state))); + Self { + spec: Spec::new(r_f, r_p), + state: default_state.clone(), + default_state, + buf: Vec::new(), + } + } + + pub fn from_spec(loader: &L::Loader, spec: Spec) -> Self { + let default_state = + State::new(poseidon::State::default().words().map(|state| loader.load_const(&state))); + Self { spec, state: default_state.clone(), default_state, buf: Vec::new() } + } + + pub fn clear(&mut self) { + self.state = self.default_state.clone(); + self.buf.clear(); + } + + pub fn update(&mut self, elements: &[L]) { + self.buf.extend_from_slice(elements); + } + + pub fn squeeze(&mut self) -> L { + let buf = mem::take(&mut self.buf); + let exact = buf.len() % RATE == 0; + + for chunk in buf.chunks(RATE) { + self.permutation(chunk); + } + if exact { + self.permutation(&[]); + } + + self.state.inner[1].clone() + } + + fn permutation(&mut self, inputs: &[L]) { + let r_f = self.spec.r_f() / 2; + let mds = self.spec.mds_matrices().mds().rows(); + let pre_sparse_mds = self.spec.mds_matrices().pre_sparse_mds().rows(); + let sparse_matrices = self.spec.mds_matrices().sparse_matrices(); + + // First half of the full rounds + let constants = self.spec.constants().start(); + self.state.absorb_with_pre_constants(inputs, &constants[0]); + for constants in constants.iter().skip(1).take(r_f - 1) { + self.state.sbox_full(constants); + self.state.apply_mds(&mds); + } + self.state.sbox_full(constants.last().unwrap()); + self.state.apply_mds(&pre_sparse_mds); + + // Partial rounds + let constants = self.spec.constants().partial(); + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.state.sbox_part(constant); + self.state.apply_sparse_mds(sparse_mds); + } + + // Second half of the full rounds + let constants = self.spec.constants().end(); + for constants in constants.iter() { + self.state.sbox_full(constants); + self.state.apply_mds(&mds); + } + self.state.sbox_full(&[F::zero(); T]); + self.state.apply_mds(&mds); + } +} diff --git a/snark-verifier/src/util/msm.rs b/snark-verifier/src/util/msm.rs new file mode 100644 index 00000000..014a29e8 --- /dev/null +++ b/snark-verifier/src/util/msm.rs @@ -0,0 +1,332 @@ +use crate::{ + loader::{LoadedEcPoint, Loader}, + util::{ + arithmetic::{CurveAffine, Group, PrimeField}, + Itertools, + }, +}; +use num_integer::Integer; +use std::{ + default::Default, + iter::{self, Sum}, + mem::size_of, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +#[derive(Clone, Debug)] +pub struct Msm<'a, C: CurveAffine, L: Loader> { + constant: Option, + scalars: Vec, + bases: Vec<&'a L::LoadedEcPoint>, +} + +impl<'a, C, L> Default for Msm<'a, C, L> +where + C: CurveAffine, + L: Loader, +{ + fn default() -> Self { + Self { + constant: None, + scalars: Vec::new(), + bases: Vec::new(), + } + } +} + +impl<'a, C, L> Msm<'a, C, L> +where + C: CurveAffine, + L: Loader, +{ + pub fn constant(constant: L::LoadedScalar) -> Self { + Msm { + constant: Some(constant), + ..Default::default() + } + } + + pub fn base<'b: 'a>(base: &'b L::LoadedEcPoint) -> Self { + let one = base.loader().load_one(); + Msm { + scalars: vec![one], + bases: vec![base], + ..Default::default() + } + } + + pub(crate) fn size(&self) -> usize { + self.bases.len() + } + + pub(crate) fn split(mut self) -> (Self, Option) { + let constant = self.constant.take(); + (self, constant) + } + + pub(crate) fn try_into_constant(self) -> Option { + self.bases.is_empty().then(|| self.constant.unwrap()) + } + + pub fn evaluate(self, gen: Option) -> L::LoadedEcPoint { + let gen = gen.map(|gen| { + self.bases + .first() + .unwrap() + .loader() + .ec_point_load_const(&gen) + }); + let pairs = iter::empty() + .chain( + self.constant + .as_ref() + .map(|constant| (constant, gen.as_ref().unwrap())), + ) + .chain(self.scalars.iter().zip(self.bases.into_iter())) + .collect_vec(); + L::multi_scalar_multiplication(&pairs) + } + + pub fn scale(&mut self, factor: &L::LoadedScalar) { + if let Some(constant) = self.constant.as_mut() { + *constant *= factor; + } + for scalar in self.scalars.iter_mut() { + *scalar *= factor + } + } + + pub fn push<'b: 'a>(&mut self, scalar: L::LoadedScalar, base: &'b L::LoadedEcPoint) { + if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { + self.scalars[pos] += &scalar; + } else { + self.scalars.push(scalar); + self.bases.push(base); + } + } + + pub fn extend<'b: 'a>(&mut self, mut other: Msm<'b, C, L>) { + match (self.constant.as_mut(), other.constant.as_ref()) { + (Some(lhs), Some(rhs)) => *lhs += rhs, + (None, Some(_)) => self.constant = other.constant.take(), + _ => {} + }; + for (scalar, base) in other.scalars.into_iter().zip(other.bases) { + self.push(scalar, base); + } + } +} + +impl<'a, 'b, C, L> Add> for Msm<'a, C, L> +where + 'b: 'a, + C: CurveAffine, + L: Loader, +{ + type Output = Msm<'a, C, L>; + + fn add(mut self, rhs: Msm<'b, C, L>) -> Self::Output { + self.extend(rhs); + self + } +} + +impl<'a, 'b, C, L> AddAssign> for Msm<'a, C, L> +where + 'b: 'a, + C: CurveAffine, + L: Loader, +{ + fn add_assign(&mut self, rhs: Msm<'b, C, L>) { + self.extend(rhs); + } +} + +impl<'a, 'b, C, L> Sub> for Msm<'a, C, L> +where + 'b: 'a, + C: CurveAffine, + L: Loader, +{ + type Output = Msm<'a, C, L>; + + fn sub(mut self, rhs: Msm<'b, C, L>) -> Self::Output { + self.extend(-rhs); + self + } +} + +impl<'a, 'b, C, L> SubAssign> for Msm<'a, C, L> +where + 'b: 'a, + C: CurveAffine, + L: Loader, +{ + fn sub_assign(&mut self, rhs: Msm<'b, C, L>) { + self.extend(-rhs); + } +} + +impl<'a, C, L> Mul<&L::LoadedScalar> for Msm<'a, C, L> +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm<'a, C, L>; + + fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { + self.scale(rhs); + self + } +} + +impl<'a, C, L> MulAssign<&L::LoadedScalar> for Msm<'a, C, L> +where + C: CurveAffine, + L: Loader, +{ + fn mul_assign(&mut self, rhs: &L::LoadedScalar) { + self.scale(rhs); + } +} + +impl<'a, C, L> Neg for Msm<'a, C, L> +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm<'a, C, L>; + fn neg(mut self) -> Msm<'a, C, L> { + self.constant = self.constant.map(|constant| -constant); + for scalar in self.scalars.iter_mut() { + *scalar = -scalar.clone(); + } + self + } +} + +impl<'a, C, L> Sum for Msm<'a, C, L> +where + C: CurveAffine, + L: Loader, +{ + fn sum>(iter: I) -> Self { + iter.reduce(|acc, item| acc + item).unwrap_or_default() + } +} + +#[derive(Clone, Copy)] +enum Bucket { + None, + Affine(C), + Projective(C::Curve), +} + +impl Bucket { + fn add_assign(&mut self, rhs: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*rhs), + Bucket::Affine(lhs) => Bucket::Projective(lhs + *rhs), + Bucket::Projective(mut lhs) => { + lhs += *rhs; + Bucket::Projective(lhs) + } + } + } + + fn add(self, mut rhs: C::Curve) -> C::Curve { + match self { + Bucket::None => rhs, + Bucket::Affine(lhs) => { + rhs += lhs; + rhs + } + Bucket::Projective(lhs) => lhs + rhs, + } + } +} + +fn multi_scalar_multiplication_serial( + scalars: &[C::Scalar], + bases: &[C], + result: &mut C::Curve, +) { + let scalars = scalars.iter().map(|scalar| scalar.to_repr()).collect_vec(); + let num_bytes = scalars[0].as_ref().len(); + let num_bits = 8 * num_bytes; + + let window_size = (scalars.len() as f64).ln().ceil() as usize + 2; + let num_buckets = (1 << window_size) - 1; + + let windowed_scalar = |idx: usize, bytes: &::Repr| { + let skip_bits = idx * window_size; + let skip_bytes = skip_bits / 8; + + let mut value = [0; size_of::()]; + for (dst, src) in value.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *dst = *src; + } + + (usize::from_le_bytes(value) >> (skip_bits - (skip_bytes * 8))) & num_buckets + }; + + let num_window = Integer::div_ceil(&num_bits, &window_size); + for idx in (0..num_window).rev() { + for _ in 0..window_size { + *result = result.double(); + } + + let mut buckets = vec![Bucket::None; num_buckets]; + + for (scalar, base) in scalars.iter().zip(bases.iter()) { + let scalar = windowed_scalar(idx, scalar); + if scalar != 0 { + buckets[scalar - 1].add_assign(base); + } + } + + let mut running_sum = C::Curve::identity(); + for bucket in buckets.into_iter().rev() { + running_sum = bucket.add(running_sum); + *result += &running_sum; + } + } +} + +// Copy from https://github.com/zcash/halo2/blob/main/halo2_proofs/src/arithmetic.rs +pub fn multi_scalar_multiplication(scalars: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(scalars.len(), bases.len()); + + #[cfg(feature = "parallel")] + { + use crate::util::{current_num_threads, parallelize_iter}; + + let num_threads = current_num_threads(); + if scalars.len() < num_threads { + let mut result = C::Curve::identity(); + multi_scalar_multiplication_serial(scalars, bases, &mut result); + return result; + } + + let chunk_size = Integer::div_ceil(&scalars.len(), &num_threads); + let mut results = vec![C::Curve::identity(); num_threads]; + parallelize_iter( + scalars + .chunks(chunk_size) + .zip(bases.chunks(chunk_size)) + .zip(results.iter_mut()), + |((scalars, bases), result)| { + multi_scalar_multiplication_serial(scalars, bases, result); + }, + ); + results + .iter() + .fold(C::Curve::identity(), |acc, result| acc + result) + } + #[cfg(not(feature = "parallel"))] + { + let mut result = C::Curve::identity(); + multi_scalar_multiplication_serial(scalars, bases, &mut result); + result + } +} diff --git a/snark-verifier/src/util/poly.rs b/snark-verifier/src/util/poly.rs new file mode 100644 index 00000000..ea120b33 --- /dev/null +++ b/snark-verifier/src/util/poly.rs @@ -0,0 +1,175 @@ +use crate::util::{arithmetic::Field, parallelize}; +use rand::Rng; +use std::{ + iter::{self, Sum}, + ops::{ + Add, Index, IndexMut, Mul, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, + RangeToInclusive, Sub, + }, +}; + +#[derive(Clone, Debug)] +pub struct Polynomial(Vec); + +impl Polynomial { + pub fn new(inner: Vec) -> Self { + Self(inner) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn iter_mut(&mut self) -> impl Iterator { + self.0.iter_mut() + } + + pub fn to_vec(self) -> Vec { + self.0 + } +} + +impl Polynomial { + pub fn rand(n: usize, mut rng: R) -> Self { + Self::new(iter::repeat_with(|| F::random(&mut rng)).take(n).collect()) + } + + pub fn evaluate(&self, x: F) -> F { + let evaluate_serial = |coeffs: &[F]| { + coeffs + .iter() + .rev() + .fold(F::zero(), |acc, coeff| acc * x + coeff) + }; + + #[cfg(feature = "parallel")] + { + use crate::util::{arithmetic::powers, current_num_threads, parallelize_iter}; + use num_integer::Integer; + + let num_threads = current_num_threads(); + if self.len() * 2 < num_threads { + return evaluate_serial(&self.0); + } + + let chunk_size = Integer::div_ceil(&self.len(), &num_threads); + let mut results = vec![F::zero(); num_threads]; + parallelize_iter( + results + .iter_mut() + .zip(self.0.chunks(chunk_size)) + .zip(powers(x.pow_vartime(&[chunk_size as u64, 0, 0, 0]))), + |((result, coeffs), scalar)| *result = evaluate_serial(coeffs) * scalar, + ); + results.iter().fold(F::zero(), |acc, result| acc + result) + } + #[cfg(not(feature = "parallel"))] + evaluate_serial(&self.0) + } +} + +impl<'a, F: Field> Add<&'a Polynomial> for Polynomial { + type Output = Polynomial; + + fn add(mut self, rhs: &'a Polynomial) -> Polynomial { + parallelize(&mut self.0, |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs.0[start..].iter()) { + *lhs += *rhs; + } + }); + self + } +} + +impl<'a, F: Field> Sub<&'a Polynomial> for Polynomial { + type Output = Polynomial; + + fn sub(mut self, rhs: &'a Polynomial) -> Polynomial { + parallelize(&mut self.0, |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs.0[start..].iter()) { + *lhs -= *rhs; + } + }); + self + } +} + +impl Sub for Polynomial { + type Output = Polynomial; + + fn sub(mut self, rhs: F) -> Polynomial { + self.0[0] -= rhs; + self + } +} + +impl Add for Polynomial { + type Output = Polynomial; + + fn add(mut self, rhs: F) -> Polynomial { + self.0[0] += rhs; + self + } +} + +impl Mul for Polynomial { + type Output = Polynomial; + + fn mul(mut self, rhs: F) -> Polynomial { + if rhs == F::zero() { + return Polynomial::new(vec![F::zero(); self.len()]); + } + if rhs == F::one() { + return self; + } + parallelize(&mut self.0, |(lhs, _)| { + for lhs in lhs.iter_mut() { + *lhs *= rhs; + } + }); + self + } +} + +impl Sum for Polynomial { + fn sum>(iter: I) -> Self { + iter.reduce(|acc, item| acc + &item).unwrap() + } +} + +macro_rules! impl_index { + ($($range:ty => $output:ty,)*) => { + $( + impl Index<$range> for Polynomial { + type Output = $output; + + fn index(&self, index: $range) -> &$output { + self.0.index(index) + } + } + impl IndexMut<$range> for Polynomial { + fn index_mut(&mut self, index: $range) -> &mut $output { + self.0.index_mut(index) + } + } + )* + }; +} + +impl_index!( + usize => F, + Range => [F], + RangeFrom => [F], + RangeFull => [F], + RangeInclusive => [F], + RangeTo => [F], + RangeToInclusive => [F], +); diff --git a/src/util/protocol.rs b/snark-verifier/src/util/protocol.rs similarity index 88% rename from src/util/protocol.rs rename to snark-verifier/src/util/protocol.rs index 6996aa59..f7747060 100644 --- a/src/util/protocol.rs +++ b/snark-verifier/src/util/protocol.rs @@ -4,7 +4,9 @@ use crate::{ arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, Itertools, }, + Protocol, }; +use num_integer::Integer; use num_traits::One; use std::{ cmp::max, @@ -14,6 +16,37 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; +impl Protocol +where + C: CurveAffine, +{ + pub fn loaded>(&self, loader: &L) -> Protocol { + let preprocessed = self + .preprocessed + .iter() + .map(|preprocessed| loader.ec_point_load_const(preprocessed)) + .collect(); + let transcript_initial_state = self + .transcript_initial_state + .as_ref() + .map(|transcript_initial_state| loader.load_const(transcript_initial_state)); + Protocol { + domain: self.domain.clone(), + preprocessed, + num_instance: self.num_instance.clone(), + num_witness: self.num_witness.clone(), + num_challenge: self.num_challenge.clone(), + evaluations: self.evaluations.clone(), + queries: self.queries.clone(), + quotient: self.quotient.clone(), + transcript_initial_state, + instance_committing_key: self.instance_committing_key.clone(), + linearization: self.linearization, + accumulator_indices: self.accumulator_indices.clone(), + } + } +} + #[derive(Clone, Copy, Debug)] pub enum CommonPolynomial { Identity, @@ -49,11 +82,11 @@ where let langranges = langranges.into_iter().sorted().dedup().collect_vec(); let one = loader.load_one(); - let zn_minus_one = zn.clone() - one; + let zn_minus_one = zn.clone() - &one; let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); let n_inv = loader.load_const(&domain.n_inv); - let numer = zn_minus_one.clone() * n_inv; + let numer = zn_minus_one.clone() * &n_inv; let omegas = langranges .iter() .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) @@ -116,7 +149,7 @@ pub struct QuotientPolynomial { impl QuotientPolynomial { pub fn num_chunk(&self) -> usize { - (self.numerator.degree() - 1).div_ceil(self.chunk_degree) + Integer::div_ceil(&(self.numerator.degree() - 1), &self.chunk_degree) } } @@ -345,7 +378,7 @@ fn merge_left_right(a: Option>, b: Option>) -> O } } -#[derive(Clone, Debug)] +#[derive(Clone, Copy, Debug)] pub enum LinearizationStrategy { /// Older linearization strategy of GWC19, which has linearization /// polynomial that doesn't evaluate to 0, and requires prover to send extra diff --git a/src/util/transcript.rs b/snark-verifier/src/util/transcript.rs similarity index 100% rename from src/util/transcript.rs rename to snark-verifier/src/util/transcript.rs diff --git a/src/verifier.rs b/snark-verifier/src/verifier.rs similarity index 87% rename from src/verifier.rs rename to snark-verifier/src/verifier.rs index 51e05382..0eef23d2 100644 --- a/src/verifier.rs +++ b/snark-verifier/src/verifier.rs @@ -20,7 +20,7 @@ where fn read_proof( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -29,7 +29,7 @@ where fn succinct_verify( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result, Error>; @@ -37,15 +37,14 @@ where fn verify( svk: &MOS::SuccinctVerifyingKey, dk: &MOS::DecidingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result where MOS: Decider, { - let accumulators = Self::succinct_verify(svk, protocol, instances, proof) - .expect("succinct verify should not fail"); + let accumulators = Self::succinct_verify(svk, protocol, instances, proof)?; let output = MOS::decide_all(dk, accumulators); Ok(output) } diff --git a/src/verifier/plonk.rs b/snark-verifier/src/verifier/plonk.rs similarity index 82% rename from src/verifier/plonk.rs rename to snark-verifier/src/verifier/plonk.rs index 7d653137..6ed1d550 100644 --- a/src/verifier/plonk.rs +++ b/snark-verifier/src/verifier/plonk.rs @@ -29,7 +29,7 @@ where fn read_proof( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -41,7 +41,7 @@ where fn succinct_verify( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result, Error> { @@ -52,7 +52,7 @@ where &proof.z, ); - L::LoadedScalar::batch_invert(common_poly_eval.denoms()); + L::batch_invert(common_poly_eval.denoms()); common_poly_eval.evaluate(); common_poly_eval @@ -96,9 +96,9 @@ where L: Loader, MOS: MultiOpenScheme, { - fn read( + pub fn read( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -106,18 +106,30 @@ where T: TranscriptRead, AE: AccumulatorEncoding, { - if protocol.num_instance != instances.iter().map(|instances| instances.len()).collect_vec() + if let Some(transcript_initial_state) = &protocol.transcript_initial_state { + transcript.common_scalar(transcript_initial_state)?; + } + + if protocol.num_instance + != instances + .iter() + .map(|instances| instances.len()) + .collect_vec() { return Err(Error::InvalidInstances); } let committed_instances = if let Some(ick) = &protocol.instance_committing_key { - // this case is synonymous with KZG_QUERY_INSTANCE = true let loader = transcript.loader(); - let bases = - ick.bases.iter().map(|value| loader.ec_point_load_const(value)).collect_vec(); - // constant not used in kzg: - // let constant = ick.constant.as_ref().map(|value| loader.ec_point_load_const(value)); + let bases = ick + .bases + .iter() + .map(|value| loader.ec_point_load_const(value)) + .collect_vec(); + let constant = ick + .constant + .as_ref() + .map(|value| loader.ec_point_load_const(value)); let committed_instances = instances .iter() @@ -125,29 +137,18 @@ where instances .iter() .zip(bases.iter()) - .map(|(scalar, base)| Msm::::base(base.clone()) * scalar) - // .chain(constant.clone().map(|constant| Msm::base(constant))) + .map(|(scalar, base)| Msm::::base(base) * scalar) + .chain(constant.as_ref().map(Msm::base)) .sum::>() .evaluate(None) }) .collect_vec(); - - // For EvmTranscript we need to hash in vk here so that the buffer is reset after the above MSMs - if let Some(transcript_initial_state) = &protocol.transcript_initial_state { - transcript.common_scalar(&loader.load_const(transcript_initial_state))?; - } - for committed_instance in committed_instances.iter() { transcript.common_ec_point(committed_instance)?; } Some(committed_instances) } else { - if let Some(transcript_initial_state) = &protocol.transcript_initial_state { - transcript - .common_scalar(&transcript.loader().load_const(transcript_initial_state))?; - } - for instances in instances.iter() { for instance in instances.iter() { transcript.common_scalar(instance)?; @@ -163,7 +164,10 @@ where .iter() .zip(protocol.num_challenge.iter()) .map(|(&n, &m)| { - Ok((transcript.read_n_ec_points(n)?, transcript.squeeze_n_challenges(m))) + Ok(( + transcript.read_n_ec_points(n)?, + transcript.squeeze_n_challenges(m), + )) }) .collect::, Error>>()? .into_iter() @@ -186,9 +190,13 @@ where .accumulator_indices .iter() .map(|accumulator_indices| { - accumulator_indices.iter().map(|&(i, j)| instances[i][j].clone()).collect() + AE::from_repr( + &accumulator_indices + .iter() + .map(|&(i, j)| &instances[i][j]) + .collect_vec(), + ) }) - .map(AE::from_repr) .collect::, _>>()?; Ok(Self { @@ -203,13 +211,15 @@ where }) } - fn empty_queries(protocol: &Protocol) -> Vec> { + pub fn empty_queries(protocol: &Protocol) -> Vec> { protocol .queries .iter() .map(|query| pcs::Query { poly: query.poly, - shift: protocol.domain.rotate_scalar(C::Scalar::one(), query.rotation), + shift: protocol + .domain + .rotate_scalar(C::Scalar::one(), query.rotation), eval: (), }) .collect() @@ -217,35 +227,35 @@ where fn queries( &self, - protocol: &Protocol, + protocol: &Protocol, mut evaluations: HashMap, ) -> Vec> { Self::empty_queries(protocol) .into_iter() - .zip(protocol.queries.iter().map(|query| evaluations.remove(query).unwrap())) + .zip( + protocol + .queries + .iter() + .map(|query| evaluations.remove(query).unwrap()), + ) .map(|(query, eval)| query.with_evaluation(eval)) .collect() } - fn commitments( - &self, - protocol: &Protocol, + fn commitments<'a>( + &'a self, + protocol: &'a Protocol, common_poly_eval: &CommonPolynomialEvaluation, evaluations: &mut HashMap, ) -> Result>, Error> { let loader = common_poly_eval.zn().loader(); let mut commitments = iter::empty() - .chain( - protocol - .preprocessed - .iter() - .map(|value| Msm::base(loader.ec_point_load_const(value))), - ) + .chain(protocol.preprocessed.iter().map(Msm::base)) .chain( self.committed_instances - .clone() + .as_ref() .map(|committed_instances| { - committed_instances.into_iter().map(Msm::base).collect_vec() + committed_instances.iter().map(Msm::base).collect_vec() }) .unwrap_or_else(|| { iter::repeat_with(Default::default) @@ -253,7 +263,7 @@ where .collect_vec() }), ) - .chain(self.witnesses.iter().cloned().map(Msm::base)) + .chain(self.witnesses.iter().map(Msm::base)) .collect_vec(); let numerator = protocol.quotient.numerator.evaluate( @@ -300,7 +310,7 @@ where .pow_const(protocol.quotient.chunk_degree as u64) .powers(self.quotients.len()) .into_iter() - .zip(self.quotients.iter().cloned().map(Msm::base)) + .zip(self.quotients.iter().map(Msm::base)) .map(|(coeff, chunk)| chunk * &coeff) .sum::>(); match protocol.linearization { @@ -320,13 +330,18 @@ where let (msm, constant) = (numerator - quotient * common_poly_eval.zn_minus_one()).split(); commitments.push(msm); - evaluations.insert(quotient_query, constant.unwrap_or_else(|| loader.load_zero())); + evaluations.insert( + quotient_query, + constant.unwrap_or_else(|| loader.load_zero()), + ); } None => { commitments.push(quotient); evaluations.insert( quotient_query, - numerator.try_into_constant().ok_or(Error::InvalidLinearization)? + numerator + .try_into_constant() + .ok_or(Error::InvalidLinearization)? * common_poly_eval.zn_minus_one_inv(), ); } @@ -337,7 +352,7 @@ where fn evaluations( &self, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], common_poly_eval: &CommonPolynomialEvaluation, ) -> Result, Error> { @@ -366,7 +381,13 @@ where let evals = iter::empty() .chain(instance_evals.into_iter().flatten()) - .chain(protocol.evaluations.iter().cloned().zip(self.evaluations.iter().cloned())) + .chain( + protocol + .evaluations + .iter() + .cloned() + .zip(self.evaluations.iter().cloned()), + ) .collect(); Ok(evals) @@ -398,9 +419,13 @@ where } } -fn langranges(protocol: &Protocol, instances: &[Vec]) -> impl IntoIterator +fn langranges( + protocol: &Protocol, + instances: &[Vec], +) -> impl IntoIterator where C: CurveAffine, + L: Loader, { let instance_eval_lagrange = protocol.instance_committing_key.is_none().then(|| { let queries = { @@ -423,7 +448,7 @@ where } }); let max_instance_len = - instances.iter().map(|instance| instance.len()).max().unwrap_or_default(); + Iterator::max(instances.iter().map(|instance| instance.len())).unwrap_or_default(); -max_rotation..max_instance_len as i32 + min_rotation.abs() }); protocol diff --git a/src/loader/evm/code.rs b/src/loader/evm/code.rs deleted file mode 100644 index 80dd5c71..00000000 --- a/src/loader/evm/code.rs +++ /dev/null @@ -1,295 +0,0 @@ -use crate::util::Itertools; -use ethereum_types::U256; -use std::{collections::HashMap, iter}; - -pub enum Precompiled { - BigModExp = 0x05, - Bn254Add = 0x6, - Bn254ScalarMul = 0x7, - Bn254Pairing = 0x8, -} - -#[derive(Clone, Debug)] -pub struct Code { - code: Vec, - constants: HashMap, - stack_len: usize, -} - -impl Code { - pub fn new(constants: impl IntoIterator) -> Self { - let mut code = Self { - code: Vec::new(), - constants: HashMap::new(), - stack_len: 0, - }; - let constants = constants.into_iter().collect_vec(); - for constant in constants.iter() { - code.push(*constant); - } - code.constants = HashMap::from_iter( - constants - .into_iter() - .enumerate() - .map(|(idx, value)| (value, idx)), - ); - code - } - - pub fn deployment(code: Vec) -> Vec { - let code_len = code.len(); - assert_ne!(code_len, 0); - - iter::empty() - .chain([ - PUSH1 + 1, - (code_len >> 8) as u8, - (code_len & 0xff) as u8, - PUSH1, - 14, - PUSH1, - 0, - CODECOPY, - ]) - .chain([ - PUSH1 + 1, - (code_len >> 8) as u8, - (code_len & 0xff) as u8, - PUSH1, - 0, - RETURN, - ]) - .chain(code) - .collect() - } - - pub fn stack_len(&self) -> usize { - self.stack_len - } - - pub fn len(&self) -> usize { - self.code.len() - } - - pub fn is_empty(&self) -> bool { - self.code.is_empty() - } - - pub fn push>(&mut self, value: T) -> &mut Self { - let value = value.into(); - match self.constants.get(&value) { - Some(idx) if (0..16).contains(&(self.stack_len - idx - 1)) => { - self.dup(self.stack_len - idx - 1); - } - _ => { - let mut bytes = vec![0; 32]; - value.to_big_endian(&mut bytes); - let bytes = bytes - .iter() - .position(|byte| *byte != 0) - .map_or(vec![0], |pos| bytes.drain(pos..).collect()); - self.code.push(PUSH1 - 1 + bytes.len() as u8); - self.code.extend(bytes); - self.stack_len += 1; - } - } - self - } - - pub fn dup(&mut self, pos: usize) -> &mut Self { - assert!((0..16).contains(&pos)); - self.code.push(DUP1 + pos as u8); - self.stack_len += 1; - self - } - - pub fn swap(&mut self, pos: usize) -> &mut Self { - assert!((1..17).contains(&pos)); - self.code.push(SWAP1 - 1 + pos as u8); - self - } -} - -impl From for Vec { - fn from(code: Code) -> Self { - code.code - } -} - -macro_rules! impl_opcodes { - ($($method:ident -> ($opcode:ident, $stack_len_diff:expr))*) => { - $( - #[allow(dead_code)] - impl Code { - pub fn $method(&mut self) -> &mut Self { - self.code.push($opcode); - self.stack_len = ((self.stack_len as isize) + $stack_len_diff) as usize; - self - } - } - )* - }; -} - -impl_opcodes!( - stop -> (STOP, 0) - add -> (ADD, -1) - mul -> (MUL, -1) - sub -> (SUB, -1) - div -> (DIV, -1) - sdiv -> (SDIV, -1) - r#mod -> (MOD, -1) - smod -> (SMOD, -1) - addmod -> (ADDMOD, -2) - mulmod -> (MULMOD, -2) - exp -> (EXP, -1) - signextend -> (SIGNEXTEND, -1) - lt -> (LT, -1) - gt -> (GT, -1) - slt -> (SLT, -1) - sgt -> (SGT, -1) - eq -> (EQ, -1) - iszero -> (ISZERO, 0) - and -> (AND, -1) - or -> (OR, -1) - xor -> (XOR, -1) - not -> (NOT, 0) - byte -> (BYTE, -1) - shl -> (SHL, -1) - shr -> (SHR, -1) - sar -> (SAR, -1) - keccak256 -> (SHA3, -1) - address -> (ADDRESS, 1) - balance -> (BALANCE, 0) - origin -> (ORIGIN, 1) - caller -> (CALLER, 1) - callvalue -> (CALLVALUE, 1) - calldataload -> (CALLDATALOAD, 0) - calldatasize -> (CALLDATASIZE, 1) - calldatacopy -> (CALLDATACOPY, -3) - codesize -> (CODESIZE, 1) - codecopy -> (CODECOPY, -3) - gasprice -> (GASPRICE, 1) - extcodesize -> (EXTCODESIZE, 0) - extcodecopy -> (EXTCODECOPY, -4) - returndatasize -> (RETURNDATASIZE, 1) - returndatacopy -> (RETURNDATACOPY, -3) - extcodehash -> (EXTCODEHASH, 0) - blockhash -> (BLOCKHASH, 0) - coinbase -> (COINBASE, 1) - timestamp -> (TIMESTAMP, 1) - number -> (NUMBER, 1) - difficulty -> (DIFFICULTY, 1) - gaslimit -> (GASLIMIT, 1) - chainid -> (CHAINID, 1) - selfbalance -> (SELFBALANCE, 1) - basefee -> (BASEFEE, 1) - pop -> (POP, -1) - mload -> (MLOAD, 0) - mstore -> (MSTORE, -2) - mstore8 -> (MSTORE8, -2) - sload -> (SLOAD, 0) - sstore -> (SSTORE, -2) - jump -> (JUMP, -1) - jumpi -> (JUMPI, -2) - pc -> (PC, 1) - msize -> (MSIZE, 1) - gas -> (GAS, 1) - jumpdest -> (JUMPDEST, 0) - log0 -> (LOG0, -2) - log1 -> (LOG1, -3) - log2 -> (LOG2, -4) - log3 -> (LOG3, -5) - log4 -> (LOG4, -6) - create -> (CREATE, -2) - call -> (CALL, -6) - callcode -> (CALLCODE, -6) - r#return -> (RETURN, -2) - delegatecall -> (DELEGATECALL, -5) - create2 -> (CREATE2, -3) - staticcall -> (STATICCALL, -5) - revert -> (REVERT, -2) - selfdestruct -> (SELFDESTRUCT, -1) -); - -const STOP: u8 = 0x00; -const ADD: u8 = 0x01; -const MUL: u8 = 0x02; -const SUB: u8 = 0x03; -const DIV: u8 = 0x04; -const SDIV: u8 = 0x05; -const MOD: u8 = 0x06; -const SMOD: u8 = 0x07; -const ADDMOD: u8 = 0x08; -const MULMOD: u8 = 0x09; -const EXP: u8 = 0x0A; -const SIGNEXTEND: u8 = 0x0B; -const LT: u8 = 0x10; -const GT: u8 = 0x11; -const SLT: u8 = 0x12; -const SGT: u8 = 0x13; -const EQ: u8 = 0x14; -const ISZERO: u8 = 0x15; -const AND: u8 = 0x16; -const OR: u8 = 0x17; -const XOR: u8 = 0x18; -const NOT: u8 = 0x19; -const BYTE: u8 = 0x1A; -const SHL: u8 = 0x1B; -const SHR: u8 = 0x1C; -const SAR: u8 = 0x1D; -const SHA3: u8 = 0x20; -const ADDRESS: u8 = 0x30; -const BALANCE: u8 = 0x31; -const ORIGIN: u8 = 0x32; -const CALLER: u8 = 0x33; -const CALLVALUE: u8 = 0x34; -const CALLDATALOAD: u8 = 0x35; -const CALLDATASIZE: u8 = 0x36; -const CALLDATACOPY: u8 = 0x37; -const CODESIZE: u8 = 0x38; -const CODECOPY: u8 = 0x39; -const GASPRICE: u8 = 0x3A; -const EXTCODESIZE: u8 = 0x3B; -const EXTCODECOPY: u8 = 0x3C; -const RETURNDATASIZE: u8 = 0x3D; -const RETURNDATACOPY: u8 = 0x3E; -const EXTCODEHASH: u8 = 0x3F; -const BLOCKHASH: u8 = 0x40; -const COINBASE: u8 = 0x41; -const TIMESTAMP: u8 = 0x42; -const NUMBER: u8 = 0x43; -const DIFFICULTY: u8 = 0x44; -const GASLIMIT: u8 = 0x45; -const CHAINID: u8 = 0x46; -const SELFBALANCE: u8 = 0x47; -const BASEFEE: u8 = 0x48; -const POP: u8 = 0x50; -const MLOAD: u8 = 0x51; -const MSTORE: u8 = 0x52; -const MSTORE8: u8 = 0x53; -const SLOAD: u8 = 0x54; -const SSTORE: u8 = 0x55; -const JUMP: u8 = 0x56; -const JUMPI: u8 = 0x57; -const PC: u8 = 0x58; -const MSIZE: u8 = 0x59; -const GAS: u8 = 0x5A; -const JUMPDEST: u8 = 0x5B; -const PUSH1: u8 = 0x60; -const DUP1: u8 = 0x80; -const SWAP1: u8 = 0x90; -const LOG0: u8 = 0xA0; -const LOG1: u8 = 0xA1; -const LOG2: u8 = 0xA2; -const LOG3: u8 = 0xA3; -const LOG4: u8 = 0xA4; -const CREATE: u8 = 0xF0; -const CALL: u8 = 0xF1; -const CALLCODE: u8 = 0xF2; -const RETURN: u8 = 0xF3; -const DELEGATECALL: u8 = 0xF4; -const CREATE2: u8 = 0xF5; -const STATICCALL: u8 = 0xFA; -const REVERT: u8 = 0xFD; -const SELFDESTRUCT: u8 = 0xFF; diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs deleted file mode 100644 index 752ddae2..00000000 --- a/src/loader/evm/loader.rs +++ /dev/null @@ -1,940 +0,0 @@ -use crate::{ - loader::evm::{ - code::{Code, Precompiled}, - fe_to_u256, modulus, - }, - loader::{evm::u256_to_fe, EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{ - arithmetic::{CurveAffine, FieldOps, PrimeField}, - Itertools, - }, - Error, -}; -use ethereum_types::{U256, U512}; -use std::{ - cell::RefCell, - collections::HashMap, - fmt::{self, Debug}, - iter, - ops::{Add, AddAssign, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign}, - rc::Rc, -}; - -#[derive(Clone, Debug)] -pub enum Value { - Constant(T), - Memory(usize), - Negated(Box>), - Sum(Box>, Box>), - Product(Box>, Box>), -} - -impl PartialEq for Value { - fn eq(&self, other: &Self) -> bool { - self.identifier() == other.identifier() - } -} - -impl Value { - fn identifier(&self) -> String { - match &self { - Value::Constant(_) | Value::Memory(_) => format!("{:?}", self), - Value::Negated(value) => format!("-({:?})", value), - Value::Sum(lhs, rhs) => format!("({:?} + {:?})", lhs, rhs), - Value::Product(lhs, rhs) => format!("({:?} * {:?})", lhs, rhs), - } - } -} - -#[derive(Clone, Debug)] -pub struct EvmLoader { - base_modulus: U256, - scalar_modulus: U256, - code: RefCell, - ptr: RefCell, - cache: RefCell>, - #[cfg(test)] - gas_metering_ids: RefCell>, -} - -impl EvmLoader { - pub fn new() -> Rc - where - Base: PrimeField, - Scalar: PrimeField, - { - let base_modulus = modulus::(); - let scalar_modulus = modulus::(); - let code = Code::new([1.into(), base_modulus, scalar_modulus - 1, scalar_modulus]) - .push(1) - .to_owned(); - Rc::new(Self { - base_modulus, - scalar_modulus, - code: RefCell::new(code), - ptr: Default::default(), - cache: Default::default(), - #[cfg(test)] - gas_metering_ids: RefCell::new(Vec::new()), - }) - } - - pub fn deployment_code(self: &Rc) -> Vec { - Code::deployment(self.runtime_code()) - } - - pub fn runtime_code(self: &Rc) -> Vec { - let mut code = self.code.borrow().clone(); - let dst = code.len() + 9; - code.push(dst).jumpi().push(0).push(0).revert().jumpdest().stop().to_owned().into() - } - - pub fn allocate(self: &Rc, size: usize) -> usize { - let ptr = *self.ptr.borrow(); - *self.ptr.borrow_mut() += size; - ptr - } - - pub(crate) fn scalar_modulus(&self) -> U256 { - self.scalar_modulus - } - - pub(crate) fn ptr(&self) -> usize { - *self.ptr.borrow() - } - - pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { - self.code.borrow_mut() - } - - pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { - let value = if matches!(value, Value::Constant(_) | Value::Memory(_) | Value::Negated(_)) { - value - } else { - let identifier = value.identifier(); - let some_ptr = self.cache.borrow().get(&identifier).cloned(); - let ptr = if let Some(ptr) = some_ptr { - ptr - } else { - self.push(&Scalar { loader: self.clone(), value }); - let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); - self.cache.borrow_mut().insert(identifier, ptr); - ptr - }; - Value::Memory(ptr) - }; - Scalar { loader: self.clone(), value } - } - - fn ec_point(self: &Rc, value: Value<(U256, U256)>) -> EcPoint { - EcPoint { loader: self.clone(), value } - } - - fn push(self: &Rc, scalar: &Scalar) { - match scalar.value.clone() { - Value::Constant(constant) => { - self.code.borrow_mut().push(constant); - } - Value::Memory(ptr) => { - self.code.borrow_mut().push(ptr).mload(); - } - Value::Negated(value) => { - self.push(&self.scalar(*value)); - self.code.borrow_mut().push(self.scalar_modulus).sub(); - } - Value::Sum(lhs, rhs) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(*lhs)); - self.push(&self.scalar(*rhs)); - self.code.borrow_mut().addmod(); - } - Value::Product(lhs, rhs) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(*lhs)); - self.push(&self.scalar(*rhs)); - self.code.borrow_mut().mulmod(); - } - } - } - - pub fn calldataload_scalar(self: &Rc, offset: usize) -> Scalar { - let ptr = self.allocate(0x20); - self.code - .borrow_mut() - .push(self.scalar_modulus) - .push(offset) - .calldataload() - .r#mod() - .push(ptr) - .mstore(); - self.scalar(Value::Memory(ptr)) - } - - pub fn calldataload_ec_point(self: &Rc, offset: usize) -> EcPoint { - let ptr = self.allocate(0x40); - self.code - .borrow_mut() - // [..., success] - .push(offset) - // [..., success, x_cd_ptr] - .calldataload() - // [..., success, x] - .dup(0) - // [..., success, x, x] - .push(ptr) - // [..., success, x, x, x_ptr] - .mstore() - // [..., success, x] - .push(offset + 0x20) - // [..., success, x, y_cd_ptr] - .calldataload() - // [..., success, x, y] - .dup(0) - // [..., success, x, y, y] - .push(ptr + 0x20) - // [..., success, x, y, y, y_ptr] - .mstore(); - // [..., success, x, y] - self.validate_ec_point(); - self.ec_point(Value::Memory(ptr)) - } - - pub fn ec_point_from_limbs( - self: &Rc, - x_limbs: [Scalar; LIMBS], - y_limbs: [Scalar; LIMBS], - ) -> EcPoint { - let ptr = self.allocate(0x40); - for (ptr, limbs) in [(ptr, x_limbs), (ptr + 0x20, y_limbs)] { - for (idx, limb) in limbs.into_iter().enumerate() { - self.push(&limb); - // [..., success, acc] - if idx > 0 { - self.code - .borrow_mut() - .push(idx * BITS) - // [..., success, acc, limb_i, shift] - .shl() - // [..., success, acc, limb_i << shift] - .add(); - // [..., success, acc] - } - } - self.code - .borrow_mut() - // [..., success, coordinate] - .dup(0) - // [..., success, coordinate, coordinate] - .push(ptr) - // [..., success, coordinate, coordinate, ptr] - .mstore(); - // [..., success, coordinate] - } - // [..., success, x, y] - self.validate_ec_point(); - self.ec_point(Value::Memory(ptr)) - } - - fn validate_ec_point(self: &Rc) { - self.code - .borrow_mut() - // [..., success, x, y] - .push(self.base_modulus) - // [..., success, x, y, p] - .dup(2) - // [..., success, x, y, p, x] - .lt() - // [..., success, x, y, x_lt_p] - .push(self.base_modulus) - // [..., success, x, y, x_lt_p, p] - .dup(2) - // [..., success, x, y, x_lt_p, p, y] - .lt() - // [..., success, x, y, x_lt_p, y_lt_p] - .and() - // [..., success, x, y, valid] - .dup(2) - // [..., success, x, y, valid, x] - .iszero() - // [..., success, x, y, valid, x_is_zero] - .dup(2) - // [..., success, x, y, valid, x_is_zero, y] - .iszero() - // [..., success, x, y, valid, x_is_zero, y_is_zero] - .or() - // [..., success, x, y, valid, x_or_y_is_zero] - .not() - // [..., success, x, y, valid, x_and_y_is_not_zero] - .and() - // [..., success, x, y, valid] - .push(self.base_modulus) - // [..., success, x, y, valid, p] - .dup(2) - // [..., success, x, y, valid, p, y] - .dup(0) - // [..., success, x, y, valid, p, y, y] - .mulmod() - // [..., success, x, y, valid, y_square] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p] - .push(3) - // [..., success, x, y, valid, y_square, p, 3] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p, 3, p] - .dup(6) - // [..., success, x, y, valid, y_square, p, 3, p, x] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p, 3, p, x, p] - .dup(1) - // [..., success, x, y, valid, y_square, p, 3, p, x, p, x] - .dup(0) - // [..., success, x, y, valid, y_square, p, 3, p, x, p, x, x] - .mulmod() - // [..., success, x, y, valid, y_square, p, 3, p, x, x_square] - .mulmod() - // [..., success, x, y, valid, y_square, p, 3, x_cube] - .addmod() - // [..., success, x, y, valid, y_square, x_cube_plus_3] - .eq() - // [..., success, x, y, valid, y_square_eq_x_cube_plus_3] - .and() - // [..., success, x, y, valid] - .swap(2) - // [..., success, valid, y, x] - .pop() - // [..., success, valid, y] - .pop() - // [..., success, valid] - .and(); - } - - pub fn keccak256(self: &Rc, ptr: usize, len: usize) -> usize { - let hash_ptr = self.allocate(0x20); - self.code.borrow_mut().push(len).push(ptr).keccak256().push(hash_ptr).mstore(); - hash_ptr - } - - pub fn copy_scalar(self: &Rc, scalar: &Scalar, ptr: usize) { - self.push(scalar); - self.code.borrow_mut().push(ptr).mstore(); - } - - pub fn dup_scalar(self: &Rc, scalar: &Scalar) -> Scalar { - let ptr = self.allocate(0x20); - self.copy_scalar(scalar, ptr); - self.scalar(Value::Memory(ptr)) - } - - pub fn dup_ec_point(self: &Rc, value: &EcPoint) -> EcPoint { - let ptr = self.allocate(0x40); - match value.value { - Value::Constant((x, y)) => { - self.code.borrow_mut().push(x).push(ptr).mstore().push(y).push(ptr + 0x20).mstore(); - } - Value::Memory(src_ptr) => { - self.code - .borrow_mut() - .push(src_ptr) - .mload() - .push(ptr) - .mstore() - .push(src_ptr + 0x20) - .mload() - .push(ptr + 0x20) - .mstore(); - } - Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => { - unreachable!() - } - } - self.ec_point(Value::Memory(ptr)) - } - - fn staticcall(self: &Rc, precompile: Precompiled, cd_ptr: usize, rd_ptr: usize) { - let (cd_len, rd_len) = match precompile { - Precompiled::BigModExp => (0xc0, 0x20), - Precompiled::Bn254Add => (0x80, 0x40), - Precompiled::Bn254ScalarMul => (0x60, 0x40), - Precompiled::Bn254Pairing => (0x180, 0x20), - }; - self.code - .borrow_mut() - .push(rd_len) - .push(rd_ptr) - .push(cd_len) - .push(cd_ptr) - .push(precompile as usize) - .gas() - .staticcall() - .and(); - } - - fn invert(self: &Rc, scalar: &Scalar) -> Scalar { - let rd_ptr = self.allocate(0x20); - let [cd_ptr, ..] = [ - &self.scalar(Value::Constant(0x20.into())), - &self.scalar(Value::Constant(0x20.into())), - &self.scalar(Value::Constant(0x20.into())), - scalar, - &self.scalar(Value::Constant(self.scalar_modulus - 2)), - &self.scalar(Value::Constant(self.scalar_modulus)), - ] - .map(|value| self.dup_scalar(value).ptr()); - self.staticcall(Precompiled::BigModExp, cd_ptr, rd_ptr); - self.scalar(Value::Memory(rd_ptr)) - } - - fn ec_point_add(self: &Rc, lhs: &EcPoint, rhs: &EcPoint) -> EcPoint { - let rd_ptr = self.dup_ec_point(lhs).ptr(); - self.dup_ec_point(rhs); - self.staticcall(Precompiled::Bn254Add, rd_ptr, rd_ptr); - self.ec_point(Value::Memory(rd_ptr)) - } - - fn ec_point_scalar_mul(self: &Rc, ec_point: &EcPoint, scalar: &Scalar) -> EcPoint { - let rd_ptr = self.dup_ec_point(ec_point).ptr(); - self.dup_scalar(scalar); - self.staticcall(Precompiled::Bn254ScalarMul, rd_ptr, rd_ptr); - self.ec_point(Value::Memory(rd_ptr)) - } - - pub fn pairing( - self: &Rc, - lhs: &EcPoint, - g2: (U256, U256, U256, U256), - rhs: &EcPoint, - minus_s_g2: (U256, U256, U256, U256), - ) { - let rd_ptr = self.dup_ec_point(lhs).ptr(); - self.allocate(0x80); - self.code - .borrow_mut() - .push(g2.0) - .push(rd_ptr + 0x40) - .mstore() - .push(g2.1) - .push(rd_ptr + 0x60) - .mstore() - .push(g2.2) - .push(rd_ptr + 0x80) - .mstore() - .push(g2.3) - .push(rd_ptr + 0xa0) - .mstore(); - self.dup_ec_point(rhs); - self.allocate(0x80); - self.code - .borrow_mut() - .push(minus_s_g2.0) - .push(rd_ptr + 0x100) - .mstore() - .push(minus_s_g2.1) - .push(rd_ptr + 0x120) - .mstore() - .push(minus_s_g2.2) - .push(rd_ptr + 0x140) - .mstore() - .push(minus_s_g2.3) - .push(rd_ptr + 0x160) - .mstore(); - self.staticcall(Precompiled::Bn254Pairing, rd_ptr, rd_ptr); - self.code.borrow_mut().push(rd_ptr).mload().and(); - } - - fn add(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { - if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { - let out = (U512::from(lhs) + U512::from(rhs)) % U512::from(self.scalar_modulus); - return self.scalar(Value::Constant(out.try_into().unwrap())); - } - - self.scalar(Value::Sum(Box::new(lhs.value.clone()), Box::new(rhs.value.clone()))) - } - - fn sub(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { - if rhs.is_const() { - return self.add(lhs, &self.neg(rhs)); - } - - self.scalar(Value::Sum( - Box::new(lhs.value.clone()), - Box::new(Value::Negated(Box::new(rhs.value.clone()))), - )) - } - - fn mul(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { - if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { - let out = (U512::from(lhs) * U512::from(rhs)) % U512::from(self.scalar_modulus); - return self.scalar(Value::Constant(out.try_into().unwrap())); - } - - self.scalar(Value::Product(Box::new(lhs.value.clone()), Box::new(rhs.value.clone()))) - } - - fn neg(self: &Rc, scalar: &Scalar) -> Scalar { - if let Value::Constant(constant) = scalar.value { - return self.scalar(Value::Constant(self.scalar_modulus - constant)); - } - - self.scalar(Value::Negated(Box::new(scalar.value.clone()))) - } -} - -#[cfg(test)] -impl EvmLoader { - fn start_gas_metering(self: &Rc, identifier: &str) { - self.gas_metering_ids.borrow_mut().push(identifier.to_string()); - self.code.borrow_mut().gas().swap(1); - } - - fn end_gas_metering(self: &Rc) { - self.code.borrow_mut().swap(1).push(9).gas().swap(2).sub().sub().push(0).push(0).log1(); - } - - pub fn print_gas_metering(self: &Rc, costs: Vec) { - for (identifier, cost) in self.gas_metering_ids.borrow().iter().zip(costs) { - println!("{}: {}", identifier, cost); - } - } -} - -#[derive(Clone)] -pub struct EcPoint { - loader: Rc, - value: Value<(U256, U256)>, -} - -impl EcPoint { - pub(crate) fn loader(&self) -> &Rc { - &self.loader - } - - pub(crate) fn value(&self) -> Value<(U256, U256)> { - self.value.clone() - } - - pub(crate) fn ptr(&self) -> usize { - match self.value { - Value::Memory(ptr) => ptr, - _ => unreachable!(), - } - } -} - -impl Debug for EcPoint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("EcPoint").field("value", &self.value).finish() - } -} - -impl PartialEq for EcPoint { - fn eq(&self, other: &Self) -> bool { - self.value == other.value - } -} - -impl LoadedEcPoint for EcPoint -where - C: CurveAffine, - C::ScalarExt: PrimeField, -{ - type Loader = Rc; - - fn loader(&self) -> &Rc { - &self.loader - } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, ec_point)| match scalar.value { - Value::Constant(constant) if constant == U256::one() => ec_point, - _ => ec_point.loader.ec_point_scalar_mul(&ec_point, &scalar), - }) - .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) - .unwrap() - } -} - -#[derive(Clone)] -pub struct Scalar { - loader: Rc, - value: Value, -} - -impl Scalar { - pub(crate) fn loader(&self) -> &Rc { - &self.loader - } - - pub(crate) fn value(&self) -> Value { - self.value.clone() - } - - pub(crate) fn is_const(&self) -> bool { - matches!(self.value, Value::Constant(_)) - } - - pub(crate) fn ptr(&self) -> usize { - match self.value { - Value::Memory(ptr) => ptr, - _ => *self.loader.cache.borrow().get(&self.value.identifier()).unwrap(), - } - } -} - -impl Debug for Scalar { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Scalar").field("value", &self.value).finish() - } -} - -impl Add for Scalar { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - self.loader.add(&self, &rhs) - } -} - -impl Sub for Scalar { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - self.loader.sub(&self, &rhs) - } -} - -impl Mul for Scalar { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - self.loader.mul(&self, &rhs) - } -} - -impl Neg for Scalar { - type Output = Self; - - fn neg(self) -> Self { - self.loader.neg(&self) - } -} - -impl<'a> Add<&'a Self> for Scalar { - type Output = Self; - - fn add(self, rhs: &'a Self) -> Self { - self.loader.add(&self, rhs) - } -} - -impl<'a> Sub<&'a Self> for Scalar { - type Output = Self; - - fn sub(self, rhs: &'a Self) -> Self { - self.loader.sub(&self, rhs) - } -} - -impl<'a> Mul<&'a Self> for Scalar { - type Output = Self; - - fn mul(self, rhs: &'a Self) -> Self { - self.loader.mul(&self, rhs) - } -} - -impl AddAssign for Scalar { - fn add_assign(&mut self, rhs: Self) { - *self = self.loader.add(self, &rhs); - } -} - -impl SubAssign for Scalar { - fn sub_assign(&mut self, rhs: Self) { - *self = self.loader.sub(self, &rhs); - } -} - -impl MulAssign for Scalar { - fn mul_assign(&mut self, rhs: Self) { - *self = self.loader.mul(self, &rhs); - } -} - -impl<'a> AddAssign<&'a Self> for Scalar { - fn add_assign(&mut self, rhs: &'a Self) { - *self = self.loader.add(self, rhs); - } -} - -impl<'a> SubAssign<&'a Self> for Scalar { - fn sub_assign(&mut self, rhs: &'a Self) { - *self = self.loader.sub(self, rhs); - } -} - -impl<'a> MulAssign<&'a Self> for Scalar { - fn mul_assign(&mut self, rhs: &'a Self) { - *self = self.loader.mul(self, rhs); - } -} - -impl FieldOps for Scalar { - fn invert(&self) -> Option { - Some(self.loader.invert(self)) - } -} - -impl PartialEq for Scalar { - fn eq(&self, other: &Self) -> bool { - self.value == other.value - } -} - -impl> LoadedScalar for Scalar { - type Loader = Rc; - - fn loader(&self) -> &Rc { - &self.loader - } - - fn mul_add(a: &Self, b: &Self, c: &Self) -> Self { - a.clone() * b + c - } - - fn mul_add_constant(a: &Self, b: &Self, c: &F) -> Self { - a.clone() * b + a.loader().load_const(c) - } - - fn batch_invert<'a>(values: impl IntoIterator) { - let values = values.into_iter().collect_vec(); - let loader = &values.first().unwrap().loader; - let products = iter::once(values[0].clone()) - .chain( - iter::repeat_with(|| loader.allocate(0x20)) - .map(|ptr| loader.scalar(Value::Memory(ptr))) - .take(values.len() - 1), - ) - .collect_vec(); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(products.first().unwrap()); - for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { - loader.push(value); - loader.code.borrow_mut().mulmod(); - if idx < values.len() - 2 { - loader.code.borrow_mut().dup(0); - } - loader.code.borrow_mut().push(product.ptr()).mstore(); - } - - let inv = loader.invert(products.last().unwrap()); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(&inv); - for (value, product) in - values.iter().rev().zip(products.iter().rev().skip(1).map(Some).chain(iter::once(None))) - { - if let Some(product) = product { - loader.push(value); - loader - .code - .borrow_mut() - .dup(2) - .dup(2) - .push(product.ptr()) - .mload() - .mulmod() - .push(value.ptr()) - .mstore() - .mulmod(); - } else { - loader.code.borrow_mut().push(value.ptr()).mstore(); - } - } - } -} - -impl EcPointLoader for Rc -where - C: CurveAffine, - C::Scalar: PrimeField, -{ - type LoadedEcPoint = EcPoint; - - fn ec_point_load_const(&self, value: &C) -> EcPoint { - let coordinates = value.coordinates().unwrap(); - let [x, y] = [coordinates.x(), coordinates.y()] - .map(|coordinate| U256::from_little_endian(coordinate.to_repr().as_ref())); - self.ec_point(Value::Constant((x, y))) - } - - fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> { - unimplemented!() - } -} - -impl> ScalarLoader for Rc { - type LoadedScalar = Scalar; - - fn load_const(&self, value: &F) -> Scalar { - self.scalar(Value::Constant(fe_to_u256(*value))) - } - - fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) -> Result<(), Error> { - unimplemented!() - } - - fn sum_with_coeff_and_constant(&self, values: &[(F, &Scalar)], constant: F) -> Scalar { - if values.is_empty() { - return self.load_const(&constant); - } - - let push_addend = |(coeff, value): &(F, &Scalar)| { - assert_ne!(*coeff, F::zero()); - match (*coeff == F::one(), &value.value) { - (true, _) => { - self.push(value); - } - (false, Value::Constant(value)) => { - self.push( - &self.scalar(Value::Constant(fe_to_u256(*coeff * u256_to_fe::(*value)))), - ); - } - (false, _) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); - self.push(value); - self.code.borrow_mut().mulmod(); - } - } - }; - - let mut values = values.iter(); - if constant == F::zero() { - push_addend(values.next().unwrap()); - } else { - self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); - } - - let chunk_size = 16 - self.code.borrow().stack_len(); - for values in &values.chunks(chunk_size) { - let values = values.into_iter().collect_vec(); - - self.code.borrow_mut().push(self.scalar_modulus); - for _ in 1..chunk_size.min(values.len()) { - self.code.borrow_mut().dup(0); - } - self.code.borrow_mut().swap(chunk_size.min(values.len())); - - for value in values { - push_addend(value); - self.code.borrow_mut().addmod(); - } - } - - let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) - } - - fn sum_products_with_coeff_and_constant( - &self, - values: &[(F, &Scalar, &Scalar)], - constant: F, - ) -> Scalar { - if values.is_empty() { - return self.load_const(&constant); - } - - let push_addend = - |(coeff, lhs, rhs): &(F, &Scalar, &Scalar)| { - assert_ne!(*coeff, F::zero()); - match (*coeff == F::one(), &lhs.value, &rhs.value) { - (_, Value::Constant(lhs), Value::Constant(rhs)) => { - self.push(&self.scalar(Value::Constant(fe_to_u256( - *coeff * u256_to_fe::(*lhs) * u256_to_fe::(*rhs), - )))); - } - (_, value @ Value::Memory(_), Value::Constant(constant)) - | (_, Value::Constant(constant), value @ Value::Memory(_)) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(Value::Constant(fe_to_u256( - *coeff * u256_to_fe::(*constant), - )))); - self.push(&self.scalar(value.clone())); - self.code.borrow_mut().mulmod(); - } - (true, _, _) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(lhs); - self.push(rhs); - self.code.borrow_mut().mulmod(); - } - (false, _, _) => { - self.code.borrow_mut().push(self.scalar_modulus).dup(0); - self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); - self.push(lhs); - self.code.borrow_mut().mulmod(); - self.push(rhs); - self.code.borrow_mut().mulmod(); - } - } - }; - - let mut values = values.iter(); - if constant == F::zero() { - push_addend(values.next().unwrap()); - } else { - self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); - } - - let chunk_size = 16 - self.code.borrow().stack_len(); - for values in &values.chunks(chunk_size) { - let values = values.into_iter().collect_vec(); - - self.code.borrow_mut().push(self.scalar_modulus); - for _ in 1..chunk_size.min(values.len()) { - self.code.borrow_mut().dup(0); - } - self.code.borrow_mut().swap(chunk_size.min(values.len())); - - for value in values { - push_addend(value); - self.code.borrow_mut().addmod(); - } - } - - let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) - } -} - -impl Loader for Rc -where - C: CurveAffine, - C::Scalar: PrimeField, -{ - #[cfg(test)] - fn start_cost_metering(&self, identifier: &str) { - self.start_gas_metering(identifier) - } - - #[cfg(test)] - fn end_cost_metering(&self) { - self.end_gas_metering() - } -} diff --git a/src/loader/evm/test.rs b/src/loader/evm/test.rs deleted file mode 100644 index 845d03fc..00000000 --- a/src/loader/evm/test.rs +++ /dev/null @@ -1,56 +0,0 @@ -use crate::{loader::evm::test::tui::Tui, util::Itertools}; -use foundry_evm::{ - executor::{backend::Backend, fork::MultiFork, ExecutorBuilder}, - revm::{AccountInfo, Bytecode}, - utils::h256_to_u256_be, - Address, -}; -use std::env::var_os; - -mod tui; - -fn debug() -> bool { - matches!( - var_os("DEBUG"), - Some(value) if value.to_str() == Some("1") - ) -} - -pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { - assert!( - code.len() <= 0x6000, - "Contract size {} exceeds the limit 24576", - code.len() - ); - - let debug = debug(); - let caller = Address::from_low_u64_be(0xfe); - let callee = Address::from_low_u64_be(0xff); - - let mut evm = ExecutorBuilder::default() - .with_gas_limit(u64::MAX.into()) - .set_tracing(debug) - .set_debugger(debug) - .build(Backend::new(MultiFork::new().0, None)); - - evm.backend_mut().insert_account_info( - callee, - AccountInfo::new(0.into(), 1, Bytecode::new_raw(code.into())), - ); - - let result = evm - .call_raw(caller, callee, calldata.into(), 0.into()) - .unwrap(); - - let costs = result - .logs - .into_iter() - .map(|log| h256_to_u256_be(log.topics[0]).as_u64()) - .collect_vec(); - - if debug { - Tui::new(result.debug.unwrap().flatten(0), 0).start(); - } - - (!result.reverted, result.gas, costs) -} diff --git a/src/loader/halo2.rs b/src/loader/halo2.rs deleted file mode 100644 index ef6922fd..00000000 --- a/src/loader/halo2.rs +++ /dev/null @@ -1,26 +0,0 @@ -pub(crate) mod loader; -pub mod poseidon_chip; - -#[cfg(test)] -pub(crate) mod test; - -pub use loader::{EcPoint, Halo2Loader, Scalar}; -pub use util::Valuetools; - -mod util { - use halo2_proofs::circuit::Value; - - pub trait Valuetools: Iterator> { - fn fold_zipped(self, init: B, mut f: F) -> Value - where - Self: Sized, - F: FnMut(B, V) -> B, - { - self.into_iter().fold(Value::known(init), |acc, value| { - acc.zip(value).map(|(acc, value)| f(acc, value)) - }) - } - } - - impl>> Valuetools for I {} -} diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs deleted file mode 100644 index 003f0811..00000000 --- a/src/loader/halo2/loader.rs +++ /dev/null @@ -1,900 +0,0 @@ -use crate::{ - loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{ - arithmetic::{Curve, CurveAffine, Field, FieldOps, PrimeCurveAffine, PrimeField}, - Itertools, - }, -}; -use halo2_base::{ - self, - gates::{flex_gate::FlexGateConfig, range::RangeConfig, GateInstructions, RangeInstructions}, - utils::fe_to_bigint, - Context, - QuantumCell::{self, Constant, Existing, Witness}, -}; -use halo2_ecc::{ - bigint::{CRTInteger, OverflowInteger}, - ecc::{fixed::FixedEccPoint, EccChip, EccPoint}, - fields::{fp::FpConfig, FieldChip}, -}; -use halo2_proofs::circuit; -use num_bigint::{BigInt, BigUint}; -use std::{ - cell::RefCell, - fmt::{self, Debug}, - ops::{Add, AddAssign, Deref, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign}, - rc::Rc, -}; - -pub type AssignedValue = halo2_base::AssignedValue<::Scalar>; -pub type BaseFieldChip = FpConfig<::ScalarExt, ::Base>; -pub type AssignedInteger = CRTInteger<::ScalarExt>; -pub type AssignedEcPoint = EccPoint<::ScalarExt, AssignedInteger>; - -// Sometimes it is useful to know that a cell is really a constant, for optimization purposes -#[derive(Clone, Debug)] -pub enum Value { - Constant(T), - Assigned(L), -} - -pub struct Halo2Loader<'a, 'b, C: CurveAffine> { - pub ecc_chip: EccChip<'a, C::Scalar, BaseFieldChip>, - ctx: RefCell>, - num_ec_point: RefCell, - num_scalar: RefCell, -} -impl<'a, 'b, C: CurveAffine> Debug for Halo2Loader<'a, 'b, C> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Halo2Loader") - .field("num_ec_point", &self.num_ec_point) - .field("num_scalar", &self.num_scalar) - .finish() - } -} - -impl<'a, 'b, C: CurveAffine> Halo2Loader<'a, 'b, C> -where - C::Base: PrimeField, -{ - pub fn new(field_chip: &'a BaseFieldChip, ctx: Context<'b, C::Scalar>) -> Rc { - Rc::new(Self { - ecc_chip: EccChip::construct(field_chip), - ctx: RefCell::new(ctx), - num_ec_point: RefCell::new(0), - num_scalar: RefCell::new(0), - }) - } - - pub fn ecc_chip(&self) -> &EccChip<'a, C::Scalar, BaseFieldChip> { - &self.ecc_chip - } - - pub fn field_chip(&self) -> &BaseFieldChip { - &self.ecc_chip.field_chip - } - - pub fn range(&self) -> &RangeConfig { - self.field_chip().range() - } - - pub fn gate(&self) -> &FlexGateConfig { - &self.range().gate - } - - pub fn ctx(&self) -> impl Deref> + '_ { - self.ctx.borrow() - } - - pub(crate) fn ctx_mut(&self) -> impl DerefMut> + '_ { - self.ctx.borrow_mut() - } - - pub fn finalize(&self) { - let stats = self - .field_chip() - .finalize(&mut self.ctx_mut()) - .expect("finalizing constants and lookups"); - println!("stats (max rows fixed, total fixed cells, max rows lookup) {:?}", stats); - - let total_cells = - self.ctx.borrow().advice_rows[&self.range().context_id].iter().sum::(); - println!("total non-lookup advice cells used: {}", total_cells); - println!( - "total cells used in special lookup advice columns: {}", - self.ctx.borrow().cells_to_lookup.len() - ); - } - - pub fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, 'b, C> { - let output = if constant == C::Scalar::zero() { - self.gate().load_zero(&mut self.ctx_mut()).unwrap() - } else { - let assigned = self - .gate() - .assign_region_smart( - &mut self.ctx_mut(), - vec![Constant(constant)], - vec![], - vec![], - vec![], - ) - .unwrap(); - assigned[0].clone() - }; - self.scalar(Value::Assigned(output)) - } - - pub fn assign_scalar(self: &Rc, scalar: circuit::Value) -> Scalar<'a, 'b, C> { - let assigned = self - .gate() - .assign_region_smart(&mut self.ctx_mut(), vec![Witness(scalar)], vec![], vec![], vec![]) - .unwrap(); - self.scalar(Value::Assigned(assigned[0].clone())) - } - - pub fn scalar(self: &Rc, value: Value>) -> Scalar<'a, 'b, C> { - let index = *self.num_scalar.borrow(); - *self.num_scalar.borrow_mut() += 1; - Scalar { loader: self.clone(), index, value } - } - - pub fn ec_point(self: &Rc, assigned: AssignedEcPoint) -> EcPoint<'a, 'b, C> { - let index = *self.num_ec_point.borrow(); - *self.num_ec_point.borrow_mut() += 1; - EcPoint { loader: self.clone(), value: Value::Assigned(assigned), index } - } - - pub fn assign_const_ec_point(self: &Rc, ec_point: C) -> EcPoint<'a, 'b, C> { - let index = *self.num_ec_point.borrow(); - *self.num_ec_point.borrow_mut() += 1; - EcPoint { loader: self.clone(), value: Value::Constant(ec_point), index } - } - - pub fn assign_ec_point(self: &Rc, ec_point: circuit::Value) -> EcPoint<'a, 'b, C> { - let assigned = self.ecc_chip.assign_point(&mut self.ctx_mut(), ec_point).unwrap(); - let is_on_curve_or_infinity = - self.ecc_chip.is_on_curve_or_infinity::(&mut self.ctx_mut(), &assigned).unwrap(); - self.gate().assert_is_const( - &mut self.ctx_mut(), - &is_on_curve_or_infinity, - C::Scalar::one(), - ); - - self.ec_point(assigned) - } - - pub fn assign_ec_point_from_limbs( - self: &Rc, - x_limbs: Vec>, - y_limbs: Vec>, - ) -> EcPoint<'a, 'b, C> { - let limbs_to_crt = |limbs| { - let native = OverflowInteger::evaluate( - self.gate(), - &mut self.ctx_mut(), - &limbs, - self.field_chip().limb_bits, - ) - .unwrap(); - let mut big_value = circuit::Value::known(BigInt::from(0)); - for limb in limbs.iter().rev() { - let limb_big = limb.value().map(|v| fe_to_bigint(v)); - big_value = big_value.map(|acc| acc << self.field_chip().limb_bits) + limb_big; - } - let truncation = OverflowInteger::construct( - limbs, - (BigUint::from(1u64) << self.field_chip().limb_bits) - 1usize, - self.field_chip().limb_bits, - self.field_chip().p.clone() - 1usize, - ); - CRTInteger::construct(truncation, native, big_value) - }; - - let ec_point = EccPoint::construct(limbs_to_crt(x_limbs), limbs_to_crt(y_limbs)); - self.ecc_chip - .assert_is_on_curve::(&mut self.ctx_mut(), &ec_point) - .expect("ec point should lie on curve"); - - self.ec_point(ec_point) - } - - fn add(self: &Rc, lhs: &Scalar<'a, 'b, C>, rhs: &Scalar<'a, 'b, C>) -> Scalar<'a, 'b, C> { - let output = match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs + rhs), - (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => Value::Assigned( - GateInstructions::add( - self.gate(), - &mut self.ctx_mut(), - &Existing(assigned), - &Constant(*constant), - ) - .expect("add should not fail"), - ), - (Value::Assigned(lhs), Value::Assigned(rhs)) => Value::Assigned( - GateInstructions::add( - self.gate(), - &mut self.ctx_mut(), - &Existing(lhs), - &Existing(rhs), - ) - .expect("add should not fail"), - ), - }; - self.scalar(output) - } - - fn sub(self: &Rc, lhs: &Scalar<'a, 'b, C>, rhs: &Scalar<'a, 'b, C>) -> Scalar<'a, 'b, C> { - let output = match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs - rhs), - (Value::Constant(constant), Value::Assigned(assigned)) => Value::Assigned( - GateInstructions::sub( - self.gate(), - &mut self.ctx_mut(), - &Constant(*constant), - &Existing(assigned), - ) - .expect("sub should not fail"), - ), - (Value::Assigned(assigned), Value::Constant(constant)) => Value::Assigned( - GateInstructions::sub( - self.gate(), - &mut self.ctx_mut(), - &Existing(assigned), - &Constant(*constant), - ) - .expect("sub should not fail"), - ), - (Value::Assigned(lhs), Value::Assigned(rhs)) => Value::Assigned( - GateInstructions::sub( - self.gate(), - &mut self.ctx_mut(), - &Existing(lhs), - &Existing(rhs), - ) - .expect("sub should not fail"), - ), - }; - self.scalar(output) - } - - fn mul(self: &Rc, lhs: &Scalar<'a, 'b, C>, rhs: &Scalar<'a, 'b, C>) -> Scalar<'a, 'b, C> { - let output = match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs * rhs), - (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => Value::Assigned( - GateInstructions::mul( - self.gate(), - &mut self.ctx_mut(), - &Existing(assigned), - &Constant(*constant), - ) - .expect("mul should not fail"), - ), - (Value::Assigned(lhs), Value::Assigned(rhs)) => Value::Assigned( - GateInstructions::mul( - self.gate(), - &mut self.ctx_mut(), - &Existing(lhs), - &Existing(rhs), - ) - .expect("mul should not fail"), - ), - }; - self.scalar(output) - } - - fn mul_add( - self: &Rc, - a: &Scalar<'a, 'b, C>, - b: &Scalar<'a, 'b, C>, - c: &Scalar<'a, 'b, C>, - ) -> Scalar<'a, 'b, C> { - if let (Value::Constant(a), Value::Constant(b), Value::Constant(c)) = - (&a.value, &b.value, &c.value) - { - return self.scalar(Value::Constant(*a * b + c)); - } - let a = match &a.value { - Value::Constant(constant) => Constant(*constant), - Value::Assigned(assigned) => Existing(assigned), - }; - let b = match &b.value { - Value::Constant(constant) => Constant(*constant), - Value::Assigned(assigned) => Existing(assigned), - }; - let c = match &c.value { - Value::Constant(constant) => Constant(*constant), - Value::Assigned(assigned) => Existing(assigned), - }; - let output = self.gate().mul_add(&mut self.ctx_mut(), &a, &b, &c).unwrap(); - self.scalar(Value::Assigned(output)) - } - - fn neg(self: &Rc, scalar: &Scalar<'a, 'b, C>) -> Scalar<'a, 'b, C> { - let output = match &scalar.value { - Value::Constant(constant) => Value::Constant(constant.neg()), - Value::Assigned(assigned) => Value::Assigned( - GateInstructions::neg(self.gate(), &mut self.ctx_mut(), &Existing(assigned)) - .expect("neg should not fail"), - ), - }; - self.scalar(output) - } - - fn invert(self: &Rc, scalar: &Scalar<'a, 'b, C>) -> Scalar<'a, 'b, C> { - let output = match &scalar.value { - Value::Constant(constant) => Value::Constant(Field::invert(constant).unwrap()), - Value::Assigned(assigned) => Value::Assigned({ - // make sure scalar != 0 - let is_zero = - RangeInstructions::is_zero(self.range(), &mut self.ctx_mut(), assigned) - .unwrap(); - self.gate().assert_is_const(&mut self.ctx_mut(), &is_zero, C::Scalar::zero()); - GateInstructions::div_unsafe( - self.gate(), - &mut self.ctx_mut(), - &Constant(C::Scalar::one()), - &Existing(assigned), - ) - .expect("invert should not fail") - }), - }; - self.scalar(output) - } - - fn div(self: &Rc, lhs: &Scalar<'a, 'b, C>, rhs: &Scalar<'a, 'b, C>) -> Scalar<'a, 'b, C> { - let output = match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => { - Value::Constant(*lhs * Field::invert(rhs).unwrap()) - } - (Value::Constant(constant), Value::Assigned(assigned)) => Value::Assigned( - GateInstructions::div_unsafe( - self.gate(), - &mut self.ctx_mut(), - &Constant(*constant), - &Existing(assigned), - ) - .expect("div should not fail"), - ), - (Value::Assigned(assigned), Value::Constant(constant)) => Value::Assigned( - GateInstructions::div_unsafe( - self.gate(), - &mut self.ctx_mut(), - &Existing(assigned), - &Constant(*constant), - ) - .expect("div should not fail"), - ), - (Value::Assigned(lhs), Value::Assigned(rhs)) => Value::Assigned( - GateInstructions::div_unsafe( - self.gate(), - &mut self.ctx_mut(), - &Existing(lhs), - &Existing(rhs), - ) - .expect("div should not fail"), - ), - }; - self.scalar(output) - } -} - -#[derive(Clone)] -pub struct Scalar<'a, 'b, C: CurveAffine> { - loader: Rc>, - index: usize, - value: Value>, -} - -impl<'a, 'b, C: CurveAffine> Scalar<'a, 'b, C> { - pub fn assigned(&self) -> AssignedValue { - match &self.value { - Value::Constant(constant) => self.loader.assign_const_scalar(*constant).assigned(), - Value::Assigned(assigned) => assigned.clone(), - } - } - - pub fn to_quantum(&self) -> QuantumCell { - match &self.value { - Value::Constant(constant) => Constant(*constant), - Value::Assigned(assigned) => Existing(assigned), - } - } -} - -impl<'a, 'b, C: CurveAffine> PartialEq for Scalar<'a, 'b, C> { - fn eq(&self, other: &Self) -> bool { - self.index == other.index - } -} - -impl<'a, 'b, C: CurveAffine> LoadedScalar for Scalar<'a, 'b, C> { - type Loader = Rc>; - - fn loader(&self) -> &Self::Loader { - &self.loader - } - - fn mul_add(a: &Self, b: &Self, c: &Self) -> Self { - let loader = a.loader(); - Halo2Loader::mul_add(loader, a, b, c) - } - - fn mul_add_constant(a: &Self, b: &Self, c: &C::Scalar) -> Self { - Self::mul_add(a, b, &a.loader().scalar(Value::Constant(*c))) - } - - fn pow_const(&self, exp: u64) -> Self { - fn get_naf(mut e: u64) -> Vec { - // https://en.wikipedia.org/wiki/Non-adjacent_form - // NAF for exp: - let mut naf: Vec = Vec::with_capacity(32); - - // generate the NAF for exp - for _ in 0..64 { - if e & 1 == 1 { - let z = 2i8 - (e % 4) as i8; - e = e / 2; - if z == -1 { - e += 1; - } - naf.push(z); - } else { - naf.push(0); - e = e / 2; - } - } - if e != 0 { - assert_eq!(e, 1); - naf.push(1); - } - naf - } - - assert!(exp > 0); - let is_zero = RangeInstructions::is_zero( - self.loader().range(), - &mut self.loader.ctx_mut(), - &self.assigned(), - ) - .unwrap(); - self.loader.gate().assert_is_const(&mut self.loader.ctx_mut(), &is_zero, C::Scalar::zero()); - - let naf = get_naf(exp); - let mut acc = self.clone(); - let mut is_started = false; - - for &z in naf.iter().rev() { - if is_started { - acc = acc.clone() * &acc; - } - if z != 0 { - if is_started { - acc = if z == 1 { acc * self } else { (&self.loader).div(&acc, self) }; - } else { - is_started = true; - } - } - } - acc - } -} - -impl<'a, 'b, C: CurveAffine> Debug for Scalar<'a, 'b, C> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Scalar").field("value", &self.value).finish() - } -} - -impl<'a, 'b, C: CurveAffine> FieldOps for Scalar<'a, 'b, C> { - fn invert(&self) -> Option { - Some(self.loader.invert(self)) - } -} - -impl<'a, 'b, C: CurveAffine> Add for Scalar<'a, 'b, C> { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Halo2Loader::add(&self.loader, &self, &rhs) - } -} -impl<'a, 'b, C: CurveAffine> Sub for Scalar<'a, 'b, C> { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Halo2Loader::sub(&self.loader, &self, &rhs) - } -} - -impl<'a, 'b, C: CurveAffine> Mul for Scalar<'a, 'b, C> { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - Halo2Loader::mul(&self.loader, &self, &rhs) - } -} - -impl<'a, 'b, C: CurveAffine> Neg for Scalar<'a, 'b, C> { - type Output = Self; - - fn neg(self) -> Self::Output { - Halo2Loader::neg(&self.loader, &self) - } -} - -impl<'a, 'b, 'c, C: CurveAffine> Add<&'c Self> for Scalar<'a, 'b, C> { - type Output = Self; - - fn add(self, rhs: &'c Self) -> Self::Output { - Halo2Loader::add(&self.loader, &self, rhs) - } -} - -impl<'a, 'b, 'c, C: CurveAffine> Sub<&'c Self> for Scalar<'a, 'b, C> { - type Output = Self; - - fn sub(self, rhs: &'c Self) -> Self::Output { - Halo2Loader::sub(&self.loader, &self, rhs) - } -} - -impl<'a, 'b, 'c, C: CurveAffine> Mul<&'c Self> for Scalar<'a, 'b, C> { - type Output = Self; - - fn mul(self, rhs: &'c Self) -> Self::Output { - Halo2Loader::mul(&self.loader, &self, rhs) - } -} - -impl<'a, 'b, C: CurveAffine> AddAssign for Scalar<'a, 'b, C> { - fn add_assign(&mut self, rhs: Self) { - *self = Halo2Loader::add(&self.loader, self, &rhs) - } -} - -impl<'a, 'b, C: CurveAffine> SubAssign for Scalar<'a, 'b, C> { - fn sub_assign(&mut self, rhs: Self) { - *self = Halo2Loader::sub(&self.loader, self, &rhs) - } -} - -impl<'a, 'b, C: CurveAffine> MulAssign for Scalar<'a, 'b, C> { - fn mul_assign(&mut self, rhs: Self) { - *self = Halo2Loader::mul(&self.loader, self, &rhs) - } -} - -impl<'a, 'b, 'c, C: CurveAffine> AddAssign<&'c Self> for Scalar<'a, 'b, C> { - fn add_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).add(self, rhs) - } -} - -impl<'a, 'b, 'c, C: CurveAffine> SubAssign<&'c Self> for Scalar<'a, 'b, C> { - fn sub_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).sub(self, rhs) - } -} - -impl<'a, 'b, 'c, C: CurveAffine> MulAssign<&'c Self> for Scalar<'a, 'b, C> { - fn mul_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).mul(self, rhs) - } -} - -#[derive(Clone)] -pub struct EcPoint<'a, 'b, C: CurveAffine> { - loader: Rc>, - index: usize, - pub value: Value>, -} - -impl<'a, 'b, C: CurveAffine> EcPoint<'a, 'b, C> { - pub fn assigned(&self) -> AssignedEcPoint { - match &self.value { - Value::Constant(constant) => { - let point = FixedEccPoint::from_g1( - constant, - self.loader.field_chip().num_limbs, - self.loader.field_chip().limb_bits, - ); - point.assign(self.loader.field_chip(), &mut self.loader.ctx_mut()).unwrap() - } - Value::Assigned(assigned) => assigned.clone(), - } - } -} - -impl<'a, 'b, C: CurveAffine> PartialEq for EcPoint<'a, 'b, C> { - fn eq(&self, other: &Self) -> bool { - self.index == other.index - } -} - -impl<'a, 'b, C: CurveAffine> LoadedEcPoint for EcPoint<'a, 'b, C> -where - C::Base: PrimeField, -{ - type Loader = Rc>; - - fn loader(&self) -> &Self::Loader { - &self.loader - } - - fn multi_scalar_multiplication( - pairs: impl IntoIterator, Self)>, - ) -> Self { - let pairs = pairs.into_iter().collect_vec(); - let loader = &pairs[0].0.loader; - - let mut sum_constants = None; - - let (mut non_scaled, fixed, scaled) = pairs.iter().fold( - (Vec::new(), Vec::new(), Vec::new()), - |(mut non_scaled, mut fixed, mut scaled), (scalar, ec_point)| { - if matches!(scalar.value, Value::Constant(constant) if constant == C::Scalar::one()) - { - non_scaled.push(ec_point.assigned()); - } else { - match &ec_point.value { - Value::Constant(constant_pt) => { - if let Value::Constant(constant_scalar) = scalar.value { - let prod = (constant_pt.clone() * constant_scalar).to_affine(); - sum_constants = if let Some(sum) = sum_constants { - Some(C::Curve::to_affine(&(sum + prod))) - } else { - Some(prod) - }; - } - fixed.push((constant_pt.clone(), scalar.assigned())); - } - Value::Assigned(assigned_pt) => { - scaled.push((assigned_pt.clone(), scalar.assigned())); - } - } - } - (non_scaled, fixed, scaled) - }, - ); - if let Some(sum) = sum_constants { - non_scaled.push(loader.assign_const_ec_point(sum).assigned()); - } - - let mut sum = None; - if !scaled.is_empty() { - sum = loader - .ecc_chip - .multi_scalar_mult::( - &mut loader.ctx_mut(), - &scaled.iter().map(|pair| pair.0.clone()).collect(), - &scaled.into_iter().map(|pair| vec![pair.1]).collect(), - ::NUM_BITS as usize, - 4, - ) - .ok(); - } - if !non_scaled.is_empty() || !fixed.is_empty() { - let rand_point = loader.ecc_chip.load_random_point::(&mut loader.ctx_mut()).unwrap(); - let mut acc = if let Some(prev) = sum { - loader - .ecc_chip - .add_unequal(&mut loader.ctx_mut(), &prev, &rand_point, true) - .unwrap() - } else { - rand_point.clone() - }; - for point in non_scaled.into_iter() { - acc = - loader.ecc_chip.add_unequal(&mut loader.ctx_mut(), &acc, &point, true).unwrap(); - } - for (constant_point, scalar) in fixed.iter() { - if constant_point.is_identity().into() { - continue; - } - let fixed_point = FixedEccPoint::from_g1( - constant_point, - loader.field_chip().num_limbs, - loader.field_chip().limb_bits, - ); - let fixed_msm = loader - .ecc_chip - .fixed_base_scalar_mult( - &mut loader.ctx_mut(), - &fixed_point, - &vec![scalar.clone()], - C::Scalar::NUM_BITS as usize, - 4, - ) - .expect("fixed msms should not fail"); - acc = loader - .ecc_chip - .add_unequal(&mut loader.ctx_mut(), &acc, &fixed_msm, true) - .unwrap(); - } - acc = loader - .ecc_chip - .sub_unequal(&mut loader.ctx_mut(), &acc, &rand_point, true) - .unwrap(); - sum = Some(acc); - } - loader.ec_point(sum.unwrap()) - } -} - -impl<'a, 'b, C: CurveAffine> Debug for EcPoint<'a, 'b, C> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("EcPoint").field("assigned", &self.assigned()).finish() - } -} - -impl<'a, 'b, C: CurveAffine> Add for EcPoint<'a, 'b, C> { - type Output = Self; - - fn add(self, _: Self) -> Self::Output { - todo!() - } -} - -impl<'a, 'b, C: CurveAffine> Sub for EcPoint<'a, 'b, C> { - type Output = Self; - - fn sub(self, _: Self) -> Self::Output { - todo!() - } -} - -impl<'a, 'b, C: CurveAffine> Neg for EcPoint<'a, 'b, C> { - type Output = Self; - - fn neg(self) -> Self::Output { - todo!() - } -} - -impl<'a, 'b, 'c, C: CurveAffine> Add<&'c Self> for EcPoint<'a, 'b, C> { - type Output = Self; - - fn add(self, rhs: &'c Self) -> Self::Output { - self + rhs.clone() - } -} - -impl<'a, 'b, 'c, C: CurveAffine> Sub<&'c Self> for EcPoint<'a, 'b, C> { - type Output = Self; - - fn sub(self, rhs: &'c Self) -> Self::Output { - self - rhs.clone() - } -} - -impl<'a, 'b, C: CurveAffine> AddAssign for EcPoint<'a, 'b, C> { - fn add_assign(&mut self, rhs: Self) { - *self = self.clone() + rhs - } -} - -impl<'a, 'b, C: CurveAffine> SubAssign for EcPoint<'a, 'b, C> { - fn sub_assign(&mut self, rhs: Self) { - *self = self.clone() - rhs - } -} - -impl<'a, 'b, 'c, C: CurveAffine> AddAssign<&'c Self> for EcPoint<'a, 'b, C> { - fn add_assign(&mut self, rhs: &'c Self) { - *self = self.clone() + rhs - } -} - -impl<'a, 'b, 'c, C: CurveAffine> SubAssign<&'c Self> for EcPoint<'a, 'b, C> { - fn sub_assign(&mut self, rhs: &'c Self) { - *self = self.clone() - rhs - } -} - -impl<'a, 'b, C: CurveAffine> ScalarLoader for Rc> { - type LoadedScalar = Scalar<'a, 'b, C>; - - fn load_const(&self, value: &C::Scalar) -> Scalar<'a, 'b, C> { - self.scalar(Value::Constant(*value)) - } - - fn assert_eq( - &self, - annotation: &str, - lhs: &Self::LoadedScalar, - rhs: &Self::LoadedScalar, - ) -> Result<(), crate::Error> { - match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => { - assert_eq!(*lhs, *rhs); - } - _ => { - let loader = lhs.loader(); - loader - .gate() - .assert_equal(&mut loader.ctx_mut(), &lhs.to_quantum(), &rhs.to_quantum()) - .expect(annotation); - } - } - Ok(()) - } - - fn sum_with_coeff_and_constant( - &self, - values: &[(C::Scalar, &Self::LoadedScalar)], - constant: C::Scalar, - ) -> Self::LoadedScalar { - let mut a = Vec::with_capacity(values.len() + 1); - let mut b = Vec::with_capacity(values.len() + 1); - if constant != C::Scalar::zero() { - a.push(Constant(C::Scalar::one())); - b.push(Constant(constant)); - } - a.extend(values.iter().map(|(_, a)| match &a.value { - Value::Constant(constant) => Constant(*constant), - Value::Assigned(assigned) => Existing(assigned), - })); - b.extend(values.iter().map(|(c, _)| Constant(*c))); - let (_, _, sum) = self.gate().inner_product(&mut self.ctx_mut(), &a, &b).unwrap(); - - self.scalar(Value::Assigned(sum)) - } - - fn sum_products_with_coeff_and_constant( - &self, - values: &[(C::Scalar, &Self::LoadedScalar, &Self::LoadedScalar)], - constant: C::Scalar, - ) -> Self::LoadedScalar { - let mut prods = Vec::with_capacity(values.len()); - for (c, a, b) in values { - let a = match &a.value { - Value::Assigned(assigned) => Existing(assigned), - Value::Constant(constant) => Constant(*constant), - }; - let b = match &b.value { - Value::Assigned(assigned) => Existing(assigned), - Value::Constant(constant) => Constant(*constant), - }; - prods.push((*c, a, b)); - } - let output = self - .gate() - .sum_products_with_coeff_and_var(&mut self.ctx_mut(), &prods[..], &Constant(constant)) - .unwrap(); - self.scalar(Value::Assigned(output)) - } -} - -impl<'a, 'b, C: CurveAffine> EcPointLoader for Rc> { - type LoadedEcPoint = EcPoint<'a, 'b, C>; - - fn ec_point_load_const(&self, ec_point: &C) -> EcPoint<'a, 'b, C> { - self.assign_const_ec_point(*ec_point) - } - - fn ec_point_assert_eq( - &self, - annotation: &str, - lhs: &Self::LoadedEcPoint, - rhs: &Self::LoadedEcPoint, - ) -> Result<(), crate::Error> { - let loader = lhs.loader(); - match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => { - assert_eq!(*lhs, *rhs); - } - _ => { - loader - .ecc_chip - .assert_equal(&mut loader.ctx_mut(), &lhs.assigned(), &rhs.assigned()) - .expect(annotation); - } - } - Ok(()) - } -} - -impl<'a, 'b, C: CurveAffine> Loader for Rc> {} diff --git a/src/loader/halo2/poseidon_chip.rs b/src/loader/halo2/poseidon_chip.rs deleted file mode 100644 index 3ac90fe7..00000000 --- a/src/loader/halo2/poseidon_chip.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::marker::PhantomData; - -use crate::loader::{LoadedScalar, ScalarLoader}; -use crate::util::arithmetic::{FieldExt, PrimeField}; -// taken from https://github.com/scroll-tech/halo2-snark-aggregator/tree/main/halo2-snark-aggregator-api/src/hash -use poseidon::{SparseMDSMatrix, Spec, State}; - -struct PoseidonState, const T: usize, const RATE: usize> { - s: [L; T], - _marker: PhantomData, -} - -impl, const T: usize, const RATE: usize> - PoseidonState -{ - fn x_power5_with_constant(x: &L, constant: &F) -> L { - let x2 = x.clone() * x; - let x4 = x2.clone() * x2; - LoadedScalar::mul_add_constant(&x, &x4, constant) - } - - fn sbox_full(&mut self, constants: &[F; T]) { - for (x, constant) in self.s.iter_mut().zip(constants.iter()) { - *x = Self::x_power5_with_constant(x, constant); - } - } - - fn sbox_part(&mut self, constant: &F) { - let x = &mut self.s[0]; - *x = Self::x_power5_with_constant(x, constant); - } - - fn absorb_with_pre_constants(&mut self, inputs: Vec, pre_constants: &[F; T]) { - assert!(inputs.len() < T); - let offset = inputs.len() + 1; - - self.s[0] = L::Loader::sum_with_const( - self.s[0].loader(), - &self.s[..1].iter().collect::>()[..], - pre_constants[0], - ); - - for ((x, constant), input) in - self.s.iter_mut().skip(1).zip(pre_constants.iter().skip(1)).zip(inputs.iter()) - { - *x = L::Loader::sum_with_const(x.loader(), &[x, input], *constant); - } - - for (i, (x, constant)) in - self.s.iter_mut().skip(offset).zip(pre_constants.iter().skip(offset)).enumerate() - { - *x = L::Loader::sum_with_const( - x.loader(), - &[x], - if i == 0 { F::one() + constant } else { *constant }, - ); - } - } - - fn apply_mds(&mut self, mds: &[[F; T]; T]) { - let res = mds - .iter() - .map(|row| { - let a = self - .s - .iter() - .zip(row.iter()) - .map(|(e, word)| (*word, e.clone())) - .collect::>(); - - L::Loader::sum_with_coeff( - a[0].1.loader(), - &a.iter().map(|(c, b)| (*c, b)).collect::>()[..], - ) - }) - .collect::>(); - - self.s = res.try_into().unwrap(); - } - - fn apply_sparse_mds(&mut self, mds: &SparseMDSMatrix) { - let a = self - .s - .iter() - .zip(mds.row().iter()) - .map(|(e, word)| (*word, e.clone())) - .collect::>(); - - let mut res = vec![L::Loader::sum_with_coeff( - a[0].1.loader(), - &a.iter().map(|(c, b)| (*c, b)).collect::>()[..], - )]; - - for (e, x) in mds.col_hat().iter().zip(self.s.iter().skip(1)) { - res.push(L::Loader::sum_with_coeff(x.loader(), &[(*e, &self.s[0]), (F::one(), x)])); - } - - for (x, new_x) in self.s.iter_mut().zip(res.into_iter()) { - *x = new_x - } - } -} - -pub struct PoseidonChip< - F: PrimeField + FieldExt, - L: LoadedScalar, - const T: usize, - const RATE: usize, -> { - state: PoseidonState, - spec: Spec, - absorbing: Vec, -} - -impl, const T: usize, const RATE: usize> - PoseidonChip -{ - pub fn new(loader: L::Loader, r_f: usize, r_p: usize) -> Self { - let init_state = State::::default() - .words() - .iter() - .map(|x| loader.load_const(x)) - .collect::>(); - - Self { - spec: Spec::new(r_f, r_p), - state: PoseidonState { s: init_state.try_into().unwrap(), _marker: PhantomData }, - absorbing: Vec::new(), - } - } - - pub fn update(&mut self, elements: &[L]) { - self.absorbing.extend_from_slice(elements); - } - - pub fn squeeze(&mut self) -> L { - let mut input_elements = vec![]; - input_elements.append(&mut self.absorbing); - - let mut padding_offset = 0; - - for chunk in input_elements.chunks(RATE) { - padding_offset = RATE - chunk.len(); - self.permutation(chunk.to_vec()); - } - - if padding_offset == 0 { - self.permutation(vec![]); - } - - self.state.s[1].clone() - } - - fn permutation(&mut self, inputs: Vec) { - let r_f = self.spec.r_f() / 2; - let mds = &self.spec.mds_matrices().mds().rows(); - - let constants = &self.spec.constants().start(); - self.state.absorb_with_pre_constants(inputs, &constants[0]); - for constants in constants.iter().skip(1).take(r_f - 1) { - self.state.sbox_full(constants); - self.state.apply_mds(mds); - } - - let pre_sparse_mds = &self.spec.mds_matrices().pre_sparse_mds().rows(); - self.state.sbox_full(constants.last().unwrap()); - self.state.apply_mds(&pre_sparse_mds); - - let sparse_matrices = &self.spec.mds_matrices().sparse_matrices(); - let constants = &self.spec.constants().partial(); - for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { - self.state.sbox_part(constant); - self.state.apply_sparse_mds(sparse_mds); - } - - let constants = &self.spec.constants().end(); - for constants in constants.iter() { - self.state.sbox_full(constants); - self.state.apply_mds(mds); - } - self.state.sbox_full(&[F::zero(); T]); - self.state.apply_mds(mds); - } -} diff --git a/src/loader/halo2/test/circuit.rs b/src/loader/halo2/test/circuit.rs deleted file mode 100644 index c480006d..00000000 --- a/src/loader/halo2/test/circuit.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod standard; diff --git a/src/loader/halo2/test/circuit/standard.rs b/src/loader/halo2/test/circuit/standard.rs deleted file mode 100644 index 90f30f2b..00000000 --- a/src/loader/halo2/test/circuit/standard.rs +++ /dev/null @@ -1,122 +0,0 @@ -use crate::util::arithmetic::FieldExt; -use halo2_proofs::{ - circuit::{floor_planner::V1, Layouter, Value}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, - poly::Rotation, -}; -use rand::RngCore; - -#[allow(dead_code)] -#[derive(Clone)] -pub struct StandardPlonkConfig { - a: Column, - b: Column, - c: Column, - q_a: Column, - q_b: Column, - q_c: Column, - q_ab: Column, - constant: Column, - instance: Column, -} - -impl StandardPlonkConfig { - pub fn configure(meta: &mut ConstraintSystem) -> Self { - let [a, b, c] = [(); 3].map(|_| meta.advice_column()); - let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); - let instance = meta.instance_column(); - - [a, b, c].map(|column| meta.enable_equality(column)); - - meta.create_gate( - "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", - |meta| { - let [a, b, c] = [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); - let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] - .map(|column| meta.query_fixed(column, Rotation::cur())); - let instance = meta.query_instance(instance, Rotation::cur()); - Some( - q_a * a.clone() - + q_b * b.clone() - + q_c * c - + q_ab * a * b - + constant - + instance, - ) - }, - ); - - StandardPlonkConfig { - a, - b, - c, - q_a, - q_b, - q_c, - q_ab, - constant, - instance, - } - } -} - -#[derive(Clone, Default)] -pub struct StandardPlonk(F); - -impl StandardPlonk { - pub fn rand(mut rng: R) -> Self { - Self(F::from(rng.next_u32() as u64)) - } - - pub fn instances(&self) -> Vec> { - vec![vec![self.0]] - } -} - -impl Circuit for StandardPlonk { - type Config = StandardPlonkConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - meta.set_minimum_degree(4); - StandardPlonkConfig::configure(meta) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; - region.assign_fixed(|| "", config.q_a, 0, || Value::known(-F::one()))?; - - region.assign_advice(|| "", config.a, 1, || Value::known(-F::from(5)))?; - for (column, idx) in [ - config.q_a, - config.q_b, - config.q_c, - config.q_ab, - config.constant, - ] - .iter() - .zip(1..) - { - region.assign_fixed(|| "", *column, 1, || Value::known(F::from(idx)))?; - } - - let a = region.assign_advice(|| "", config.a, 2, || Value::known(F::one()))?; - a.copy_advice(|| "", &mut region, config.b, 3)?; - a.copy_advice(|| "", &mut region, config.c, 4)?; - - Ok(()) - }, - ) - } -} diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs deleted file mode 100644 index 6ceb70fa..00000000 --- a/src/pcs/kzg/accumulator.rs +++ /dev/null @@ -1,169 +0,0 @@ -use crate::{loader::Loader, util::arithmetic::CurveAffine}; -use std::fmt::Debug; - -#[derive(Clone, Debug)] -pub struct KzgAccumulator -where - C: CurveAffine, - L: Loader, -{ - pub lhs: L::LoadedEcPoint, - pub rhs: L::LoadedEcPoint, -} - -impl KzgAccumulator -where - C: CurveAffine, - L: Loader, -{ - pub fn new(lhs: L::LoadedEcPoint, rhs: L::LoadedEcPoint) -> Self { - Self { lhs, rhs } - } -} - -/// `AccumulatorEncoding` that encodes `Accumulator` into limbs. -/// -/// Since in circuit everything are in scalar field, but `Accumulator` might contain base field elements, so we split them into limbs. -/// The const generic `LIMBS` and `BITS` respectively represents how many limbs -/// a base field element are split into and how many bits each limbs could have. -#[derive(Clone, Debug)] -pub struct LimbsEncoding; - -mod native { - use crate::{ - loader::native::NativeLoader, - pcs::{ - kzg::{KzgAccumulator, LimbsEncoding}, - AccumulatorEncoding, PolynomialCommitmentScheme, - }, - util::{ - arithmetic::{fe_from_limbs, CurveAffine}, - Itertools, - }, - Error, - }; - - impl AccumulatorEncoding - for LimbsEncoding - where - C: CurveAffine, - PCS: PolynomialCommitmentScheme< - C, - NativeLoader, - Accumulator = KzgAccumulator, - >, - { - fn from_repr(limbs: Vec) -> Result { - assert_eq!(limbs.len(), 4 * LIMBS); - - let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs - .chunks(LIMBS) - .into_iter() - .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) - .collect_vec() - .try_into() - .unwrap(); - let accumulator = KzgAccumulator::new( - C::from_xy(lhs_x, lhs_y).unwrap(), - C::from_xy(rhs_x, rhs_y).unwrap(), - ); - - Ok(accumulator) - } - } -} - -#[cfg(feature = "loader_evm")] -mod evm { - use crate::{ - loader::evm::{EvmLoader, Scalar}, - pcs::{ - kzg::{KzgAccumulator, LimbsEncoding}, - AccumulatorEncoding, PolynomialCommitmentScheme, - }, - util::{ - arithmetic::{CurveAffine, PrimeField}, - Itertools, - }, - Error, - }; - use std::rc::Rc; - - impl AccumulatorEncoding, PCS> - for LimbsEncoding - where - C: CurveAffine, - C::Scalar: PrimeField, - PCS: PolynomialCommitmentScheme< - C, - Rc, - Accumulator = KzgAccumulator>, - >, - { - fn from_repr(limbs: Vec) -> Result { - assert_eq!(limbs.len(), 4 * LIMBS); - - let loader = limbs[0].loader(); - - let [lhs_x, lhs_y, rhs_x, rhs_y]: [[_; LIMBS]; 4] = limbs - .chunks(LIMBS) - .into_iter() - .map(|limbs| limbs.to_vec().try_into().unwrap()) - .collect_vec() - .try_into() - .unwrap(); - let accumulator = KzgAccumulator::new( - loader.ec_point_from_limbs::(lhs_x, lhs_y), - loader.ec_point_from_limbs::(rhs_x, rhs_y), - ); - - Ok(accumulator) - } - } -} - -#[cfg(feature = "loader_halo2")] -mod halo2 { - use crate::{ - loader::halo2::{Halo2Loader, Scalar}, - loader::LoadedScalar, - pcs::{ - kzg::{KzgAccumulator, LimbsEncoding}, - AccumulatorEncoding, PolynomialCommitmentScheme, - }, - util::{arithmetic::CurveAffine, Itertools}, - Error, - }; - use std::rc::Rc; - - impl<'a, 'b, C, PCS, const LIMBS: usize, const BITS: usize> - AccumulatorEncoding>, PCS> for LimbsEncoding - where - C: CurveAffine, - PCS: PolynomialCommitmentScheme< - C, - Rc>, - Accumulator = KzgAccumulator>>, - >, - { - fn from_repr(limbs: Vec>) -> Result { - assert_eq!(limbs.len(), 4 * LIMBS); - - let loader = limbs[0].loader(); - - let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); - let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( - |assigned_limbs| { - loader.assign_ec_point_from_limbs( - assigned_limbs[..LIMBS].to_vec(), - assigned_limbs[LIMBS..2 * LIMBS].to_vec(), - ) - }, - ); - - let accumulator = KzgAccumulator::new(lhs, rhs); - - Ok(accumulator) - } - } -} diff --git a/src/system.rs b/src/system.rs deleted file mode 100644 index 5d5aa99c..00000000 --- a/src/system.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[cfg(feature = "system_halo2")] -pub mod halo2; diff --git a/src/system/halo2/test/README.md b/src/system/halo2/test/README.md deleted file mode 100644 index 88becc84..00000000 --- a/src/system/halo2/test/README.md +++ /dev/null @@ -1,19 +0,0 @@ -In `plonk-verifier` root directory: - -1. Create `params` folder. Do not reuse params generated from other versions of `halo2_proofs` for now. - -2. Create `configs/verify_circuit.config`. - -3. Create `src/system/halo2/test/data` directory. Then run - -For single evm circuit verification: - -``` -cargo test --release -- --nocapture system::halo2::test::kzg::halo2::zkevm::test_shplonk_bench_evm_circuit --exact -``` - -For evm circuit + state circuit aggregation: - -``` -cargo test --release -- --nocapture system::halo2::test::kzg::halo2::zkevm::test_shplonk_bench_evm_and_state --exact -``` diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs deleted file mode 100644 index a72d4284..00000000 --- a/src/system/halo2/transcript/halo2.rs +++ /dev/null @@ -1,424 +0,0 @@ -use crate::{ - loader::{ - halo2::{ - loader::{AssignedEcPoint, EcPoint, Halo2Loader, Scalar, Value}, - poseidon_chip::PoseidonChip, - }, - native::NativeLoader, - Loader, - }, - util::{ - arithmetic::{Coordinates, CurveAffine, PrimeField}, - transcript::{Transcript, TranscriptRead, TranscriptWrite}, - }, - Error, -}; -use ::poseidon::Poseidon; -use halo2_base::utils::{biguint_to_fe, fe_to_biguint}; -use halo2_curves::group::GroupEncoding; -use halo2_proofs::{circuit, transcript::EncodedChallenge}; -use std::{ - io::{self, Read, Write}, - marker::PhantomData, - rc::Rc, - slice::from_ref, -}; - -pub struct PoseidonTranscript< - C: CurveAffine, - L: Loader, - S, - B, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, -> { - loader: L, - stream: S, - buf: B, - _marker: PhantomData, -} - -impl< - 'a, - 'b, - R: Read, - C: CurveAffine, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > - PoseidonTranscript< - C, - Rc>, - circuit::Value, - PoseidonChip, T, RATE>, - T, - RATE, - R_F, - R_P, - > -{ - pub fn new(loader: &Rc>, stream: circuit::Value) -> Self { - Self { - loader: loader.clone(), - stream, - buf: PoseidonChip::new(loader.clone(), R_F, R_P), - _marker: PhantomData, - } - } - - fn encode_point(&self, v: &AssignedEcPoint) -> Vec> { - let x_native = v.x.native.clone(); - let y_native = v.y.native.clone(); - [x_native, y_native] - .into_iter() - .map(|x| (&self.loader).scalar(Value::Assigned(x))) - .collect() - } -} - -impl< - 'a, - 'b, - R: Read, - C: CurveAffine, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript>> - for PoseidonTranscript< - C, - Rc>, - circuit::Value, - PoseidonChip, T, RATE>, - T, - RATE, - R_F, - R_P, - > -{ - fn loader(&self) -> &Rc> { - &self.loader - } - - fn squeeze_challenge(&mut self) -> Scalar<'a, 'b, C> { - self.buf.squeeze() - } - - fn common_scalar(&mut self, scalar: &Scalar<'a, 'b, C>) -> Result<(), Error> { - self.buf.update(from_ref(scalar)); - Ok(()) - } - - fn common_ec_point(&mut self, ec_point: &EcPoint<'a, 'b, C>) -> Result<(), Error> { - self.buf.update(&self.encode_point(&ec_point.assigned())[..]); - Ok(()) - } -} - -impl< - 'a, - 'b, - R: Read, - C: CurveAffine, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead>> - for PoseidonTranscript< - C, - Rc>, - circuit::Value, - PoseidonChip, T, RATE>, - T, - RATE, - R_F, - R_P, - > -{ - fn read_scalar(&mut self) -> Result, Error> { - let scalar = self.stream.as_mut().and_then(|stream| { - let mut data = ::Repr::default(); - if stream.read_exact(data.as_mut()).is_err() { - return circuit::Value::unknown(); - } - Option::::from(C::Scalar::from_repr(data)) - .map(circuit::Value::known) - .unwrap_or_else(circuit::Value::unknown) - }); - let scalar = self.loader.assign_scalar(scalar); - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result, Error> { - let ec_point = self.stream.as_mut().and_then(|stream| { - let mut compressed = C::Repr::default(); - if stream.read_exact(compressed.as_mut()).is_err() { - return circuit::Value::unknown(); - } - Option::::from(C::from_bytes(&compressed)) - .map(circuit::Value::known) - .unwrap_or_else(circuit::Value::unknown) - }); - let ec_point = self.loader.assign_ec_point(ec_point); - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl - PoseidonTranscript, T, RATE, R_F, R_P> -{ - pub fn new(stream: S) -> Self { - Self { loader: NativeLoader, stream, buf: Poseidon::new(R_F, R_P), _marker: PhantomData } - } -} - -impl - Transcript - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn loader(&self) -> &NativeLoader { - &self.loader - } - - fn squeeze_challenge(&mut self) -> C::Scalar { - self.buf.squeeze() - } - - fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { - self.buf.update(&[*scalar]); - Ok(()) - } - - fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { - let coords: Coordinates = Option::from(ec_point.coordinates()).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Cannot write points at infinity to the transcript".to_string(), - ) - })?; - let x = biguint_to_fe(&fe_to_biguint(coords.x())); - let y = biguint_to_fe(&fe_to_biguint(coords.y())); - self.buf.update(&[x, y]); - Ok(()) - } -} - -impl< - C: CurveAffine, - R: Read, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn read_scalar(&mut self) -> Result { - let mut data = ::Repr::default(); - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { - Error::Transcript(io::ErrorKind::Other, "Invalid scalar encoding in proof".to_string()) - })?; - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result { - let mut data = C::Repr::default(); - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - let ec_point = - Option::::from(::from_bytes(&data)).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid elliptic curve point encoding in proof".to_string(), - ) - })?; - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > PoseidonTranscript, T, RATE, R_F, R_P> -{ - pub fn stream_mut(&mut self) -> &mut W { - &mut self.stream - } - - pub fn finalize(self) -> W { - self.stream - } -} - -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptWrite - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { - self.common_scalar(&scalar)?; - let data = scalar.to_repr(); - self.stream_mut().write_all(data.as_ref()).map_err(|err| { - Error::Transcript(err.kind(), "Failed to write scalar to transcript".to_string()) - }) - } - - fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error> { - self.common_ec_point(&ec_point)?; - let data = ec_point.to_bytes(); - self.stream_mut().write_all(data.as_ref()).map_err(|err| { - Error::Transcript( - err.kind(), - "Failed to write elliptic curve to transcript".to_string(), - ) - }) - } -} - -pub struct ChallengeScalar(C::Scalar); - -impl EncodedChallenge for ChallengeScalar { - type Input = C::Scalar; - - fn new(challenge_input: &C::Scalar) -> Self { - ChallengeScalar(*challenge_input) - } - - fn get_scalar(&self) -> C::Scalar { - self.0 - } -} - -impl - halo2_proofs::transcript::Transcript> - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn squeeze_challenge(&mut self) -> ChallengeScalar { - ChallengeScalar::new(&Transcript::squeeze_challenge(self)) - } - - fn common_point(&mut self, ec_point: C) -> io::Result<()> { - match Transcript::common_ec_point(self, &ec_point) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } - - fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - match Transcript::common_scalar(self, &scalar) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } -} - -impl< - C: CurveAffine, - R: Read, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptRead> - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn read_point(&mut self) -> io::Result { - match TranscriptRead::read_ec_point(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value), - } - } - - fn read_scalar(&mut self) -> io::Result { - match TranscriptRead::read_scalar(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value), - } - } -} - -impl< - C: CurveAffine, - R: Read, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptReadBuffer> - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn init(reader: R) -> Self { - Self::new(reader) - } -} - -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptWrite> - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn write_point(&mut self, ec_point: C) -> io::Result<()> { - halo2_proofs::transcript::Transcript::>::common_point( - self, ec_point, - )?; - let data = ec_point.to_bytes(); - self.stream_mut().write_all(data.as_ref()) - } - - fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - halo2_proofs::transcript::Transcript::>::common_scalar(self, scalar)?; - let data = scalar.to_repr(); - self.stream_mut().write_all(data.as_ref()) - } -} - -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptWriterBuffer> - for PoseidonTranscript, T, RATE, R_F, R_P> -{ - fn init(writer: W) -> Self { - Self::new(writer) - } - - fn finalize(self) -> W { - self.finalize() - } -} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index ec2e450b..00000000 --- a/src/util.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod arithmetic; -pub mod msm; -pub mod protocol; -pub mod transcript; - -pub(crate) use itertools::Itertools; diff --git a/src/util/msm.rs b/src/util/msm.rs deleted file mode 100644 index a7a3d45d..00000000 --- a/src/util/msm.rs +++ /dev/null @@ -1,203 +0,0 @@ -use crate::{ - loader::{LoadedEcPoint, Loader}, - util::arithmetic::CurveAffine, -}; -use std::{ - default::Default, - iter::{self, Sum}, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; - -#[derive(Clone, Debug)] -pub struct Msm> { - constant: Option, - scalars: Vec, - bases: Vec, -} - -impl Default for Msm -where - C: CurveAffine, - L: Loader, -{ - fn default() -> Self { - Self { - constant: None, - scalars: Vec::new(), - bases: Vec::new(), - } - } -} - -impl Msm -where - C: CurveAffine, - L: Loader, -{ - pub fn constant(constant: L::LoadedScalar) -> Self { - Msm { - constant: Some(constant), - ..Default::default() - } - } - - pub fn base(base: L::LoadedEcPoint) -> Self { - let one = base.loader().load_one(); - Msm { - scalars: vec![one], - bases: vec![base], - ..Default::default() - } - } - - pub(crate) fn size(&self) -> usize { - self.bases.len() - } - - pub(crate) fn split(mut self) -> (Self, Option) { - let constant = self.constant.take(); - (self, constant) - } - - pub(crate) fn try_into_constant(self) -> Option { - self.bases.is_empty().then(|| self.constant.unwrap()) - } - - pub fn evaluate(self, gen: Option) -> L::LoadedEcPoint { - let gen = gen.map(|gen| { - self.bases - .first() - .unwrap() - .loader() - .ec_point_load_const(&gen) - }); - L::LoadedEcPoint::multi_scalar_multiplication( - iter::empty() - .chain(self.constant.map(|constant| (constant, gen.unwrap()))) - .chain(self.scalars.into_iter().zip(self.bases.into_iter())), - ) - } - - pub fn scale(&mut self, factor: &L::LoadedScalar) { - if let Some(constant) = self.constant.as_mut() { - *constant *= factor; - } - for scalar in self.scalars.iter_mut() { - *scalar *= factor - } - } - - pub fn push(&mut self, scalar: L::LoadedScalar, base: L::LoadedEcPoint) { - if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { - self.scalars[pos] += scalar; - } else { - self.scalars.push(scalar); - self.bases.push(base); - } - } - - pub fn extend(&mut self, mut other: Self) { - match (self.constant.as_mut(), other.constant.as_ref()) { - (Some(lhs), Some(rhs)) => *lhs += rhs, - (None, Some(_)) => self.constant = other.constant.take(), - _ => {} - }; - for (scalar, base) in other.scalars.into_iter().zip(other.bases) { - self.push(scalar, base); - } - } -} - -impl Add> for Msm -where - C: CurveAffine, - L: Loader, -{ - type Output = Msm; - - fn add(mut self, rhs: Msm) -> Self::Output { - self.extend(rhs); - self - } -} - -impl AddAssign> for Msm -where - C: CurveAffine, - L: Loader, -{ - fn add_assign(&mut self, rhs: Msm) { - self.extend(rhs); - } -} - -impl Sub> for Msm -where - C: CurveAffine, - L: Loader, -{ - type Output = Msm; - - fn sub(mut self, rhs: Msm) -> Self::Output { - self.extend(-rhs); - self - } -} - -impl SubAssign> for Msm -where - C: CurveAffine, - L: Loader, -{ - fn sub_assign(&mut self, rhs: Msm) { - self.extend(-rhs); - } -} - -impl Mul<&L::LoadedScalar> for Msm -where - C: CurveAffine, - L: Loader, -{ - type Output = Msm; - - fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { - self.scale(rhs); - self - } -} - -impl MulAssign<&L::LoadedScalar> for Msm -where - C: CurveAffine, - L: Loader, -{ - fn mul_assign(&mut self, rhs: &L::LoadedScalar) { - self.scale(rhs); - } -} - -impl Neg for Msm -where - C: CurveAffine, - L: Loader, -{ - type Output = Msm; - fn neg(mut self) -> Msm { - self.constant = self.constant.map(|constant| -constant); - for scalar in self.scalars.iter_mut() { - *scalar = -scalar.clone(); - } - self - } -} - -impl Sum for Msm -where - C: CurveAffine, - L: Loader, -{ - fn sum>(iter: I) -> Self { - iter.reduce(|acc, item| acc + item).unwrap_or_default() - } -}