Skip to content

Commit

Permalink
fix: add bigint support for inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
immortal-tofu committed Mar 1, 2024
1 parent 117ce9c commit a798189
Show file tree
Hide file tree
Showing 5 changed files with 7,200 additions and 7,178 deletions.
91 changes: 51 additions & 40 deletions codegen/generateOverloads.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
type Test = {
inputs: number[];
inputs: bigint[];
output: number | boolean | bigint;
};

Expand All @@ -20,11 +20,11 @@ type SupportedFunction = SupportedFunctionParams &
(
| {
unary?: false;
evalTest: (lhsNumber: number, rhsNumber: number, lhs: number, rhs: number) => number | boolean | bigint;
evalTest: (lhsNumber: bigint, rhsNumber: bigint, lhs: number, rhs: number) => number | boolean | bigint;
}
| {
unary: true;
evalTest: (lhs: number, bits: number) => number | boolean | bigint;
evalTest: (lhs: bigint, bits: number) => number | boolean | bigint;
}
);

Expand All @@ -35,14 +35,25 @@ type SupportedFunction = SupportedFunctionParams &
const SUPPORTED_UINT = [8, 16, 32, 64, 128, 256];
const SUPPORTED_BITS = [4, 8, 16, 32, 64];

const bigIntMin = (...args: bigint[]) => {
return args.reduce((min, e) => (e < min ? e : min), args[0]);
};

const bigIntMax = (...args: bigint[]) => {
return args.reduce((max, e) => (e > max ? e : max), args[0]);
};

const generateNumber = (bits: number) => {
return Math.max(Math.floor(Math.random() * (Math.pow(2, Math.min(bits, 28)) - 1)), 1);
const power = BigInt(Math.pow(2, bits) - 1);
const maxRange = bigIntMin(power, BigInt(Number.MAX_SAFE_INTEGER));
const divider = bigIntMax(BigInt(Math.floor(Math.random() * Number(maxRange))), 1n);
return bigIntMax(power / divider, 1n);
};

