diff --git a/yarn-project/circuit-types/src/tx/validator/empty_validator.ts b/yarn-project/circuit-types/src/tx/validator/empty_validator.ts index b96bd9b2381..7ad6f80dcf1 100644 --- a/yarn-project/circuit-types/src/tx/validator/empty_validator.ts +++ b/yarn-project/circuit-types/src/tx/validator/empty_validator.ts @@ -4,4 +4,8 @@ export class EmptyTxValidator implements TxValidator public validateTxs(txs: T[]): Promise<[validTxs: T[], invalidTxs: T[]]> { return Promise.resolve([txs, []]); } + + public validateTx(_tx: T): Promise { + return Promise.resolve(true); + } } diff --git a/yarn-project/circuit-types/src/tx/validator/tx_validator.ts b/yarn-project/circuit-types/src/tx/validator/tx_validator.ts index 336d36fe648..6669b56055a 100644 --- a/yarn-project/circuit-types/src/tx/validator/tx_validator.ts +++ b/yarn-project/circuit-types/src/tx/validator/tx_validator.ts @@ -4,5 +4,6 @@ import { type Tx } from '../tx.js'; export type AnyTx = Tx | ProcessedTx; export interface TxValidator { + validateTx(tx: T): Promise; validateTxs(txs: T[]): Promise<[validTxs: T[], invalidTxs: T[]]>; } diff --git a/yarn-project/p2p/src/client/p2p_client.ts b/yarn-project/p2p/src/client/p2p_client.ts index 9fec7856fae..63c5b453f0a 100644 --- a/yarn-project/p2p/src/client/p2p_client.ts +++ b/yarn-project/p2p/src/client/p2p_client.ts @@ -319,8 +319,6 @@ export class P2PClient implements P2P { this.log.debug(`Requested ${txHash.toString()} from peer | success = ${!!tx}`); if (tx) { - // TODO(https://github.com/AztecProtocol/aztec-packages/issues/8485): This check is not sufficient to validate the transaction. We need to validate the entire proof. - // TODO(https://github.com/AztecProtocol/aztec-packages/issues/8483): alter peer scoring system for a validator that returns an invalid transcation await this.txPool.addTxs([tx]); } diff --git a/yarn-project/p2p/src/mocks/index.ts b/yarn-project/p2p/src/mocks/index.ts index 61e16486d33..2610d691203 100644 --- a/yarn-project/p2p/src/mocks/index.ts +++ b/yarn-project/p2p/src/mocks/index.ts @@ -1,3 +1,5 @@ +import { type ClientProtocolCircuitVerifier, type Tx } from '@aztec/circuit-types'; + import { noise } from '@chainsafe/libp2p-noise'; import { yamux } from '@chainsafe/libp2p-yamux'; import { bootstrap } from '@libp2p/bootstrap'; @@ -10,8 +12,10 @@ import { pingHandler, statusHandler } from '../service/reqresp/handlers.js'; import { PING_PROTOCOL, type ReqRespSubProtocolHandlers, + type ReqRespSubProtocolValidators, STATUS_PROTOCOL, TX_REQ_PROTOCOL, + noopValidator, } from '../service/reqresp/interface.js'; import { ReqResp } from '../service/reqresp/reqresp.js'; @@ -57,6 +61,14 @@ export const MOCK_SUB_PROTOCOL_HANDLERS: ReqRespSubProtocolHandlers = { [TX_REQ_PROTOCOL]: (_msg: any) => Promise.resolve(Uint8Array.from(Buffer.from('tx'))), }; +// By default, all requests are valid +// If you want to test an invalid response, you can override the validator +export const MOCK_SUB_PROTOCOL_VALIDATORS: ReqRespSubProtocolValidators = { + [PING_PROTOCOL]: noopValidator, + [STATUS_PROTOCOL]: noopValidator, + [TX_REQ_PROTOCOL]: noopValidator, +}; + /** * @param numberOfNodes - the number of nodes to create * @returns An array of the created nodes @@ -65,10 +77,13 @@ export const createNodes = async (peerManager: PeerManager, numberOfNodes: numbe return await Promise.all(Array.from({ length: numberOfNodes }, () => createReqResp(peerManager))); }; -// TODO: think about where else this can go -export const startNodes = async (nodes: ReqRespNode[], subProtocolHandlers = MOCK_SUB_PROTOCOL_HANDLERS) => { +export const startNodes = async ( + nodes: ReqRespNode[], + subProtocolHandlers = MOCK_SUB_PROTOCOL_HANDLERS, + subProtocolValidators = MOCK_SUB_PROTOCOL_VALIDATORS, +) => { for (const node of nodes) { - await node.req.start(subProtocolHandlers); + await node.req.start(subProtocolHandlers, subProtocolValidators); } }; @@ -105,3 +120,15 @@ export const connectToPeers = async (nodes: ReqRespNode[]): Promise => { } } }; + +// Mock circuit verifier for testing - reimplementation from bb to avoid dependency +export class AlwaysTrueCircuitVerifier implements ClientProtocolCircuitVerifier { + verifyProof(_tx: Tx): Promise { + return Promise.resolve(true); + } +} +export class AlwaysFalseCircuitVerifier implements ClientProtocolCircuitVerifier { + verifyProof(_tx: Tx): Promise { + return Promise.resolve(false); + } +} diff --git a/yarn-project/p2p/src/service/libp2p_service.ts b/yarn-project/p2p/src/service/libp2p_service.ts index aa450fb1b44..2c6778eb8fb 100644 --- a/yarn-project/p2p/src/service/libp2p_service.ts +++ b/yarn-project/p2p/src/service/libp2p_service.ts @@ -47,13 +47,13 @@ import { PeerErrorSeverity } from './peer_scoring.js'; import { pingHandler, statusHandler } from './reqresp/handlers.js'; import { DEFAULT_SUB_PROTOCOL_HANDLERS, + DEFAULT_SUB_PROTOCOL_VALIDATORS, PING_PROTOCOL, type ReqRespSubProtocol, type ReqRespSubProtocolHandlers, STATUS_PROTOCOL, type SubProtocolMap, TX_REQ_PROTOCOL, - subProtocolMap, } from './reqresp/interface.js'; import { ReqResp } from './reqresp/reqresp.js'; import type { P2PService, PeerDiscoveryService } from './service.js'; @@ -162,7 +162,13 @@ export class LibP2PService implements P2PService { this.peerManager.heartbeat(); }, this.config.peerCheckIntervalMS); this.discoveryRunningPromise.start(); - await this.reqresp.start(this.requestResponseHandlers); + + // Define the sub protocol validators - This is done within this start() method to gain a callback to the existing validateTx function + const reqrespSubProtocolValidators = { + ...DEFAULT_SUB_PROTOCOL_VALIDATORS, + [TX_REQ_PROTOCOL]: this.validateRequestedTx.bind(this), + }; + await this.reqresp.start(this.requestResponseHandlers, reqrespSubProtocolValidators); } /** @@ -302,18 +308,11 @@ export class LibP2PService implements P2PService { * @param request The request type to send * @returns */ - async sendRequest( + sendRequest( protocol: SubProtocol, request: InstanceType, ): Promise | undefined> { - const pair = subProtocolMap[protocol]; - - const res = await this.reqresp.sendRequest(protocol, request.toBuffer()); - if (!res) { - return undefined; - } - - return pair.response.fromBuffer(res!); + return this.reqresp.sendRequest(protocol, request); } /** @@ -418,19 +417,53 @@ export class LibP2PService implements P2PService { const txHashString = txHash.toString(); this.logger.verbose(`Received tx ${txHashString} from external peer.`); - const isValidTx = await this.validateTx(tx, peerId); + const isValidTx = await this.validatePropagatedTx(tx, peerId); if (isValidTx) { await this.txPool.addTxs([tx]); } } - private async validateTx(tx: Tx, peerId: PeerId): Promise { + /** + * Validate a tx that has been requested from a peer. + * + * The core component of this validator is that the tx hash MUST match the requested tx hash, + * In order to perform this check, the tx proof must be verified. + * + * Note: This function is called from within `ReqResp.sendRequest` as part of the + * TX_REQ_PROTOCOL subprotocol validation. + * + * @param requestedTxHash - The hash of the tx that was requested. + * @param responseTx - The tx that was received as a response to the request. + * @param peerId - The peer ID of the peer that sent the tx. + * @returns True if the tx is valid, false otherwise. + */ + private async validateRequestedTx(requestedTxHash: TxHash, responseTx: Tx, peerId: PeerId): Promise { + const proofValidator = new TxProofValidator(this.proofVerifier); + const validProof = await proofValidator.validateTx(responseTx); + + // If the node returns the wrong data, we penalize it + if (!requestedTxHash.equals(responseTx.getTxHash())) { + // Returning the wrong data is a low tolerance error + this.peerManager.penalizePeer(peerId, PeerErrorSeverity.MidToleranceError); + return false; + } + + if (!validProof) { + // If the proof is invalid, but the txHash is correct, then this is an active attack and we severly punish + this.peerManager.penalizePeer(peerId, PeerErrorSeverity.LowToleranceError); + return false; + } + + return true; + } + + private async validatePropagatedTx(tx: Tx, peerId: PeerId): Promise { const blockNumber = (await this.l2BlockSource.getBlockNumber()) + 1; // basic data validation const dataValidator = new DataTxValidator(); - const [_, dataInvalidTxs] = await dataValidator.validateTxs([tx]); - if (dataInvalidTxs.length > 0) { + const validData = await dataValidator.validateTx(tx); + if (!validData) { // penalize this.node.services.pubsub.score.markInvalidMessageDelivery(peerId.toString(), Tx.p2pTopic); return false; @@ -438,8 +471,8 @@ export class LibP2PService implements P2PService { // metadata validation const metadataValidator = new MetadataTxValidator(new Fr(this.config.l1ChainId), new Fr(blockNumber)); - const [__, metaInvalidTxs] = await metadataValidator.validateTxs([tx]); - if (metaInvalidTxs.length > 0) { + const validMetadata = await metadataValidator.validateTx(tx); + if (!validMetadata) { // penalize this.node.services.pubsub.score.markInvalidMessageDelivery(peerId.toString(), Tx.p2pTopic); return false; @@ -453,8 +486,8 @@ export class LibP2PService implements P2PService { return index; }, }); - const [___, doubleSpendInvalidTxs] = await doubleSpendValidator.validateTxs([tx]); - if (doubleSpendInvalidTxs.length > 0) { + const validDoubleSpend = await doubleSpendValidator.validateTx(tx); + if (!validDoubleSpend) { // check if nullifier is older than 20 blocks if (blockNumber - this.config.severePeerPenaltyBlockLength > 0) { const snapshotValidator = new DoubleSpendTxValidator({ @@ -467,9 +500,9 @@ export class LibP2PService implements P2PService { }, }); - const [____, snapshotInvalidTxs] = await snapshotValidator.validateTxs([tx]); + const validSnapshot = await snapshotValidator.validateTx(tx); // High penalty if nullifier is older than 20 blocks - if (snapshotInvalidTxs.length > 0) { + if (!validSnapshot) { // penalize this.peerManager.penalizePeer(peerId, PeerErrorSeverity.LowToleranceError); return false; @@ -482,8 +515,8 @@ export class LibP2PService implements P2PService { // proof validation const proofValidator = new TxProofValidator(this.proofVerifier); - const [_____, proofInvalidTxs] = await proofValidator.validateTxs([tx]); - if (proofInvalidTxs.length > 0) { + const validProof = await proofValidator.validateTx(tx); + if (!validProof) { // penalize this.peerManager.penalizePeer(peerId, PeerErrorSeverity.MidToleranceError); return false; diff --git a/yarn-project/p2p/src/service/reqresp/interface.ts b/yarn-project/p2p/src/service/reqresp/interface.ts index 606efc17bc9..8370b8a8a21 100644 --- a/yarn-project/p2p/src/service/reqresp/interface.ts +++ b/yarn-project/p2p/src/service/reqresp/interface.ts @@ -1,5 +1,7 @@ import { Tx, TxHash } from '@aztec/circuit-types'; +import { type PeerId } from '@libp2p/interface'; + /* * Request Response Sub Protocols */ @@ -46,11 +48,29 @@ export interface ProtocolRateLimitQuota { globalLimit: RateLimitQuota; } +export const noopValidator = () => Promise.resolve(true); + /** * A type mapping from supprotocol to it's handling funciton */ export type ReqRespSubProtocolHandlers = Record; +type ResponseValidator = ( + request: RequestIdentifier, + response: Response, + peerId: PeerId, +) => Promise; + +export type ReqRespSubProtocolValidators = { + [S in ReqRespSubProtocol]: ResponseValidator; +}; + +export const DEFAULT_SUB_PROTOCOL_VALIDATORS: ReqRespSubProtocolValidators = { + [PING_PROTOCOL]: noopValidator, + [STATUS_PROTOCOL]: noopValidator, + [TX_REQ_PROTOCOL]: noopValidator, +}; + /** * Sub protocol map determines the request and response types for each * Req Resp protocol diff --git a/yarn-project/p2p/src/service/reqresp/p2p_client.integration.test.ts b/yarn-project/p2p/src/service/reqresp/p2p_client.integration.test.ts index 52f06c8700b..6364c497b86 100644 --- a/yarn-project/p2p/src/service/reqresp/p2p_client.integration.test.ts +++ b/yarn-project/p2p/src/service/reqresp/p2p_client.integration.test.ts @@ -7,17 +7,22 @@ import { getRandomPort } from '@aztec/foundation/testing'; import { type AztecKVStore } from '@aztec/kv-store'; import { type DataStoreConfig, openTmpStore } from '@aztec/kv-store/utils'; +import { SignableENR } from '@chainsafe/enr'; import { describe, expect, it, jest } from '@jest/globals'; +import { multiaddr } from '@multiformats/multiaddr'; import { generatePrivateKey } from 'viem/accounts'; import { type AttestationPool } from '../../attestation_pool/attestation_pool.js'; -import { BootstrapNode } from '../../bootstrap/bootstrap.js'; import { createP2PClient } from '../../client/index.js'; import { MockBlockSource } from '../../client/mocks.js'; import { type P2PClient } from '../../client/p2p_client.js'; -import { type BootnodeConfig, type P2PConfig, getP2PDefaultConfig } from '../../config.js'; +import { type P2PConfig, getP2PDefaultConfig } from '../../config.js'; +import { AlwaysFalseCircuitVerifier, AlwaysTrueCircuitVerifier } from '../../mocks/index.js'; import { type TxPool } from '../../tx_pool/index.js'; +import { convertToMultiaddr } from '../../util.js'; +import { AZTEC_ENR_KEY, AZTEC_NET } from '../discV5_service.js'; import { createLibP2PPeerId } from '../index.js'; +import { PeerErrorSeverity } from '../peer_scoring.js'; /** * Mockify helper for testing purposes. @@ -28,22 +33,6 @@ type Mockify = { const TEST_TIMEOUT = 80000; -const DEFAULT_BOOT_NODE_UDP_PORT = 40400; -async function createBootstrapNode(port: number) { - const peerId = await createLibP2PPeerId(); - const bootstrapNode = new BootstrapNode(); - const config: BootnodeConfig = { - udpListenAddress: `0.0.0.0:${port}`, - udpAnnounceAddress: `127.0.0.1:${port}`, - peerIdPrivateKey: Buffer.from(peerId.privateKey!).toString('hex'), - minPeerCount: 1, - maxPeerCount: 100, - }; - await bootstrapNode.start(config); - - return bootstrapNode; -} - function generatePeerIdPrivateKeys(numberOfPeers: number): string[] { const peerIdPrivateKeys: string[] = []; for (let i = 0; i < numberOfPeers; i++) { @@ -65,21 +54,45 @@ describe('Req Resp p2p client integration', () => { let bootNodePort: number; const logger = createDebugLogger('p2p-client-integration-test'); - const makeBootstrapNode = async (): Promise<[BootstrapNode, string]> => { - bootNodePort = (await getRandomPort()) || DEFAULT_BOOT_NODE_UDP_PORT; - const bootstrapNode = await createBootstrapNode(bootNodePort); - const enr = bootstrapNode.getENR().encodeTxt(); - return [bootstrapNode, enr]; + const getPorts = async (numberOfPeers: number) => { + const ports = []; + for (let i = 0; i < numberOfPeers; i++) { + const port = (await getRandomPort()) || bootNodePort + i + 1; + ports.push(port); + } + return ports; }; - const createClients = async (numberOfPeers: number, bootstrapNodeEnr: string): Promise => { + const createClients = async (numberOfPeers: number, alwaysTrueVerifier: boolean = true): Promise => { const clients: P2PClient[] = []; const peerIdPrivateKeys = generatePeerIdPrivateKeys(numberOfPeers); + + const ports = await getPorts(numberOfPeers); + + const peerEnrs = await Promise.all( + peerIdPrivateKeys.map(async (pk, i) => { + const peerId = await createLibP2PPeerId(pk); + const enr = SignableENR.createFromPeerId(peerId); + + const udpAnnounceAddress = `127.0.0.1:${ports[i]}`; + const publicAddr = multiaddr(convertToMultiaddr(udpAnnounceAddress, 'udp')); + + // ENRS must include the network and a discoverable address (udp for discv5) + enr.set(AZTEC_ENR_KEY, Uint8Array.from([AZTEC_NET])); + enr.setLocationMultiaddr(publicAddr); + + return enr.encodeTxt(); + }), + ); + for (let i = 0; i < numberOfPeers; i++) { // Note these bindings are important - const port = (await getRandomPort()) || bootNodePort + i + 1; - const addr = `127.0.0.1:${port}`; - const listenAddr = `0.0.0.0:${port}`; + const addr = `127.0.0.1:${ports[i]}`; + const listenAddr = `0.0.0.0:${ports[i]}`; + + // Filter nodes so that we only dial active peers + const otherNodes = peerEnrs.filter((_, ind) => ind < i); + const config: P2PConfig & DataStoreConfig = { ...getP2PDefaultConfig(), p2pEnabled: true, @@ -89,7 +102,7 @@ describe('Req Resp p2p client integration', () => { tcpAnnounceAddress: addr, udpAnnounceAddress: addr, l2QueueSize: 1, - bootstrapNodes: [bootstrapNodeEnr], + bootstrapNodes: [...otherNodes], blockCheckIntervalMS: 1000, peerCheckIntervalMS: 1000, transactionProtocol: '', @@ -122,6 +135,7 @@ describe('Req Resp p2p client integration', () => { }; blockSource = new MockBlockSource(); + proofVerifier = alwaysTrueVerifier ? new AlwaysTrueCircuitVerifier() : new AlwaysFalseCircuitVerifier(); kvStore = openTmpStore(); const deps = { txPool: txPool as unknown as TxPool, @@ -150,8 +164,8 @@ describe('Req Resp p2p client integration', () => { }; // Shutdown all test clients - const shutdown = async (clients: P2PClient[], bootnode: BootstrapNode) => { - await Promise.all([bootnode.stop(), ...clients.map(client => client.stop())]); + const shutdown = async (clients: P2PClient[]) => { + await Promise.all([...clients.map(client => client.stop())]); await sleep(1000); }; @@ -160,8 +174,7 @@ describe('Req Resp p2p client integration', () => { async () => { // We want to create a set of nodes and request transaction from them // Not using a before each as a the wind down is not working as expected - const [bootstrapNode, bootstrapNodeEnr] = await makeBootstrapNode(); - const clients = await createClients(NUMBER_OF_PEERS, bootstrapNodeEnr); + const clients = await createClients(NUMBER_OF_PEERS); const [client1] = clients; await sleep(2000); @@ -173,7 +186,8 @@ describe('Req Resp p2p client integration', () => { const requestedTx = await client1.requestTxByHash(txHash); expect(requestedTx).toBeUndefined(); - await shutdown(clients, bootstrapNode); + // await shutdown(clients, bootstrapNode); + await shutdown(clients); }, TEST_TIMEOUT, ); @@ -182,8 +196,7 @@ describe('Req Resp p2p client integration', () => { 'Can request a transaction from another peer', async () => { // We want to create a set of nodes and request transaction from them - const [bootstrapNode, bootstrapNodeEnr] = await makeBootstrapNode(); - const clients = await createClients(NUMBER_OF_PEERS, bootstrapNodeEnr); + const clients = await createClients(NUMBER_OF_PEERS); const [client1] = clients; // Give the nodes time to discover each other @@ -200,7 +213,72 @@ describe('Req Resp p2p client integration', () => { // Expect the tx to be the returned tx to be the same as the one we mocked expect(requestedTx?.toBuffer()).toStrictEqual(tx.toBuffer()); - await shutdown(clients, bootstrapNode); + await shutdown(clients); + }, + TEST_TIMEOUT, + ); + + it( + 'Will penalize peers that send invalid proofs', + async () => { + // We want to create a set of nodes and request transaction from them + const clients = await createClients(NUMBER_OF_PEERS, /*valid proofs*/ false); + const [client1, client2] = clients; + const client2PeerId = (await client2.getEnr()?.peerId())!; + + // Give the nodes time to discover each other + await sleep(6000); + + const penalizePeerSpy = jest.spyOn((client1 as any).p2pService.peerManager, 'penalizePeer'); + + // Perform a get tx request from client 1 + const tx = mockTx(); + const txHash = tx.getTxHash(); + + // Return the correct tx with an invalid proof -> active attack + txPool.getTxByHash.mockImplementationOnce(() => tx); + + const requestedTx = await client1.requestTxByHash(txHash); + // Even though we got a response, the proof was deemed invalid + expect(requestedTx).toBeUndefined(); + + // Low tolerance error is due to the invalid proof + expect(penalizePeerSpy).toHaveBeenCalledWith(client2PeerId, PeerErrorSeverity.LowToleranceError); + + await shutdown(clients); + }, + TEST_TIMEOUT, + ); + + it( + 'Will penalize peers that send the wrong transaction', + async () => { + // We want to create a set of nodes and request transaction from them + const clients = await createClients(NUMBER_OF_PEERS, /*Valid proofs*/ true); + const [client1, client2] = clients; + const client2PeerId = (await client2.getEnr()?.peerId())!; + + // Give the nodes time to discover each other + await sleep(6000); + + const penalizePeerSpy = jest.spyOn((client1 as any).p2pService.peerManager, 'penalizePeer'); + + // Perform a get tx request from client 1 + const tx = mockTx(); + const txHash = tx.getTxHash(); + const tx2 = mockTx(420); + + // Return an invalid tx + txPool.getTxByHash.mockImplementationOnce(() => tx2); + + const requestedTx = await client1.requestTxByHash(txHash); + // Even though we got a response, the proof was deemed invalid + expect(requestedTx).toBeUndefined(); + + // Received wrong tx + expect(penalizePeerSpy).toHaveBeenCalledWith(client2PeerId, PeerErrorSeverity.MidToleranceError); + + await shutdown(clients); }, TEST_TIMEOUT, ); diff --git a/yarn-project/p2p/src/service/reqresp/reqresp.test.ts b/yarn-project/p2p/src/service/reqresp/reqresp.test.ts index e1c5f5ad102..1807a318522 100644 --- a/yarn-project/p2p/src/service/reqresp/reqresp.test.ts +++ b/yarn-project/p2p/src/service/reqresp/reqresp.test.ts @@ -5,9 +5,19 @@ import { describe, expect, it, jest } from '@jest/globals'; import { type MockProxy, mock } from 'jest-mock-extended'; import { CollectiveReqRespTimeoutError, IndiviualReqRespTimeoutError } from '../../errors/reqresp.error.js'; -import { MOCK_SUB_PROTOCOL_HANDLERS, connectToPeers, createNodes, startNodes, stopNodes } from '../../mocks/index.js'; +import { + MOCK_SUB_PROTOCOL_HANDLERS, + MOCK_SUB_PROTOCOL_VALIDATORS, + connectToPeers, + createNodes, + startNodes, + stopNodes, +} from '../../mocks/index.js'; import { type PeerManager } from '../peer_manager.js'; -import { PING_PROTOCOL, TX_REQ_PROTOCOL } from './interface.js'; +import { PeerErrorSeverity } from '../peer_scoring.js'; +import { PING_PROTOCOL, RequestableBuffer, TX_REQ_PROTOCOL } from './interface.js'; + +const PING_REQUEST = RequestableBuffer.fromBuffer(Buffer.from('ping')); // The Req Resp protocol should allow nodes to dial specific peers // and ask for specific data that they missed via the traditional gossip protocol. @@ -31,10 +41,10 @@ describe('ReqResp', () => { await sleep(500); - const res = await pinger.sendRequest(PING_PROTOCOL, Buffer.from('ping')); + const res = await pinger.sendRequest(PING_PROTOCOL, PING_REQUEST); await sleep(500); - expect(res?.toString('utf-8')).toEqual('pong'); + expect(res?.toBuffer().toString('utf-8')).toEqual('pong'); await stopNodes(nodes); }); @@ -53,7 +63,7 @@ describe('ReqResp', () => { void ponger.stop(); // It should return undefined if it cannot dial the peer - const res = await pinger.sendRequest(PING_PROTOCOL, Buffer.from('ping')); + const res = await pinger.sendRequest(PING_PROTOCOL, PING_REQUEST); expect(res).toBeUndefined(); @@ -73,9 +83,9 @@ describe('ReqResp', () => { void nodes[2].req.stop(); // send from the first node - const res = await nodes[0].req.sendRequest(PING_PROTOCOL, Buffer.from('ping')); + const res = await nodes[0].req.sendRequest(PING_PROTOCOL, PING_REQUEST); - expect(res?.toString('utf-8')).toEqual('pong'); + expect(res?.toBuffer().toString('utf-8')).toEqual('pong'); await stopNodes(nodes); }); @@ -125,8 +135,8 @@ describe('ReqResp', () => { await connectToPeers(nodes); await sleep(500); - const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, txHash.toBuffer()); - expect(res).toEqual(tx.toBuffer()); + const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, txHash); + expect(res).toEqual(tx); await stopNodes(nodes); }); @@ -148,7 +158,7 @@ describe('ReqResp', () => { await connectToPeers(nodes); await sleep(500); - const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, txHash.toBuffer()); + const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, txHash); expect(res).toBeUndefined(); await stopNodes(nodes); @@ -170,7 +180,8 @@ describe('ReqResp', () => { await connectToPeers(nodes); await sleep(500); - const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, Buffer.from('tx')); + const request = TxHash.random(); + const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, request); expect(res).toBeUndefined(); // Make sure the error message is logged @@ -179,6 +190,14 @@ describe('ReqResp', () => { } | peerId: ${nodes[1].p2p.peerId.toString()} | subProtocol: ${TX_REQ_PROTOCOL}`; expect(loggerSpy).toHaveBeenCalledWith(errorMessage); + // Expect the peer to be penalized for timing out + expect(peerManager.penalizePeer).toHaveBeenCalledWith( + expect.objectContaining({ + publicKey: nodes[1].p2p.peerId.publicKey, // must use objectContaining as we do not match exactly, as private key is contained in this test mapping + }), + PeerErrorSeverity.HighToleranceError, + ); + await stopNodes(nodes); }); @@ -200,7 +219,8 @@ describe('ReqResp', () => { await connectToPeers(nodes); await sleep(500); - const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, Buffer.from('tx')); + const request = TxHash.random(); + const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, request); expect(res).toBeUndefined(); // Make sure the error message is logged @@ -209,5 +229,47 @@ describe('ReqResp', () => { await stopNodes(nodes); }); + + it('Should penalize peer if transaction validation fails', async () => { + const tx = mockTx(); + const txHash = tx.getTxHash(); + + // Mock that the node will respond with the tx + const protocolHandlers = MOCK_SUB_PROTOCOL_HANDLERS; + protocolHandlers[TX_REQ_PROTOCOL] = (message: Buffer): Promise => { + const receivedHash = TxHash.fromBuffer(message); + if (txHash.equals(receivedHash)) { + return Promise.resolve(Uint8Array.from(tx.toBuffer())); + } + return Promise.resolve(Uint8Array.from(Buffer.from(''))); + }; + + // Mock that the receiving node will find that the transaction is invalid + const protocolValidators = MOCK_SUB_PROTOCOL_VALIDATORS; + protocolValidators[TX_REQ_PROTOCOL] = (_request, _response, peer) => { + peerManager.penalizePeer(peer, PeerErrorSeverity.LowToleranceError); + return Promise.resolve(false); + }; + + const nodes = await createNodes(peerManager, 2); + + await startNodes(nodes, protocolHandlers, protocolValidators); + await sleep(500); + await connectToPeers(nodes); + await sleep(500); + + const res = await nodes[0].req.sendRequest(TX_REQ_PROTOCOL, txHash); + expect(res).toBeUndefined(); + + // Expect the peer to be penalized for sending an invalid response + expect(peerManager.penalizePeer).toHaveBeenCalledWith( + expect.objectContaining({ + publicKey: nodes[1].p2p.peerId.publicKey, // must use objectContaining as we do not match exactly, as private key is contained in this test mapping + }), + PeerErrorSeverity.LowToleranceError, + ); + + await stopNodes(nodes); + }); }); }); diff --git a/yarn-project/p2p/src/service/reqresp/reqresp.ts b/yarn-project/p2p/src/service/reqresp/reqresp.ts index 39e5da0441f..41a9fd97d1e 100644 --- a/yarn-project/p2p/src/service/reqresp/reqresp.ts +++ b/yarn-project/p2p/src/service/reqresp/reqresp.ts @@ -9,11 +9,16 @@ import { type Uint8ArrayList } from 'uint8arraylist'; import { CollectiveReqRespTimeoutError, IndiviualReqRespTimeoutError } from '../../errors/reqresp.error.js'; import { type PeerManager } from '../peer_manager.js'; +import { PeerErrorSeverity } from '../peer_scoring.js'; import { type P2PReqRespConfig } from './config.js'; import { DEFAULT_SUB_PROTOCOL_HANDLERS, + DEFAULT_SUB_PROTOCOL_VALIDATORS, type ReqRespSubProtocol, type ReqRespSubProtocolHandlers, + type ReqRespSubProtocolValidators, + type SubProtocolMap, + subProtocolMap, } from './interface.js'; import { RequestResponseRateLimiter } from './rate_limiter/rate_limiter.js'; @@ -36,10 +41,13 @@ export class ReqResp { private overallRequestTimeoutMs: number; private individualRequestTimeoutMs: number; + // Warning, if the `start` function is not called as the parent class constructor, then the default sub protocol handlers will be used ( not good ) private subProtocolHandlers: ReqRespSubProtocolHandlers = DEFAULT_SUB_PROTOCOL_HANDLERS; + private subProtocolValidators: ReqRespSubProtocolValidators = DEFAULT_SUB_PROTOCOL_VALIDATORS; + private rateLimiter: RequestResponseRateLimiter; - constructor(config: P2PReqRespConfig, protected readonly libp2p: Libp2p, peerManager: PeerManager) { + constructor(config: P2PReqRespConfig, protected readonly libp2p: Libp2p, private peerManager: PeerManager) { this.logger = createDebugLogger('aztec:p2p:reqresp'); this.overallRequestTimeoutMs = config.overallRequestTimeoutMs; @@ -51,8 +59,10 @@ export class ReqResp { /** * Start the reqresp service */ - async start(subProtocolHandlers: ReqRespSubProtocolHandlers) { + async start(subProtocolHandlers: ReqRespSubProtocolHandlers, subProtocolValidators: ReqRespSubProtocolValidators) { this.subProtocolHandlers = subProtocolHandlers; + this.subProtocolValidators = subProtocolValidators; + // Register all protocol handlers for (const subProtocol of Object.keys(this.subProtocolHandlers)) { await this.libp2p.handle(subProtocol, this.streamHandler.bind(this, subProtocol as ReqRespSubProtocol)); @@ -77,29 +87,64 @@ export class ReqResp { * Send a request to peers, returns the first response * * @param subProtocol - The protocol being requested - * @param payload - The payload to send + * @param request - The request to send * @returns - The response from the peer, otherwise undefined + * + * @description + * This method attempts to send a request to all active peers using the specified sub-protocol. + * It opens a stream with each peer, sends the request, and awaits a response. + * If a valid response is received, it returns the response; otherwise, it continues to the next peer. + * If no response is received from any peer, it returns undefined. + * + * The method performs the following steps: + * - Iterates over all active peers. + * - Opens a stream with each peer using the specified sub-protocol. + * + * When a response is received, it is validated using the given sub protocols response validator. + * To see the interface for the response validator - see `interface.ts` + * + * Failing a response validation requests in a severe peer penalty, and will + * prompt the node to continue to search to the next peer. + * For example, a transaction request validator will check that the payload returned does in fact + * match the txHash that was requested. A peer that fails this check an only be an extremely naughty peer. + * + * This entire operation is wrapped in an overall timeout, that is independent of the + * peer it is requesting data from. + * */ - async sendRequest(subProtocol: ReqRespSubProtocol, payload: Buffer): Promise { + async sendRequest( + subProtocol: SubProtocol, + request: InstanceType, + ): Promise | undefined> { const requestFunction = async () => { + const responseValidator = this.subProtocolValidators[subProtocol]; + const requestBuffer = request.toBuffer(); + // Get active peers const peers = this.libp2p.getPeers(); // Attempt to ask all of our peers for (const peer of peers) { - const response = await this.sendRequestToPeer(peer, subProtocol, payload); + const response = await this.sendRequestToPeer(peer, subProtocol, requestBuffer); // If we get a response, return it, otherwise we iterate onto the next peer // We do not consider it a success if we have an empty buffer if (response && response.length > 0) { - return response; + const object = subProtocolMap[subProtocol].response.fromBuffer(response); + // The response validator handles peer punishment within + const isValid = await responseValidator(request, object, peer); + if (!isValid) { + this.logger.error(`Invalid response for ${subProtocol} from ${peer.toString()}`); + return undefined; + } + return object; } } return undefined; }; try { - return await executeTimeoutWithCustomError( + return await executeTimeoutWithCustomError | undefined>( requestFunction, this.overallRequestTimeoutMs, () => new CollectiveReqRespTimeoutError(), @@ -113,10 +158,26 @@ export class ReqResp { /** * Sends a request to a specific peer * + * We first dial a particular protocol for the peer, this ensures that the peer knows + * what to respond with + * + * * @param peerId - The peer to send the request to * @param subProtocol - The protocol to use to request * @param payload - The payload to send * @returns If the request is successful, the response is returned, otherwise undefined + * + * @description + * This method attempts to open a stream with the specified peer, send the payload, + * and await a response. + * If an error occurs, it penalizes the peer and returns undefined. + * + * The method performs the following steps: + * - Opens a stream with the peer using the specified sub-protocol. + * - Sends the payload and awaits a response with a timeout. + * + * If the stream is not closed by the dialled peer, and a timeout occurs, then + * the stream is closed on the requester's end and sender (us) updates its peer score */ async sendRequestToPeer( peerId: PeerId, @@ -129,6 +190,7 @@ export class ReqResp { this.logger.debug(`Stream opened with ${peerId.toString()} for ${subProtocol}`); + // Open the stream with a timeout const result = await executeTimeoutWithCustomError( (): Promise => pipe([payload], stream!, this.readMessage), this.individualRequestTimeoutMs, @@ -141,6 +203,7 @@ export class ReqResp { return result; } catch (e: any) { this.logger.error(`${e.message} | peerId: ${peerId.toString()} | subProtocol: ${subProtocol}`); + this.peerManager.penalizePeer(peerId, PeerErrorSeverity.HighToleranceError); } finally { if (stream) { try { @@ -173,6 +236,16 @@ export class ReqResp { * Reads the incoming stream, determines the protocol, then triggers the appropriate handler * * @param param0 - The incoming stream data + * + * @description + * An individual stream handler will be bound to each sub protocol, and handles returning data back + * to the requesting peer. + * + * The sub protocol handler interface is defined within `interface.ts` and will be assigned to the + * req resp service on start up. + * + * We check rate limits for each peer, note the peer will be penalised within the rate limiter implementation + * if they exceed their peer specific limits. */ private async streamHandler(protocol: ReqRespSubProtocol, { stream, connection }: IncomingStreamData) { // Store a reference to from this for the async generator diff --git a/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.test.ts b/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.test.ts index 1eb7558bf25..c74a0fc5e16 100644 --- a/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.test.ts +++ b/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.test.ts @@ -25,5 +25,9 @@ describe('AggregateTxValidator', () => { txs.filter(tx => this.denyList.has(Tx.getHash(tx).toString())), ]); } + + validateTx(tx: AnyTx): Promise { + return Promise.resolve(this.denyList.has(Tx.getHash(tx).toString())); + } } }); diff --git a/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.ts b/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.ts index 8397a45eede..99dfb6c282d 100644 --- a/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.ts +++ b/yarn-project/p2p/src/tx_validator/aggregate_tx_validator.ts @@ -21,4 +21,14 @@ export class AggregateTxValidator implements TxValid return [txPool, invalidTxs]; } + + async validateTx(tx: T): Promise { + for (const validator of this.#validators) { + const valid = await validator.validateTx(tx); + if (!valid) { + return false; + } + } + return true; + } } diff --git a/yarn-project/p2p/src/tx_validator/data_validator.ts b/yarn-project/p2p/src/tx_validator/data_validator.ts index 39731cbc08d..f284f4638ce 100644 --- a/yarn-project/p2p/src/tx_validator/data_validator.ts +++ b/yarn-project/p2p/src/tx_validator/data_validator.ts @@ -19,6 +19,10 @@ export class DataTxValidator implements TxValidator { return Promise.resolve([validTxs, invalidTxs]); } + validateTx(tx: Tx): Promise { + return Promise.resolve(this.#hasCorrectExecutionRequests(tx)); + } + #hasCorrectExecutionRequests(tx: Tx): boolean { const callRequests = [ ...tx.data.getRevertiblePublicCallRequests(), diff --git a/yarn-project/p2p/src/tx_validator/double_spend_validator.ts b/yarn-project/p2p/src/tx_validator/double_spend_validator.ts index e7a8e065e5a..15ba9f76d27 100644 --- a/yarn-project/p2p/src/tx_validator/double_spend_validator.ts +++ b/yarn-project/p2p/src/tx_validator/double_spend_validator.ts @@ -31,6 +31,10 @@ export class DoubleSpendTxValidator implements TxValidator { return [validTxs, invalidTxs]; } + validateTx(tx: T): Promise { + return this.#uniqueNullifiers(tx, new Set()); + } + async #uniqueNullifiers(tx: AnyTx, thisBlockNullifiers: Set): Promise { const nullifiers = tx.data.getNonEmptyNullifiers().map(x => x.toBigInt()); diff --git a/yarn-project/p2p/src/tx_validator/metadata_validator.ts b/yarn-project/p2p/src/tx_validator/metadata_validator.ts index 145378e21e4..995e61cdb00 100644 --- a/yarn-project/p2p/src/tx_validator/metadata_validator.ts +++ b/yarn-project/p2p/src/tx_validator/metadata_validator.ts @@ -27,6 +27,10 @@ export class MetadataTxValidator implements TxValidator { return Promise.resolve([validTxs, invalidTxs]); } + validateTx(tx: T): Promise { + return Promise.resolve(this.#hasCorrectChainId(tx) && this.#isValidForBlockNumber(tx)); + } + #hasCorrectChainId(tx: T): boolean { if (!tx.data.constants.txContext.chainId.equals(this.chainId)) { this.#log.warn( diff --git a/yarn-project/sequencer-client/src/tx_validator/gas_validator.ts b/yarn-project/sequencer-client/src/tx_validator/gas_validator.ts index ea93d51cafa..0a9f1f9e82a 100644 --- a/yarn-project/sequencer-client/src/tx_validator/gas_validator.ts +++ b/yarn-project/sequencer-client/src/tx_validator/gas_validator.ts @@ -34,6 +34,10 @@ export class GasTxValidator implements TxValidator { return [validTxs, invalidTxs]; } + validateTx(tx: Tx): Promise { + return this.#validateTxFee(tx); + } + async #validateTxFee(tx: Tx): Promise { const feePayer = tx.data.feePayer; // TODO(@spalladino) Eventually remove the is_zero condition as we should always charge fees to every tx diff --git a/yarn-project/sequencer-client/src/tx_validator/phases_validator.ts b/yarn-project/sequencer-client/src/tx_validator/phases_validator.ts index 345fd066ef3..7fa0aaaf067 100644 --- a/yarn-project/sequencer-client/src/tx_validator/phases_validator.ts +++ b/yarn-project/sequencer-client/src/tx_validator/phases_validator.ts @@ -27,7 +27,7 @@ export class PhasesTxValidator implements TxValidator { // which is what we're trying to do as part of the current txs. await this.contractDataSource.addNewContracts(tx); - if (await this.#validateTx(tx)) { + if (await this.validateTx(tx)) { validTxs.push(tx); } else { invalidTxs.push(tx); @@ -39,7 +39,7 @@ export class PhasesTxValidator implements TxValidator { return Promise.resolve([validTxs, invalidTxs]); } - async #validateTx(tx: Tx): Promise { + async validateTx(tx: Tx): Promise { if (!tx.data.forPublic) { this.#log.debug(`Tx ${Tx.getHash(tx)} does not contain enqueued public functions. Skipping phases validation.`); return true;