Skip to content

Commit

Permalink
feat: properly pipe AvmCircuitPublicInputs to witgen
Browse files Browse the repository at this point in the history
  • Loading branch information
dbanks12 committed Nov 21, 2024
1 parent 06ef61e commit 244d222
Show file tree
Hide file tree
Showing 14 changed files with 267 additions and 387 deletions.
140 changes: 9 additions & 131 deletions yarn-project/bb-prover/src/avm_proving.test.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,16 @@
import {
AvmCircuitInputs,
AvmCircuitPublicInputs,
Gas,
GlobalVariables,
type PublicFunction,
PublicKeys,
SerializableContractInstance,
VerificationKeyData,
} from '@aztec/circuits.js';
import { makeContractClassPublic, makeContractInstanceFromClassId } from '@aztec/circuits.js/testing';
import { AztecAddress } from '@aztec/foundation/aztec-address';
import { Fr, Point } from '@aztec/foundation/fields';
import { VerificationKeyData } from '@aztec/circuits.js';
import { Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';
import { openTmpStore } from '@aztec/kv-store/utils';
import { AvmSimulator, PublicSideEffectTrace, type WorldStateDB } from '@aztec/simulator';
import {
getAvmTestContractBytecode,
getAvmTestContractFunctionSelector,
initContext,
initExecutionEnvironment,
initPersistableStateManager,
resolveAvmTestContractAssertionMessage,
} from '@aztec/simulator/avm/fixtures';
import { NoopTelemetryClient } from '@aztec/telemetry-client/noop';
import { MerkleTrees } from '@aztec/world-state';
import { simulateAvmTestContractGenerateCircuitInputs } from '@aztec/simulator/public/fixtures';

import { mock } from 'jest-mock-extended';
import fs from 'node:fs/promises';
import { tmpdir } from 'node:os';
import path from 'path';

import { type BBSuccess, BB_RESULT, generateAvmProof, verifyAvmProof } from './bb/execute.js';
import { getPublicInputs } from './test/test_avm.js';
import { extractAvmVkData } from './verification_key/verification_key_data.js';

const TIMEOUT = 180_000;
const TIMESTAMP = new Fr(99833);

describe('AVM WitGen, proof generation and verification', () => {
it(
Expand All @@ -50,77 +25,8 @@ describe('AVM WitGen, proof generation and verification', () => {
);
});

/************************************************************************
* Helpers
************************************************************************/

/**
* If assertionErrString is set, we expect a (non exceptional halting) revert due to a failing assertion and
* we check that the revert reason error contains this string. However, the circuit must correctly prove the
* execution.
*/
const proveAndVerifyAvmTestContract = async (
functionName: string,
calldata: Fr[] = [],
assertionErrString?: string,
) => {
const startSideEffectCounter = 0;
const functionSelector = getAvmTestContractFunctionSelector(functionName);
calldata = [functionSelector.toField(), ...calldata];
const globals = GlobalVariables.empty();
globals.timestamp = TIMESTAMP;

const worldStateDB = mock<WorldStateDB>();
//
// Top level contract call
const bytecode = getAvmTestContractBytecode('public_dispatch');
const fnSelector = getAvmTestContractFunctionSelector('public_dispatch');
const publicFn: PublicFunction = { bytecode, selector: fnSelector };
const contractClass = makeContractClassPublic(0, publicFn);
const contractInstance = makeContractInstanceFromClassId(contractClass.id);

// The values here should match those in `avm_simulator.test.ts`
const instanceGet = new SerializableContractInstance({
version: 1,
salt: new Fr(0x123),
deployer: new AztecAddress(new Fr(0x456)),
contractClassId: new Fr(0x789),
initializationHash: new Fr(0x101112),
publicKeys: new PublicKeys(
new Point(new Fr(0x131415), new Fr(0x161718), false),
new Point(new Fr(0x192021), new Fr(0x222324), false),
new Point(new Fr(0x252627), new Fr(0x282930), false),
new Point(new Fr(0x313233), new Fr(0x343536), false),
),
}).withAddress(contractInstance.address);

worldStateDB.getContractInstance
.mockResolvedValueOnce(contractInstance)
.mockResolvedValueOnce(instanceGet) // test gets deployer
.mockResolvedValueOnce(instanceGet) // test gets class id
.mockResolvedValueOnce(instanceGet) // test gets init hash
.mockResolvedValue(contractInstance);
worldStateDB.getContractClass.mockResolvedValue(contractClass);

const storageValue = new Fr(5);
worldStateDB.storageRead.mockResolvedValue(Promise.resolve(storageValue));

const trace = new PublicSideEffectTrace(startSideEffectCounter);
const telemetry = new NoopTelemetryClient();
const merkleTrees = await (await MerkleTrees.new(openTmpStore(), telemetry)).fork();
worldStateDB.getMerkleInterface.mockReturnValue(merkleTrees);
const persistableState = initPersistableStateManager({ worldStateDB, trace, merkleTrees, doMerkleOperations: true });
const environment = initExecutionEnvironment({
functionSelector,
calldata,
globals,
address: contractInstance.address,
});
const context = initContext({ env: environment, persistableState });

worldStateDB.getBytecode.mockResolvedValue(bytecode);

const startGas = new Gas(context.machineState.gasLeft.daGas, context.machineState.gasLeft.l2Gas);
async function proveAndVerifyAvmTestContract(functionName: string, calldata: Fr[] = []) {
const avmCircuitInputs = await simulateAvmTestContractGenerateCircuitInputs(functionName, calldata);

const internalLogger = createDebugLogger('aztec:avm-proving-test');
const logger = (msg: string, _data?: any) => internalLogger.verbose(msg);
Expand All @@ -129,39 +35,11 @@ const proveAndVerifyAvmTestContract = async (
const bbPath = path.resolve('../../barretenberg/cpp/build/bin/bb');
const bbWorkingDirectory = await fs.mkdtemp(path.join(tmpdir(), 'bb-'));

// First we simulate (though it's not needed in this simple case).
const simulator = new AvmSimulator(context);
const avmResult = await simulator.execute();

if (assertionErrString == undefined) {
expect(avmResult.reverted).toBe(false);
} else {
// Explicit revert when an assertion failed.
expect(avmResult.reverted).toBe(true);
expect(avmResult.revertReason).toBeDefined();
expect(resolveAvmTestContractAssertionMessage(functionName, avmResult.revertReason!, avmResult.output)).toContain(
assertionErrString,
);
}

const pxResult = trace.toPublicFunctionCallResult(
environment,
startGas,
/*bytecode=*/ simulator.getBytecode()!,
avmResult.finalize(),
functionName,
);

const avmCircuitInputs = new AvmCircuitInputs(
functionName,
/*calldata=*/ context.environment.calldata,
/*publicInputs=*/ getPublicInputs(pxResult),
/*avmHints=*/ pxResult.avmCircuitHints,
/*output*/ AvmCircuitPublicInputs.empty(),
);

// Then we prove.
const proofRes = await generateAvmProof(bbPath, bbWorkingDirectory, avmCircuitInputs, logger);
if (proofRes.status === BB_RESULT.FAILURE) {
internalLogger.error(`Proof generation failed: ${proofRes.reason}`);
}
expect(proofRes.status).toEqual(BB_RESULT.SUCCESS);

// Then we test VK extraction and serialization.
Expand All @@ -173,4 +51,4 @@ const proveAndVerifyAvmTestContract = async (
const rawVkPath = path.join(succeededRes.vkPath!, 'vk');
const verificationRes = await verifyAvmProof(bbPath, succeededRes.proofPath!, rawVkPath, logger);
expect(verificationRes.status).toBe(BB_RESULT.SUCCESS);
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
} from '@aztec/foundation/serialize';
import { type FieldsOf } from '@aztec/foundation/types';

import { inspect } from 'util';

import {
MAX_ENQUEUED_CALLS_PER_CALL,
MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL,
Expand Down Expand Up @@ -324,4 +326,66 @@ export class PublicCircuitPublicInputs {
reader.readField(),
);
}

[inspect.custom]() {
return `PublicCircuitPublicInputs {
callContext: ${inspect(this.callContext)},
argsHash: ${inspect(this.argsHash)},
returnsHash: ${inspect(this.returnsHash)},
noteHashReadRequests: [${this.noteHashReadRequests
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
nullifierReadRequests: [${this.nullifierReadRequests
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
nullifierNonExistentReadRequests: [${this.nullifierNonExistentReadRequests
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
l1ToL2MsgReadRequests: [${this.l1ToL2MsgReadRequests
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
contractStorageUpdateRequests: [${this.contractStorageUpdateRequests
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
contractStorageReads: [${this.contractStorageReads
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
publicCallRequests: [${this.publicCallRequests
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
noteHashes: [${this.noteHashes
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
nullifiers: [${this.nullifiers
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
l2ToL1Msgs: [${this.l2ToL1Msgs
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
startSideEffectCounter: ${inspect(this.startSideEffectCounter)},
endSideEffectCounter: ${inspect(this.endSideEffectCounter)},
startSideEffectCounter: ${inspect(this.startSideEffectCounter)},
unencryptedLogsHashes: [${this.unencryptedLogsHashes
.filter(x => !x.isEmpty())
.map(h => inspect(h))
.join(', ')}]},
historicalHeader: ${inspect(this.historicalHeader)},
globalVariables: ${inspect(this.globalVariables)},
proverAddress: ${inspect(this.proverAddress)},
revertCode: ${inspect(this.revertCode)},
startGasLeft: ${inspect(this.startGasLeft)},
endGasLeft: ${inspect(this.endGasLeft)},
transactionFee: ${inspect(this.transactionFee)},
}`;
}
}
Loading

0 comments on commit 244d222

Please sign in to comment.