From b82e70c61771c8a3cef4026dc522f2c99147180b Mon Sep 17 00:00:00 2001 From: Facundo Date: Fri, 26 Jan 2024 20:45:17 +0000 Subject: [PATCH] fix(avm): fix usage of Fr with tagged memory (#4240) Some calls to TaggedMemory were using `Fr` since it respected the defined interface. This was too lax, we really only want specific types in memory. We may allow conversion, but we cannot allow just any type that satisfies the interface. Therefore I'm changing this to specific classes. Ref #4213. --- .../src/avm/avm_memory_types.ts | 57 ++++++++++--------- .../src/avm/opcodes/external_calls.test.ts | 12 ++-- .../src/avm/opcodes/external_calls.ts | 10 ++-- .../src/avm/opcodes/storage.test.ts | 2 +- .../acir-simulator/src/avm/opcodes/storage.ts | 8 ++- 5 files changed, 49 insertions(+), 40 deletions(-) diff --git a/yarn-project/acir-simulator/src/avm/avm_memory_types.ts b/yarn-project/acir-simulator/src/avm/avm_memory_types.ts index 69ed95c4a66..8d6d240c922 100644 --- a/yarn-project/acir-simulator/src/avm/avm_memory_types.ts +++ b/yarn-project/acir-simulator/src/avm/avm_memory_types.ts @@ -2,31 +2,35 @@ import { Fr } from '@aztec/foundation/fields'; import { strict as assert } from 'assert'; -export interface MemoryValue { - add(rhs: MemoryValue): MemoryValue; - sub(rhs: MemoryValue): MemoryValue; - mul(rhs: MemoryValue): MemoryValue; - div(rhs: MemoryValue): MemoryValue; +export abstract class MemoryValue { + public abstract add(rhs: MemoryValue): MemoryValue; + public abstract sub(rhs: MemoryValue): MemoryValue; + public abstract mul(rhs: MemoryValue): MemoryValue; + public abstract div(rhs: MemoryValue): MemoryValue; + + // We need this to be able to build an instance of the subclasses. + public abstract build(n: bigint): MemoryValue; // Use sparingly. - toBigInt(): bigint; + public abstract toBigInt(): bigint; } -export interface IntegralValue extends MemoryValue { - shl(rhs: IntegralValue): IntegralValue; - shr(rhs: IntegralValue): IntegralValue; - and(rhs: IntegralValue): IntegralValue; - or(rhs: IntegralValue): IntegralValue; - xor(rhs: IntegralValue): IntegralValue; - not(): IntegralValue; +export abstract class IntegralValue extends MemoryValue { + public abstract shl(rhs: IntegralValue): IntegralValue; + public abstract shr(rhs: IntegralValue): IntegralValue; + public abstract and(rhs: IntegralValue): IntegralValue; + public abstract or(rhs: IntegralValue): IntegralValue; + public abstract xor(rhs: IntegralValue): IntegralValue; + public abstract not(): IntegralValue; } // TODO: Optimize calculation of mod, etc. Can only do once per class? -abstract class UnsignedInteger implements IntegralValue { +abstract class UnsignedInteger extends IntegralValue { private readonly bitmask: bigint; private readonly mod: bigint; protected constructor(private n: bigint, private bits: bigint) { + super(); assert(bits > 0); // x % 2^n == x & (2^n - 1) this.mod = 1n << bits; @@ -34,9 +38,7 @@ abstract class UnsignedInteger implements IntegralValue { assert(n < this.mod); } - // We need this to be able to build an instance of the subclass - // and not of type UnsignedInteger. - protected abstract build(n: bigint): UnsignedInteger; + public abstract build(n: bigint): UnsignedInteger; public add(rhs: UnsignedInteger): UnsignedInteger { assert(this.bits == rhs.bits); @@ -93,10 +95,6 @@ abstract class UnsignedInteger implements IntegralValue { public toBigInt(): bigint { return this.n; } - - public equals(rhs: UnsignedInteger) { - return this.bits == rhs.bits && this.toBigInt() == rhs.toBigInt(); - } } export class Uint8 extends UnsignedInteger { @@ -104,7 +102,7 @@ export class Uint8 extends UnsignedInteger { super(BigInt(n), 8n); } - protected build(n: bigint): Uint8 { + public build(n: bigint): Uint8 { return new Uint8(n); } } @@ -114,7 +112,7 @@ export class Uint16 extends UnsignedInteger { super(BigInt(n), 16n); } - protected build(n: bigint): Uint16 { + public build(n: bigint): Uint16 { return new Uint16(n); } } @@ -124,7 +122,7 @@ export class Uint32 extends UnsignedInteger { super(BigInt(n), 32n); } - protected build(n: bigint): Uint32 { + public build(n: bigint): Uint32 { return new Uint32(n); } } @@ -134,7 +132,7 @@ export class Uint64 extends UnsignedInteger { super(BigInt(n), 64n); } - protected build(n: bigint): Uint64 { + public build(n: bigint): Uint64 { return new Uint64(n); } } @@ -144,19 +142,24 @@ export class Uint128 extends UnsignedInteger { super(BigInt(n), 128n); } - protected build(n: bigint): Uint128 { + public build(n: bigint): Uint128 { return new Uint128(n); } } -export class Field implements MemoryValue { +export class Field extends MemoryValue { public static readonly MODULUS: bigint = Fr.MODULUS; private readonly rep: Fr; constructor(v: number | bigint | Fr) { + super(); this.rep = new Fr(v); } + public build(n: bigint): Field { + return new Field(n); + } + public add(rhs: Field): Field { return new Field(this.rep.add(rhs.rep)); } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/external_calls.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/external_calls.test.ts index 7a2f74df0e4..8f6fd076038 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/external_calls.test.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/external_calls.test.ts @@ -40,7 +40,7 @@ describe('External Calls', () => { const addr = new Fr(123456n); const argsOffset = 2; - const args = [new Fr(1n), new Fr(2n), new Fr(3n)]; + const args = [new Field(1n), new Field(2n), new Field(3n)]; const argsSize = args.length; const retOffset = 8; @@ -48,8 +48,8 @@ describe('External Calls', () => { const successOffset = 7; - machineState.memory.set(0, gas); - machineState.memory.set(1, addr); + machineState.memory.set(0, new Field(gas)); + machineState.memory.set(1, new Field(addr)); machineState.memory.setSlice(2, args); const otherContextInstructions: [Opcode, any[]][] = [ @@ -72,10 +72,10 @@ describe('External Calls', () => { await instruction.execute(machineState, journal); const successValue = machineState.memory.get(successOffset); - expect(successValue).toEqual(new Fr(1n)); + expect(successValue).toEqual(new Field(1n)); const retValue = machineState.memory.getSlice(retOffset, retSize); - expect(retValue).toEqual([new Fr(1n), new Fr(2n)]); + expect(retValue).toEqual([new Field(1n), new Field(2n)]); // Check that the storage call has been merged into the parent journal const { storageWrites } = journal.flush(); @@ -126,7 +126,7 @@ describe('External Calls', () => { // No revert has occurred, but the nested execution has failed const successValue = machineState.memory.get(successOffset); - expect(successValue).toEqual(new Fr(0n)); + expect(successValue).toEqual(new Field(0n)); }); }); }); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/external_calls.ts b/yarn-project/acir-simulator/src/avm/opcodes/external_calls.ts index fc98f6c780c..280b1284f02 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/external_calls.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/external_calls.ts @@ -39,10 +39,11 @@ export class Call extends Instruction { // We only take as much data as was specified in the return size -> TODO: should we be reverting here const returnData = returnObject.output.slice(0, this.retSize); + const convertedReturnData = returnData.map(f => new Field(f)); // Write our return data into memory - machineState.memory.set(this.successOffset, new Fr(success)); - machineState.memory.setSlice(this.retOffset, returnData); + machineState.memory.set(this.successOffset, new Field(success ? 1 : 0)); + machineState.memory.setSlice(this.retOffset, convertedReturnData); if (success) { avmContext.mergeJournal(); @@ -84,10 +85,11 @@ export class StaticCall extends Instruction { // We only take as much data as was specified in the return size -> TODO: should we be reverting here const returnData = returnObject.output.slice(0, this.retSize); + const convertedReturnData = returnData.map(f => new Field(f)); // Write our return data into memory - machineState.memory.set(this.successOffset, new Fr(success)); - machineState.memory.setSlice(this.retOffset, returnData); + machineState.memory.set(this.successOffset, new Field(success ? 1 : 0)); + machineState.memory.setSlice(this.retOffset, convertedReturnData); if (success) { avmContext.mergeJournal(); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/storage.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/storage.test.ts index 47bc4433862..e62b0abc561 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/storage.test.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/storage.test.ts @@ -63,6 +63,6 @@ describe('Storage Instructions', () => { expect(journal.readStorage).toBeCalledWith(address, new Fr(a.toBigInt())); const actual = machineState.memory.get(1); - expect(actual).toEqual(expectedResult); + expect(actual).toEqual(new Field(expectedResult)); }); }); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/storage.ts b/yarn-project/acir-simulator/src/avm/opcodes/storage.ts index e8f078f7909..419cbecc1da 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/storage.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/storage.ts @@ -1,6 +1,7 @@ import { Fr } from '@aztec/foundation/fields'; import { AvmMachineState } from '../avm_machine_state.js'; +import { Field } from '../avm_memory_types.js'; import { AvmInterpreterError } from '../interpreter/interpreter.js'; import { AvmJournal } from '../journal/journal.js'; import { Instruction } from './instruction.js'; @@ -44,9 +45,12 @@ export class SLoad extends Instruction { async execute(machineState: AvmMachineState, journal: AvmJournal): Promise { const slot = machineState.memory.get(this.slotOffset); - const data = journal.readStorage(machineState.executionEnvironment.storageAddress, new Fr(slot.toBigInt())); + const data: Fr = await journal.readStorage( + machineState.executionEnvironment.storageAddress, + new Fr(slot.toBigInt()), + ); - machineState.memory.set(this.destOffset, await data); + machineState.memory.set(this.destOffset, new Field(data)); this.incrementPc(machineState); }