Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bb): add ability to write pk to file or stdout #3335

Merged
merged 7 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions barretenberg/acir_tests/flows/all_cmds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ FLAGS="-c $CRS_PATH $VFLAG"
$BIN gates $FLAGS $BFLAG > /dev/null
$BIN prove -o proof $FLAGS $BFLAG
$BIN write_vk -o vk $FLAGS $BFLAG
$BIN write_pk -o pk $FLAGS $BFLAG
$BIN verify -k vk -p proof $FLAGS

# Check supplemental functions.
Expand Down
67 changes: 44 additions & 23 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "config.hpp"
#include "get_bytecode.hpp"
#include "get_crs.hpp"
Expand Down Expand Up @@ -183,7 +184,7 @@ bool verify(const std::string& proof_path, bool recursive, const std::string& vk
* @param bytecodePath Path to the file containing the serialized circuit
* @param outputPath Path to write the verification key to
*/
void writeVk(const std::string& bytecodePath, const std::string& outputPath)
void write_vk(const std::string& bytecodePath, const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto acir_composer = init(constraint_system);
Expand All @@ -199,6 +200,22 @@ void writeVk(const std::string& bytecodePath, const std::string& outputPath)
}
}

void write_pk(const std::string& bytecodePath, const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal but vscode has a refactoring thing to turn auto into the explicit type, can make the C++ easier to follow (compared to typescript, C++ tooling is rougher, so prefer more explicit types)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the file uses the auto convention, so i am reluctant to make this change. Going forward I will use explicit types

auto acir_composer = init(constraint_system);
auto pk = acir_composer.init_proving_key(constraint_system);
auto serialized_pk = to_buffer(*pk);

if (outputPath == "-") {
writeRawBytesToStdout(serialized_pk);
vinfo("pk written to stdout");
} else {
write_file(outputPath, serialized_pk);
vinfo("pk written to: ", outputPath);
}
}

