Skip to content

Commit

Permalink
fix keychain deserialize (#735)
Browse files Browse the repository at this point in the history
Co-authored-by: brdy <41711440+BrodyHughes@users.noreply.github.com>
  • Loading branch information
greg-schrammel and BrodyHughes authored Jul 10, 2023
1 parent d3cd283 commit d7da33f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 65 deletions.
15 changes: 9 additions & 6 deletions src/core/keychain/keychainTypes/hardwareWalletKeychain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ export class HardwareWalletKeychain implements IKeychain {
};
}

async deserialize(opts: SerializedHardwareWalletKeychain) {
if (opts?.hdPath) privates.get(this).hdPath = opts.hdPath;
if (opts?.deviceId) privates.get(this).deviceId = opts.deviceId;
if (opts?.wallets) privates.get(this).wallets = opts.wallets;
if (opts?.vendor) this.vendor = opts.vendor;
if (opts?.accountsEnabled)
async deserialize(opts?: SerializedHardwareWalletKeychain) {
if (!opts) return;
if (opts.hdPath) privates.get(this).hdPath = opts.hdPath;
if (opts.deviceId) privates.get(this).deviceId = opts.deviceId;
if (opts.wallets) privates.get(this).wallets = opts.wallets;
if (opts.vendor) this.vendor = opts.vendor;
if (opts.accountsEnabled)
privates.get(this).accountsEnabled = opts.accountsEnabled;
if (opts.accountsDeleted?.length)
privates.get(this).accountsDeleted = opts.accountsDeleted;
}

getAccounts(): Promise<Array<Address>> {
Expand Down
138 changes: 79 additions & 59 deletions src/core/keychain/keychainTypes/hdKeychain.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { Signer } from '@ethersproject/abstract-signer';
import { HDNode } from '@ethersproject/hdnode';
import { Wallet } from '@ethersproject/wallet';
Expand All @@ -8,17 +9,36 @@ import { KeychainType } from '~/core/types/keychainTypes';
import { IKeychain, PrivateKey } from '../IKeychain';
import { autoDiscoverAccounts } from '../utils';

type SupportedHDPath = "m/44'/60'/0'/0";

export interface SerializedHdKeychain {
mnemonic: string;
hdPath?: string;
hdPath?: SupportedHDPath;
accountsEnabled?: number;
type: string;
imported?: boolean;
autodiscover?: boolean;
accountsDeleted?: Array<number>;
}

const privates = new WeakMap();
type TWallet = Omit<Wallet, 'address' | 'privateKey'> & {
address: Address;
privateKey: PrivateKey;
};

const privates = new WeakMap<
IKeychain,
{
wallets: Array<{ wallet: TWallet; index: number }>;
mnemonic: string | null;
accountsEnabled: number;
accountsDeleted: number[];
hdPath: SupportedHDPath;
getWalletForAddress(address: Address): Wallet | undefined;
deriveWallet(index: number): HDNode;
addAccount(index: number): Wallet;
}
>();

export class HdKeychain implements IKeychain {
type: string;
Expand All @@ -34,87 +54,90 @@ export class HdKeychain implements IKeychain {
accountsEnabled: 1,
accountsDeleted: [],
hdPath: "m/44'/60'/0'/0",
getWalletForAddress: (address: Address): Wallet => {
getWalletForAddress: (address: Address) => {
return privates
.get(this)
.get(this)!
.wallets.find(
(wallet: Wallet) =>
(wallet as Wallet).address.toLowerCase() ===
address.toLowerCase(),
) as Wallet;
({ wallet }) =>
wallet.address.toLowerCase() === address.toLowerCase(),
)?.wallet;
},
deriveWallet: (index: number): HDNode => {
const hdNode = HDNode.fromMnemonic(
privates.get(this).mnemonic as string,
);
const derivedWallet = hdNode.derivePath(
`${privates.get(this).hdPath}/${index}`,
);
const _privates = privates.get(this)!;
if (!_privates.mnemonic) throw new Error('No mnemonic');

const hdNode = HDNode.fromMnemonic(_privates.mnemonic);
const derivedWallet = hdNode.derivePath(`${_privates.hdPath}/${index}`);
return derivedWallet;
},

addAccount: (index: number): Wallet => {
const derivedWallet = privates.get(this).deriveWallet(index);
const wallet = new Wallet(derivedWallet.privateKey);
privates.get(this).wallets.push(wallet);
const _privates = privates.get(this)!;
const derivedWallet = _privates.deriveWallet(index);
const wallet = new Wallet(derivedWallet.privateKey) as TWallet;
_privates.wallets.push({ wallet, index: derivedWallet.index });
return wallet;
},
});
}

init(options: SerializedHdKeychain) {
init(options?: SerializedHdKeychain) {
return this.deserialize(options);
}

getSigner(address: Address): Signer {
const wallet = privates.get(this).getWalletForAddress(address);
const _privates = privates.get(this)!;
const wallet = _privates.getWalletForAddress(address);
if (!wallet?._isSigner) throw new Error('Not a signer');
return wallet;
}

async serialize(): Promise<SerializedHdKeychain> {
const _privates = privates.get(this)!;
if (!_privates.mnemonic) throw new Error('No mnemonic');
return {
imported: this.imported,
mnemonic: privates.get(this).mnemonic as string,
accountsEnabled: privates.get(this).accountsEnabled,
hdPath: privates.get(this).hdPath,
mnemonic: _privates.mnemonic,
accountsEnabled: _privates.accountsEnabled,
hdPath: _privates.hdPath,
type: this.type,
accountsDeleted: privates.get(this).accountsDeleted,
accountsDeleted: _privates.accountsDeleted,
};
}

async deserialize(opts: SerializedHdKeychain) {
if (opts?.hdPath) privates.get(this).hdPath = opts.hdPath;
if (opts?.imported) this.imported = opts.imported;
if (opts?.accountsEnabled)
privates.get(this).accountsEnabled = opts.accountsEnabled;

if (opts?.mnemonic) {
privates.get(this).mnemonic = opts.mnemonic;
} else {
privates.get(this).mnemonic = Wallet.createRandom().mnemonic
.phrase as string;
}
async deserialize(opts?: SerializedHdKeychain) {
const _privates = privates.get(this)!;

if (opts?.hdPath) _privates.hdPath = opts.hdPath;
this.imported = !!opts?.imported;
if (opts?.accountsEnabled) _privates.accountsEnabled = opts.accountsEnabled;
if (opts?.accountsDeleted) _privates.accountsDeleted = opts.accountsDeleted;

_privates.mnemonic =
opts?.mnemonic || Wallet.createRandom().mnemonic.phrase;

// If we didn't explicit add a new account, we need attempt to autodiscover the rest
if (opts?.autodiscover) {
const { accountsEnabled } = await autoDiscoverAccounts({
deriveWallet: privates.get(this).deriveWallet,
deriveWallet: _privates.deriveWallet,
});
privates.get(this).accountsEnabled = accountsEnabled;
_privates.accountsEnabled = accountsEnabled;
}

for (let i = 0; i < privates.get(this).accountsEnabled; i++) {
for (let i = 0; i < _privates.accountsEnabled; i++) {
// Do not re-add deleted accounts
if (!opts?.accountsDeleted?.includes(i)) {
privates.get(this).addAccount(i);
_privates.addAccount(i);
}
}
}

async addNewAccount(): Promise<Array<Wallet>> {
privates.get(this).addAccount(privates.get(this).accountsEnabled);
privates.get(this).accountsEnabled += 1;
return privates.get(this).wallets as Wallet[];
const _privates = privates.get(this)!;

_privates.addAccount(_privates.accountsEnabled);
_privates.accountsEnabled += 1;
return _privates.wallets.map(({ wallet }) => wallet);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand All @@ -123,35 +146,32 @@ export class HdKeychain implements IKeychain {
}

getAccounts(): Promise<Array<Address>> {
const addresses = privates
.get(this)
.wallets.map((wallet: Wallet) => (wallet as Wallet).address as Address);
const _privates = privates.get(this)!;
const addresses = _privates.wallets.map(({ wallet }) => wallet.address);
return Promise.resolve(addresses);
}

async exportAccount(address: Address): Promise<PrivateKey> {
const wallet = privates.get(this).getWalletForAddress(address);
const wallet = privates.get(this)!.getWalletForAddress(address);
if (!wallet) throw new Error('Account not found');
return wallet.privateKey;
}

async exportKeychain(): Promise<string> {
return privates.get(this).mnemonic as string;
const { mnemonic } = privates.get(this)!;
if (!mnemonic) throw new Error('No mnemonic');
return mnemonic;
}

async removeAccount(address: Address): Promise<void> {
const accounts = await this.getAccounts();
const accountToDeleteIndex = accounts.indexOf(address);
if (accountToDeleteIndex === -1) {
throw new Error('Account not found');
}
const wallets = privates.get(this)!.wallets;

const accountToDelete = wallets.find((w) => w.wallet.address === address);
if (!accountToDelete) throw new Error('Account not found');

const filteredList = privates
.get(this)
.wallets.filter(
(wallet: Wallet) => (wallet as Wallet).address !== address,
);
const filteredList = wallets.filter((w) => w.wallet.address !== address);

privates.get(this).wallets = filteredList;
privates.get(this).accountsDeleted.push(accountToDeleteIndex);
privates.get(this)!.wallets = filteredList;
privates.get(this)!.accountsDeleted.push(accountToDelete.index);
}
}

0 comments on commit d7da33f

Please sign in to comment.