diff --git a/app/scripts/lib/account-tracker.test.ts b/app/scripts/controllers/account-tracker-controller.test.ts similarity index 85% rename from app/scripts/lib/account-tracker.test.ts rename to app/scripts/controllers/account-tracker-controller.test.ts index 7cc0dcba14c7..dbabb927fa71 100644 --- a/app/scripts/lib/account-tracker.test.ts +++ b/app/scripts/controllers/account-tracker-controller.test.ts @@ -1,19 +1,19 @@ import EventEmitter from 'events'; import { ControllerMessenger } from '@metamask/base-controller'; import { InternalAccount } from '@metamask/keyring-api'; -import { Hex } from '@metamask/utils'; import { BlockTracker, Provider } from '@metamask/network-controller'; import { flushPromises } from '../../../test/lib/timer-helpers'; -import PreferencesController from '../controllers/preferences-controller'; -import OnboardingController from '../controllers/onboarding'; import { createTestProviderTools } from '../../../test/stub/provider'; -import AccountTracker, { - AccountTrackerOptions, +import PreferencesController from './preferences-controller'; +import type { + AccountTrackerControllerOptions, AllowedActions, AllowedEvents, - getDefaultAccountTrackerState, -} from './account-tracker'; +} from './account-tracker-controller'; +import AccountTrackerController, { + getDefaultAccountTrackerControllerState, +} from './account-tracker-controller'; const noop = () => true; const currentNetworkId = '5'; @@ -68,18 +68,18 @@ type WithControllerOptions = { useMultiAccountBalanceChecker?: boolean; getNetworkClientById?: jest.Mock; getSelectedAccount?: jest.Mock; -} & Partial; +} & Partial; type WithControllerCallback = ({ controller, blockTrackerFromHookStub, blockTrackerStub, - triggerOnAccountRemoved, + triggerAccountRemoved, }: { - controller: AccountTracker; + controller: AccountTrackerController; blockTrackerFromHookStub: MockBlockTracker; blockTrackerStub: MockBlockTracker; - triggerOnAccountRemoved: (address: string) => void; + triggerAccountRemoved: (address: string) => void; }) => ReturnValue; type WithControllerArgs = @@ -132,23 +132,37 @@ function withController( chainId: '0x1', }); - const blockTrackerFromHookStub = buildMockBlockTracker(); + const getNetworkStateStub = jest.fn().mockReturnValue({ + selectedNetworkClientId: 'selectedNetworkClientId', + }); + controllerMessenger.registerActionHandler( + 'NetworkController:getState', + getNetworkStateStub, + ); + const blockTrackerFromHookStub = buildMockBlockTracker(); const getNetworkClientByIdStub = jest.fn().mockReturnValue({ configuration: { - chainId: '0x1', + chainId: currentChainId, }, blockTracker: blockTrackerFromHookStub, provider: providerFromHook, }); - controllerMessenger.registerActionHandler( 'NetworkController:getNetworkClientById', getNetworkClientById || getNetworkClientByIdStub, ); - const controller = new AccountTracker({ - initState: getDefaultAccountTrackerState(), + const getOnboardingControllerState = jest.fn().mockReturnValue({ + completedOnboarding, + }); + controllerMessenger.registerActionHandler( + 'OnboardingController:getState', + getOnboardingControllerState, + ); + + const controller = new AccountTrackerController({ + state: getDefaultAccountTrackerControllerState(), provider: provider as Provider, blockTracker: blockTrackerStub as unknown as BlockTracker, getNetworkIdentifier: jest.fn(), @@ -159,13 +173,20 @@ function withController( }), }, } as PreferencesController, - onboardingController: { - state: { - completedOnboarding, - }, - } as OnboardingController, - controllerMessenger, - getCurrentChainId: () => currentChainId, + messenger: controllerMessenger.getRestricted({ + name: 'AccountTrackerController', + allowedActions: [ + 'AccountsController:getSelectedAccount', + 'NetworkController:getState', + 'NetworkController:getNetworkClientById', + 'OnboardingController:getState', + ], + allowedEvents: [ + 'AccountsController:selectedEvmAccountChange', + 'OnboardingController:stateChange', + 'KeyringController:accountRemoved', + ], + }), ...accountTrackerOptions, }); @@ -173,13 +194,13 @@ function withController( controller, blockTrackerFromHookStub, blockTrackerStub, - triggerOnAccountRemoved: (address: string) => { + triggerAccountRemoved: (address: string) => { controllerMessenger.publish('KeyringController:accountRemoved', address); }, }); } -describe('Account Tracker', () => { +describe('AccountTrackerController', () => { describe('start', () => { it('restarts the subscription to the block tracker and update accounts', async () => { withController(({ controller, blockTrackerStub }) => { @@ -456,9 +477,7 @@ describe('Account Tracker', () => { expect(updateAccountsSpy).toHaveBeenCalledWith(undefined); - const newState = controller.store.getState(); - - expect(newState).toStrictEqual({ + expect(controller.state).toStrictEqual({ accounts: {}, accountsByChainId: {}, currentBlockGasLimit: GAS_LIMIT, @@ -509,9 +528,7 @@ describe('Account Tracker', () => { expect(updateAccountsSpy).toHaveBeenCalledWith('mainnet'); - const newState = controller.store.getState(); - - expect(newState).toStrictEqual({ + expect(controller.state).toStrictEqual({ accounts: {}, accountsByChainId: {}, currentBlockGasLimit: '', @@ -567,8 +584,7 @@ describe('Account Tracker', () => { async ({ controller }) => { await controller.updateAccounts(); - const state = controller.store.getState(); - expect(state).toStrictEqual({ + expect(controller.state).toStrictEqual({ accounts: {}, currentBlockGasLimit: '', accountsByChainId: {}, @@ -579,7 +595,6 @@ describe('Account Tracker', () => { }); describe('chain does not have single call balance address', () => { - const getCurrentChainIdStub: () => Hex = () => '0x999'; // chain without single call balance address const mockAccountsWithSelectedAddress = { ...mockAccounts, [SELECTED_ADDRESS]: { @@ -600,11 +615,9 @@ describe('Account Tracker', () => { { completedOnboarding: true, useMultiAccountBalanceChecker: true, - getCurrentChainId: getCurrentChainIdStub, + state: mockInitialState, }, async ({ controller }) => { - controller.store.updateState(mockInitialState); - await controller.updateAccounts(); const accounts = { @@ -622,8 +635,7 @@ describe('Account Tracker', () => { }, }; - const newState = controller.store.getState(); - expect(newState).toStrictEqual({ + expect(controller.state).toStrictEqual({ accounts, accountsByChainId: { '0x999': accounts, @@ -642,11 +654,9 @@ describe('Account Tracker', () => { { completedOnboarding: true, useMultiAccountBalanceChecker: false, - getCurrentChainId: getCurrentChainIdStub, + state: mockInitialState, }, async ({ controller }) => { - controller.store.updateState(mockInitialState); - await controller.updateAccounts(); const accounts = { @@ -661,8 +671,7 @@ describe('Account Tracker', () => { }, }; - const newState = controller.store.getState(); - expect(newState).toStrictEqual({ + expect(controller.state).toStrictEqual({ accounts, accountsByChainId: { '0x999': accounts, @@ -686,20 +695,18 @@ describe('Account Tracker', () => { getNetworkIdentifier: jest .fn() .mockReturnValue('http://not-localhost:8545'), - getCurrentChainId: () => '0x1', // chain with single call balance address getSelectedAccount: jest.fn().mockReturnValue({ id: 'accountId', address: VALID_ADDRESS, } as InternalAccount), - }, - async ({ controller }) => { - controller.store.updateState({ + state: { accounts: { ...mockAccounts }, accountsByChainId: { '0x1': { ...mockAccounts }, }, - }); - + }, + }, + async ({ controller }) => { await controller.updateAccounts('mainnet'); const accounts = { @@ -713,8 +720,7 @@ describe('Account Tracker', () => { }, }; - const newState = controller.store.getState(); - expect(newState).toStrictEqual({ + expect(controller.state).toStrictEqual({ accounts, accountsByChainId: { '0x1': accounts, @@ -731,75 +737,77 @@ describe('Account Tracker', () => { describe('onAccountRemoved', () => { it('should remove an account from state', () => { - withController(({ controller, triggerOnAccountRemoved }) => { - controller.store.updateState({ - accounts: { ...mockAccounts }, - accountsByChainId: { - [currentChainId]: { - ...mockAccounts, - }, - '0x1': { - ...mockAccounts, - }, - '0x2': { - ...mockAccounts, + withController( + { + state: { + accounts: { ...mockAccounts }, + accountsByChainId: { + [currentChainId]: { + ...mockAccounts, + }, + '0x1': { + ...mockAccounts, + }, + '0x2': { + ...mockAccounts, + }, }, }, - }); - - triggerOnAccountRemoved(VALID_ADDRESS); - - const newState = controller.store.getState(); - - const accounts = { - [VALID_ADDRESS_TWO]: mockAccounts[VALID_ADDRESS_TWO], - }; - - expect(newState).toStrictEqual({ - accounts, - accountsByChainId: { - [currentChainId]: accounts, - '0x1': accounts, - '0x2': accounts, - }, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: {}, - }); - }); + }, + ({ controller, triggerAccountRemoved }) => { + triggerAccountRemoved(VALID_ADDRESS); + + const accounts = { + [VALID_ADDRESS_TWO]: mockAccounts[VALID_ADDRESS_TWO], + }; + + expect(controller.state).toStrictEqual({ + accounts, + accountsByChainId: { + [currentChainId]: accounts, + '0x1': accounts, + '0x2': accounts, + }, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: {}, + }); + }, + ); }); }); describe('clearAccounts', () => { it('should reset state', () => { - withController(({ controller }) => { - controller.store.updateState({ - accounts: { ...mockAccounts }, - accountsByChainId: { - [currentChainId]: { - ...mockAccounts, - }, - '0x1': { - ...mockAccounts, - }, - '0x2': { - ...mockAccounts, + withController( + { + state: { + accounts: { ...mockAccounts }, + accountsByChainId: { + [currentChainId]: { + ...mockAccounts, + }, + '0x1': { + ...mockAccounts, + }, + '0x2': { + ...mockAccounts, + }, }, }, - }); - - controller.clearAccounts(); - - const newState = controller.store.getState(); + }, + ({ controller }) => { + controller.clearAccounts(); - expect(newState).toStrictEqual({ - accounts: {}, - accountsByChainId: { - [currentChainId]: {}, - }, - currentBlockGasLimit: '', - currentBlockGasLimitByChainId: {}, - }); - }); + expect(controller.state).toStrictEqual({ + accounts: {}, + accountsByChainId: { + [currentChainId]: {}, + }, + currentBlockGasLimit: '', + currentBlockGasLimitByChainId: {}, + }); + }, + ); }); }); }); diff --git a/app/scripts/lib/account-tracker.ts b/app/scripts/controllers/account-tracker-controller.ts similarity index 68% rename from app/scripts/lib/account-tracker.ts rename to app/scripts/controllers/account-tracker-controller.ts index 8ca119ccf83f..e2c78ea3f3f9 100644 --- a/app/scripts/lib/account-tracker.ts +++ b/app/scripts/controllers/account-tracker-controller.ts @@ -10,7 +10,6 @@ import EthQuery from '@metamask/eth-query'; import { v4 as random } from 'uuid'; -import { ObservableStore } from '@metamask/obs-store'; import log from 'loglevel'; import pify from 'pify'; import { Web3Provider } from '@ethersproject/providers'; @@ -22,10 +21,16 @@ import { NetworkClientConfiguration, NetworkClientId, NetworkControllerGetNetworkClientByIdAction, + NetworkControllerGetStateAction, Provider, } from '@metamask/network-controller'; import { hasProperty, Hex } from '@metamask/utils'; -import { ControllerMessenger } from '@metamask/base-controller'; +import { + BaseController, + ControllerGetStateAction, + ControllerStateChangeEvent, + RestrictedControllerMessenger, +} from '@metamask/base-controller'; import { AccountsControllerGetSelectedAccountAction, AccountsControllerSelectedEvmAccountChangeEvent, @@ -33,51 +38,139 @@ import { import { KeyringControllerAccountRemovedEvent } from '@metamask/keyring-controller'; import { InternalAccount } from '@metamask/keyring-api'; -import OnboardingController, { - OnboardingControllerStateChangeEvent, -} from '../controllers/onboarding'; -import PreferencesController from '../controllers/preferences-controller'; import { LOCALHOST_RPC_URL } from '../../../shared/constants/network'; import { SINGLE_CALL_BALANCES_ADDRESSES } from '../constants/contracts'; -import { previousValueComparator } from './util'; +import { previousValueComparator } from '../lib/util'; +import type { + OnboardingControllerGetStateAction, + OnboardingControllerStateChangeEvent, +} from './onboarding'; +import PreferencesController from './preferences-controller'; + +// Unique name for the controller +const controllerName = 'AccountTrackerController'; type Account = { address: string; balance: string | null; }; -export type AccountTrackerState = { +/** + * The state of the {@link AccountTrackerController} + * + * @property accounts - The accounts currently stored in this AccountTrackerController + * @property accountsByChainId - The accounts currently stored in this AccountTrackerController keyed by chain id + * @property currentBlockGasLimit - A hex string indicating the gas limit of the current block + * @property currentBlockGasLimitByChainId - A hex string indicating the gas limit of the current block keyed by chain id + */ +export type AccountTrackerControllerState = { accounts: Record>; currentBlockGasLimit: string; - accountsByChainId: Record; + accountsByChainId: Record; currentBlockGasLimitByChainId: Record; }; -export const getDefaultAccountTrackerState = (): AccountTrackerState => ({ - accounts: {}, - currentBlockGasLimit: '', - accountsByChainId: {}, - currentBlockGasLimitByChainId: {}, -}); +/** + * {@link AccountTrackerController}'s metadata. + * + * This allows us to choose if fields of the state should be persisted or not + * using the `persist` flag; and if they can be sent to Sentry or not, using + * the `anonymous` flag. + */ +const controllerMetadata = { + accounts: { + persist: true, + anonymous: false, + }, + currentBlockGasLimit: { + persist: true, + anonymous: true, + }, + accountsByChainId: { + persist: true, + anonymous: false, + }, + currentBlockGasLimitByChainId: { + persist: true, + anonymous: true, + }, +}; + +/** + * Function to get default state of the {@link AccountTrackerController}. + */ +export const getDefaultAccountTrackerControllerState = + (): AccountTrackerControllerState => ({ + accounts: {}, + currentBlockGasLimit: '', + accountsByChainId: {}, + currentBlockGasLimitByChainId: {}, + }); + +/** + * Returns the state of the {@link AccountTrackerController}. + */ +export type AccountTrackerControllerGetStateAction = ControllerGetStateAction< + typeof controllerName, + AccountTrackerControllerState +>; +/** + * Actions exposed by the {@link AccountTrackerController}. + */ +export type AccountTrackerControllerActions = + AccountTrackerControllerGetStateAction; + +/** + * Event emitted when the state of the {@link AccountTrackerController} changes. + */ +export type AccountTrackerControllerStateChangeEvent = + ControllerStateChangeEvent< + typeof controllerName, + AccountTrackerControllerState + >; + +/** + * Events emitted by {@link AccountTrackerController}. + */ +export type AccountTrackerControllerEvents = + AccountTrackerControllerStateChangeEvent; + +/** + * Actions that this controller is allowed to call. + */ export type AllowedActions = + | OnboardingControllerGetStateAction | AccountsControllerGetSelectedAccountAction + | NetworkControllerGetStateAction | NetworkControllerGetNetworkClientByIdAction; +/** + * Events that this controller is allowed to subscribe. + */ export type AllowedEvents = | AccountsControllerSelectedEvmAccountChangeEvent | KeyringControllerAccountRemovedEvent | OnboardingControllerStateChangeEvent; -export type AccountTrackerOptions = { - initState: Partial; +/** + * Messenger type for the {@link AccountTrackerController}. + */ +export type AccountTrackerControllerMessenger = RestrictedControllerMessenger< + typeof controllerName, + AccountTrackerControllerActions | AllowedActions, + AccountTrackerControllerEvents | AllowedEvents, + AllowedActions['type'], + AllowedEvents['type'] +>; + +export type AccountTrackerControllerOptions = { + state: Partial; + messenger: AccountTrackerControllerMessenger; provider: Provider; blockTracker: BlockTracker; - getCurrentChainId: () => Hex; getNetworkIdentifier: (config?: NetworkClientConfiguration) => string; preferencesController: PreferencesController; - onboardingController: OnboardingController; - controllerMessenger: ControllerMessenger; }; /** @@ -86,22 +179,12 @@ export type AccountTrackerOptions = { * * It also tracks transaction hashes, and checks their inclusion status on each new block. * - * AccountTracker - * - * @property store The stored object containing all accounts to track, as well as the current block's gas limit. - * @property store.accounts The accounts currently stored in this AccountTracker - * @property store.accountsByChainId The accounts currently stored in this AccountTracker keyed by chain id - * @property store.currentBlockGasLimit A hex string indicating the gas limit of the current block - * @property store.currentBlockGasLimitByChainId A hex string indicating the gas limit of the current block keyed by chain id */ -export default class AccountTracker { - /** - * Observable store containing controller data. - */ - store: ObservableStore; - - resetState: () => void; - +export default class AccountTrackerController extends BaseController< + typeof controllerName, + AccountTrackerControllerState, + AccountTrackerControllerMessenger +> { #pollingTokenSets = new Map>(); #listeners: Record Promise> = @@ -113,52 +196,48 @@ export default class AccountTracker { #currentBlockNumberByChainId: Record = {}; - #getCurrentChainId: AccountTrackerOptions['getCurrentChainId']; - - #getNetworkIdentifier: AccountTrackerOptions['getNetworkIdentifier']; + #getNetworkIdentifier: AccountTrackerControllerOptions['getNetworkIdentifier']; - #preferencesController: AccountTrackerOptions['preferencesController']; - - #onboardingController: AccountTrackerOptions['onboardingController']; - - #controllerMessenger: AccountTrackerOptions['controllerMessenger']; + #preferencesController: AccountTrackerControllerOptions['preferencesController']; #selectedAccount: InternalAccount; /** - * @param opts - Options for initializing the controller - * @param opts.provider - An EIP-1193 provider instance that uses the current global network - * @param opts.blockTracker - A block tracker, which emits events for each new block - * @param opts.getCurrentChainId - A function that returns the `chainId` for the current global network - * @param opts.getNetworkIdentifier - A function that returns the current network or passed nework configuration + * @param options - Options for initializing the controller + * @param options.state - Initial controller state. + * @param options.messenger - Messenger used to communicate with BaseV2 controller. + * @param options.provider - An EIP-1193 provider instance that uses the current global network + * @param options.blockTracker - A block tracker, which emits events for each new block + * @param options.getNetworkIdentifier - A function that returns the current network or passed network configuration + * @param options.preferencesController - The preferences controller */ - constructor(opts: AccountTrackerOptions) { - const initState = getDefaultAccountTrackerState(); - this.store = new ObservableStore({ - ...initState, - ...opts.initState, + constructor(options: AccountTrackerControllerOptions) { + super({ + name: controllerName, + metadata: controllerMetadata, + state: { + ...getDefaultAccountTrackerControllerState(), + ...options.state, + }, + messenger: options.messenger, }); - this.resetState = () => { - this.store.updateState(initState); - }; + this.#provider = options.provider; + this.#blockTracker = options.blockTracker; - this.#provider = opts.provider; - this.#blockTracker = opts.blockTracker; - - this.#getCurrentChainId = opts.getCurrentChainId; - this.#getNetworkIdentifier = opts.getNetworkIdentifier; - this.#preferencesController = opts.preferencesController; - this.#onboardingController = opts.onboardingController; - this.#controllerMessenger = opts.controllerMessenger; + this.#getNetworkIdentifier = options.getNetworkIdentifier; + this.#preferencesController = options.preferencesController; // subscribe to account removal - this.#controllerMessenger.subscribe( + this.messagingSystem.subscribe( 'KeyringController:accountRemoved', (address) => this.removeAccounts([address]), ); - this.#controllerMessenger.subscribe( + const onboardingState = this.messagingSystem.call( + 'OnboardingController:getState', + ); + this.messagingSystem.subscribe( 'OnboardingController:stateChange', previousValueComparator((prevState, currState) => { const { completedOnboarding: prevCompletedOnboarding } = prevState; @@ -167,14 +246,14 @@ export default class AccountTracker { this.updateAccountsAllActiveNetworks(); } return true; - }, this.#onboardingController.state), + }, onboardingState), ); - this.#selectedAccount = this.#controllerMessenger.call( + this.#selectedAccount = this.messagingSystem.call( 'AccountsController:getSelectedAccount', ); - this.#controllerMessenger.subscribe( + this.messagingSystem.subscribe( 'AccountsController:selectedEvmAccountChange', (newAccount) => { const { useMultiAccountBalanceChecker } = @@ -191,6 +270,21 @@ export default class AccountTracker { ); } + resetState(): void { + const { + accounts, + accountsByChainId, + currentBlockGasLimit, + currentBlockGasLimitByChainId, + } = getDefaultAccountTrackerControllerState(); + this.update((state) => { + state.accounts = accounts; + state.accountsByChainId = accountsByChainId; + state.currentBlockGasLimit = currentBlockGasLimit; + state.currentBlockGasLimitByChainId = currentBlockGasLimitByChainId; + }); + } + /** * Starts polling with global selected network */ @@ -220,6 +314,22 @@ export default class AccountTracker { this.#blockTracker.removeListener('latest', this.#updateForBlock); } + /** + * Gets the current chain ID. + */ + #getCurrentChainId(): Hex { + const { selectedNetworkClientId } = this.messagingSystem.call( + 'NetworkController:getState', + ); + const { + configuration: { chainId }, + } = this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + selectedNetworkClientId, + ); + return chainId; + } + /** * Resolves a networkClientId to a network client config * or globally selected network config if not provided @@ -235,7 +345,7 @@ export default class AccountTracker { } { if (networkClientId) { const { configuration, provider, blockTracker } = - this.#controllerMessenger.call( + this.messagingSystem.call( 'NetworkController:getNetworkClientById', networkClientId, ); @@ -355,13 +465,15 @@ export default class AccountTracker { * * @param chainId - The chain ID */ - #getAccountsForChainId(chainId: Hex): AccountTrackerState['accounts'] { - const { accounts, accountsByChainId } = this.store.getState(); + #getAccountsForChainId( + chainId: Hex, + ): AccountTrackerControllerState['accounts'] { + const { accounts, accountsByChainId } = this.state; if (accountsByChainId[chainId]) { return cloneDeep(accountsByChainId[chainId]); } - const newAccounts: AccountTrackerState['accounts'] = {}; + const newAccounts: AccountTrackerControllerState['accounts'] = {}; Object.keys(accounts).forEach((address) => { newAccounts[address] = {}; }); @@ -370,16 +482,16 @@ export default class AccountTracker { /** * Ensures that the locally stored accounts are in sync with a set of accounts stored externally to this - * AccountTracker. + * AccountTrackerController. * - * Once this AccountTracker's accounts are up to date with those referenced by the passed addresses, each + * Once this AccountTrackerController accounts are up to date with those referenced by the passed addresses, each * of these accounts are given an updated balance via EthQuery. * - * @param addresses - The array of hex addresses for accounts with which this AccountTracker's accounts should be + * @param addresses - The array of hex addresses for accounts with which this AccountTrackerController accounts should be * in sync */ syncWithAddresses(addresses: string[]): void { - const { accounts } = this.store.getState(); + const { accounts } = this.state; const locals = Object.keys(accounts); const accountsToAdd: string[] = []; @@ -408,7 +520,7 @@ export default class AccountTracker { */ addAccounts(addresses: string[]): void { const { accounts: _accounts, accountsByChainId: _accountsByChainId } = - this.store.getState(); + this.state; const accounts = cloneDeep(_accounts); const accountsByChainId = cloneDeep(_accountsByChainId); @@ -422,7 +534,10 @@ export default class AccountTracker { }); }); // save accounts state - this.store.updateState({ accounts, accountsByChainId }); + this.update((state) => { + state.accounts = accounts; + state.accountsByChainId = accountsByChainId; + }); // fetch balances for the accounts if there is block number ready if (this.#currentBlockNumberByChainId[this.#getCurrentChainId()]) { @@ -443,7 +558,7 @@ export default class AccountTracker { */ removeAccounts(addresses: string[]): void { const { accounts: _accounts, accountsByChainId: _accountsByChainId } = - this.store.getState(); + this.state; const accounts = cloneDeep(_accounts); const accountsByChainId = cloneDeep(_accountsByChainId); @@ -457,23 +572,26 @@ export default class AccountTracker { }); }); // save accounts state - this.store.updateState({ accounts, accountsByChainId }); + this.update((state) => { + state.accounts = accounts; + state.accountsByChainId = accountsByChainId; + }); } /** * Removes all addresses and associated balances */ clearAccounts(): void { - this.store.updateState({ - accounts: {}, - accountsByChainId: { + this.update((state) => { + state.accounts = {}; + state.accountsByChainId = { [this.#getCurrentChainId()]: {}, - }, + }; }); } /** - * Given a block, updates this AccountTracker's currentBlockGasLimit and currentBlockGasLimitByChainId and then updates + * Given a block, updates this AccountTrackerController currentBlockGasLimit and currentBlockGasLimitByChainId and then updates * each local account's balance via EthQuery * * @private @@ -485,7 +603,7 @@ export default class AccountTracker { }; /** - * Given a block, updates this AccountTracker's currentBlockGasLimitByChainId, and then updates each local account's balance + * Given a block, updates this AccountTrackerController currentBlockGasLimitByChainId, and then updates each local account's balance * via EthQuery * * @private @@ -510,15 +628,11 @@ export default class AccountTracker { return; } const currentBlockGasLimit = currentBlock.gasLimit; - const { currentBlockGasLimitByChainId } = this.store.getState(); - this.store.updateState({ - ...(chainId === this.#getCurrentChainId() && { - currentBlockGasLimit, - }), - currentBlockGasLimitByChainId: { - ...currentBlockGasLimitByChainId, - [chainId]: currentBlockGasLimit, - }, + this.update((state) => { + if (chainId === this.#getCurrentChainId()) { + state.currentBlockGasLimit = currentBlockGasLimit; + } + state.currentBlockGasLimitByChainId[chainId] = currentBlockGasLimit; }); try { @@ -549,7 +663,9 @@ export default class AccountTracker { * @param networkClientId - optional network client ID to use instead of the globally selected network. */ async updateAccounts(networkClientId?: NetworkClientId): Promise { - const { completedOnboarding } = this.#onboardingController.state; + const { completedOnboarding } = this.messagingSystem.call( + 'OnboardingController:getState', + ); if (!completedOnboarding) { return; } @@ -561,11 +677,11 @@ export default class AccountTracker { let addresses = []; if (useMultiAccountBalanceChecker) { - const { accounts } = this.store.getState(); + const { accounts } = this.state; addresses = Object.keys(accounts); } else { - const selectedAddress = this.#controllerMessenger.call( + const selectedAddress = this.messagingSystem.call( 'AccountsController:getSelectedAccount', ).address; @@ -573,14 +689,11 @@ export default class AccountTracker { } const rpcUrl = 'http://127.0.0.1:8545'; - const singleCallBalancesAddress = - SINGLE_CALL_BALANCES_ADDRESSES[ - chainId as keyof typeof SINGLE_CALL_BALANCES_ADDRESSES - ]; if ( identifier === LOCALHOST_RPC_URL || identifier === rpcUrl || - !singleCallBalancesAddress + !((id): id is keyof typeof SINGLE_CALL_BALANCES_ADDRESSES => + id in SINGLE_CALL_BALANCES_ADDRESSES)(chainId) ) { await Promise.all( addresses.map((address) => @@ -590,7 +703,7 @@ export default class AccountTracker { } else { await this.#updateAccountsViaBalanceChecker( addresses, - singleCallBalancesAddress, + SINGLE_CALL_BALANCES_ADDRESSES[chainId], provider, chainId, ); @@ -657,15 +770,11 @@ export default class AccountTracker { newAccounts[address] = result; - const { accountsByChainId } = this.store.getState(); - this.store.updateState({ - ...(chainId === this.#getCurrentChainId() && { - accounts: newAccounts, - }), - accountsByChainId: { - ...accountsByChainId, - [chainId]: newAccounts, - }, + this.update((state) => { + if (chainId === this.#getCurrentChainId()) { + state.accounts = newAccounts; + } + state.accountsByChainId[chainId] = newAccounts; }); } @@ -695,7 +804,7 @@ export default class AccountTracker { const balances = await ethContract.balances(addresses, ethBalance); const accounts = this.#getAccountsForChainId(chainId); - const newAccounts: AccountTrackerState['accounts'] = {}; + const newAccounts: AccountTrackerControllerState['accounts'] = {}; Object.keys(accounts).forEach((address) => { if (!addresses.includes(address)) { newAccounts[address] = { address, balance: null }; @@ -706,15 +815,11 @@ export default class AccountTracker { newAccounts[address] = { address, balance }; }); - const { accountsByChainId } = this.store.getState(); - this.store.updateState({ - ...(chainId === this.#getCurrentChainId() && { - accounts: newAccounts, - }), - accountsByChainId: { - ...accountsByChainId, - [chainId]: newAccounts, - }, + this.update((state) => { + if (chainId === this.#getCurrentChainId()) { + state.accounts = newAccounts; + } + state.accountsByChainId[chainId] = newAccounts; }); } catch (error) { log.warn( diff --git a/app/scripts/controllers/mmi-controller.test.ts b/app/scripts/controllers/mmi-controller.test.ts index 3a9e6cddba6a..348ccd40916b 100644 --- a/app/scripts/controllers/mmi-controller.test.ts +++ b/app/scripts/controllers/mmi-controller.test.ts @@ -99,7 +99,7 @@ describe('MMIController', function () { 'NetworkController:infuraIsUnblocked', ], }), - state: mockNetworkState({chainId: CHAIN_IDS.SEPOLIA}), + state: mockNetworkState({ chainId: CHAIN_IDS.SEPOLIA }), infuraProjectId: 'mock-infura-project-id', }); @@ -272,7 +272,7 @@ describe('MMIController', function () { mmiController.getState = jest.fn(); mmiController.captureException = jest.fn(); - mmiController.accountTracker = { syncWithAddresses: jest.fn() }; + mmiController.accountTrackerController = { syncWithAddresses: jest.fn() }; jest.spyOn(metaMetricsController.store, 'getState').mockReturnValue({ metaMetricsId: mockMetaMetricsId, @@ -385,7 +385,7 @@ describe('MMIController', function () { mmiController.keyringController.addNewAccountForKeyring = jest.fn(); mmiController.custodyController.setAccountDetails = jest.fn(); - mmiController.accountTracker.syncWithAddresses = jest.fn(); + mmiController.accountTrackerController.syncWithAddresses = jest.fn(); mmiController.storeCustodianSupportedChains = jest.fn(); mmiController.custodyController.storeCustodyStatusMap = jest.fn(); @@ -400,7 +400,9 @@ describe('MMIController', function () { expect( mmiController.custodyController.setAccountDetails, ).toHaveBeenCalled(); - expect(mmiController.accountTracker.syncWithAddresses).toHaveBeenCalled(); + expect( + mmiController.accountTrackerController.syncWithAddresses, + ).toHaveBeenCalled(); expect(mmiController.storeCustodianSupportedChains).toHaveBeenCalled(); expect( mmiController.custodyController.storeCustodyStatusMap, diff --git a/app/scripts/controllers/mmi-controller.ts b/app/scripts/controllers/mmi-controller.ts index 0c43684d7f58..d0e905d673d8 100644 --- a/app/scripts/controllers/mmi-controller.ts +++ b/app/scripts/controllers/mmi-controller.ts @@ -39,12 +39,12 @@ import { Signature, ConnectionRequest, } from '../../../shared/constants/mmi-controller'; -import AccountTracker from '../lib/account-tracker'; // TODO: Remove restricted import // eslint-disable-next-line import/no-restricted-paths import { getCurrentChainId } from '../../../ui/selectors'; import MetaMetricsController from './metametrics'; import { getPermissionBackgroundApiMethods } from './permissions'; +import AccountTrackerController from './account-tracker-controller'; import PreferencesController from './preferences-controller'; import { AppStateController } from './app-state'; @@ -86,7 +86,7 @@ export default class MMIController extends EventEmitter { // eslint-disable-next-line @typescript-eslint/no-explicit-any private getPendingNonce: (address: string) => Promise; - private accountTracker: AccountTracker; + private accountTrackerController: AccountTrackerController; private metaMetricsController: MetaMetricsController; @@ -148,7 +148,7 @@ export default class MMIController extends EventEmitter { this.custodyController = opts.custodyController; this.getState = opts.getState; this.getPendingNonce = opts.getPendingNonce; - this.accountTracker = opts.accountTracker; + this.accountTrackerController = opts.accountTrackerController; this.metaMetricsController = opts.metaMetricsController; this.networkController = opts.networkController; this.permissionController = opts.permissionController; @@ -504,7 +504,7 @@ export default class MMIController extends EventEmitter { } }); - this.accountTracker.syncWithAddresses(accountsToTrack); + this.accountTrackerController.syncWithAddresses(accountsToTrack); for (const address of newAccounts) { try { diff --git a/app/scripts/metamask-controller.js b/app/scripts/metamask-controller.js index a5445e16875a..3d6d16df4b95 100644 --- a/app/scripts/metamask-controller.js +++ b/app/scripts/metamask-controller.js @@ -274,7 +274,7 @@ import MMIController from './controllers/mmi-controller'; import { mmiKeyringBuilderFactory } from './mmi-keyring-builder-factory'; ///: END:ONLY_INCLUDE_IF import ComposableObservableStore from './lib/ComposableObservableStore'; -import AccountTracker from './lib/account-tracker'; +import AccountTrackerController from './controllers/account-tracker-controller'; import createDupeReqFilterStream from './lib/createDupeReqFilterStream'; import createLoggerMiddleware from './lib/createLoggerMiddleware'; import { @@ -1222,7 +1222,7 @@ export default class MetamaskController extends EventEmitter { const internalAccountCount = internalAccounts.length; const accountTrackerCount = Object.keys( - this.accountTracker.store.getState().accounts || {}, + this.accountTrackerController.state.accounts || {}, ).length; captureException( @@ -1655,11 +1655,24 @@ export default class MetamaskController extends EventEmitter { }); // account tracker watches balances, nonces, and any code at their address - this.accountTracker = new AccountTracker({ + this.accountTrackerController = new AccountTrackerController({ + state: { accounts: {} }, + messenger: this.controllerMessenger.getRestricted({ + name: 'AccountTrackerController', + allowedActions: [ + 'AccountsController:getSelectedAccount', + 'NetworkController:getState', + 'NetworkController:getNetworkClientById', + 'OnboardingController:getState', + ], + allowedEvents: [ + 'AccountsController:selectedEvmAccountChange', + 'OnboardingController:stateChange', + 'KeyringController:accountRemoved', + ], + }), provider: this.provider, blockTracker: this.blockTracker, - getCurrentChainId: () => - getCurrentChainId({ metamask: this.networkController.state }), getNetworkIdentifier: (providerConfig) => { const { type, rpcUrl } = providerConfig ?? @@ -1669,17 +1682,6 @@ export default class MetamaskController extends EventEmitter { return type === NETWORK_TYPES.RPC ? rpcUrl : type; }, preferencesController: this.preferencesController, - onboardingController: this.onboardingController, - controllerMessenger: this.controllerMessenger.getRestricted({ - name: 'AccountTracker', - allowedActions: ['AccountsController:getSelectedAccount'], - allowedEvents: [ - 'AccountsController:selectedEvmAccountChange', - 'OnboardingController:stateChange', - 'KeyringController:accountRemoved', - ], - }), - initState: { accounts: {} }, }); // start and stop polling for balances based on activeControllerConnections @@ -1998,7 +2000,7 @@ export default class MetamaskController extends EventEmitter { custodyController: this.custodyController, getState: this.getState.bind(this), getPendingNonce: this.getPendingNonce.bind(this), - accountTracker: this.accountTracker, + accountTrackerController: this.accountTrackerController, metaMetricsController: this.metaMetricsController, networkController: this.networkController, permissionController: this.permissionController, @@ -2207,11 +2209,11 @@ export default class MetamaskController extends EventEmitter { this._onUserOperationTransactionUpdated.bind(this), ); - // ensure accountTracker updates balances after network change + // ensure AccountTrackerController updates balances after network change networkControllerMessenger.subscribe( 'NetworkController:networkDidChange', () => { - this.accountTracker.updateAccounts(); + this.accountTrackerController.updateAccounts(); }, ); @@ -2323,7 +2325,7 @@ export default class MetamaskController extends EventEmitter { * On chrome profile re-start, they will be re-initialized. */ const resetOnRestartStore = { - AccountTracker: this.accountTracker.store, + AccountTracker: this.accountTrackerController, TokenRatesController: this.tokenRatesController, DecryptMessageController: this.decryptMessageController, EncryptionPublicKeyController: this.encryptionPublicKeyController, @@ -2448,7 +2450,9 @@ export default class MetamaskController extends EventEmitter { // if this is the first time, clear the state of by calling these methods const resetMethods = [ - this.accountTracker.resetState, + this.accountTrackerController.resetState.bind( + this.accountTrackerController, + ), this.decryptMessageController.resetState.bind( this.decryptMessageController, ), @@ -2548,7 +2552,7 @@ export default class MetamaskController extends EventEmitter { } triggerNetworkrequests() { - this.accountTracker.start(); + this.accountTrackerController.start(); this.txController.startIncomingTransactionPolling(); this.tokenDetectionController.enable(); @@ -2567,7 +2571,7 @@ export default class MetamaskController extends EventEmitter { } stopNetworkRequests() { - this.accountTracker.stop(); + this.accountTrackerController.stop(); this.txController.stopIncomingTransactionPolling(); this.tokenDetectionController.disable(); @@ -4268,8 +4272,8 @@ export default class MetamaskController extends EventEmitter { // Clear notification state this.notificationController.clear(); - // clear accounts in accountTracker - this.accountTracker.clearAccounts(); + // clear accounts in AccountTrackerController + this.accountTrackerController.clearAccounts(); this.txController.clearUnapprovedTransactions(); @@ -4366,14 +4370,14 @@ export default class MetamaskController extends EventEmitter { } /** - * Get an account balance from the AccountTracker or request it directly from the network. + * Get an account balance from the AccountTrackerController or request it directly from the network. * * @param {string} address - The account address * @param {EthQuery} ethQuery - The EthQuery instance to use when asking the network */ getBalance(address, ethQuery) { return new Promise((resolve, reject) => { - const cached = this.accountTracker.store.getState().accounts[address]; + const cached = this.accountTrackerController.state.accounts[address]; if (cached && cached.balance) { resolve(cached.balance); @@ -4431,9 +4435,9 @@ export default class MetamaskController extends EventEmitter { // Automatic login via config password await this.submitPassword(password); - // Updating accounts in this.accountTracker before starting UI syncing ensure that + // Updating accounts in this.accountTrackerController before starting UI syncing ensure that // state has account balance before it is synced with UI - await this.accountTracker.updateAccountsAllActiveNetworks(); + await this.accountTrackerController.updateAccountsAllActiveNetworks(); } finally { this._startUISync(); } @@ -4610,7 +4614,7 @@ export default class MetamaskController extends EventEmitter { oldAccounts.concat(accounts.map((a) => a.address.toLowerCase())), ), ]; - this.accountTracker.syncWithAddresses(accountsToTrack); + this.accountTrackerController.syncWithAddresses(accountsToTrack); return accounts; } @@ -6157,7 +6161,7 @@ export default class MetamaskController extends EventEmitter { return; } - this.accountTracker.syncWithAddresses(addresses); + this.accountTrackerController.syncWithAddresses(addresses); } /** diff --git a/app/scripts/metamask-controller.test.js b/app/scripts/metamask-controller.test.js index 4121160a45af..d1da34c48e0e 100644 --- a/app/scripts/metamask-controller.test.js +++ b/app/scripts/metamask-controller.test.js @@ -738,19 +738,23 @@ describe('MetaMaskController', () => { }); describe('#getBalance', () => { - it('should return the balance known by accountTracker', async () => { + it('should return the balance known by accountTrackerController', async () => { const accounts = {}; const balance = '0x14ced5122ce0a000'; accounts[TEST_ADDRESS] = { balance }; - metamaskController.accountTracker.store.putState({ accounts }); + jest + .spyOn(metamaskController.accountTrackerController, 'state', 'get') + .mockReturnValue({ + accounts, + }); const gotten = await metamaskController.getBalance(TEST_ADDRESS); expect(balance).toStrictEqual(gotten); }); - it('should ask the network for a balance when not known by accountTracker', async () => { + it('should ask the network for a balance when not known by accountTrackerController', async () => { const accounts = {}; const balance = '0x14ced5122ce0a000'; const ethQuery = new EthQuery(); @@ -758,7 +762,11 @@ describe('MetaMaskController', () => { callback(undefined, balance); }); - metamaskController.accountTracker.store.putState({ accounts }); + jest + .spyOn(metamaskController.accountTrackerController, 'state', 'get') + .mockReturnValue({ + accounts, + }); const gotten = await metamaskController.getBalance( TEST_ADDRESS, @@ -1687,21 +1695,27 @@ describe('MetaMaskController', () => { it('should do nothing if there are no keyrings in state', async () => { jest - .spyOn(metamaskController.accountTracker, 'syncWithAddresses') + .spyOn( + metamaskController.accountTrackerController, + 'syncWithAddresses', + ) .mockReturnValue(); const oldState = metamaskController.getState(); await metamaskController._onKeyringControllerUpdate({ keyrings: [] }); expect( - metamaskController.accountTracker.syncWithAddresses, + metamaskController.accountTrackerController.syncWithAddresses, ).not.toHaveBeenCalled(); expect(metamaskController.getState()).toStrictEqual(oldState); }); it('should sync addresses if there are keyrings in state', async () => { jest - .spyOn(metamaskController.accountTracker, 'syncWithAddresses') + .spyOn( + metamaskController.accountTrackerController, + 'syncWithAddresses', + ) .mockReturnValue(); const oldState = metamaskController.getState(); @@ -1714,14 +1728,17 @@ describe('MetaMaskController', () => { }); expect( - metamaskController.accountTracker.syncWithAddresses, + metamaskController.accountTrackerController.syncWithAddresses, ).toHaveBeenCalledWith(accounts); expect(metamaskController.getState()).toStrictEqual(oldState); }); it('should NOT update selected address if already unlocked', async () => { jest - .spyOn(metamaskController.accountTracker, 'syncWithAddresses') + .spyOn( + metamaskController.accountTrackerController, + 'syncWithAddresses', + ) .mockReturnValue(); const oldState = metamaskController.getState(); @@ -1735,14 +1752,17 @@ describe('MetaMaskController', () => { }); expect( - metamaskController.accountTracker.syncWithAddresses, + metamaskController.accountTrackerController.syncWithAddresses, ).toHaveBeenCalledWith(accounts); expect(metamaskController.getState()).toStrictEqual(oldState); }); it('filter out non-EVM addresses prior to calling syncWithAddresses', async () => { jest - .spyOn(metamaskController.accountTracker, 'syncWithAddresses') + .spyOn( + metamaskController.accountTrackerController, + 'syncWithAddresses', + ) .mockReturnValue(); const oldState = metamaskController.getState(); @@ -1759,7 +1779,7 @@ describe('MetaMaskController', () => { }); expect( - metamaskController.accountTracker.syncWithAddresses, + metamaskController.accountTrackerController.syncWithAddresses, ).toHaveBeenCalledWith(accounts); expect(metamaskController.getState()).toStrictEqual(oldState); }); diff --git a/development/ts-migration-dashboard/files-to-convert.json b/development/ts-migration-dashboard/files-to-convert.json index d5063250db16..ea3015d4c1ba 100644 --- a/development/ts-migration-dashboard/files-to-convert.json +++ b/development/ts-migration-dashboard/files-to-convert.json @@ -63,7 +63,6 @@ "app/scripts/inpage.js", "app/scripts/lib/ComposableObservableStore.js", "app/scripts/lib/ComposableObservableStore.test.js", - "app/scripts/lib/account-tracker.js", "app/scripts/lib/cleanErrorStack.js", "app/scripts/lib/cleanErrorStack.test.js", "app/scripts/lib/createLoggerMiddleware.js", diff --git a/shared/constants/mmi-controller.ts b/shared/constants/mmi-controller.ts index 50cc26ef5541..e61d7ed807cd 100644 --- a/shared/constants/mmi-controller.ts +++ b/shared/constants/mmi-controller.ts @@ -12,7 +12,7 @@ import PreferencesController from '../../app/scripts/controllers/preferences-con import { AppStateController } from '../../app/scripts/controllers/app-state'; // TODO: Remove restricted import // eslint-disable-next-line import/no-restricted-paths -import AccountTracker from '../../app/scripts/lib/account-tracker'; +import AccountTrackerController from '../../app/scripts/controllers/account-tracker-controller'; // TODO: Remove restricted import // eslint-disable-next-line import/no-restricted-paths import MetaMetricsController from '../../app/scripts/controllers/metametrics'; @@ -35,7 +35,7 @@ export type MMIControllerOptions = { // TODO: Replace `any` with type // eslint-disable-next-line @typescript-eslint/no-explicit-any getPendingNonce: (address: string) => Promise; - accountTracker: AccountTracker; + accountTrackerController: AccountTrackerController; metaMetricsController: MetaMetricsController; networkController: NetworkController; // TODO: Replace `any` with type