Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes in preparation for proof input hasing #344

Merged
merged 14 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions client/test_commands/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def wait_for_tx_update_mk_tree(
def get_mix_parameters_components(
zeth_client: MixerClient,
prover_client: ProverClient,
zksnark: IZKSnarkProvider,
mk_tree: MerkleTree,
sender_ownership_keypair: OwnershipKeyPair,
inputs: List[Tuple[int, ZethNote]],
Expand All @@ -79,8 +78,7 @@ def get_mix_parameters_components(
compute_h_sig_cb)
prover_inputs, signing_keypair = zeth_client.create_prover_inputs(
mix_call_desc)
ext_proof_proto = prover_client.get_proof(prover_inputs)
ext_proof = zksnark.extended_proof_from_proto(ext_proof_proto)
ext_proof = prover_client.get_proof(prover_inputs)
return (
prover_inputs.js_outputs[0],
prover_inputs.js_outputs[1],
Expand Down Expand Up @@ -262,7 +260,6 @@ def compute_h_sig_attack_nf(
get_mix_parameters_components(
zeth_client,
prover_client,
zksnark,
mk_tree,
keystore["Charlie"].ownership_keypair(), # sender
[input1, input2],
Expand Down Expand Up @@ -370,7 +367,6 @@ def charlie_corrupt_bob_deposit(
get_mix_parameters_components(
zeth_client,
prover_client,
zksnark,
mk_tree,
keystore["Bob"].ownership_keypair(),
[input1, input2],
Expand Down
44 changes: 44 additions & 0 deletions client/tests/test_input_hasher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2015-2021 Clearmatics Technologies Ltd
#
# SPDX-License-Identifier: LGPL-3.0+

from zeth.core.mimc import MiMC7, MiMC31
from zeth.core.input_hasher import InputHasher
from unittest import TestCase

DUMMY_INPUT_VALUES = [-1, 0, 1]


class TestInputHasher(TestCase):

def test_input_hasher_simple(self) -> None:
# Some very simple cases
mimc = MiMC7()
input_hasher = InputHasher(mimc, 7)
self.assertEqual(mimc.hash_int(7, 0), input_hasher.hash([]))
self.assertEqual(
mimc.hash_int(mimc.hash_int(7, 1), 1), input_hasher.hash([1]))
self.assertEqual(
mimc.hash_int(
mimc.hash_int(
mimc.hash_int(7, 1), 2),
2),
input_hasher.hash([1, 2]))

def test_input_hasher_mimc7(self) -> None:
mimc = MiMC7()
input_hasher = InputHasher(mimc)
values = [x % mimc.prime for x in DUMMY_INPUT_VALUES]
# pylint:disable=line-too-long
expect = 5568471640435576440988459485125198359192118312228711462978763973844457667180 # noqa
# pylint:enable=line-too-long
self.assertEqual(expect, input_hasher.hash(values))

def test_input_hasher_mimc31(self) -> None:
mimc = MiMC31()
input_hasher = InputHasher(mimc)
values = [x % mimc.prime for x in DUMMY_INPUT_VALUES]
# pylint: disable=line-too-long
expect = 1029772481427643815119825324071277815354972734622711297984795198139876181749 # noqa
# pylint: enable=line-too-long
self.assertEqual(expect, input_hasher.hash(values))
8 changes: 1 addition & 7 deletions client/zeth/cli/zeth_get_verification_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from click import command, option, Context, pass_context

from zeth.core.zksnark import get_zksnark_provider
from zeth.cli.utils import create_prover_client
import json
from typing import Optional
Expand All @@ -21,12 +20,7 @@ def get_verification_key(ctx: Context, vk_out: Optional[str]) -> None:
# Get the VK (proto object)
client_ctx = ctx.obj
prover_client = create_prover_client(client_ctx)
vk_proto = prover_client.get_verification_key()

# Get a zksnark provider and convert the VK to json
zksnark_name = prover_client.get_configuration().zksnark_name
zksnark = get_zksnark_provider(zksnark_name)
vk = zksnark.verification_key_from_proto(vk_proto)
vk = prover_client.get_verification_key()
vk_json = vk.to_json_dict()

# Write the json to stdout or a file
Expand Down
36 changes: 36 additions & 0 deletions client/zeth/core/input_hasher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2015-2021 Clearmatics Technologies Ltd
#
# SPDX-License-Identifier: LGPL-3.0+

from zeth.core.mimc import MiMCBase
from typing import List


# Default seed, generated as:
# zeth.core.mimc._keccak_256(
# zeth.core.mimc._str_to_bytes("clearmatics_hash_seed"))
DEFAULT_IV_UINT256 = \
13196537064117388418196223856311987714388543839552400408340921397545324034315


class InputHasher:
"""
Note that this is currently experimental code. Hash a series of field
elements via the Merkle-Damgard construction on a MiMC compression
function. Note that since this function only accepts whole numbers of
scalar field elements, there is no ambiguity w.r.t to padding and we could
technically omit the finalization step. It has been kept for now, to allow
time for further consideration, and in case the form of the hasher changes
(e.g. in case we want to be able to hash arbitrary bit strings in the
future).
"""
def __init__(self, compression_fn: MiMCBase, iv: int = DEFAULT_IV_UINT256):
assert compression_fn.prime < (2 << 256)
self._compression_fn = compression_fn
self._iv = iv % compression_fn.prime

def hash(self, values: List[int]) -> int:
current = self._iv
for m in values:
current = self._compression_fn.hash_int(current, m)
return self._compression_fn.hash_int(current, len(values))
11 changes: 9 additions & 2 deletions client/zeth/core/mimc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,15 @@ def hash(self, left: bytes, right: bytes) -> bytes:
"""
x = int.from_bytes(left, byteorder='big') % self.prime
y = int.from_bytes(right, byteorder='big') % self.prime
result = (self.encrypt(x, y) + x + y) % self.prime
return result.to_bytes(32, byteorder='big')
return self.hash_int(x, y).to_bytes(32, byteorder='big')

def hash_int(self, x: int, y: int) -> int:
"""
Similar to hash, but use field elements directly.
"""
assert x < self.prime
assert y < self.prime
return (self.encrypt(x, y) + x + y) % self.prime

@abstractmethod
def mimc_round(self, message: int, key: int, rc: int) -> int:
Expand Down
12 changes: 4 additions & 8 deletions client/zeth/core/mixer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,12 @@ def deploy(
Deploy Zeth contracts.
"""
prover_config = prover_client.get_configuration()
zksnark = get_zksnark_provider(prover_config.zksnark_name)
vk_proto = prover_client.get_verification_key()
pp = prover_config.pairing_parameters
vk = zksnark.verification_key_from_proto(vk_proto)
vk = prover_client.get_verification_key()
deploy_gas = deploy_gas or constants.DEPLOYMENT_GAS_WEI

contracts_dir = get_contracts_dir()
zksnark = get_zksnark_provider(prover_config.zksnark_name)
pp = prover_config.pairing_parameters
mixer_name = zksnark.get_contract_name(pp)
mixer_src = os.path.join(contracts_dir, mixer_name + ".sol")

Expand Down Expand Up @@ -581,11 +580,8 @@ def create_mix_parameters_and_signing_key(
prover_inputs, signing_keypair = MixerClient.create_prover_inputs(
mix_call_desc)

zksnark = get_zksnark_provider(self.prover_config.zksnark_name)

# Query the prover_server for the related proof
ext_proof_proto = prover_client.get_proof(prover_inputs)
ext_proof = zksnark.extended_proof_from_proto(ext_proof_proto)
ext_proof = prover_client.get_proof(prover_inputs)

# Create the final MixParameters object
mix_params = self.create_mix_parameters_from_proof(
Expand Down
24 changes: 17 additions & 7 deletions client/zeth/core/prover_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# SPDX-License-Identifier: LGPL-3.0+

from __future__ import annotations
from .pairing import PairingParameters, pairing_parameters_from_proto
from zeth.core.zksnark import IZKSnarkProvider, get_zksnark_provider, \
IVerificationKey, ExtendedProof
from zeth.core.pairing import PairingParameters, pairing_parameters_from_proto
from zeth.api.zeth_messages_pb2 import ProofInputs
from zeth.api.snark_messages_pb2 import VerificationKey, ExtendedProof
from zeth.api import prover_pb2 # type: ignore
from zeth.api import prover_pb2_grpc # type: ignore
import grpc # type: ignore
Expand Down Expand Up @@ -93,14 +94,22 @@ def get_configuration(self) -> ProverConfiguration:

return self.prover_config

def get_verification_key(self) -> VerificationKey:
def get_zksnark_provider(self) -> IZKSnarkProvider:
"""
Get the appropriate zksnark provider, based on the server configuration.
"""
config = self.get_configuration()
return get_zksnark_provider(config.zksnark_name)

def get_verification_key(self) -> IVerificationKey:
"""
Fetch the verification key from the proving service
"""
with grpc.insecure_channel(self.endpoint) as channel:
stub = prover_pb2_grpc.ProverStub(channel) # type: ignore
verificationkey = stub.GetVerificationKey(_make_empty_message())
return verificationkey
vk_proto = stub.GetVerificationKey(_make_empty_message())
zksnark = self.get_zksnark_provider()
return zksnark.verification_key_from_proto(vk_proto)

def get_proof(
self,
Expand All @@ -111,8 +120,9 @@ def get_proof(
with grpc.insecure_channel(self.endpoint) as channel:
stub = prover_pb2_grpc.ProverStub(channel) # type: ignore
print("-------------- Get the proof --------------")
proof = stub.Prove(proof_inputs)
return proof
extproof_proto = stub.Prove(proof_inputs)
zksnark = self.get_zksnark_provider()
return zksnark.extended_proof_from_proto(extproof_proto)


def _make_empty_message() -> empty_pb2.Empty:
Expand Down
2 changes: 1 addition & 1 deletion depends/libsnark
23 changes: 14 additions & 9 deletions libzeth/circuits/circuit_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,27 @@ template<
size_t TreeDepth>
class circuit_wrapper
{
private:
std::shared_ptr<joinsplit_gadget<
libff::Fr<ppT>,
public:
using Field = libff::Fr<ppT>;
// Both `joinsplit` and `joinsplit_gadget` are already used in the
// namespace.
using joinsplit_type = joinsplit_gadget<
Field,
HashT,
HashTreeT,
NumInputs,
NumOutputs,
TreeDepth>>
joinsplit_g;

public:
using Field = libff::Fr<ppT>;
TreeDepth>;
AntoineRondelet marked this conversation as resolved.
Show resolved Hide resolved

circuit_wrapper();
circuit_wrapper(const circuit_wrapper &) = delete;
circuit_wrapper &operator=(const circuit_wrapper &) = delete;

// Generate the trusted setup
typename snarkT::keypair generate_trusted_setup() const;

// Retrieve the constraint system (intended for debugging purposes).
libsnark::protoboard<Field> get_constraint_system() const;
const libsnark::protoboard<Field> &get_constraint_system() const;

// Generate a proof and returns an extended proof
extended_proof<ppT, snarkT> prove(
Expand All @@ -56,6 +57,10 @@ class circuit_wrapper
const bits256 &h_sig_in,
const bits256 &phi_in,
const typename snarkT::proving_key &proving_key) const;

private:
libsnark::protoboard<Field> pb;
std::shared_ptr<joinsplit_type> joinsplit;
};

} // namespace libzeth
Expand Down
26 changes: 10 additions & 16 deletions libzeth/circuits/circuit_wrapper.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ circuit_wrapper<
NumOutputs,
TreeDepth>::circuit_wrapper()
{
// Joinsplit gadget internally allocates its public data first.
// TODO: joinsplit_gadget should be refactored to be properly composable.
AntoineRondelet marked this conversation as resolved.
Show resolved Hide resolved
joinsplit = std::make_shared<joinsplit_type>(pb);
const size_t num_public_elements = joinsplit->get_num_public_elements();
pb.set_input_sizes(num_public_elements);

// Generate constraints
joinsplit->generate_r1cs_constraints();
}

template<
Expand All @@ -46,11 +54,6 @@ typename snarkT::keypair circuit_wrapper<
NumOutputs,
TreeDepth>::generate_trusted_setup() const
{
libsnark::protoboard<Field> pb;
joinsplit_gadget<Field, HashT, HashTreeT, NumInputs, NumOutputs, TreeDepth>
g(pb);
g.generate_r1cs_constraints();

// Generate a verification and proving key (trusted setup) and write them
// in a file
return snarkT::generate_setup(pb);
Expand All @@ -64,7 +67,7 @@ template<
size_t NumInputs,
size_t NumOutputs,
size_t TreeDepth>
libsnark::protoboard<libff::Fr<ppT>> circuit_wrapper<
const libsnark::protoboard<libff::Fr<ppT>> &circuit_wrapper<
HashT,
HashTreeT,
ppT,
Expand All @@ -73,10 +76,6 @@ libsnark::protoboard<libff::Fr<ppT>> circuit_wrapper<
NumOutputs,
TreeDepth>::get_constraint_system() const
{
libsnark::protoboard<Field> pb;
joinsplit_gadget<Field, HashT, HashTreeT, NumInputs, NumOutputs, TreeDepth>
g(pb);
g.generate_r1cs_constraints();
return pb;
}

Expand Down Expand Up @@ -128,12 +127,7 @@ extended_proof<ppT, snarkT> circuit_wrapper<
throw std::invalid_argument("invalid joinsplit balance");
}

libsnark::protoboard<Field> pb;

joinsplit_gadget<Field, HashT, HashTreeT, NumInputs, NumOutputs, TreeDepth>
g(pb);
g.generate_r1cs_constraints();
g.generate_r1cs_witness(
joinsplit->generate_r1cs_witness(
root, inputs, outputs, vpub_in, vpub_out, h_sig_in, phi_in);

bool is_valid_witness = pb.is_satisfied();
Expand Down
Loading