From 1c91980d37e3f35bb21a5aeb76feb085a8346d5f Mon Sep 17 00:00:00 2001 From: Facundo Date: Mon, 12 Feb 2024 14:54:48 +0000 Subject: [PATCH] chore(avm-simulator): reduce boilerplate in AVM memory types (#4542) Uses a mixin factory to create the unsigned integral types. This also has the benefit that the mod, bitmask, etc, are created just once per type, and not once per instance. --- .../src/avm/avm_memory_types.test.ts | 182 ++++++++--------- .../simulator/src/avm/avm_memory_types.ts | 191 +++++++----------- 2 files changed, 165 insertions(+), 208 deletions(-) diff --git a/yarn-project/simulator/src/avm/avm_memory_types.test.ts b/yarn-project/simulator/src/avm/avm_memory_types.test.ts index cdb581467de..62f0aff7451 100644 --- a/yarn-project/simulator/src/avm/avm_memory_types.test.ts +++ b/yarn-project/simulator/src/avm/avm_memory_types.test.ts @@ -39,96 +39,98 @@ describe('TaggedMemory', () => { type IntegralClass = typeof Uint8 | typeof Uint16 | typeof Uint32 | typeof Uint64 | typeof Uint128; describe.each([Uint8, Uint16, Uint32, Uint64, Uint128])('Integral Types', (clsValue: IntegralClass) => { - it(`Should construct a new ${clsValue.name} from a number`, () => { - const x = new clsValue(5); - expect(x.toBigInt()).toStrictEqual(5n); - }); - - it(`Should construct a new ${clsValue.name} from a bigint`, () => { - const x = new clsValue(5n); - expect(x.toBigInt()).toStrictEqual(5n); - }); - - it(`Should build a new ${clsValue.name}`, () => { - const x = new clsValue(5); - const newX = x.build(10n); - expect(newX).toStrictEqual(new clsValue(10n)); - }); - - it(`Should add two ${clsValue.name} correctly`, () => { - const a = new clsValue(5); - const b = new clsValue(10); - const result = a.add(b); - expect(result).toStrictEqual(new clsValue(15n)); - }); - - it(`Should subtract two ${clsValue.name} correctly`, () => { - const a = new clsValue(10); - const b = new clsValue(5); - const result = a.sub(b); - expect(result).toStrictEqual(new clsValue(5n)); - }); - - it(`Should multiply two ${clsValue.name} correctly`, () => { - const a = new clsValue(5); - const b = new clsValue(10); - const result = a.mul(b); - expect(result).toStrictEqual(new clsValue(50n)); - }); - - it(`Should divide two ${clsValue.name} correctly`, () => { - const a = new clsValue(10); - const b = new clsValue(5); - const result = a.div(b); - expect(result).toStrictEqual(new clsValue(2n)); - }); - - it('Should shift right ${clsValue.name} correctly', () => { - const uintA = new clsValue(10); - const result = uintA.shr(new clsValue(1n)); - expect(result).toEqual(new clsValue(5n)); - }); - - it('Should shift left ${clsValue.name} correctly', () => { - const uintA = new clsValue(10); - const result = uintA.shl(new clsValue(1n)); - expect(result).toEqual(new clsValue(20n)); - }); - - it('Should and two ${clsValue.name} correctly', () => { - const uintA = new clsValue(10); - const uintB = new clsValue(5); - const result = uintA.and(uintB); - expect(result).toEqual(new clsValue(0n)); - }); - - it('Should or two ${clsValue.name} correctly', () => { - const uintA = new clsValue(10); - const uintB = new clsValue(5); - const result = uintA.or(uintB); - expect(result).toEqual(new clsValue(15n)); - }); - - it('Should xor two ${clsValue.name} correctly', () => { - const uintA = new clsValue(10); - const uintB = new clsValue(5); - const result = uintA.xor(uintB); - expect(result).toEqual(new clsValue(15n)); - }); - - it(`Should check equality of two ${clsValue.name} correctly`, () => { - const a = new clsValue(5); - const b = new clsValue(5); - const c = new clsValue(10); - expect(a.equals(b)).toBe(true); - expect(a.equals(c)).toBe(false); - }); - - it(`Should check if one ${clsValue.name} is less than another correctly`, () => { - const a = new clsValue(5); - const b = new clsValue(10); - expect(a.lt(b)).toBe(true); - expect(b.lt(a)).toBe(false); + describe(`${clsValue.name}`, () => { + it(`Should construct a new ${clsValue.name} from a number`, () => { + const x = new clsValue(5); + expect(x.toBigInt()).toStrictEqual(5n); + }); + + it(`Should construct a new ${clsValue.name} from a bigint`, () => { + const x = new clsValue(5n); + expect(x.toBigInt()).toStrictEqual(5n); + }); + + it(`Should build a new ${clsValue.name}`, () => { + const x = new clsValue(5); + const newX = x.build(10n); + expect(newX).toStrictEqual(new clsValue(10n)); + }); + + it(`Should add two ${clsValue.name} correctly`, () => { + const a = new clsValue(5); + const b = new clsValue(10); + const result = a.add(b); + expect(result).toStrictEqual(new clsValue(15n)); + }); + + it(`Should subtract two ${clsValue.name} correctly`, () => { + const a = new clsValue(10); + const b = new clsValue(5); + const result = a.sub(b); + expect(result).toStrictEqual(new clsValue(5n)); + }); + + it(`Should multiply two ${clsValue.name} correctly`, () => { + const a = new clsValue(5); + const b = new clsValue(10); + const result = a.mul(b); + expect(result).toStrictEqual(new clsValue(50n)); + }); + + it(`Should divide two ${clsValue.name} correctly`, () => { + const a = new clsValue(10); + const b = new clsValue(5); + const result = a.div(b); + expect(result).toStrictEqual(new clsValue(2n)); + }); + + it(`Should shift right ${clsValue.name} correctly`, () => { + const uintA = new clsValue(10); + const result = uintA.shr(new clsValue(1n)); + expect(result).toEqual(new clsValue(5n)); + }); + + it(`Should shift left ${clsValue.name} correctly`, () => { + const uintA = new clsValue(10); + const result = uintA.shl(new clsValue(1n)); + expect(result).toEqual(new clsValue(20n)); + }); + + it(`Should and two ${clsValue.name} correctly`, () => { + const uintA = new clsValue(10); + const uintB = new clsValue(5); + const result = uintA.and(uintB); + expect(result).toEqual(new clsValue(0n)); + }); + + it(`Should or two ${clsValue.name} correctly`, () => { + const uintA = new clsValue(10); + const uintB = new clsValue(5); + const result = uintA.or(uintB); + expect(result).toEqual(new clsValue(15n)); + }); + + it(`Should xor two ${clsValue.name} correctly`, () => { + const uintA = new clsValue(10); + const uintB = new clsValue(5); + const result = uintA.xor(uintB); + expect(result).toEqual(new clsValue(15n)); + }); + + it(`Should check equality of two ${clsValue.name} correctly`, () => { + const a = new clsValue(5); + const b = new clsValue(5); + const c = new clsValue(10); + expect(a.equals(b)).toBe(true); + expect(a.equals(c)).toBe(false); + }); + + it(`Should check if one ${clsValue.name} is less than another correctly`, () => { + const a = new clsValue(5); + const b = new clsValue(10); + expect(a.lt(b)).toBe(true); + expect(b.lt(a)).toBe(false); + }); }); }); diff --git a/yarn-project/simulator/src/avm/avm_memory_types.ts b/yarn-project/simulator/src/avm/avm_memory_types.ts index f4fb0b80e9c..49546d41ec9 100644 --- a/yarn-project/simulator/src/avm/avm_memory_types.ts +++ b/yarn-project/simulator/src/avm/avm_memory_types.ts @@ -4,6 +4,7 @@ import { strict as assert } from 'assert'; import { TagCheckError } from './errors.js'; +/** MemoryValue gathers the common operations for all memory types. */ export abstract class MemoryValue { public abstract add(rhs: MemoryValue): MemoryValue; public abstract sub(rhs: MemoryValue): MemoryValue; @@ -24,6 +25,7 @@ export abstract class MemoryValue { } } +/** IntegralValue gathers the common operations for all integral memory types. */ export abstract class IntegralValue extends MemoryValue { public abstract shl(rhs: IntegralValue): IntegralValue; public abstract shr(rhs: IntegralValue): IntegralValue; @@ -32,141 +34,94 @@ export abstract class IntegralValue extends MemoryValue { public abstract xor(rhs: IntegralValue): IntegralValue; public abstract not(): IntegralValue; - public abstract lt(rhs: MemoryValue): boolean; + public abstract lt(rhs: IntegralValue): boolean; } -// TODO: Optimize calculation of mod, etc. Can only do once per class? -abstract class UnsignedInteger extends IntegralValue { - private readonly bitmask: bigint; - private readonly mod: bigint; - - protected constructor(private n: bigint, private bits: bigint) { - super(); - assert(bits > 0); - // x % 2^n == x & (2^n - 1) - this.mod = 1n << bits; - this.bitmask = this.mod - 1n; - assert(n < this.mod); - } - - public abstract build(n: bigint): UnsignedInteger; - - public add(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - return this.build((this.n + rhs.n) & this.bitmask); - } - - public sub(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - const res: bigint = this.n - rhs.n; - return this.build(res >= 0 ? res : res + this.mod); - } - - public mul(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - return this.build((this.n * rhs.n) & this.bitmask); - } - - public div(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - return this.build(this.n / rhs.n); - } - - // No sign extension. - public shr(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - // Note that this.n is > 0 by class invariant. - return this.build(this.n >> rhs.n); - } - - public shl(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - return this.build((this.n << rhs.n) & this.bitmask); - } - - public and(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - return this.build(this.n & rhs.n); - } +/** + * This function creates a class for unsigned integers of a given number of bits. + * In TypeScript terms, it's a class mixin. + **/ +function UnsignedIntegerClassFactory(bits: number) { + return class NewUintClass extends IntegralValue { + static readonly mod: bigint = 1n << BigInt(bits); + static readonly bitmask: bigint = this.mod - 1n; + public readonly n: bigint; // Cannot be private due to TS limitations. + + public constructor(n: bigint | number) { + super(); + this.n = BigInt(n); + assert(n < NewUintClass.mod, `Value ${n} is too large for ${this.constructor.name}.`); + } - public or(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - return this.build(this.n | rhs.n); - } + public build(n: bigint): NewUintClass { + return new this.constructor.prototype.constructor(n); + } - public xor(rhs: UnsignedInteger): UnsignedInteger { - assert(this.bits == rhs.bits); - return this.build(this.n ^ rhs.n); - } + public add(rhs: NewUintClass): NewUintClass { + return this.build((this.n + rhs.n) & NewUintClass.bitmask); + } - public not(): UnsignedInteger { - return this.build(~this.n & this.bitmask); - } + public sub(rhs: NewUintClass): NewUintClass { + const res: bigint = this.n - rhs.n; + return this.build(res >= 0 ? res : res + NewUintClass.mod); + } - public equals(rhs: UnsignedInteger): boolean { - assert(this.bits == rhs.bits); - return this.n === rhs.n; - } + public mul(rhs: NewUintClass): NewUintClass { + return this.build((this.n * rhs.n) & NewUintClass.bitmask); + } - public lt(rhs: UnsignedInteger): boolean { - assert(this.bits == rhs.bits); - return this.n < rhs.n; - } + public div(rhs: NewUintClass): NewUintClass { + return this.build(this.n / rhs.n); + } - public toBigInt(): bigint { - return this.n; - } -} + // No sign extension. + public shr(rhs: NewUintClass): NewUintClass { + // Note that this.n is > 0 by class invariant. + return this.build(this.n >> rhs.n); + } -export class Uint8 extends UnsignedInteger { - constructor(n: number | bigint) { - super(BigInt(n), 8n); - } + public shl(rhs: NewUintClass): NewUintClass { + return this.build((this.n << rhs.n) & NewUintClass.bitmask); + } - public build(n: bigint): Uint8 { - return new Uint8(n); - } -} + public and(rhs: NewUintClass): NewUintClass { + return this.build(this.n & rhs.n); + } -export class Uint16 extends UnsignedInteger { - constructor(n: number | bigint) { - super(BigInt(n), 16n); - } + public or(rhs: NewUintClass): NewUintClass { + return this.build(this.n | rhs.n); + } - public build(n: bigint): Uint16 { - return new Uint16(n); - } -} + public xor(rhs: NewUintClass): NewUintClass { + return this.build(this.n ^ rhs.n); + } -export class Uint32 extends UnsignedInteger { - constructor(n: number | bigint) { - super(BigInt(n), 32n); - } + public not(): NewUintClass { + return this.build(~this.n & NewUintClass.bitmask); + } - public build(n: bigint): Uint32 { - return new Uint32(n); - } -} + public equals(rhs: NewUintClass): boolean { + return this.n === rhs.n; + } -export class Uint64 extends UnsignedInteger { - constructor(n: number | bigint) { - super(BigInt(n), 64n); - } + public lt(rhs: NewUintClass): boolean { + return this.n < rhs.n; + } - public build(n: bigint): Uint64 { - return new Uint64(n); - } + public toBigInt(): bigint { + return this.n; + } + }; } -export class Uint128 extends UnsignedInteger { - constructor(n: number | bigint) { - super(BigInt(n), 128n); - } - - public build(n: bigint): Uint128 { - return new Uint128(n); - } -} +// Now we can create the classes for each unsigned integer type. +// We extend instead of just assigning so that the class has the right name. +// Otherwise they are all called "NewUintClass". +export class Uint8 extends UnsignedIntegerClassFactory(8) {} +export class Uint16 extends UnsignedIntegerClassFactory(16) {} +export class Uint32 extends UnsignedIntegerClassFactory(32) {} +export class Uint64 extends UnsignedIntegerClassFactory(64) {} +export class Uint128 extends UnsignedIntegerClassFactory(128) {} export class Field extends MemoryValue { public static readonly MODULUS: bigint = Fr.MODULUS; @@ -335,7 +290,7 @@ export class TaggedMemory { // Truncates the value to fit the type. public static integralFromTag(v: bigint | number, tag: TypeTag): IntegralValue { - v = v as bigint; + v = BigInt(v); switch (tag) { case TypeTag.UINT8: return new Uint8(v & ((1n << 8n) - 1n));