/**
* @brief Writes a Solidity verifier contract for an ACIR circuit to a file
*
Expand Down Expand Up @@ -253,7 +270,7 @@ void contract(const std::string& output_path, const std::string& vk_path)
* @param vk_path Path to the file containing the serialized verification key
* @param output_path Path to write the proof to
*/
void proofAsFields(const std::string& proof_path, std::string const& vk_path, const std::string& output_path)
void proof_as_fields(const std::string& proof_path, std::string const& vk_path, const std::string& output_path)
{
auto acir_composer = init();
auto vk_data = from_buffer<plonk::verification_key_data>(read_file(vk_path));
Expand Down Expand Up @@ -282,7 +299,7 @@ void proofAsFields(const std::string& proof_path, std::string const& vk_path, co
* @param vk_path Path to the file containing the serialized verification key
* @param output_path Path to write the verification key to
*/
void vkAsFields(const std::string& vk_path, const std::string& output_path)
void vk_as_fields(const std::string& vk_path, const std::string& output_path)
{
auto acir_composer = init();
auto vk_data = from_buffer<plonk::verification_key_data>(read_file(vk_path));
Expand Down Expand Up @@ -311,7 +328,7 @@ void vkAsFields(const std::string& vk_path, const std::string& output_path)
*
* @param output_path Path to write the information to
*/
void acvmInfo(const std::string& output_path)
void acvm_info(const std::string& output_path)
{

const char* jsonData = R"({
Expand All @@ -335,12 +352,12 @@ void acvmInfo(const std::string& output_path)
}
}

bool flagPresent(std::vector<std::string>& args, const std::string& flag)
bool flag_present(std::vector<std::string>& args, const std::string& flag)
{
return std::find(args.begin(), args.end(), flag) != args.end();
}

std::string getOption(std::vector<std::string>& args, const std::string& option, const std::string& defaultValue)
std::string get_option(std::vector<std::string>& args, const std::string& option, const std::string& defaultValue)
{
auto itr = std::find(args.begin(), args.end(), option);
return (itr != args.end() && std::next(itr) != args.end()) ? *(std::next(itr)) : defaultValue;
Expand All @@ -350,7 +367,7 @@ int main(int argc, char* argv[])
{
try {
std::vector<std::string> args(argv + 1, argv + argc);
verbose = flagPresent(args, "-v") || flagPresent(args, "--verbose");
verbose = flag_present(args, "-v") || flag_present(args, "--verbose");

if (args.empty()) {
std::cerr << "No command provided.\n";
Expand All @@ -359,46 +376,50 @@ int main(int argc, char* argv[])

std::string command = args[0];

std::string bytecode_path = getOption(args, "-b", "./target/acir.gz");
std::string witness_path = getOption(args, "-w", "./target/witness.gz");
std::string proof_path = getOption(args, "-p", "./proofs/proof");
std::string vk_path = getOption(args, "-k", "./target/vk");
CRS_PATH = getOption(args, "-c", "./crs");
bool recursive = flagPresent(args, "-r") || flagPresent(args, "--recursive");
std::string bytecode_path = get_option(args, "-b", "./target/acir.gz");
std::string witness_path = get_option(args, "-w", "./target/witness.gz");
std::string proof_path = get_option(args, "-p", "./proofs/proof");
std::string vk_path = get_option(args, "-k", "./target/vk");
std::string pk_path = get_option(args, "-r", "./target/pk");
CRS_PATH = get_option(args, "-c", "./crs");
bool recursive = flag_present(args, "-r") || flag_present(args, "--recursive");

// Skip CRS initialization for any command which doesn't require the CRS.
if (command == "--version") {
writeStringToStdout(BB_VERSION);
return 0;
}
if (command == "info") {
std::string output_path = getOption(args, "-o", "info.json");
acvmInfo(output_path);
std::string output_path = get_option(args, "-o", "info.json");
acvm_info(output_path);
return 0;
}

if (command == "prove_and_verify") {
return proveAndVerify(bytecode_path, witness_path, recursive) ? 0 : 1;
}
if (command == "prove") {
std::string output_path = getOption(args, "-o", "./proofs/proof");
std::string output_path = get_option(args, "-o", "./proofs/proof");
prove(bytecode_path, witness_path, recursive, output_path);
} else if (command == "gates") {
gateCount(bytecode_path);
} else if (command == "verify") {
return verify(proof_path, recursive, vk_path) ? 0 : 1;
} else if (command == "contract") {
std::string output_path = getOption(args, "-o", "./target/contract.sol");
std::string output_path = get_option(args, "-o", "./target/contract.sol");
contract(output_path, vk_path);
} else if (command == "write_vk") {
std::string output_path = getOption(args, "-o", "./target/vk");
writeVk(bytecode_path, output_path);
std::string output_path = get_option(args, "-o", "./target/vk");
write_vk(bytecode_path, output_path);
} else if (command == "write_pk") {
std::string output_path = get_option(args, "-o", "./target/pk");
write_pk(bytecode_path, output_path);
} else if (command == "proof_as_fields") {
std::string output_path = getOption(args, "-o", proof_path + "_fields.json");
proofAsFields(proof_path, vk_path, output_path);
std::string output_path = get_option(args, "-o", proof_path + "_fields.json");
proof_as_fields(proof_path, vk_path, output_path);
} else if (command == "vk_as_fields") {
std::string output_path = getOption(args, "-o", vk_path + "_fields.json");
vkAsFields(vk_path, output_path);
std::string output_path = get_option(args, "-o", vk_path + "_fields.json");
vk_as_fields(vk_path, output_path);
} else {
std::cerr << "Unknown command: " << command << "\n";
return 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/dsl/acir_format/recursion_constraint.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "barretenberg/plonk/proof_system/verification_key/sol_gen.hpp"
#include "barretenberg/plonk/proof_system/verification_key/verification_key.hpp"
Expand All @@ -30,12 +31,14 @@ void AcirComposer::create_circuit(acir_format::acir_format& constraint_system)
vinfo("gates: ", builder_.get_total_circuit_size());
}

void AcirComposer::init_proving_key(acir_format::acir_format& constraint_system)
std::shared_ptr<proof_system::plonk::proving_key> AcirComposer::init_proving_key(
acir_format::acir_format& constraint_system)
{
create_circuit(constraint_system);
acir_format::Composer composer;
vinfo("computing proving key...");
proving_key_ = composer.compute_proving_key(builder_);
return proving_key_;
}

std::vector<uint8_t> AcirComposer::create_proof(acir_format::acir_format& constraint_system,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AcirComposer {

void create_circuit(acir_format::acir_format& constraint_system);

void init_proving_key(acir_format::acir_format& constraint_system);
std::shared_ptr<proof_system::plonk::proving_key> init_proving_key(acir_format::acir_format& constraint_system);

std::vector<uint8_t> create_proof(acir_format::acir_format& constraint_system,
acir_format::WitnessVector& witness,
Expand Down
10 changes: 10 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "barretenberg/common/serialize.hpp"
#include "barretenberg/common/slab_allocator.hpp"
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "barretenberg/plonk/proof_system/verification_key/verification_key.hpp"
#include "barretenberg/srs/global_crs.hpp"
#include <cstdint>
Expand Down Expand Up @@ -73,6 +74,15 @@ WASM_EXPORT void acir_get_verification_key(in_ptr acir_composer_ptr, uint8_t** o
*out = to_heap_buffer(to_buffer(*vk));
}

WASM_EXPORT void acir_get_proving_key(in_ptr acir_composer_ptr, uint8_t const* acir_vec, uint8_t** out)
{
auto acir_composer = reinterpret_cast<acir_proofs::AcirComposer*>(*acir_composer_ptr);
auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
auto pk = acir_composer->init_proving_key(constraint_system);
// We flatten to a vector<uint8_t> first, as that's how we treat it on the calling side.
*out = to_heap_buffer(to_buffer(*pk));
}

WASM_EXPORT void acir_verify_proof(in_ptr acir_composer_ptr,
uint8_t const* proof_buf,
bool const* is_recursive,
Expand Down
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ WASM_EXPORT void acir_init_verification_key(in_ptr acir_composer_ptr);

WASM_EXPORT void acir_get_verification_key(in_ptr acir_composer_ptr, uint8_t** out);

WASM_EXPORT void acir_get_proving_key(in_ptr acir_composer_ptr, uint8_t const* acir_vec, uint8_t** out);

WASM_EXPORT void acir_verify_proof(in_ptr acir_composer_ptr,
uint8_t const* proof_buf,
bool const* is_recursive,
Expand Down
12 changes: 12 additions & 0 deletions barretenberg/ts/src/barretenberg_api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,18 @@ export class BarretenbergApi {
return out[0];
}

async acirGetProvingKey(acirComposerPtr: Ptr, constraintSystemBuf: Uint8Array): Promise<Uint8Array> {
const inArgs = [acirComposerPtr, constraintSystemBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BufferDeserializer()];
const result = await this.wasm.callWasmExport(
'acir_get_proving_key',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}

async acirVerifyProof(acirComposerPtr: Ptr, proofBuf: Uint8Array, isRecursive: boolean): Promise<boolean> {
const inArgs = [acirComposerPtr, proofBuf, isRecursive].map(serializeBufferable);
const outTypes: OutputType[] = [BoolDeserializer()];
Expand Down
29 changes: 29 additions & 0 deletions barretenberg/ts/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,25 @@ export async function writeVk(bytecodePath: string, crsPath: string, outputPath:
}
}

export async function writePk(bytecodePath: string, crsPath: string, outputPath: string) {
const { api, acirComposer } = await init(bytecodePath, crsPath);
try {
debug('initing proving key...');
const bytecode = getBytecode(bytecodePath);
const pk = await api.acirGetProvingKey(acirComposer, bytecode);

if (outputPath === '-') {
process.stdout.write(pk);
debug(`pk written to stdout`);
} else {
writeFileSync(outputPath, pk);
debug(`pk written to: ${outputPath}`);
}
} finally {
await api.destroy();
}
}

export async function proofAsFields(proofPath: string, vkPath: string, outputPath: string) {
const { api, acirComposer } = await initLite();

Expand Down Expand Up @@ -347,6 +366,16 @@ program
await writeVk(bytecodePath, crsPath, outputPath);
});

program
.command('write_pk')
.description('Output proving key.')
.option('-b, --bytecode-path <path>', 'Specify the bytecode path', './target/acir.gz')
.requiredOption('-o, --output-path <path>', 'Specify the path to write the key')
.action(async ({ bytecodePath, outputPath, crsPath }) => {
handleGlobalOptions();
await writePk(bytecodePath, crsPath, outputPath);
});

program
.command('proof_as_fields')
.description('Return the proof as fields elements')
Expand Down