Skip to content

Commit

Permalink
feat(avm): class id + contract address
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed Oct 1, 2024
1 parent 769d979 commit 32228e8
Show file tree
Hide file tree
Showing 18 changed files with 395 additions and 164 deletions.
12 changes: 5 additions & 7 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,18 +939,17 @@ void vk_as_fields(const std::string& vk_path, const std::string& output_path)
* @param hints_path Path to the file containing the serialised avm circuit hints
* @param output_path Path (directory) to write the output proof and verification keys
*/
void avm_prove(const std::filesystem::path& bytecode_path,
const std::filesystem::path& calldata_path,
void avm_prove(const std::filesystem::path& calldata_path,
const std::filesystem::path& public_inputs_path,
const std::filesystem::path& hints_path,
const std::filesystem::path& output_path)
{
std::vector<uint8_t> const bytecode = read_file(bytecode_path);
std::vector<fr> const calldata = many_from_buffer<fr>(read_file(calldata_path));
std::vector<fr> const public_inputs_vec = many_from_buffer<fr>(read_file(public_inputs_path));
auto const avm_hints = bb::avm_trace::ExecutionHints::from(read_file(hints_path));

vinfo("bytecode size: ", bytecode.size());
// Using [0] is fine now for the top-level call, but we might need to index by address in future
vinfo("bytecode size: ", avm_hints.all_contract_bytecode[0].bytecode.size());
vinfo("calldata size: ", calldata.size());
vinfo("public_inputs size: ", public_inputs_vec.size());
vinfo("hints.storage_value_hints size: ", avm_hints.storage_value_hints.size());
Expand All @@ -965,7 +964,7 @@ void avm_prove(const std::filesystem::path& bytecode_path,

// Prove execution and return vk
auto const [verification_key, proof] =
AVM_TRACK_TIME_V("prove/all", avm_trace::Execution::prove(bytecode, calldata, public_inputs_vec, avm_hints));
AVM_TRACK_TIME_V("prove/all", avm_trace::Execution::prove(calldata, public_inputs_vec, avm_hints));

std::vector<fr> vk_as_fields = verification_key.to_field_elements();

Expand Down Expand Up @@ -1520,7 +1519,6 @@ int main(int argc, char* argv[])
write_recursion_inputs_honk<UltraFlavor>(bytecode_path, witness_path, output_path);
#ifndef DISABLE_AZTEC_VM
} else if (command == "avm_prove") {
std::filesystem::path avm_bytecode_path = get_option(args, "--avm-bytecode", "./target/avm_bytecode.bin");
std::filesystem::path avm_calldata_path = get_option(args, "--avm-calldata", "./target/avm_calldata.bin");
std::filesystem::path avm_public_inputs_path =
get_option(args, "--avm-public-inputs", "./target/avm_public_inputs.bin");
Expand All @@ -1529,7 +1527,7 @@ int main(int argc, char* argv[])
std::filesystem::path output_path = get_option(args, "-o", "./proofs");
extern std::filesystem::path avm_dump_trace_path;
avm_dump_trace_path = get_option(args, "--avm-dump-trace", "");
avm_prove(avm_bytecode_path, avm_calldata_path, avm_public_inputs_path, avm_hints_path, output_path);
avm_prove(avm_calldata_path, avm_public_inputs_path, avm_hints_path, output_path);
} else if (command == "avm_verify") {
return avm_verify(proof_path, vk_path) ? 0 : 1;
#endif
Expand Down
105 changes: 64 additions & 41 deletions barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace bb::avm_trace {
using poseidon2 = crypto::Poseidon2<crypto::Poseidon2Bn254ScalarFieldParams>;
AvmBytecodeTraceBuilder::AvmBytecodeTraceBuilder(const std::vector<std::vector<uint8_t>>& all_contracts_bytecode)
AvmBytecodeTraceBuilder::AvmBytecodeTraceBuilder(const std::vector<AvmContractBytecode>& all_contracts_bytecode)
: all_contracts_bytecode(all_contracts_bytecode)
{}

Expand All @@ -31,7 +31,7 @@ void AvmBytecodeTraceBuilder::build_bytecode_columns()
// This is the main loop that will generate the bytecode trace
for (auto& contract_bytecode : all_contracts_bytecode) {
FF running_hash = FF::zero();
auto packed_bytecode = pack_bytecode(contract_bytecode);
auto packed_bytecode = pack_bytecode(contract_bytecode.bytecode);
// This size is already based on the number of fields
for (size_t i = 0; i < packed_bytecode.size(); ++i) {
bytecode_trace.push_back(BytecodeTraceEntry{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ class AvmBytecodeTraceBuilder {
// Derive the contract address
FF contract_address{};
};
AvmBytecodeTraceBuilder() = default;
// These interfaces will change when we start feeding in more inputs and hints
AvmBytecodeTraceBuilder(const std::vector<std::vector<uint8_t>>& all_contracts_bytecode);
AvmBytecodeTraceBuilder(const std::vector<AvmContractBytecode>& all_contracts_bytecode);

size_t size() const { return bytecode_trace.size(); }
void reset();
Expand All @@ -38,7 +37,7 @@ class AvmBytecodeTraceBuilder {

std::vector<BytecodeTraceEntry> bytecode_trace;
// The first element is the main top-level contract, the rest are external calls
std::vector<std::vector<uint8_t>> all_contracts_bytecode;
std::vector<AvmContractBytecode> all_contracts_bytecode;
// TODO: Come back to this
// VmPublicInputs public_inputs;
// ExecutionHints hints;
Expand Down
58 changes: 29 additions & 29 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,13 @@ void show_trace_info(const auto& trace)
} // namespace

// Needed for dependency injection in tests.
Execution::TraceBuilderConstructor Execution::trace_builder_constructor =
[](VmPublicInputs public_inputs,
ExecutionHints execution_hints,
uint32_t side_effect_counter,
std::vector<FF> calldata,
std::vector<std::vector<uint8_t>> all_contract_bytecode) {
return AvmTraceBuilder(std::move(public_inputs),
std::move(execution_hints),
side_effect_counter,
std::move(calldata),
all_contract_bytecode);
};
Execution::TraceBuilderConstructor Execution::trace_builder_constructor = [](VmPublicInputs public_inputs,
ExecutionHints execution_hints,
uint32_t side_effect_counter,
std::vector<FF> calldata) {
return AvmTraceBuilder(
std::move(public_inputs), std::move(execution_hints), side_effect_counter, std::move(calldata));
};

/**
* @brief Temporary routine to generate default public inputs (gas values) until we get
Expand All @@ -180,8 +175,7 @@ std::vector<FF> Execution::getDefaultPublicInputs()
* @throws runtime_error exception when the bytecode is invalid.
* @return The verifier key and zk proof of the execution.
*/
std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata,
std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<FF> const& calldata,
std::vector<FF> const& public_inputs_vec,
ExecutionHints const& execution_hints)
{
Expand All @@ -190,8 +184,8 @@ std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<u
}

std::vector<FF> returndata;
std::vector<Row> trace = AVM_TRACK_TIME_V(
"prove/gen_trace", gen_trace(bytecode, calldata, public_inputs_vec, returndata, execution_hints));
std::vector<Row> trace =
AVM_TRACK_TIME_V("prove/gen_trace", gen_trace(calldata, public_inputs_vec, returndata, execution_hints));
if (!avm_dump_trace_path.empty()) {
info("Dumping trace as CSV to: " + avm_dump_trace_path.string());
dump_trace_as_csv(trace, avm_dump_trace_path);
Expand Down Expand Up @@ -268,31 +262,37 @@ bool Execution::verify(AvmFlavor::VerificationKey vk, HonkProof const& proof)
* @param public_inputs expressed as a vector of finite field elements.
* @return The trace as a vector of Row.
*/
std::vector<Row> Execution::gen_trace(std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata,
std::vector<Row> Execution::gen_trace(std::vector<FF> const& calldata,
std::vector<FF> const& public_inputs_vec,
std::vector<FF>& returndata,
ExecutionHints const& execution_hints)

{
std::vector<Instruction> instructions = Deserialization::parse(bytecode);
vinfo("Deserialized " + std::to_string(instructions.size()) + " instructions");

vinfo("------- GENERATING TRACE -------");
// TODO(https://github.com/AztecProtocol/aztec-packages/issues/6718): construction of the public input columns
// should be done in the kernel - this is stubbed and underconstrained
VmPublicInputs public_inputs = avm_trace::convert_public_inputs(public_inputs_vec);
uint32_t start_side_effect_counter =
!public_inputs_vec.empty() ? static_cast<uint32_t>(public_inputs_vec[PCPI_START_SIDE_EFFECT_COUNTER_OFFSET])
: 0;
std::vector<std::vector<uint8_t>> all_contract_bytecode;
all_contract_bytecode.reserve(execution_hints.externalcall_hints.size() + 1);
// Start with the main, top-level contract bytecode
all_contract_bytecode.push_back(bytecode);
for (const auto& externalcall_hint : execution_hints.externalcall_hints) {
all_contract_bytecode.emplace_back(externalcall_hint.bytecode);
}
AvmTraceBuilder trace_builder = Execution::trace_builder_constructor(
public_inputs, execution_hints, start_side_effect_counter, calldata, all_contract_bytecode);

// This address is the top-level contract address
vinfo("Length of all contract bytecode: ", execution_hints.all_contract_bytecode.size());

FF contract_address = std::get<0>(public_inputs)[ADDRESS_SELECTOR];
vinfo("Top level contract address: ", contract_address);
// We use it to extract the bytecode we need to execute
std::vector<uint8_t> bytecode =
std::find_if(execution_hints.all_contract_bytecode.begin(),
execution_hints.all_contract_bytecode.end(),
[&](auto& contract) { return contract.contract_instance.address == contract_address; })
->bytecode;

std::vector<Instruction> instructions = Deserialization::parse(bytecode);
vinfo("Deserialized " + std::to_string(instructions.size()) + " instructions");
AvmTraceBuilder trace_builder =
Execution::trace_builder_constructor(public_inputs, execution_hints, start_side_effect_counter, calldata);

// Copied version of pc maintained in trace builder. The value of pc is evolving based
// on opcode logic and therefore is not maintained here. However, the next opcode in the execution
Expand Down
14 changes: 5 additions & 9 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ namespace bb::avm_trace {
class Execution {
public:
static constexpr size_t SRS_SIZE = 1 << 22;
using TraceBuilderConstructor =
std::function<AvmTraceBuilder(VmPublicInputs public_inputs,
ExecutionHints execution_hints,
uint32_t side_effect_counter,
std::vector<FF> calldata,
const std::vector<std::vector<uint8_t>>& all_contract_bytecode)>;
using TraceBuilderConstructor = std::function<AvmTraceBuilder(VmPublicInputs public_inputs,
ExecutionHints execution_hints,
uint32_t side_effect_counter,
std::vector<FF> calldata)>;

Execution() = default;

Expand All @@ -30,8 +28,7 @@ class Execution {

// Bytecode is currently the bytecode of the top-level function call
// Eventually this will be the bytecode of the dispatch function of top-level contract
static std::vector<Row> gen_trace(std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata,
static std::vector<Row> gen_trace(std::vector<FF> const& calldata,
std::vector<FF> const& public_inputs,
std::vector<FF>& returndata,
ExecutionHints const& execution_hints);
Expand All @@ -43,7 +40,6 @@ class Execution {
}

static std::tuple<AvmFlavor::VerificationKey, bb::HonkProof> prove(
std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata = {},
std::vector<FF> const& public_inputs_vec = getDefaultPublicInputs(),
ExecutionHints const& execution_hints = {});
Expand Down
67 changes: 60 additions & 7 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution_hints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct ExternalCallHint {
uint32_t l2_gas_used;
uint32_t da_gas_used;
FF end_side_effect_counter;
std::vector<uint8_t> bytecode;
FF contract_address;
};

// Add support for deserialization of ExternalCallHint. This is implicitly used by serialize::read
Expand All @@ -25,12 +25,26 @@ inline void read(uint8_t const*& it, ExternalCallHint& hint)
read(it, hint.l2_gas_used);
read(it, hint.da_gas_used);
read(it, hint.end_side_effect_counter);
read(it, hint.bytecode);
read(it, hint.contract_address);
}

struct ContractClassIdHint {
FF artifact_hash;
FF private_fn_root;
FF public_bytecode_commitment;
};

inline void read(uint8_t const*& it, ContractClassIdHint& preimage)
{
using serialize::read;
read(it, preimage.artifact_hash);
read(it, preimage.private_fn_root);
read(it, preimage.public_bytecode_commitment);
}

struct ContractInstanceHint {
FF address;
FF instance_found_in_address;
bool exists; // Useful for membership checks
FF salt;
FF deployer_addr;
FF contract_class_id;
Expand All @@ -43,21 +57,49 @@ inline void read(uint8_t const*& it, ContractInstanceHint& hint)
{
using serialize::read;
read(it, hint.address);
read(it, hint.instance_found_in_address);
read(it, hint.exists);
read(it, hint.salt);
read(it, hint.deployer_addr);
read(it, hint.contract_class_id);
read(it, hint.initialisation_hash);
read(it, hint.public_key_hash);
}

struct AvmContractBytecode {
std::vector<uint8_t> bytecode;
ContractInstanceHint contract_instance;
ContractClassIdHint contract_class_id_preimage;

AvmContractBytecode() = default;
AvmContractBytecode(std::vector<uint8_t> bytecode,
ContractInstanceHint contract_instance,
ContractClassIdHint contract_class_id_preimage)
: bytecode(std::move(bytecode))
, contract_instance(contract_instance)
, contract_class_id_preimage(contract_class_id_preimage)
{}
AvmContractBytecode(std::vector<uint8_t> bytecode)
: bytecode(std::move(bytecode))
{}
};

inline void read(uint8_t const*& it, AvmContractBytecode& bytecode)
{
using serialize::read;
read(it, bytecode.bytecode);
read(it, bytecode.contract_instance);
read(it, bytecode.contract_class_id_preimage);
}

struct ExecutionHints {
std::vector<std::pair<FF, FF>> storage_value_hints;
std::vector<std::pair<FF, FF>> note_hash_exists_hints;
std::vector<std::pair<FF, FF>> nullifier_exists_hints;
std::vector<std::pair<FF, FF>> l1_to_l2_message_exists_hints;
std::vector<ExternalCallHint> externalcall_hints;
std::map<FF, ContractInstanceHint> contract_instance_hints;
// We could make this address-indexed
std::vector<AvmContractBytecode> all_contract_bytecode;

ExecutionHints() = default;

Expand Down Expand Up @@ -92,6 +134,11 @@ struct ExecutionHints {
this->contract_instance_hints = std::move(contract_instance_hints);
return *this;
}
ExecutionHints& with_avm_contract_bytecode(std::vector<AvmContractBytecode> all_contract_bytecode)
{
this->all_contract_bytecode = std::move(all_contract_bytecode);
return *this;
}

static void push_vec_into_map(std::unordered_map<uint32_t, FF>& into_map,
const std::vector<std::pair<FF, FF>>& from_pair_vec)
Expand Down Expand Up @@ -144,14 +191,18 @@ struct ExecutionHints {
contract_instance_hints[instance.address] = instance;
}

std::vector<AvmContractBytecode> all_contract_bytecode;
read(it, all_contract_bytecode);

if (it != data.data() + data.size()) {
throw_or_abort("Failed to deserialize ExecutionHints: only read" + std::to_string(it - data.data()) +
throw_or_abort("Failed to deserialize ExecutionHints: only read " + std::to_string(it - data.data()) +
" bytes out of " + std::to_string(data.size()) + " bytes");
}

return { std::move(storage_value_hints), std::move(note_hash_exists_hints),
std::move(nullifier_exists_hints), std::move(l1_to_l2_message_exists_hints),
std::move(externalcall_hints), std::move(contract_instance_hints) };
std::move(externalcall_hints), std::move(contract_instance_hints),
std::move(all_contract_bytecode) };
}

private:
Expand All @@ -160,13 +211,15 @@ struct ExecutionHints {
std::vector<std::pair<FF, FF>> nullifier_exists_hints,
std::vector<std::pair<FF, FF>> l1_to_l2_message_exists_hints,
std::vector<ExternalCallHint> externalcall_hints,
std::map<FF, ContractInstanceHint> contract_instance_hints)
std::map<FF, ContractInstanceHint> contract_instance_hints,
std::vector<AvmContractBytecode> all_contract_bytecode)
: storage_value_hints(std::move(storage_value_hints))
, note_hash_exists_hints(std::move(note_hash_exists_hints))
, nullifier_exists_hints(std::move(nullifier_exists_hints))
, l1_to_l2_message_exists_hints(std::move(l1_to_l2_message_exists_hints))
, externalcall_hints(std::move(externalcall_hints))
, contract_instance_hints(std::move(contract_instance_hints))
, all_contract_bytecode(std::move(all_contract_bytecode))
{}
};

Expand Down
7 changes: 3 additions & 4 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,13 @@ void AvmTraceBuilder::finalise_mem_trace_lookup_counts()
AvmTraceBuilder::AvmTraceBuilder(VmPublicInputs public_inputs,
ExecutionHints execution_hints_,
uint32_t side_effect_counter,
std::vector<FF> calldata,
const std::vector<std::vector<uint8_t>>& all_contract_bytecode)
std::vector<FF> calldata)
// NOTE: we initialise the environment builder here as it requires public inputs
: calldata(std::move(calldata))
, side_effect_counter(side_effect_counter)
, execution_hints(std::move(execution_hints_))
, kernel_trace_builder(side_effect_counter, public_inputs, execution_hints)
, bytecode_trace_builder(all_contract_bytecode)
, bytecode_trace_builder(execution_hints_.all_contract_bytecode)
{
// TODO: think about cast
gas_trace_builder.set_initial_gas(
Expand Down Expand Up @@ -2545,7 +2544,7 @@ void AvmTraceBuilder::op_get_contract_instance(uint8_t indirect, uint32_t addres
ContractInstanceHint contract_instance = execution_hints.contract_instance_hints.at(read_address.val);

// NOTE: we don't write the first entry (the contract instance's address/key) to memory
std::vector<FF> contract_instance_vec = { contract_instance.instance_found_in_address,
std::vector<FF> contract_instance_vec = { contract_instance.exists ? FF::one() : FF::zero(),
contract_instance.salt,
contract_instance.deployer_addr,
contract_instance.contract_class_id,
Expand Down
Loading

0 comments on commit 32228e8

Please sign in to comment.