Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/default network #4557

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 39 additions & 65 deletions packages/ens-controller/src/EnsController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import {
toHex,
InfuraNetworkType,
} from '@metamask/controller-utils';
import type {
NetworkController,
NetworkState,
} from '@metamask/network-controller';
import { defaultState as defaultNetworkState } from '@metamask/network-controller';

import type {
Expand All @@ -19,6 +23,7 @@ import { EnsController, DEFAULT_ENS_NETWORK_MAP } from './EnsController';
import type {
EnsControllerState,
EnsControllerMessenger,
AllowedActions,
} from './EnsController';

const defaultState: EnsControllerState = {
Expand Down Expand Up @@ -74,23 +79,48 @@ const name = 'EnsController';
* @returns A restricted controller messenger.
*/
function getRootMessenger(): RootMessenger {
return new ControllerMessenger();
return new ControllerMessenger<
ExtractAvailableAction<EnsControllerMessenger> | AllowedActions,
ExtractAvailableEvent<EnsControllerMessenger> | never
>();
}

/**
* Constructs the messenger restricted to EnsController actions and events.
*
* @param rootMessenger - The root messenger to base the restricted messenger
* off of.
* @param getNetworkClientByIdMock - Optional mock version of `getNetworkClientById`.
* @returns A restricted controller messenger.
*/
function getRestrictedMessenger(rootMessenger: RootMessenger) {
return rootMessenger.getRestricted<
'EnsController',
'NetworkController:getNetworkClientById'
>({
function getRestrictedMessenger(
rootMessenger: RootMessenger,
getNetworkClientByIdMock?: NetworkController['getNetworkClientById'],
) {
const mockNetworkState = jest.fn<NetworkState, []>().mockReturnValue({
...defaultNetworkState,
selectedNetworkClientId: InfuraNetworkType.mainnet,
});

rootMessenger.registerActionHandler(
'NetworkController:getState',
mockNetworkState,
);

if (!getNetworkClientByIdMock) {
getNetworkClientByIdMock = buildMockGetNetworkClientById();
}
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientByIdMock,
);

return rootMessenger.getRestricted<'EnsController', AllowedActions['type']>({
name,
allowedActions: ['NetworkController:getNetworkClientById'],
allowedActions: [
'NetworkController:getNetworkClientById',
'NetworkController:getState',
],
allowedEvents: [],
});
}
Expand Down Expand Up @@ -174,19 +204,13 @@ describe('EnsController', () => {
it('should clear ensResolutionsByAddress state propery on networkDidChange', async () => {
const rootMessenger = getRootMessenger();
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const controller = new EnsController({
messenger: ensControllerMessenger,
state: {
ensResolutionsByAddress: {
[address1Checksum]: 'peaksignal.eth',
},
},
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand Down Expand Up @@ -492,14 +516,8 @@ describe('EnsController', () => {
it('should return undefined when network is loading', async function () {
const rootMessenger = getRootMessenger();
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -512,19 +530,17 @@ describe('EnsController', () => {

it('should return undefined when network is not ens supported', async function () {
const rootMessenger = getRootMessenger();
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const getNetworkClientById = buildMockGetNetworkClientById({
'AAAA-AAAA-AAAA-AAAA': buildCustomNetworkClientConfiguration({
chainId: '0x9999999',
}),
});
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
const ensControllerMessenger = getRestrictedMessenger(
rootMessenger,
getNetworkClientById,
);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -537,11 +553,6 @@ describe('EnsController', () => {

it('should only resolve an ENS name once', async () => {
const rootMessenger = getRootMessenger();
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const ethProvider = new providersModule.Web3Provider(getProvider());
jest.spyOn(ethProvider, 'resolveName').mockResolvedValue(address1);
Expand All @@ -552,7 +563,6 @@ describe('EnsController', () => {

const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -567,18 +577,12 @@ describe('EnsController', () => {

it('should fail if lookupAddress through an error', async () => {
const rootMessenger = getRootMessenger();
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const ethProvider = new providersModule.Web3Provider(getProvider());
jest.spyOn(ethProvider, 'lookupAddress').mockRejectedValue('error');
jest.spyOn(providersModule, 'Web3Provider').mockReturnValue(ethProvider);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -592,18 +596,12 @@ describe('EnsController', () => {

it('should fail if lookupAddress returns a null value', async () => {
const rootMessenger = getRootMessenger();
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const ethProvider = new providersModule.Web3Provider(getProvider());
jest.spyOn(ethProvider, 'lookupAddress').mockResolvedValue(null);
jest.spyOn(providersModule, 'Web3Provider').mockReturnValue(ethProvider);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -617,11 +615,6 @@ describe('EnsController', () => {

it('should fail if resolveName through an error', async () => {
const rootMessenger = getRootMessenger();
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const ethProvider = new providersModule.Web3Provider(getProvider());
jest
Expand All @@ -631,7 +624,6 @@ describe('EnsController', () => {
jest.spyOn(providersModule, 'Web3Provider').mockReturnValue(ethProvider);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -645,11 +637,6 @@ describe('EnsController', () => {

it('should fail if resolveName returns a null value', async () => {
const rootMessenger = getRootMessenger();
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const ethProvider = new providersModule.Web3Provider(getProvider());
jest.spyOn(ethProvider, 'resolveName').mockResolvedValue(null);
Expand All @@ -659,7 +646,6 @@ describe('EnsController', () => {
jest.spyOn(providersModule, 'Web3Provider').mockReturnValue(ethProvider);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -673,11 +659,6 @@ describe('EnsController', () => {

it('should fail if registred address is zero x error address', async () => {
const rootMessenger = getRootMessenger();
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);
const ethProvider = new providersModule.Web3Provider(getProvider());
jest
Expand All @@ -689,7 +670,6 @@ describe('EnsController', () => {
jest.spyOn(providersModule, 'Web3Provider').mockReturnValue(ethProvider);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand All @@ -703,11 +683,6 @@ describe('EnsController', () => {

it('should fail if the name is registered to a different address than the reverse resolved', async () => {
const rootMessenger = getRootMessenger();
const getNetworkClientById = buildMockGetNetworkClientById();
rootMessenger.registerActionHandler(
'NetworkController:getNetworkClientById',
getNetworkClientById,
);
const ensControllerMessenger = getRestrictedMessenger(rootMessenger);

const ethProvider = new providersModule.Web3Provider(getProvider());
Expand All @@ -718,7 +693,6 @@ describe('EnsController', () => {
jest.spyOn(providersModule, 'Web3Provider').mockReturnValue(ethProvider);
const ens = new EnsController({
messenger: ensControllerMessenger,
provider: getProvider(),
onNetworkDidChange: (listener) => {
listener({
...defaultNetworkState,
Expand Down
66 changes: 40 additions & 26 deletions packages/ens-controller/src/EnsController.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import type {
ExternalProvider,
JsonRpcFetchFunc,
} from '@ethersproject/providers';
import { Web3Provider } from '@ethersproject/providers';
import type { RestrictedControllerMessenger } from '@metamask/base-controller';
import { BaseController } from '@metamask/base-controller';
Expand All @@ -17,6 +13,7 @@ import {
} from '@metamask/controller-utils';
import type {
NetworkControllerGetNetworkClientByIdAction,
NetworkControllerGetStateAction,
NetworkState,
} from '@metamask/network-controller';
import type { Hex } from '@metamask/utils';
Expand Down Expand Up @@ -72,7 +69,9 @@ export type EnsControllerState = {
ensResolutionsByAddress: { [key: string]: string };
};

type AllowedActions = NetworkControllerGetNetworkClientByIdAction;
export type AllowedActions =
| NetworkControllerGetNetworkClientByIdAction
| NetworkControllerGetStateAction;

export type EnsControllerMessenger = RestrictedControllerMessenger<
typeof name,
Expand Down Expand Up @@ -113,20 +112,17 @@ export class EnsController extends BaseController<
* @param options.registriesByChainId - Map between chain IDs and ENS contract addresses.
* @param options.messenger - A reference to the messaging system.
* @param options.state - Initial state to set on this controller.
* @param options.provider - Provider instance.
* @param options.onNetworkDidChange - Allows subscribing to network controller networkDidChange events.
*/
constructor({
registriesByChainId = DEFAULT_ENS_NETWORK_MAP,
messenger,
state = {},
provider,
onNetworkDidChange,
}: {
registriesByChainId?: Record<number, Hex>;
messenger: EnsControllerMessenger;
state?: Partial<EnsControllerState>;
provider?: ExternalProvider | JsonRpcFetchFunc;
onNetworkDidChange?: (
listener: (networkState: NetworkState) => void,
) => void;
Expand All @@ -153,26 +149,12 @@ export class EnsController extends BaseController<
},
});

if (provider && onNetworkDidChange) {
this.#setDefaultEthProvider(registriesByChainId);

if (onNetworkDidChange) {
onNetworkDidChange(({ selectedNetworkClientId }) => {
this.resetState();
const selectedNetworkClient = this.messagingSystem.call(
'NetworkController:getNetworkClientById',
selectedNetworkClientId,
);
const currentChainId = selectedNetworkClient.configuration.chainId;

if (this.#getChainEnsSupport(currentChainId)) {
this.#ethProvider = new Web3Provider(provider, {
chainId: convertHexToDecimal(currentChainId),
name: CHAIN_ID_TO_ETHERS_NETWORK_NAME_MAP[
currentChainId as ChainId
],
ensAddress: registriesByChainId[parseInt(currentChainId, 16)],
});
} else {
this.#ethProvider = null;
}
this.#setEthProvider(selectedNetworkClientId, registriesByChainId);
});
}
}
Expand Down Expand Up @@ -295,6 +277,38 @@ export class EnsController extends BaseController<
return true;
}

#setDefaultEthProvider(registriesByChainId?: Record<number, Hex>) {
const { selectedNetworkClientId } = this.messagingSystem.call(
'NetworkController:getState',
);
this.#setEthProvider(selectedNetworkClientId, registriesByChainId);
}

#setEthProvider(
selectedNetworkClientId: string,
registriesByChainId?: Record<number, Hex>,
) {
const selectedNetworkClient = this.messagingSystem.call(
'NetworkController:getNetworkClientById',
selectedNetworkClientId,
);
const currentChainId = selectedNetworkClient.configuration.chainId;

if (
registriesByChainId &&
registriesByChainId[parseInt(currentChainId, 16)] &&
this.#getChainEnsSupport(currentChainId)
) {
this.#ethProvider = new Web3Provider(selectedNetworkClient.provider, {
chainId: convertHexToDecimal(currentChainId),
name: CHAIN_ID_TO_ETHERS_NETWORK_NAME_MAP[currentChainId as ChainId],
ensAddress: registriesByChainId[parseInt(currentChainId, 16)],
});
} else {
this.#ethProvider = null;
}
}

/**
mikesposito marked this conversation as resolved.
Show resolved Hide resolved
* Check if the chain supports ENS.
*
Expand Down
Loading