Skip to content

Commit

Permalink
fix(avm): fix usage of Fr with tagged memory (AztecProtocol#4240)
Browse files Browse the repository at this point in the history
Some calls to TaggedMemory were using `Fr` since it respected the
defined interface. This was too lax, we really only want specific types
in memory. We may allow conversion, but we cannot allow just any type
that satisfies the interface. Therefore I'm changing this to specific
classes.

Ref AztecProtocol#4213.
  • Loading branch information
fcarreiro authored Jan 26, 2024
1 parent 11f400f commit b82e70c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 40 deletions.
57 changes: 30 additions & 27 deletions yarn-project/acir-simulator/src/avm/avm_memory_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,43 @@ import { Fr } from '@aztec/foundation/fields';

import { strict as assert } from 'assert';

export interface MemoryValue {
add(rhs: MemoryValue): MemoryValue;
sub(rhs: MemoryValue): MemoryValue;
mul(rhs: MemoryValue): MemoryValue;
div(rhs: MemoryValue): MemoryValue;
export abstract class MemoryValue {
public abstract add(rhs: MemoryValue): MemoryValue;
public abstract sub(rhs: MemoryValue): MemoryValue;
public abstract mul(rhs: MemoryValue): MemoryValue;
public abstract div(rhs: MemoryValue): MemoryValue;

// We need this to be able to build an instance of the subclasses.
public abstract build(n: bigint): MemoryValue;

// Use sparingly.
toBigInt(): bigint;
public abstract toBigInt(): bigint;
}

export interface IntegralValue extends MemoryValue {
shl(rhs: IntegralValue): IntegralValue;
shr(rhs: IntegralValue): IntegralValue;
and(rhs: IntegralValue): IntegralValue;
or(rhs: IntegralValue): IntegralValue;
xor(rhs: IntegralValue): IntegralValue;
not(): IntegralValue;
export abstract class IntegralValue extends MemoryValue {
public abstract shl(rhs: IntegralValue): IntegralValue;
public abstract shr(rhs: IntegralValue): IntegralValue;
public abstract and(rhs: IntegralValue): IntegralValue;
public abstract or(rhs: IntegralValue): IntegralValue;
public abstract xor(rhs: IntegralValue): IntegralValue;
public abstract not(): IntegralValue;
}

// TODO: Optimize calculation of mod, etc. Can only do once per class?
abstract class UnsignedInteger implements IntegralValue {
abstract class UnsignedInteger extends IntegralValue {
private readonly bitmask: bigint;
private readonly mod: bigint;

protected constructor(private n: bigint, private bits: bigint) {
super();
assert(bits > 0);
// x % 2^n == x & (2^n - 1)
this.mod = 1n << bits;
this.bitmask = this.mod - 1n;
assert(n < this.mod);
}

// We need this to be able to build an instance of the subclass
// and not of type UnsignedInteger.
protected abstract build(n: bigint): UnsignedInteger;
public abstract build(n: bigint): UnsignedInteger;

public add(rhs: UnsignedInteger): UnsignedInteger {
assert(this.bits == rhs.bits);
Expand Down Expand Up @@ -93,18 +95,14 @@ abstract class UnsignedInteger implements IntegralValue {
public toBigInt(): bigint {
return this.n;
}

public equals(rhs: UnsignedInteger) {
return this.bits == rhs.bits && this.toBigInt() == rhs.toBigInt();
}
}

export class Uint8 extends UnsignedInteger {
constructor(n: number | bigint) {
super(BigInt(n), 8n);
}

protected build(n: bigint): Uint8 {
public build(n: bigint): Uint8 {
return new Uint8(n);
}
}
Expand All @@ -114,7 +112,7 @@ export class Uint16 extends UnsignedInteger {
super(BigInt(n), 16n);
}

protected build(n: bigint): Uint16 {
public build(n: bigint): Uint16 {
return new Uint16(n);
}
}
Expand All @@ -124,7 +122,7 @@ export class Uint32 extends UnsignedInteger {
super(BigInt(n), 32n);
}

protected build(n: bigint): Uint32 {
public build(n: bigint): Uint32 {
return new Uint32(n);
}
}
Expand All @@ -134,7 +132,7 @@ export class Uint64 extends UnsignedInteger {
super(BigInt(n), 64n);
}

protected build(n: bigint): Uint64 {
public build(n: bigint): Uint64 {
return new Uint64(n);
}
}
Expand All @@ -144,19 +142,24 @@ export class Uint128 extends UnsignedInteger {
super(BigInt(n), 128n);
}

protected build(n: bigint): Uint128 {
public build(n: bigint): Uint128 {
return new Uint128(n);
}
}

export class Field implements MemoryValue {
export class Field extends MemoryValue {
public static readonly MODULUS: bigint = Fr.MODULUS;
private readonly rep: Fr;

constructor(v: number | bigint | Fr) {
super();
this.rep = new Fr(v);
}

public build(n: bigint): Field {
return new Field(n);
}

public add(rhs: Field): Field {
return new Field(this.rep.add(rhs.rep));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ describe('External Calls', () => {
const addr = new Fr(123456n);

const argsOffset = 2;
const args = [new Fr(1n), new Fr(2n), new Fr(3n)];
const args = [new Field(1n), new Field(2n), new Field(3n)];
const argsSize = args.length;

const retOffset = 8;
const retSize = 2;

const successOffset = 7;

machineState.memory.set(0, gas);
machineState.memory.set(1, addr);
machineState.memory.set(0, new Field(gas));
machineState.memory.set(1, new Field(addr));
machineState.memory.setSlice(2, args);

const otherContextInstructions: [Opcode, any[]][] = [
Expand All @@ -72,10 +72,10 @@ describe('External Calls', () => {
await instruction.execute(machineState, journal);

const successValue = machineState.memory.get(successOffset);
expect(successValue).toEqual(new Fr(1n));
expect(successValue).toEqual(new Field(1n));

const retValue = machineState.memory.getSlice(retOffset, retSize);
expect(retValue).toEqual([new Fr(1n), new Fr(2n)]);
expect(retValue).toEqual([new Field(1n), new Field(2n)]);

// Check that the storage call has been merged into the parent journal
const { storageWrites } = journal.flush();
Expand Down Expand Up @@ -126,7 +126,7 @@ describe('External Calls', () => {

// No revert has occurred, but the nested execution has failed
const successValue = machineState.memory.get(successOffset);
expect(successValue).toEqual(new Fr(0n));
expect(successValue).toEqual(new Field(0n));
});
});
});
10 changes: 6 additions & 4 deletions yarn-project/acir-simulator/src/avm/opcodes/external_calls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ export class Call extends Instruction {

// We only take as much data as was specified in the return size -> TODO: should we be reverting here
const returnData = returnObject.output.slice(0, this.retSize);
const convertedReturnData = returnData.map(f => new Field(f));

// Write our return data into memory
machineState.memory.set(this.successOffset, new Fr(success));
machineState.memory.setSlice(this.retOffset, returnData);
machineState.memory.set(this.successOffset, new Field(success ? 1 : 0));
machineState.memory.setSlice(this.retOffset, convertedReturnData);

if (success) {
avmContext.mergeJournal();
Expand Down Expand Up @@ -84,10 +85,11 @@ export class StaticCall extends Instruction {

// We only take as much data as was specified in the return size -> TODO: should we be reverting here
const returnData = returnObject.output.slice(0, this.retSize);
const convertedReturnData = returnData.map(f => new Field(f));

// Write our return data into memory
machineState.memory.set(this.successOffset, new Fr(success));
machineState.memory.setSlice(this.retOffset, returnData);
machineState.memory.set(this.successOffset, new Field(success ? 1 : 0));
machineState.memory.setSlice(this.retOffset, convertedReturnData);

if (success) {
avmContext.mergeJournal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ describe('Storage Instructions', () => {
expect(journal.readStorage).toBeCalledWith(address, new Fr(a.toBigInt()));

const actual = machineState.memory.get(1);
expect(actual).toEqual(expectedResult);
expect(actual).toEqual(new Field(expectedResult));
});
});
8 changes: 6 additions & 2 deletions yarn-project/acir-simulator/src/avm/opcodes/storage.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Fr } from '@aztec/foundation/fields';

import { AvmMachineState } from '../avm_machine_state.js';
import { Field } from '../avm_memory_types.js';
import { AvmInterpreterError } from '../interpreter/interpreter.js';
import { AvmJournal } from '../journal/journal.js';
import { Instruction } from './instruction.js';
Expand Down Expand Up @@ -44,9 +45,12 @@ export class SLoad extends Instruction {
async execute(machineState: AvmMachineState, journal: AvmJournal): Promise<void> {
const slot = machineState.memory.get(this.slotOffset);

const data = journal.readStorage(machineState.executionEnvironment.storageAddress, new Fr(slot.toBigInt()));
const data: Fr = await journal.readStorage(
machineState.executionEnvironment.storageAddress,
new Fr(slot.toBigInt()),
);

machineState.memory.set(this.destOffset, await data);
machineState.memory.set(this.destOffset, new Field(data));

this.incrementPc(machineState);
}
Expand Down

0 comments on commit b82e70c

Please sign in to comment.