Skip to content

Commit

Permalink
chore: reduce revision overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
penovicp committed Jan 22, 2024
1 parent cd632bb commit d9de436
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 60 deletions.
4 changes: 2 additions & 2 deletions __mocks__/typedData/example_baseTypes.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"domain": {
"name": "StarkNet Mail",
"version": "1",
"chainId": 1,
"revision": 1
"chainId": "1",
"revision": "1"
},
"message": {
"n0": "0x3e8",
Expand Down
4 changes: 2 additions & 2 deletions __mocks__/typedData/example_enum.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"domain": {
"name": "StarkNet Mail",
"version": "1",
"chainId": 1,
"revision": 1
"chainId": "1",
"revision": "1"
},
"message": {
"someEnum": {
Expand Down
4 changes: 2 additions & 2 deletions __mocks__/typedData/example_presetTypes.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"domain": {
"name": "StarkNet Mail",
"version": "1",
"chainId": 1,
"revision": 1
"chainId": "1",
"revision": "1"
},
"message": {
"n0": {
Expand Down
4 changes: 2 additions & 2 deletions __tests__/utils/typedData.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,14 @@ describe('typedData', () => {
});

test('should hash messages with revision 1 types', () => {
// necessary to mock dependecy since function mocks (hash.computePedersenHash; hash.computePoseidonHash) won't work
// necessary to spy dependecy since function spies (hash.computePedersenHash; hash.computePoseidonHash) won't work
const spyPedersen = jest.spyOn(starkCurve, 'pedersen');
const spyPoseidon = jest.spyOn(starkCurve, 'poseidonHashMany');

let messageHash: string;
messageHash = getMessageHash(exampleBaseTypes, exampleAddress);
expect(messageHash).toMatchInlineSnapshot(
`"0x458a6a7cd8b781412b0bdf25d77cc7012821156f250c14b9fedb418ee2b200a"`
`"0x790d9fa99cf9ad91c515aaff9465fcb1c87784d9cfb27271ed193675cd06f9c"`
);

messageHash = getMessageHash(examplePresetTypes, exampleAddress);
Expand Down
2 changes: 1 addition & 1 deletion src/types/typedData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export interface StarkNetDomain extends Record<string, unknown> {
name?: string;
version?: string;
chainId?: string | number;
revision?: string | number;
revision?: string;
}

/**
Expand Down
159 changes: 108 additions & 51 deletions src/utils/typedData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,60 @@ interface Configuration {
domain: string;
hashMethod: (data: BigNumberish[]) => string;
escapeTypeString: (s: string) => string;
presetTypes: TypedData['types'];
}

const presetTypes: TypedData['types'] = {
u256: JSON.parse('[{ "name": "low", "type": "u128" }, { "name": "high", "type": "u128" }]'),
TokenAmount: JSON.parse(
'[{ "name": "token_address", "type": "ContractAddress" }, { "name": "amount", "type": "u256" }]'
),
NftId: JSON.parse(
'[{ "name": "collection_address", "type": "ContractAddress" }, { "name": "token_id", "type": "u256" }]'
),
};

const revisionConfiguration: Record<Revision, Configuration> = {
[Revision.Active]: {
domain: 'StarknetDomain',
hashMethod: computePoseidonHash,
escapeTypeString: (s) => `"${s}"`,
presetTypes,
},
[Revision.Legacy]: {
domain: 'StarkNetDomain',
hashMethod: computePedersenHash,
escapeTypeString: (s) => s,
presetTypes: {},
},
};

const presetTypes: TypedData['types'] = {
u256: JSON.parse('[{ "name": "low", "type": "u128" }, { "name": "high", "type": "u128" }]'),
TokenAmount: JSON.parse(
'[{ "name": "token_address", "type": "ContractAddress" }, { "name": "amount", "type": "u256" }]'
),
NftId: JSON.parse(
'[{ "name": "collection_address", "type": "ContractAddress" }, { "name": "token_id", "type": "u256" }]'
),
};
// TODO: replace with utils byteArrayFromString from PR#891 once it is available
export function byteArrayFromString(targetString: string) {
const shortStrings: string[] = splitLongString(targetString);
const remainder: string = shortStrings[shortStrings.length - 1];
const shortStringsEncoded: BigNumberish[] = shortStrings.map(encodeShortString);

const [pendingWord, pendingWordLength] =
remainder === undefined || remainder.length === 31
? ['0x00', 0]
: [shortStringsEncoded.pop()!, remainder.length];

return {
data: shortStringsEncoded.length === 0 ? ['0x00'] : shortStringsEncoded,
pending_word: pendingWord,
pending_word_len: pendingWordLength,
};
}

function identifyRevision({ types, domain }: TypedData) {
if (revisionConfiguration[Revision.Active].domain in types && domain.revision !== Revision.Legacy)
if (revisionConfiguration[Revision.Active].domain in types && domain.revision === Revision.Active)
return Revision.Active;

if (revisionConfiguration[Revision.Legacy].domain in types && domain.revision !== Revision.Active)
if (
revisionConfiguration[Revision.Legacy].domain in types &&
(domain.revision ?? Revision.Legacy) === Revision.Legacy
)
return Revision.Legacy;

return undefined;
Expand Down Expand Up @@ -96,35 +120,33 @@ export function getDependencies(
types: TypedData['types'],
type: string,
dependencies: string[] = [],
contains: string = ''
contains: string = '',
revision: Revision = Revision.Legacy
): string[] {
// Include pointers (struct arrays)
if (type[type.length - 1] === '*') {
type = type.slice(0, -1);
}
// enum base
else if (type === 'enum') {
type = contains;
}
// enum element types
else if (type.match(/^\(.*\)$/)) {
type = type.slice(1, -1);
} else if (revision === Revision.Active) {
// enum base
if (type === 'enum') {
type = contains;
}
// enum element types
else if (type.match(/^\(.*\)$/)) {
type = type.slice(1, -1);
}
}

if (dependencies.includes(type) || !types[type]) {
return dependencies;
}

if (!types[type]) {
return dependencies;
}

return [
type,
...types[type].reduce<string[]>(
...(types[type] as StarkNetEnumType[]).reduce<string[]>(
(previous, t) => [
...previous,
...getDependencies(types, t.type, previous, (t as StarkNetEnumType).contains).filter(
...getDependencies(types, t.type, previous, t.contains, revision).filter(
(dependency) => !previous.includes(dependency)
),
],
Expand Down Expand Up @@ -157,15 +179,18 @@ export function encodeType(
type: string,
revision: Revision = Revision.Legacy
): string {
const [primary, ...dependencies] = getDependencies(types, type);
const [primary, ...dependencies] = getDependencies(types, type, undefined, undefined, revision);
const newTypes = !primary ? [] : [primary, ...dependencies.sort()];

const esc = revisionConfiguration[revision].escapeTypeString;

return newTypes
.map((dependency) => {
const dependencyElements = types[dependency].map((t) => {
const targetType = t.type === 'enum' ? (t as StarkNetEnumType).contains : t.type;
const targetType =
t.type === 'enum' && revision === Revision.Active
? (t as StarkNetEnumType).contains
: t.type;
// parentheses handling for enum variant types
const typeString = targetType.match(/^\(.*\)$/)
? `(${targetType
Expand Down Expand Up @@ -207,8 +232,16 @@ export function encodeValue(
return [type, getStructHash(types, type, data as TypedData['message'], revision)];
}

if (presetTypes[type]) {
return [type, getStructHash(presetTypes, type, data as TypedData['message'], revision)];
if (revisionConfiguration[revision].presetTypes[type]) {
return [
type,
getStructHash(
revisionConfiguration[revision].presetTypes,
type,
data as TypedData['message'],
revision
),
];
}

if (type.endsWith('*')) {
Expand All @@ -218,29 +251,36 @@ export function encodeValue(

const mappingMethod: (entry: Record<string, unknown>) => string = isStructArray
? (entry) => getStructHash(types, type.slice(0, -1), entry, revision)
: (entry) => encodeValue(types, type.slice(0, 1), entry, undefined, revision)[1];
: (entry) => encodeValue(types, type.slice(0, -1), entry, undefined, revision)[1];

const hashes: string[] = (data as Array<TypedData['message']>).map(mappingMethod);
return [type, revisionConfiguration[revision].hashMethod(hashes)];
}

switch (type) {
case 'enum': {
const [variantKey, variantData] = Object.entries(data as TypedData['message'])[0];

const parentType = types[ctx.parent as string][0] as StarkNetEnumType;
const enumType = types[parentType.contains];
const variantType = enumType.find((t) => t.name === variantKey) as StarkNetType;
const variantIndex = enumType.indexOf(variantType);

const encodedSubtypes = variantType.type
.slice(1, -1)
.split(',')
.map((subtype, index) => {
const subtypeData = (variantData as unknown[])[index];
return encodeValue(types, subtype, subtypeData, undefined, revision)[1];
});
return [type, revisionConfiguration[revision].hashMethod([variantIndex, ...encodedSubtypes])];
if (revision === Revision.Active) {
const [variantKey, variantData] = Object.entries(data as TypedData['message'])[0];

const parentType = types[ctx.parent as string][0] as StarkNetEnumType;
const enumType = types[parentType.contains];
const variantType = enumType.find((t) => t.name === variantKey) as StarkNetType;
const variantIndex = enumType.indexOf(variantType);

const encodedSubtypes = variantType.type
.slice(1, -1)
.split(',')
.map((subtype, index) => {
if (!subtype) return subtype;
const subtypeData = (variantData as unknown[])[index];
return encodeValue(types, subtype.trim(), subtypeData, undefined, revision)[1];
});
return [
type,
revisionConfiguration[revision].hashMethod([variantIndex, ...encodedSubtypes]),
];
} // else fall through to default
return [type, getHex(data as string)];
}
case 'merkletree': {
const merkleTreeType = getMerkleTreeType(types, ctx);
Expand All @@ -255,13 +295,30 @@ export function encodeValue(
}
case 'string': {
if (revision === Revision.Active) {
// TODO: should this be skipped for v5 and added as ByteArray in v6?
const elements = splitLongString(data as string).map(getHex);
const byteArray = byteArrayFromString(data as string);
const elements = [
byteArray.data.length,
...byteArray.data,
byteArray.pending_word,
byteArray.pending_word_len,
];
return [type, revisionConfiguration[revision].hashMethod(elements)];
}
} // else fall through to default
return [type, getHex(data as string)];
}
case 'felt':
case 'bool':
case 'u128':
case 'i128':
case 'ContractAddress':
case 'ClassHash':
case 'timestamp':
case 'shortstring':
return [type, getHex(data as string)];
default: {
if (revision === Revision.Active) {
throw new Error(`Unsupported type: ${type}`);
}
return [type, getHex(data as string)];
}
}
Expand All @@ -277,7 +334,7 @@ export function encodeData<T extends TypedData>(
data: T['message'],
revision: Revision = Revision.Legacy
) {
const targetType = types[type] ?? presetTypes[type];
const targetType = types[type] ?? revisionConfiguration[revision].presetTypes[type];
const [returnTypes, values] = targetType.reduce<[string[], string[]]>(
([ts, vs], field) => {
if (data[field.name] === undefined || (data[field.name] === null && field.type !== 'enum')) {
Expand Down Expand Up @@ -313,7 +370,7 @@ export function getStructHash<T extends TypedData>(
}

/**
* Get the EIP-191|EIP-712 encoded message to sign, from the typedData object.
* Get the SNIP-12 encoded message to sign, from the typedData object.
*/
export function getMessageHash(typedData: TypedData, account: BigNumberish): string {
if (!validateTypedData(typedData)) {
Expand Down

0 comments on commit d9de436

Please sign in to comment.