Skip to content

Commit

Permalink
Merge pull request #344 from clearmatics/input-hashing-prepare
Browse files Browse the repository at this point in the history
Changes in preparation for proof input hasing
  • Loading branch information
AntoineRondelet authored Jan 29, 2021
2 parents e6a41a5 + d94a36a commit 024e9dc
Show file tree
Hide file tree
Showing 24 changed files with 527 additions and 180 deletions.
6 changes: 1 addition & 5 deletions client/test_commands/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def wait_for_tx_update_mk_tree(
def get_mix_parameters_components(
zeth_client: MixerClient,
prover_client: ProverClient,
zksnark: IZKSnarkProvider,
mk_tree: MerkleTree,
sender_ownership_keypair: OwnershipKeyPair,
inputs: List[Tuple[int, ZethNote]],
Expand All @@ -79,8 +78,7 @@ def get_mix_parameters_components(
compute_h_sig_cb)
prover_inputs, signing_keypair = zeth_client.create_prover_inputs(
mix_call_desc)
ext_proof_proto = prover_client.get_proof(prover_inputs)
ext_proof = zksnark.extended_proof_from_proto(ext_proof_proto)
ext_proof = prover_client.get_proof(prover_inputs)
return (
prover_inputs.js_outputs[0],
prover_inputs.js_outputs[1],
Expand Down Expand Up @@ -262,7 +260,6 @@ def compute_h_sig_attack_nf(
get_mix_parameters_components(
zeth_client,
prover_client,
zksnark,
mk_tree,
keystore["Charlie"].ownership_keypair(), # sender
[input1, input2],
Expand Down Expand Up @@ -370,7 +367,6 @@ def charlie_corrupt_bob_deposit(
get_mix_parameters_components(
zeth_client,
prover_client,
zksnark,
mk_tree,
keystore["Bob"].ownership_keypair(),
[input1, input2],
Expand Down
44 changes: 44 additions & 0 deletions client/tests/test_input_hasher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2015-2021 Clearmatics Technologies Ltd
#
# SPDX-License-Identifier: LGPL-3.0+

from zeth.core.mimc import MiMC7, MiMC31
from zeth.core.input_hasher import InputHasher
from unittest import TestCase

DUMMY_INPUT_VALUES = [-1, 0, 1]


class TestInputHasher(TestCase):

def test_input_hasher_simple(self) -> None:
# Some very simple cases
mimc = MiMC7()
input_hasher = InputHasher(mimc, 7)
self.assertEqual(mimc.hash_int(7, 0), input_hasher.hash([]))
self.assertEqual(
mimc.hash_int(mimc.hash_int(7, 1), 1), input_hasher.hash([1]))
self.assertEqual(
mimc.hash_int(
mimc.hash_int(
mimc.hash_int(7, 1), 2),
2),
input_hasher.hash([1, 2]))

def test_input_hasher_mimc7(self) -> None:
mimc = MiMC7()
input_hasher = InputHasher(mimc)
values = [x % mimc.prime for x in DUMMY_INPUT_VALUES]
# pylint:disable=line-too-long
expect = 5568471640435576440988459485125198359192118312228711462978763973844457667180 # noqa
# pylint:enable=line-too-long
self.assertEqual(expect, input_hasher.hash(values))

def test_input_hasher_mimc31(self) -> None:
mimc = MiMC31()
input_hasher = InputHasher(mimc)
values = [x % mimc.prime for x in DUMMY_INPUT_VALUES]
# pylint: disable=line-too-long
expect = 1029772481427643815119825324071277815354972734622711297984795198139876181749 # noqa
# pylint: enable=line-too-long
self.assertEqual(expect, input_hasher.hash(values))
8 changes: 1 addition & 7 deletions client/zeth/cli/zeth_get_verification_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from click import command, option, Context, pass_context

from zeth.core.zksnark import get_zksnark_provider
from zeth.cli.utils import create_prover_client
import json
from typing import Optional
Expand All @@ -21,12 +20,7 @@ def get_verification_key(ctx: Context, vk_out: Optional[str]) -> None:
# Get the VK (proto object)
client_ctx = ctx.obj
prover_client = create_prover_client(client_ctx)
vk_proto = prover_client.get_verification_key()

# Get a zksnark provider and convert the VK to json
zksnark_name = prover_client.get_configuration().zksnark_name
zksnark = get_zksnark_provider(zksnark_name)
vk = zksnark.verification_key_from_proto(vk_proto)
vk = prover_client.get_verification_key()
vk_json = vk.to_json_dict()

# Write the json to stdout or a file
Expand Down
36 changes: 36 additions & 0 deletions client/zeth/core/input_hasher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2015-2021 Clearmatics Technologies Ltd
#
# SPDX-License-Identifier: LGPL-3.0+

from zeth.core.mimc import MiMCBase
from typing import List


# Default seed, generated as:
# zeth.core.mimc._keccak_256(
# zeth.core.mimc._str_to_bytes("clearmatics_hash_seed"))
DEFAULT_IV_UINT256 = \
13196537064117388418196223856311987714388543839552400408340921397545324034315


class InputHasher:
"""
Note that this is currently experimental code. Hash a series of field
elements via the Merkle-Damgard construction on a MiMC compression
function. Note that since this function only accepts whole numbers of
scalar field elements, there is no ambiguity w.r.t to padding and we could
technically omit the finalization step. It has been kept for now, to allow
time for further consideration, and in case the form of the hasher changes
(e.g. in case we want to be able to hash arbitrary bit strings in the
future).
"""
def __init__(self, compression_fn: MiMCBase, iv: int = DEFAULT_IV_UINT256):
assert compression_fn.prime < (2 << 256)
self._compression_fn = compression_fn
self._iv = iv % compression_fn.prime

def hash(self, values: List[int]) -> int:
current = self._iv
for m in values:
current = self._compression_fn.hash_int(current, m)
return self._compression_fn.hash_int(current, len(values))
11 changes: 9 additions & 2 deletions client/zeth/core/mimc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,15 @@ def hash(self, left: bytes, right: bytes) -> bytes:
"""
x = int.from_bytes(left, byteorder='big') % self.prime
y = int.from_bytes(right, byteorder='big') % self.prime
result = (self.encrypt(x, y) + x + y) % self.prime
return result.to_bytes(32, byteorder='big')
return self.hash_int(x, y).to_bytes(32, byteorder='big')

def hash_int(self, x: int, y: int) -> int:
"""
Similar to hash, but use field elements directly.
"""
assert x < self.prime
assert y < self.prime
return (self.encrypt(x, y) + x + y) % self.prime

@abstractmethod
def mimc_round(self, message: int, key: int, rc: int) -> int:
Expand Down
12 changes: 4 additions & 8 deletions client/zeth/core/mixer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,12 @@ def deploy(
Deploy Zeth contracts.
"""
prover_config = prover_client.get_configuration()
zksnark = get_zksnark_provider(prover_config.zksnark_name)
vk_proto = prover_client.get_verification_key()
pp = prover_config.pairing_parameters
vk = zksnark.verification_key_from_proto(vk_proto)
vk = prover_client.get_verification_key()
deploy_gas = deploy_gas or constants.DEPLOYMENT_GAS_WEI

contracts_dir = get_contracts_dir()
zksnark = get_zksnark_provider(prover_config.zksnark_name)
pp = prover_config.pairing_parameters
mixer_name = zksnark.get_contract_name(pp)
mixer_src = os.path.join(contracts_dir, mixer_name + ".sol")

Expand Down Expand Up @@ -581,11 +580,8 @@ def create_mix_parameters_and_signing_key(
prover_inputs, signing_keypair = MixerClient.create_prover_inputs(
mix_call_desc)

zksnark = get_zksnark_provider(self.prover_config.zksnark_name)

# Query the prover_server for the related proof
ext_proof_proto = prover_client.get_proof(prover_inputs)
ext_proof = zksnark.extended_proof_from_proto(ext_proof_proto)
ext_proof = prover_client.get_proof(prover_inputs)

# Create the final MixParameters object
mix_params = self.create_mix_parameters_from_proof(
Expand Down
24 changes: 17 additions & 7 deletions client/zeth/core/prover_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# SPDX-License-Identifier: LGPL-3.0+

from __future__ import annotations
from .pairing import PairingParameters, pairing_parameters_from_proto
from zeth.core.zksnark import IZKSnarkProvider, get_zksnark_provider, \
IVerificationKey, ExtendedProof
from zeth.core.pairing import PairingParameters, pairing_parameters_from_proto
from zeth.api.zeth_messages_pb2 import ProofInputs
from zeth.api.snark_messages_pb2 import VerificationKey, ExtendedProof
from zeth.api import prover_pb2 # type: ignore
from zeth.api import prover_pb2_grpc # type: ignore
import grpc # type: ignore
Expand Down Expand Up @@ -93,14 +94,22 @@ def get_configuration(self) -> ProverConfiguration:

return self.prover_config

def get_verification_key(self) -> VerificationKey:
def get_zksnark_provider(self) -> IZKSnarkProvider:
"""
Get the appropriate zksnark provider, based on the server configuration.
"""
config = self.get_configuration()
return get_zksnark_provider(config.zksnark_name)

def get_verification_key(self) -> IVerificationKey:
"""
Fetch the verification key from the proving service
"""
with grpc.insecure_channel(self.endpoint) as channel:
stub = prover_pb2_grpc.ProverStub(channel) # type: ignore
verificationkey = stub.GetVerificationKey(_make_empty_message())
return verificationkey
vk_proto = stub.GetVerificationKey(_make_empty_message())
zksnark = self.get_zksnark_provider()
return zksnark.verification_key_from_proto(vk_proto)

def get_proof(
self,
Expand All @@ -111,8 +120,9 @@ def get_proof(
with grpc.insecure_channel(self.endpoint) as channel:
stub = prover_pb2_grpc.ProverStub(channel) # type: ignore
print("-------------- Get the proof --------------")
proof = stub.Prove(proof_inputs)
return proof
extproof_proto = stub.Prove(proof_inputs)
zksnark = self.get_zksnark_provider()
return zksnark.extended_proof_from_proto(extproof_proto)


def _make_empty_message() -> empty_pb2.Empty:
Expand Down
2 changes: 1 addition & 1 deletion depends/libsnark
23 changes: 14 additions & 9 deletions libzeth/circuits/circuit_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,27 @@ template<
size_t TreeDepth>
class circuit_wrapper
{
private:
std::shared_ptr<joinsplit_gadget<
libff::Fr<ppT>,
public:
using Field = libff::Fr<ppT>;
// Both `joinsplit` and `joinsplit_gadget` are already used in the
// namespace.
using joinsplit_type = joinsplit_gadget<
Field,
HashT,
HashTreeT,
NumInputs,
NumOutputs,
TreeDepth>>
joinsplit_g;

public:
using Field = libff::Fr<ppT>;
TreeDepth>;

circuit_wrapper();
circuit_wrapper(const circuit_wrapper &) = delete;
circuit_wrapper &operator=(const circuit_wrapper &) = delete;

// Generate the trusted setup
typename snarkT::keypair generate_trusted_setup() const;

// Retrieve the constraint system (intended for debugging purposes).
libsnark::protoboard<Field> get_constraint_system() const;
const libsnark::protoboard<Field> &get_constraint_system() const;

// Generate a proof and returns an extended proof
extended_proof<ppT, snarkT> prove(
Expand All @@ -56,6 +57,10 @@ class circuit_wrapper
const bits256 &h_sig_in,
const bits256 &phi_in,
const typename snarkT::proving_key &proving_key) const;

private:
libsnark::protoboard<Field> pb;
std::shared_ptr<joinsplit_type> joinsplit;
};

} // namespace libzeth
Expand Down
26 changes: 10 additions & 16 deletions libzeth/circuits/circuit_wrapper.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ circuit_wrapper<
NumOutputs,
TreeDepth>::circuit_wrapper()
{
// Joinsplit gadget internally allocates its public data first.
// TODO: joinsplit_gadget should be refactored to be properly composable.
joinsplit = std::make_shared<joinsplit_type>(pb);
const size_t num_public_elements = joinsplit->get_num_public_elements();
pb.set_input_sizes(num_public_elements);

// Generate constraints
joinsplit->generate_r1cs_constraints();
}

template<
Expand All @@ -46,11 +54,6 @@ typename snarkT::keypair circuit_wrapper<
NumOutputs,
TreeDepth>::generate_trusted_setup() const
{
libsnark::protoboard<Field> pb;
joinsplit_gadget<Field, HashT, HashTreeT, NumInputs, NumOutputs, TreeDepth>
g(pb);
g.generate_r1cs_constraints();

// Generate a verification and proving key (trusted setup) and write them
// in a file
return snarkT::generate_setup(pb);
Expand All @@ -64,7 +67,7 @@ template<
size_t NumInputs,
size_t NumOutputs,
size_t TreeDepth>
libsnark::protoboard<libff::Fr<ppT>> circuit_wrapper<
const libsnark::protoboard<libff::Fr<ppT>> &circuit_wrapper<
HashT,
HashTreeT,
ppT,
Expand All @@ -73,10 +76,6 @@ libsnark::protoboard<libff::Fr<ppT>> circuit_wrapper<
NumOutputs,
TreeDepth>::get_constraint_system() const
{
libsnark::protoboard<Field> pb;
joinsplit_gadget<Field, HashT, HashTreeT, NumInputs, NumOutputs, TreeDepth>
g(pb);
g.generate_r1cs_constraints();
return pb;
}

Expand Down Expand Up @@ -128,12 +127,7 @@ extended_proof<ppT, snarkT> circuit_wrapper<
throw std::invalid_argument("invalid joinsplit balance");
}

libsnark::protoboard<Field> pb;

joinsplit_gadget<Field, HashT, HashTreeT, NumInputs, NumOutputs, TreeDepth>
g(pb);
g.generate_r1cs_constraints();
g.generate_r1cs_witness(
joinsplit->generate_r1cs_witness(
root, inputs, outputs, vpub_in, vpub_out, h_sig_in, phi_in);

bool is_valid_witness = pb.is_satisfied();
Expand Down
Loading

0 comments on commit 024e9dc

Please sign in to comment.