Skip to content

Commit

Permalink
fix: session account prooving
Browse files Browse the repository at this point in the history
  • Loading branch information
janek26 committed Aug 19, 2022
1 parent 11e10bd commit 0b56833
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 26 deletions.
1 change: 1 addition & 0 deletions __tests__/utils/merkle.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ describe('MerkleTree class', () => {
const proof = tree.getProof('0x7');

const manualProof = [
'0x0', // proofs should always be as long as the tree is deep
MerkleTree.hash('0x5', '0x6'),
MerkleTree.hash(MerkleTree.hash('0x1', '0x2'), MerkleTree.hash('0x3', '0x4')),
];
Expand Down
36 changes: 29 additions & 7 deletions src/account/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
import { feeTransactionVersion, transactionVersion } from '../utils/hash';
import { MerkleTree } from '../utils/merkle';
import { BigNumberish, toBN } from '../utils/number';
import { SignedSession, createMerkleTreeForPolicies } from '../utils/session';
import { SignedSession, createMerkleTreeForPolicies, preparePolicy } from '../utils/session';
import { compileCalldata, estimatedFeeToMaxFee } from '../utils/stark';
import { fromCallsToExecuteCalldataWithNonce } from '../utils/transaction';
import { Account } from './default';
Expand All @@ -39,24 +39,44 @@ export class SessionAccount extends Account implements AccountInterface {
assert(signedSession.root === this.merkleTree.root, 'Invalid session');
}

private async sessionToCall(session: SignedSession): Promise<Call> {
private async sessionToCall(session: SignedSession, proofs: string[][]): Promise<Call> {
return {
contractAddress: this.address,
entrypoint: 'use_plugin',
calldata: compileCalldata({
SESSION_PLUGIN_CLASS_HASH,
classHash: SESSION_PLUGIN_CLASS_HASH,
signer: await this.signer.getPubKey(),
expires: session.expires.toString(),
root: session.root,
proofLength: proofs[0].length.toString(),
...proofs.reduce(
(acc, proof, i) => ({
...acc,
...proof.reduce((acc2, path, j) => ({ ...acc2, [`proof${i}:${j}`]: path }), {}),
}),
{}
),

token1: session.signature[0],
token2: session.signature[1],
root: session.root,
proof: [],
}),
};
}

private proofCalls(calls: Call[]): string[][] {
return calls.map((call) => {
const leaf = preparePolicy({
contractAddress: call.contractAddress,
selector: call.entrypoint,
});
return this.merkleTree.getProof(leaf);
});
}

private async extendCallsBySession(calls: Call[], session: SignedSession): Promise<Call[]> {
return [await this.sessionToCall(session), ...calls];
const proofs = this.proofCalls(calls);
const pluginCall = await this.sessionToCall(session, proofs);
return [pluginCall, ...calls];
}

public async estimateFee(
Expand Down Expand Up @@ -119,7 +139,9 @@ export class SessionAccount extends Account implements AccountInterface {
if (transactionsDetail.maxFee || transactionsDetail.maxFee === 0) {
maxFee = transactionsDetail.maxFee;
} else {
const { suggestedMaxFee } = await this.estimateFee(transactions, { nonce });
const { suggestedMaxFee } = await this.estimateFee(Array.isArray(calls) ? calls : [calls], {
nonce,
});
maxFee = suggestedMaxFee.toString();
}

Expand Down
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export * as number from './utils/number';
export * as transaction from './utils/transaction';
export * as stark from './utils/stark';
export * as merkle from './utils/merkle';
export * as session from './utils/session';
export * as ec from './utils/ellipticCurve';
export * as uint256 from './utils/uint256';
export * as shortString from './utils/shortString';
Expand Down
12 changes: 6 additions & 6 deletions src/utils/merkle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,27 @@ export class MerkleTree {
}

public getProof(leaf: string, branch = this.leaves, hashPath: string[] = []): string[] {
if (branch.length === 1) {
return hashPath;
}
const index = branch.indexOf(leaf);
if (index === -1) {
throw new Error('leaf not found');
}
if (branch.length === 1) {
return hashPath;
}
const isLeft = index % 2 === 0;
const neededBranch = (isLeft ? branch[index + 1] : branch[index - 1]) ?? branch[index];
const neededBranch = (isLeft ? branch[index + 1] : branch[index - 1]) ?? '0x0';
const newHashPath = [...hashPath, neededBranch];
const currentBranchLevelIndex =
this.leaves.length === branch.length
? -1
: this.branches.findIndex((b) => b.length === branch.length);
const nextBranch = this.branches[currentBranchLevelIndex + 1] ?? [this.root];
return this.getProof(
neededBranch === leaf
neededBranch === '0x0'
? leaf
: MerkleTree.hash(isLeft ? leaf : neededBranch, isLeft ? neededBranch : leaf),
nextBranch,
neededBranch === leaf ? hashPath : newHashPath
newHashPath
);
}
}
Expand Down
42 changes: 33 additions & 9 deletions src/utils/session.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import type { AccountInterface } from '../account';
import { StarknetChainId } from '../constants';
import { ProviderInterface } from '../provider';
import { Signature } from '../types';
import { pedersen } from './hash';
import { computeHashOnElements } from './hash';
import { MerkleTree } from './merkle';
import { StarkNetDomain, prepareSelector } from './typedData';
import { toBN } from './number';
import { compileCalldata } from './stark';
import { prepareSelector } from './typedData';

interface Policy {
contractAddress: string;
Expand All @@ -23,8 +27,25 @@ export interface SignedSession extends PreparedSession {
signature: Signature;
}

function preparePolicy({ contractAddress, selector }: Policy): string {
return pedersen([contractAddress, prepareSelector(selector)]);
export const SESSION_PLUGIN_CLASS_HASH =
'0x1031d8540af9d984d8d8aa5dff598467008c58b6f6147b7f90fda4b6d8db463';
// H(Policy(contractAddress:felt,selector:selector))
const POLICY_TYPE_HASH = '0x2f0026e78543f036f33e26a8f5891b88c58dc1e20cbbfaf0bb53274da6fa568';

export async function supportsSessions(
address: string,
provider: ProviderInterface
): Promise<boolean> {
const { result } = await provider.callContract({
contractAddress: address,
entrypoint: 'is_plugin',
calldata: compileCalldata({ classHash: SESSION_PLUGIN_CLASS_HASH }),
});
return !toBN(result[0]).isZero();
}

export function preparePolicy({ contractAddress, selector }: Policy): string {
return computeHashOnElements([POLICY_TYPE_HASH, contractAddress, prepareSelector(selector)]);
}

export function createMerkleTreeForPolicies(policies: Policy[]): MerkleTree {
Expand All @@ -38,8 +59,7 @@ export function prepareSession(session: RequestSession): PreparedSession {

export async function createSession(
session: RequestSession,
account: AccountInterface,
domain: StarkNetDomain = {}
account: AccountInterface
): Promise<SignedSession> {
const { expires, key, policies, root } = prepareSession(session);
const signature = await account.signMessage({
Expand All @@ -52,15 +72,19 @@ export async function createSession(
Session: [
{ name: 'key', type: 'felt' },
{ name: 'expires', type: 'felt' },
{ name: 'root', type: 'merkletree', contains: 'Policy*' },
{ name: 'root', type: 'merkletree', contains: 'Policy' },
],
StarkNetDomain: [
{ name: 'name', type: 'felt' },
{ name: 'version', type: 'felt' },
{ name: 'chain_id', type: 'felt' },
{ name: 'chainId', type: 'felt' },
],
},
domain,
domain: {
name: '0x0',
version: '0x0',
chainId: StarknetChainId.TESTNET,
},
message: {
key,
expires,
Expand Down
7 changes: 3 additions & 4 deletions src/utils/typedData/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ function getMerkleTreeType(types: TypedData['types'], ctx: Context) {
if (!isMerkleTree) {
throw new Error(`${ctx.key} is not a merkle tree`);
}
if (merkleType.contains.endsWith('*')) {
throw new Error(`Merkle tree contain property must not be an array but was given ${ctx.key}`);
}
return merkleType.contains;
}
return 'raw';
Expand Down Expand Up @@ -161,10 +164,6 @@ export const encodeValue = (
return ['felt*', computeHashOnElements(data as string[])];
}

if (type === 'raw') {
return ['felt', data as string];
}

if (type === 'selector') {
return ['felt', prepareSelector(data as string)];
}
Expand Down

0 comments on commit 0b56833

Please sign in to comment.