const safeEval = (
fn: (lhsNumber: number, rhsNumber: number, lhs: number, rhs: number) => number | boolean | bigint,
lhsNumber: number,
rhsNumber: number,
fn: (lhsNumber: bigint, rhsNumber: bigint, lhs: number, rhs: number) => number | boolean | bigint,
lhsNumber: bigint,
rhsNumber: bigint,
lhs: number,
rhs: number,
safeMin: boolean = false,
Expand All @@ -52,10 +63,10 @@ const safeEval = (
const logs: any[] = [];
if (typeof result === 'number' || typeof result === 'bigint') {
while ((result as number | bigint) > Math.pow(2, bitResults) - 1) {
lhsNumber = Math.max(Math.floor(lhsNumber / 2), 1);
rhsNumber = Math.max(Math.floor(rhsNumber / 2), 1);
lhsNumber = lhsNumber / 2n + 1n;
rhsNumber = rhsNumber / 2n + 1n;
result = fn(lhsNumber, rhsNumber, lhs, rhs);
logs.push([lhsNumber, rhsNumber, result]);
logs.push([lhs, rhs, lhsNumber, rhsNumber, result]);
}
}
return { inputs: [lhsNumber, rhsNumber], output: result };
Expand All @@ -65,59 +76,59 @@ export const SUPPORTED_FUNCTIONS: SupportedFunctions = {
add: {
supportedBits: SUPPORTED_BITS,
safeMin: true,
evalTest: (lhsNumber: number, rhsNumber: number) => BigInt(lhsNumber) + BigInt(rhsNumber),
evalTest: (lhsNumber, rhsNumber) => BigInt(lhsNumber) + BigInt(rhsNumber),
},
sub: {
supportedBits: SUPPORTED_BITS,
lhsHigher: true,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber - rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber - rhsNumber,
},
mul: {
supportedBits: SUPPORTED_BITS,
safeMin: true,
evalTest: (lhsNumber: number, rhsNumber: number) => BigInt(lhsNumber) * BigInt(rhsNumber),
evalTest: (lhsNumber, rhsNumber) => BigInt(lhsNumber) * BigInt(rhsNumber),
},
div: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => Math.floor(lhsNumber / rhsNumber),
evalTest: (lhsNumber, rhsNumber) => lhsNumber / rhsNumber,
scalarOnly: true,
},
rem: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber % rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber % rhsNumber,
scalarOnly: true,
},
le: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber <= rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber <= rhsNumber,
},
lt: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber < rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber < rhsNumber,
},
ge: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber >= rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber >= rhsNumber,
},
gt: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber > rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber > rhsNumber,
},
eq: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber === rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber === rhsNumber,
},
ne: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber !== rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber !== rhsNumber,
},
shl: {
supportedBits: SUPPORTED_BITS,
limit: 'bits',
evalTest: (lhsNumber: number, rhsNumber: number, lhs: number, rhs: number) => {
evalTest: (lhsNumber, rhsNumber, lhs, rhs) => {
const bits = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-lhs).split('');
const r = bits.map((_, index) => {
const newIndex = index + (rhsNumber % lhs);
const newIndex = Number(BigInt(index) + (rhsNumber % BigInt(lhs)));
return newIndex >= bits.length ? '0' : bits[newIndex];
});
return parseInt(r.join(''), 2);
Expand All @@ -126,10 +137,10 @@ export const SUPPORTED_FUNCTIONS: SupportedFunctions = {
shr: {
supportedBits: SUPPORTED_BITS,
limit: 'bits',
evalTest: (lhsNumber: number, rhsNumber: number, lhs: number, rhs: number) => {
evalTest: (lhsNumber, rhsNumber, lhs, rhs) => {
const bits = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-lhs).split('');
const r = bits.map((_, index) => {
const newIndex = index - (rhsNumber % lhs);
const newIndex = Number(BigInt(index) - (rhsNumber % BigInt(lhs)));
return newIndex < 0 ? '0' : bits[newIndex];
});
return parseInt(r.join(''), 2);
Expand All @@ -138,31 +149,31 @@ export const SUPPORTED_FUNCTIONS: SupportedFunctions = {
max: {
supportedBits: SUPPORTED_BITS,
unary: false,
evalTest: (lhsNumber: number, rhsNumber: number) => (lhsNumber > rhsNumber ? lhsNumber : rhsNumber),
evalTest: (lhsNumber, rhsNumber) => (lhsNumber > rhsNumber ? lhsNumber : rhsNumber),
},
min: {
supportedBits: SUPPORTED_BITS,
evalTest: (lhsNumber: number, rhsNumber: number) => (lhsNumber < rhsNumber ? lhsNumber : rhsNumber),
evalTest: (lhsNumber, rhsNumber) => (lhsNumber < rhsNumber ? lhsNumber : rhsNumber),
},
or: {
supportedBits: SUPPORTED_BITS,
noScalar: true,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber | rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber | rhsNumber,
},
and: {
supportedBits: SUPPORTED_BITS,
noScalar: true,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber & rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber & rhsNumber,
},
xor: {
supportedBits: SUPPORTED_BITS,
noScalar: true,
evalTest: (lhsNumber: number, rhsNumber: number) => lhsNumber ^ rhsNumber,
evalTest: (lhsNumber, rhsNumber) => lhsNumber ^ rhsNumber,
},
not: {
supportedBits: SUPPORTED_BITS,
unary: true,
evalTest: (lhsNumber: number, bits: number) => {
evalTest: (lhsNumber, bits) => {
const val = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-bits).split('');
return BigInt(
`0b${val
Expand All @@ -177,7 +188,7 @@ export const SUPPORTED_FUNCTIONS: SupportedFunctions = {
neg: {
supportedBits: SUPPORTED_BITS,
unary: true,
evalTest: (lhsNumber: number, bits: number) => {
evalTest: (lhsNumber, bits) => {
const val = `${new Array(256).fill('0').join('')}${lhsNumber.toString(2)}`.slice(-bits).split('');
return (
BigInt(
Expand Down Expand Up @@ -212,9 +223,9 @@ export const generateTests = () => {
const bitResults = Math.min(lhs, rhs);
let rhsNumber = generateNumber(rhs);
if (test.limit === 'bits') {
rhsNumber = 1 + Math.floor(Math.random() * (rhs - 1));
rhsNumber = BigInt(1 + Math.floor(Math.random() * (rhs - 1)));
}
const smallest = Math.max(Math.min(lhsNumber, rhsNumber), 8);
const smallest = bigIntMax(bigIntMin(lhsNumber, rhsNumber), 8n);
const only8bits = test.limit === 'bits' && rhs === 8;
const onlyEncrypted8bits = only8bits && lhs > 4;

Expand All @@ -223,10 +234,10 @@ export const generateTests = () => {
const encryptedTests: Test[] = [];
if (!test.lhsHigher) {
encryptedTests.push(safeEval(test.evalTest, lhsNumber, rhsNumber, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest - 4, smallest, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest - 4n, smallest, lhs, rhs, test.safeMin));
}
encryptedTests.push(safeEval(test.evalTest, smallest, smallest, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest, smallest - 4, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest, smallest - 4n, lhs, rhs, test.safeMin));
tests[encryptedTestName] = encryptedTests;
}

Expand All @@ -240,10 +251,10 @@ export const generateTests = () => {
const encryptedTests: Test[] = [];
if (!test.lhsHigher) {
encryptedTests.push(safeEval(test.evalTest, lhsNumber, rhsNumber, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest - 4, smallest, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest - 4n, smallest, lhs, rhs, test.safeMin));
}
encryptedTests.push(safeEval(test.evalTest, smallest, smallest, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest, smallest - 4, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest, smallest - 4n, lhs, rhs, test.safeMin));
tests[encryptedTestName] = encryptedTests;
}
if (SUPPORTED_UINT.includes(lhs) && test.limit !== 'bits' && scalarCondition && !test.scalarOnly) {
Expand All @@ -252,10 +263,10 @@ export const generateTests = () => {
const encryptedTests: Test[] = [];
if (!test.lhsHigher) {
encryptedTests.push(safeEval(test.evalTest, lhsNumber, rhsNumber, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest - 4, smallest, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest - 4n, smallest, lhs, rhs, test.safeMin));
}
encryptedTests.push(safeEval(test.evalTest, smallest, smallest, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest, smallest - 4, lhs, rhs, test.safeMin));
encryptedTests.push(safeEval(test.evalTest, smallest, smallest - 4n, lhs, rhs, test.safeMin));
tests[encryptedTestName] = encryptedTests;
}
});
Expand Down
19 changes: 15 additions & 4 deletions codegen/overloadTests.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import overloads from './overloads.json';
import { OverloadSignature, signatureContractMethodName } from './testgen';

type OverloadTest = {
inputs: number[];
type OverloadTestJSON = {
inputs: (number | bigint | string)[];
output: boolean | number | bigint | string;
};
const transformBigInt = (o: { [methodName: string]: OverloadTest[] }) => {

type OverloadTest = {
inputs: (number | bigint)[];
output: boolean | number | bigint;
};

const transformBigInt = (o: { [methodName: string]: OverloadTestJSON[] }) => {
Object.keys(o).forEach((k) => {
o[k].forEach((test) => {
test.inputs.forEach((input, i) => {
if (typeof input === 'string') test.inputs[i] = BigInt(input);
});
if (typeof test.output === 'string') test.output = BigInt(test.output);
});
});
};

transformBigInt(overloads);

export const overloadTests: { [methodName: string]: OverloadTest[] } = overloads;
export const overloadTests: { [methodName: string]: OverloadTest[] } = overloads as unknown as {
[methodName: string]: OverloadTest[];
};
Loading

0 comments on commit a798189

Please sign in to comment.