Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bb): load proving key #3525

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions barretenberg/acir_tests/flows/all_cmds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ FLAGS="-c $CRS_PATH $VFLAG"

# Test we can perform the proof/verify flow.
$BIN gates $FLAGS $BFLAG > /dev/null
$BIN prove -o proof $FLAGS $BFLAG
$BIN write_vk -o vk $FLAGS $BFLAG
$BIN write_pk -o pk $FLAGS $BFLAG
$BIN prove -o proof -i pk $FLAGS $BFLAG
$BIN write_vk -o vk $FLAGS $BFLAG
$BIN verify -k vk -p proof $FLAGS

# Check supplemental functions.
Expand Down
13 changes: 13 additions & 0 deletions barretenberg/acir_tests/flows/write_pk_prove_then_verify.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/sh
set -eu

VFLAG=${VERBOSE:+-v}
BFLAG="-b ./target/acir.gz"
FLAGS="-c $CRS_PATH $VFLAG"

# Test we can perform the proof/verify flow.
# This ensures we test independent pk construction through real/garbage witness data paths.
$BIN write_pk -o pk $FLAGS $BFLAG
$BIN prove -o proof -i pk $FLAGS $BFLAG
$BIN write_vk -o vk $FLAGS $BFLAG
$BIN verify -k vk -p proof $FLAGS
16 changes: 14 additions & 2 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,27 @@ bool proveAndVerify(const std::string& bytecodePath, const std::string& witnessP
* - Filesystem: The proof is written to the path specified by outputPath
*
* @param bytecodePath Path to the file containing the serialized circuit
// * @param pkPath Path to the file containing the serialized proving key
* @param witnessPath Path to the file containing the serialized witness
* @param recursive Whether to use recursive proof generation of non-recursive
* @param outputPath Path to write the proof to
*/
void prove(const std::string& bytecodePath,
const std::string& pk_path,
const std::string& witnessPath,
bool recursive,
const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto witness = get_witness(witnessPath);
auto acir_composer = init(constraint_system);

if (!pk_path.empty()) {
info("loading proving key from: ", pk_path);
auto pk_data = from_buffer<plonk::proving_key_data>(read_file(pk_path));
acir_composer.load_proving_key(std::move(pk_data));
}

auto proof = acir_composer.create_proof(constraint_system, witness, recursive);

if (outputPath == "-") {
Expand Down Expand Up @@ -380,7 +389,7 @@ int main(int argc, char* argv[])
std::string witness_path = get_option(args, "-w", "./target/witness.gz");
std::string proof_path = get_option(args, "-p", "./proofs/proof");
std::string vk_path = get_option(args, "-k", "./target/vk");
std::string pk_path = get_option(args, "-r", "./target/pk");
std::string pk_path = get_option(args, "-i", "./target/pk");
CRS_PATH = get_option(args, "-c", "./crs");
bool recursive = flag_present(args, "-r") || flag_present(args, "--recursive");

Expand All @@ -400,7 +409,10 @@ int main(int argc, char* argv[])
}
if (command == "prove") {
std::string output_path = get_option(args, "-o", "./proofs/proof");
prove(bytecode_path, witness_path, recursive, output_path);

// Overwriting pk, as we want to use a default value of an empty string when we are not *reading* a pk
std::string pk_path = get_option(args, "-i", "");
prove(bytecode_path, pk_path, witness_path, recursive, output_path);
} else if (command == "gates") {
gateCount(bytecode_path);
} else if (command == "verify") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ std::shared_ptr<proof_system::plonk::proving_key> AcirComposer::init_proving_key
return proving_key_;
}

void AcirComposer::load_proving_key(proof_system::plonk::proving_key_data&& data)
{
proving_key_ = std::make_shared<proof_system::plonk::proving_key>(
std::move(data), srs::get_crs_factory()->get_prover_crs(data.circuit_size));
}

std::vector<uint8_t> AcirComposer::create_proof(acir_format::acir_format& constraint_system,
acir_format::WitnessVector& witness,
bool is_recursive)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class AcirComposer {
void create_circuit(acir_format::acir_format& constraint_system);

std::shared_ptr<proof_system::plonk::proving_key> init_proving_key(acir_format::acir_format& constraint_system);
void load_proving_key(proof_system::plonk::proving_key_data&& data);

std::vector<uint8_t> create_proof(acir_format::acir_format& constraint_system,
acir_format::WitnessVector& witness,
Expand Down
8 changes: 8 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ WASM_EXPORT void acir_init_proving_key(in_ptr acir_composer_ptr, uint8_t const*
acir_composer->init_proving_key(constraint_system);
}

WASM_EXPORT void acir_load_proving_key(in_ptr acir_composer_ptr, uint8_t const* pk_buf)
{
auto acir_composer = reinterpret_cast<acir_proofs::AcirComposer*>(*acir_composer_ptr);
auto pk_data = from_buffer<plonk::proving_key_data>(pk_buf);

acir_composer->load_proving_key(std::move(pk_data));
}

WASM_EXPORT void acir_create_proof(in_ptr acir_composer_ptr,
uint8_t const* acir_vec,
uint8_t const* witness_vec,
Expand Down
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ WASM_EXPORT void acir_create_circuit(in_ptr acir_composer_ptr,

WASM_EXPORT void acir_init_proving_key(in_ptr acir_composer_ptr, uint8_t const* constraint_system_buf);

WASM_EXPORT void acir_load_proving_key(in_ptr acir_composer_ptr, uint8_t const* pk_buf);

/**
* It would have been nice to just hold onto the constraint_system in the acir_composer, but we can't waste the
* memory. Being able to reuse the underlying Composer would help as well. But, given the situation, we just have
Expand Down
24 changes: 24 additions & 0 deletions barretenberg/ts/src/barretenberg_api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,18 @@ export class BarretenbergApi {
return;
}

async acirLoadProvingKey(acirComposerPtr: Ptr, pkBuf: Uint8Array): Promise<void> {
const inArgs = [acirComposerPtr, pkBuf].map(serializeBufferable);
const outTypes: OutputType[] = [];
const result = await this.wasm.callWasmExport(
'acir_load_proving_key',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return;
}

async acirCreateProof(
acirComposerPtr: Ptr,
constraintSystemBuf: Uint8Array,
Expand Down Expand Up @@ -761,6 +773,18 @@ export class BarretenbergApiSync {
return;
}

acirLoadProvingKey(acirComposerPtr: Ptr, pkBuf: Uint8Array): void {
const inArgs = [acirComposerPtr, pkBuf].map(serializeBufferable);
const outTypes: OutputType[] = [];
const result = this.wasm.callWasmExport(
'acir_load_proving_key',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return;
}

acirCreateProof(
acirComposerPtr: Ptr,
constraintSystemBuf: Uint8Array,
Expand Down
12 changes: 10 additions & 2 deletions barretenberg/ts/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ export async function proveAndVerify(bytecodePath: string, witnessPath: string,

export async function prove(
bytecodePath: string,
pkPath: string,
witnessPath: string,
crsPath: string,
isRecursive: boolean,
Expand All @@ -124,6 +125,12 @@ export async function prove(
debug(`creating proof...`);
const bytecode = getBytecode(bytecodePath);
const witness = getWitness(witnessPath);

if (pkPath != '') {
debug(`loading proving key from ${pkPath}...`);
await api.acirLoadProvingKey(acirComposer, new RawBuffer(readFileSync(pkPath)));
}

const proof = await api.acirCreateProof(acirComposer, bytecode, witness, isRecursive);
debug(`done.`);

Expand Down Expand Up @@ -316,12 +323,13 @@ program
.command('prove')
.description('Generate a proof and write it to a file.')
.option('-b, --bytecode-path <path>', 'Specify the bytecode path', './target/acir.gz')
.option('-i, --proving-key-path <path>', 'Read proving key from file', '')
.option('-w, --witness-path <path>', 'Specify the witness path', './target/witness.gz')
.option('-r, --recursive', 'prove using recursive prover', false)
.option('-o, --output-path <path>', 'Specify the proof output path', './proofs/proof')
.action(async ({ bytecodePath, witnessPath, recursive, outputPath, crsPath }) => {
.action(async ({ bytecodePath, provingKeyPath, witnessPath, recursive, outputPath, crsPath }) => {
handleGlobalOptions();
await prove(bytecodePath, witnessPath, crsPath, recursive, outputPath);
await prove(bytecodePath, provingKeyPath, witnessPath, crsPath, recursive, outputPath);
});

program
Expand Down