Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] refactor: stop calling public kernels #9981

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions yarn-project/aztec-node/src/aztec-node/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ import {
} from '@aztec/p2p';
import { ProtocolContractAddress } from '@aztec/protocol-contracts';
import { GlobalVariableBuilder, SequencerClient } from '@aztec/sequencer-client';
import { PublicProcessorFactory, WASMSimulator, createSimulationProvider } from '@aztec/simulator';
import { PublicProcessorFactory, createSimulationProvider } from '@aztec/simulator';
import { type TelemetryClient } from '@aztec/telemetry-client';
import { NoopTelemetryClient } from '@aztec/telemetry-client/noop';
import { createValidatorClient } from '@aztec/validator-client';
Expand Down Expand Up @@ -733,11 +733,7 @@ export class AztecNodeService implements AztecNode {
feeRecipient,
);
const prevHeader = (await this.blockSource.getBlock(-1))?.header;
const publicProcessorFactory = new PublicProcessorFactory(
this.contractDataSource,
new WASMSimulator(),
this.telemetry,
);
const publicProcessorFactory = new PublicProcessorFactory(this.contractDataSource, this.telemetry);

const fork = await this.worldStateSynchronizer.fork();

Expand Down
30 changes: 14 additions & 16 deletions yarn-project/circuit-types/src/tx/tx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import {
ClientIvcProof,
ContractClassRegisteredEvent,
PrivateKernelTailCircuitPublicInputs,
type PublicKernelCircuitPublicInputs,
type PrivateToPublicAccumulatedData,
type ScopedLogHash,
} from '@aztec/circuits.js';
import { type Buffer32 } from '@aztec/foundation/buffer';
import { arraySerializedSizeOfNonEmpty } from '@aztec/foundation/collection';
Expand Down Expand Up @@ -344,29 +345,26 @@ export class Tx extends Gossipable {
* @param logHashes the individual log hashes we want to keep
* @param out the output to put passing logs in, to keep this function abstract
*/
public filterRevertedLogs(kernelOutput: PublicKernelCircuitPublicInputs) {
public filterRevertedLogs(
privateNonRevertible: PrivateToPublicAccumulatedData,
unencryptedLogsHashes: ScopedLogHash[],
) {
this.encryptedLogs = this.encryptedLogs.filterScoped(
kernelOutput.endNonRevertibleData.encryptedLogsHashes,
privateNonRevertible.encryptedLogsHashes,
EncryptedTxL2Logs.empty(),
);

this.unencryptedLogs = this.unencryptedLogs.filterScoped(
kernelOutput.endNonRevertibleData.unencryptedLogsHashes,
UnencryptedTxL2Logs.empty(),
);

this.noteEncryptedLogs = this.noteEncryptedLogs.filter(
kernelOutput.endNonRevertibleData.noteEncryptedLogsHashes,
privateNonRevertible.noteEncryptedLogsHashes,
EncryptedNoteTxL2Logs.empty(),
);

// See comment in enqueued_calls_processor.ts -> tx.filterRevertedLogs()
if (this.data.forPublic) {
this.contractClassLogs = this.contractClassLogs.filterScoped(
this.data.forPublic?.nonRevertibleAccumulatedData.contractClassLogsHashes,
ContractClassTxL2Logs.empty(),
);
}
this.contractClassLogs = this.contractClassLogs.filterScoped(
privateNonRevertible.contractClassLogsHashes,
ContractClassTxL2Logs.empty(),
);

this.unencryptedLogs = this.unencryptedLogs.filterScoped(unencryptedLogsHashes, UnencryptedTxL2Logs.empty());
}
}

Expand Down
3 changes: 0 additions & 3 deletions yarn-project/prover-client/src/mocks/test_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import {
PublicExecutionResultBuilder,
type PublicExecutor,
PublicProcessor,
RealPublicKernelCircuitSimulator,
type SimulationProvider,
WASMSimulator,
type WorldStateDB,
Expand Down Expand Up @@ -69,7 +68,6 @@ export class TestContext {

const publicExecutor = mock<PublicExecutor>();
const worldStateDB = mock<WorldStateDB>();
const publicKernel = new RealPublicKernelCircuitSimulator(new WASMSimulator());
const telemetry = new NoopTelemetryClient();

// Separated dbs for public processor and prover - see public_processor for context
Expand All @@ -89,7 +87,6 @@ export class TestContext {
const processor = PublicProcessor.create(
publicDb,
publicExecutor,
publicKernel,
globalVariables,
Header.empty(),
worldStateDB,
Expand Down
8 changes: 2 additions & 6 deletions yarn-project/prover-node/src/prover-node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export class ProverNode implements ClaimsMonitorHandler, EpochMonitorHandler, Pr
private readonly contractDataSource: ContractDataSource,
private readonly worldState: WorldStateSynchronizer,
private readonly coordination: ProverCoordination & Maybe<Service>,
private readonly simulator: SimulationProvider,
private readonly _simulator: SimulationProvider,
private readonly quoteProvider: QuoteProvider,
private readonly quoteSigner: QuoteSigner,
private readonly claimsMonitor: ClaimsMonitor,
Expand Down Expand Up @@ -243,11 +243,7 @@ export class ProverNode implements ClaimsMonitorHandler, EpochMonitorHandler, Pr
const proverDb = await this.worldState.fork(fromBlock - 1);

// Create a processor using the forked world state
const publicProcessorFactory = new PublicProcessorFactory(
this.contractDataSource,
this.simulator,
this.telemetryClient,
);
const publicProcessorFactory = new PublicProcessorFactory(this.contractDataSource, this.telemetryClient);

const cleanUp = async () => {
await publicDb.close();
Expand Down
4 changes: 2 additions & 2 deletions yarn-project/sequencer-client/src/client/sequencer-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ export class SequencerClient {
contractDataSource: ContractDataSource,
l2BlockSource: L2BlockSource,
l1ToL2MessageSource: L1ToL2MessageSource,
simulationProvider: SimulationProvider,
_simulationProvider: SimulationProvider,
telemetryClient: TelemetryClient,
) {
const publisher = new L1Publisher(config, telemetryClient);
const globalsBuilder = new GlobalVariableBuilder(config);

const publicProcessorFactory = new PublicProcessorFactory(contractDataSource, simulationProvider, telemetryClient);
const publicProcessorFactory = new PublicProcessorFactory(contractDataSource, telemetryClient);

const rollup = publisher.getRollupContract();
const [l1GenesisTime, slotDuration] = await Promise.all([
Expand Down
55 changes: 35 additions & 20 deletions yarn-project/simulator/src/avm/journal/journal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ export class AvmPersistableStateManager {
/** Interface to perform merkle tree operations */
public merkleTrees: MerkleTreeWriteOperations;

/** Make sure a forked state is never merged twice. */
private alreadyMergedIntoParent = false;

constructor(
/** Reference to node storage */
private readonly worldStateDB: WorldStateDB,
Expand Down Expand Up @@ -79,16 +82,46 @@ export class AvmPersistableStateManager {
/**
* Create a new state manager forked from this one
*/
public fork(incrementSideEffectCounter: boolean = false) {
public fork() {
return new AvmPersistableStateManager(
this.worldStateDB,
this.trace.fork(incrementSideEffectCounter),
this.trace.fork(),
this.publicStorage.fork(),
this.nullifiers.fork(),
this.doMerkleOperations,
);
}

/**
* Accept forked world state modifications & traced side effects / hints
*/
public merge(forkedState: AvmPersistableStateManager) {
this._merge(forkedState, /*reverted=*/ false);
}

/**
* Reject forked world state modifications & traced side effects, keep traced hints
*/
public reject(forkedState: AvmPersistableStateManager) {
this._merge(forkedState, /*reverted=*/ true);
}

/**
* Commit cached storage writes to the DB.
* Keeps public storage up to date from tx to tx within a block.
*/
public async commitStorageWritesToDB() {
await this.publicStorage.commitToDB();
}

private _merge(forkedState: AvmPersistableStateManager, reverted: boolean) {
// sanity check to avoid merging the same forked trace twice
assert(!this.alreadyMergedIntoParent, 'Cannot merge forked state that has already been merged into its parent!');
this.publicStorage.acceptAndMerge(forkedState.publicStorage);
this.nullifiers.acceptAndMerge(forkedState.nullifiers);
this.trace.merge(forkedState.trace, reverted);
}

/**
* Write to public storage, journal/trace the write.
*
Expand Down Expand Up @@ -427,24 +460,6 @@ export class AvmPersistableStateManager {
}
}

/**
* Accept forked world state modifications & traced side effects / hints
*/
public mergeForkedState(forkedState: AvmPersistableStateManager) {
this.publicStorage.acceptAndMerge(forkedState.publicStorage);
this.nullifiers.acceptAndMerge(forkedState.nullifiers);
this.trace.merge(forkedState.trace, /*reverted=*/ false);
}

/**
* Reject forked world state modifications & traced side effects, keep traced hints
*/
public rejectForkedState(forkedState: AvmPersistableStateManager) {
this.publicStorage.acceptAndMerge(forkedState.publicStorage);
this.nullifiers.acceptAndMerge(forkedState.nullifiers);
this.trace.merge(forkedState.trace, /*reverted=*/ true);
}

/**
* Get a contract's bytecode from the contracts DB, also trace the contract class and instance
*/
Expand Down
6 changes: 4 additions & 2 deletions yarn-project/simulator/src/avm/opcodes/external_calls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ abstract class ExternalCall extends Instruction {
// Refund unused gas
context.machineState.refundGas(gasLeftToGas(nestedContext.machineState));

// Accept the nested call's state and trace the nested call
// Merge nested call's state and trace based on whether it succeeded.
if (success) {
context.persistableState.mergeForkedState(nestedContext.persistableState);
context.persistableState.merge(nestedContext.persistableState);
} else {
context.persistableState.reject(nestedContext.persistableState);
}
await context.persistableState.traceNestedCall(
/*nestedState=*/ nestedContext.persistableState,
Expand Down
7 changes: 2 additions & 5 deletions yarn-project/simulator/src/public/dual_side_effect_trace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ export class DualSideEffectTrace implements PublicSideEffectTraceInterface {
public readonly enqueuedCallTrace: PublicEnqueuedCallSideEffectTrace,
) {}

public fork(incrementSideEffectCounter: boolean = false) {
return new DualSideEffectTrace(
this.innerCallTrace.fork(incrementSideEffectCounter),
this.enqueuedCallTrace.fork(incrementSideEffectCounter),
);
public fork() {
return new DualSideEffectTrace(this.innerCallTrace.fork(), this.enqueuedCallTrace.fork());
}

public merge(nestedTrace: this, reverted: boolean = false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ import { makeTuple } from '@aztec/foundation/array';
import { padArrayEnd } from '@aztec/foundation/collection';
import { Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';
import { type Tuple } from '@aztec/foundation/serialize';

import { assert } from 'console';

Expand Down Expand Up @@ -140,6 +139,9 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI

private avmCircuitHints: AvmExecutionHints;

/** Make sure a forked trace is never merged twice. */
private alreadyMergedIntoParent = false;

constructor(
/** The counter of this trace's first side effect. */
public readonly startSideEffectCounter: number = 0,
Expand All @@ -154,9 +156,9 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
this.avmCircuitHints = AvmExecutionHints.empty();
}

public fork(incrementSideEffectCounter: boolean = false) {
public fork() {
return new PublicEnqueuedCallSideEffectTrace(
incrementSideEffectCounter ? this.sideEffectCounter + 1 : this.sideEffectCounter,
this.sideEffectCounter,
new PublicValidationRequestArrayLengths(
this.previousValidationRequestArrayLengths.noteHashReadRequests + this.noteHashReadRequests.length,
this.previousValidationRequestArrayLengths.nullifierReadRequests + this.nullifierReadRequests.length,
Expand All @@ -178,23 +180,27 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
);
}

public merge(nestedTrace: this, reverted: boolean = false) {
// TODO(dbanks12): accept & merge nested trace's hints!
this.sideEffectCounter = nestedTrace.sideEffectCounter;
this.enqueuedCalls.push(...nestedTrace.enqueuedCalls);
public merge(forkedTrace: this, reverted: boolean = false) {
// sanity check to avoid merging the same forked trace twice
assert(!this.alreadyMergedIntoParent, 'Cannot merge a forked trace that has already been merged into its parent!');
forkedTrace.alreadyMergedIntoParent = true;

// TODO(dbanks12): accept & merge forked trace's hints!
this.sideEffectCounter = forkedTrace.sideEffectCounter;
this.enqueuedCalls.push(...forkedTrace.enqueuedCalls);

if (!reverted) {
this.publicDataReads.push(...nestedTrace.publicDataReads);
this.publicDataWrites.push(...nestedTrace.publicDataWrites);
this.noteHashReadRequests.push(...nestedTrace.noteHashReadRequests);
this.noteHashes.push(...nestedTrace.noteHashes);
this.nullifierReadRequests.push(...nestedTrace.nullifierReadRequests);
this.nullifierNonExistentReadRequests.push(...nestedTrace.nullifierNonExistentReadRequests);
this.nullifiers.push(...nestedTrace.nullifiers);
this.l1ToL2MsgReadRequests.push(...nestedTrace.l1ToL2MsgReadRequests);
this.l2ToL1Messages.push(...nestedTrace.l2ToL1Messages);
this.unencryptedLogs.push(...nestedTrace.unencryptedLogs);
this.unencryptedLogsHashes.push(...nestedTrace.unencryptedLogsHashes);
this.publicDataReads.push(...forkedTrace.publicDataReads);
this.publicDataWrites.push(...forkedTrace.publicDataWrites);
this.noteHashReadRequests.push(...forkedTrace.noteHashReadRequests);
this.noteHashes.push(...forkedTrace.noteHashes);
this.nullifierReadRequests.push(...forkedTrace.nullifierReadRequests);
this.nullifierNonExistentReadRequests.push(...forkedTrace.nullifierNonExistentReadRequests);
this.nullifiers.push(...forkedTrace.nullifiers);
this.l1ToL2MsgReadRequests.push(...forkedTrace.l1ToL2MsgReadRequests);
this.l2ToL1Messages.push(...forkedTrace.l2ToL1Messages);
this.unencryptedLogs.push(...forkedTrace.unencryptedLogs);
this.unencryptedLogsHashes.push(...forkedTrace.unencryptedLogsHashes);
}
}

Expand Down Expand Up @@ -454,7 +460,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
}

// This tracing function gets called everytime we start simulation/execution.
// This happens both when starting a new top-level trace and the start of every nested trace
// This happens both when starting a new top-level trace and the start of every forked trace
// We use this to collect the AvmContractBytecodeHints
public traceGetBytecode(
contractAddress: AztecAddress,
Expand Down Expand Up @@ -493,7 +499,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
*/
public traceNestedCall(
/** The trace of the nested call. */
nestedCallTrace: this,
_nestedCallTrace: this,
/** The execution environment of the nested call. */
nestedEnvironment: AvmExecutionEnvironment,
/** How much gas was available for this public execution. */
Expand Down Expand Up @@ -631,9 +637,9 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
/** How much gas was available for this public execution. */
gasLimits: GasSettings,
/** Call requests for setup phase. */
publicSetupCallRequests: Tuple<PublicCallRequest, typeof MAX_ENQUEUED_CALLS_PER_TX>,
publicSetupCallRequests: PublicCallRequest[],
/** Call requests for app logic phase. */
publicAppLogicCallRequests: Tuple<PublicCallRequest, typeof MAX_ENQUEUED_CALLS_PER_TX>,
publicAppLogicCallRequests: PublicCallRequest[],
/** Call request for teardown phase. */
publicTeardownCallRequest: PublicCallRequest,
/** End tree snapshots. */
Expand All @@ -653,8 +659,8 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
startTreeSnapshots,
startGasUsed,
gasLimits,
publicSetupCallRequests,
publicAppLogicCallRequests,
padArrayEnd(publicSetupCallRequests, PublicCallRequest.empty(), MAX_ENQUEUED_CALLS_PER_TX),
padArrayEnd(publicAppLogicCallRequests, PublicCallRequest.empty(), MAX_ENQUEUED_CALLS_PER_TX),
publicTeardownCallRequest,
/*previousNonRevertibleAccumulatedDataArrayLengths=*/ PrivateToAvmAccumulatedDataArrayLengths.empty(),
/*previousRevertibleAccumulatedDataArrayLengths=*/ PrivateToAvmAccumulatedDataArrayLengths.empty(),
Expand Down
4 changes: 3 additions & 1 deletion yarn-project/simulator/src/public/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export class PublicExecutor {
*/
public async simulate(
stateManager: AvmPersistableStateManager,
executionRequest: PublicExecutionRequest, // TODO(dbanks12): CallRequest instead?
executionRequest: PublicExecutionRequest,
globalVariables: GlobalVariables,
allocatedGas: Gas,
transactionFee: Fr = Fr.ZERO,
Expand Down Expand Up @@ -105,6 +105,8 @@ export class PublicExecutor {
* @param transactionFee - Fee offered for this TX.
* @param startSideEffectCounter - The start counter to initialize the side effect trace with.
* @returns The result of execution including side effect vectors.
* FIXME: this function is only used by the TXE. Ideally we would not support this as an external interface.
* Avoid using this interface as it it shouldn't really exist in the first place.
*/
public async simulateIsolatedEnqueuedCall(
executionRequest: PublicExecutionRequest,
Expand Down
Loading
Loading