diff --git a/app/tray/Notify/ExtensionConnect/index.jsx b/app/tray/Notify/ExtensionConnect/index.jsx
new file mode 100644
index 000000000..a56e53f03
--- /dev/null
+++ b/app/tray/Notify/ExtensionConnect/index.jsx
@@ -0,0 +1,116 @@
+import styled from 'styled-components'
+import { useState } from 'react'
+
+import link from '../../../../resources/link'
+import { capitalize } from '../../../../resources/utils'
+import svg from '../../../../resources/svg'
+import { ClusterBox, Cluster, ClusterRow, ClusterValue } from '../../../../resources/Components/Cluster'
+
+const NotifyTop = styled.div`
+ padding: 24px 0px 16px 0px;
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ align-items: center;
+`
+
+const NotifyMain = styled.div`
+ padding: 24px;
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ align-items: center;
+ font-size: 14.6px;
+ line-height: 22px;
+ font-weight: 400;
+`
+
+const NotifyPrompt = styled.div`
+ padding: 24px;
+ font-weight: 400;
+ text-transform: uppercase;
+`
+
+const ExtensionId = styled.div`
+ margin: 24px 16px;
+ height: 13px;
+ font-weight: 400;
+ text-transform: uppercase;
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ letter-spacing: 0.5px;
+ color: var(--moon);
+`
+
+const VCR = styled.div`
+ font-family: 'FiraCode';
+ font-size: 14px;
+ font-weight: 300;
+ letter-spacing: 0px;
+`
+
+const ConfirmButton = styled.div`
+ padding: 24px;
+ font-weight: 400;
+ text-transform: uppercase;
+ font-size: 16px;
+`
+
+const ExtensionConnectNotification = ({ id, browser, onClose }) => {
+ const respond = (accepted) => link.rpc('respondToExtensionRequest', id, accepted, onClose)
+ const browserName = capitalize(browser)
+ const [copyId, setCopyId] = useState(false)
+
+ return (
+
+
e.stopPropagation()}>
+
+
+
+ {svg.firefox(40)}
+
+
+
+
+
+
+ {`A new ${browserName} extension is attempting to connect as "Frame Companion"`}{' '}
+
+ {`If you did not recently add Frame Companion please verify the extension origin below`}
+
+
+
+
+ {
+ link.send('tray:clipboardData', id)
+ setCopyId(true)
+ setTimeout(() => setCopyId(false), 2000)
+ }}
+ >
+ {copyId ? 'extension origin copied' : {id}}
+
+
+
+
+ Allow this extension to connect?
+
+
+
+ respond(false)}>
+ Decline
+
+ respond(true)}>
+ Accept
+
+
+
+
+
+
+
+ )
+}
+
+export default ExtensionConnectNotification
diff --git a/app/tray/Notify/index.js b/app/tray/Notify/index.js
index a7003e992..f1f869337 100644
--- a/app/tray/Notify/index.js
+++ b/app/tray/Notify/index.js
@@ -8,6 +8,7 @@ import { usesBaseFee } from '../../../resources/domain/transaction'
import { capitalize } from '../../../resources/utils'
import frameIcon from './FrameIcon.png'
+import ExtensionConnectNotification from './ExtensionConnect'
const FEE_WARNING_THRESHOLD_USD = 50
@@ -515,6 +516,7 @@ class Notify extends React.Component {
render() {
const notify = this.store('view.notify')
+
if (notify === 'mainnet') {
return (
this.store.notify()}>
@@ -625,6 +627,10 @@ class Notify extends React.Component {
{this.openExplorer(this.store('view.notifyData'))}
)
+ } else if (notify === 'extensionConnect') {
+ const { browser, id } = this.store('view.notifyData')
+
+ return this.store.notify()} />
} else {
return null
}
diff --git a/main/accounts/types.ts b/main/accounts/types.ts
index fc20ab0af..c32ecfbd8 100644
--- a/main/accounts/types.ts
+++ b/main/accounts/types.ts
@@ -40,11 +40,14 @@ type RequestType =
| 'switchChain'
| 'addToken'
-export interface AccountRequest {
+interface Request {
type: RequestType
+ handlerId: string
+}
+
+export interface AccountRequest extends Request {
origin: string
payload: JSONRPCRequestPayload
- handlerId: string
account: string
status?: RequestStatus
mode?: RequestMode
diff --git a/main/api/origins.ts b/main/api/origins.ts
index 4df61c899..29057ceae 100644
--- a/main/api/origins.ts
+++ b/main/api/origins.ts
@@ -6,6 +6,13 @@ import accounts, { AccessRequest } from '../accounts'
import store from '../store'
const dev = process.env.NODE_ENV === 'development'
+
+const activeExtensionChecks: Record> = {}
+const extensionPrefixes = {
+ firefox: 'moz-extension',
+ safari: 'safari-web-extension'
+}
+
const protocolRegex = /^(?:ws|http)s?:\/\//
interface OriginUpdateResult {
@@ -13,6 +20,13 @@ interface OriginUpdateResult {
hasSession: boolean
}
+type Browser = 'chrome' | 'firefox' | 'safari'
+
+export interface FrameExtension {
+ browser: Browser
+ id: string
+}
+
// allows the Frame extension to request specific methods
const trustedExtensionMethods = ['wallet_getEthereumChains']
@@ -20,7 +34,8 @@ const storeApi = {
getPermission: (address: Address, origin: string) => {
const permissions: Record = store('main.permissions', address) || {}
return Object.values(permissions).find((p) => p.origin === origin)
- }
+ },
+ getKnownExtension: (id: string) => store('main.knownExtensions', id) as boolean
}
export function parseOrigin(origin?: string) {
@@ -39,6 +54,31 @@ async function getPermission(address: Address, origin: string, payload: RPCReque
return permission || requestPermission(address, payload)
}
+async function requestExtensionPermission(extension: FrameExtension) {
+ if (extension.id in activeExtensionChecks) {
+ return activeExtensionChecks[extension.id]
+ }
+
+ const result = new Promise((resolve) => {
+ const obs = store.observer(() => {
+ const isActive = extension.id in activeExtensionChecks
+ const isAllowed = store('main.knownExtensions', extension.id)
+
+ // wait for a response
+ if (isActive && typeof isAllowed !== 'undefined') {
+ delete activeExtensionChecks[extension.id]
+ obs.remove()
+ resolve(isAllowed)
+ }
+ }, 'origins:requestExtension')
+ })
+
+ activeExtensionChecks[extension.id] = result
+ store.notify('extensionConnect', extension)
+
+ return result
+}
+
async function requestPermission(address: Address, fullPayload: RPCRequestPayload) {
const { _origin: originId, ...payload } = fullPayload
@@ -99,29 +139,33 @@ export function updateOrigin(
}
}
-export function isFrameExtension(req: IncomingMessage) {
- const origin = req.headers.origin
- if (!origin) return false
+export function parseFrameExtension(req: IncomingMessage): FrameExtension | undefined {
+ const origin = req.headers.origin || ''
const query = queryString.parse((req.url || '').replace('/', ''))
- const mozOrigin = origin.startsWith('moz-extension://')
- const extOrigin =
- origin.startsWith('chrome-extension://') ||
- origin.startsWith('moz-extension://') ||
- origin.startsWith('safari-web-extension://')
+ const hasExtensionIdentity = query.identity === 'frame-extension'
if (origin === 'chrome-extension://ldcoohedfbjoobcadoglnnmmfbdlmmhf') {
// Match production chrome
- return true
- } else if (mozOrigin || (dev && extOrigin)) {
- // In production, match any Firefox extension origin where query.identity === 'frame-extension'
- // In dev, match any extension where query.identity === 'frame-extension'
- return query.identity === 'frame-extension'
- } else {
- return false
+ return { browser: 'chrome', id: 'ldcoohedfbjoobcadoglnnmmfbdlmmhf' }
+ } else if (origin.startsWith(`${extensionPrefixes.firefox}://`) && hasExtensionIdentity) {
+ // Match production Firefox
+ const extensionId = origin.substring(extensionPrefixes.firefox.length + 3)
+ return { browser: 'firefox', id: extensionId }
+ } else if (origin.startsWith(`${extensionPrefixes.safari}://`) && dev && hasExtensionIdentity) {
+ // Match Safari in dev only
+ return { browser: 'safari', id: 'frame-dev' }
}
}
+export async function isKnownExtension(extension: FrameExtension) {
+ if (extension.browser === 'chrome' || extension.browser === 'safari') return true
+
+ const extensionPermission = storeApi.getKnownExtension(extension.id)
+
+ return extensionPermission ?? requestExtensionPermission(extension)
+}
+
export async function isTrusted(payload: RPCRequestPayload) {
// Permission granted to unknown origins only persist until the Frame is closed, they are not permanent
const { name: originName } = store('main.origins', payload._origin) as { name: string }
diff --git a/main/api/ws.ts b/main/api/ws.ts
index 94cbf6816..0c51c09f2 100644
--- a/main/api/ws.ts
+++ b/main/api/ws.ts
@@ -7,7 +7,14 @@ import provider from '../provider'
import accounts from '../accounts'
import windows from '../windows'
-import { updateOrigin, isTrusted, isFrameExtension, parseOrigin } from './origins'
+import {
+ updateOrigin,
+ isTrusted,
+ parseOrigin,
+ isKnownExtension,
+ FrameExtension,
+ parseFrameExtension
+} from './origins'
import validPayload from './validPayload'
import protectedMethods from './protectedMethods'
import { IncomingMessage, Server } from 'http'
@@ -25,7 +32,7 @@ interface Subscription {
interface FrameWebSocket extends WebSocket {
id: string
origin?: string
- isFrameExtension: boolean
+ frameExtension?: FrameExtension
}
interface ExtensionPayload extends JSONRPCRequestPayload {
@@ -46,7 +53,7 @@ function extendSession(originId: string) {
const handler = (socket: FrameWebSocket, req: IncomingMessage) => {
socket.id = uuid()
socket.origin = req.headers.origin
- socket.isFrameExtension = isFrameExtension(req)
+ socket.frameExtension = parseFrameExtension(req)
const res = (payload: RPCResponsePayload) => {
if (socket.readyState === WebSocket.OPEN) {
@@ -61,7 +68,16 @@ const handler = (socket: FrameWebSocket, req: IncomingMessage) => {
if (!rawPayload) return console.warn('Invalid Payload', data)
let requestOrigin = socket.origin
- if (socket.isFrameExtension) {
+ if (socket.frameExtension) {
+ if (!(await isKnownExtension(socket.frameExtension))) {
+ const error = {
+ message: `Permission denied, approve connection from Frame Companion with id ${socket.frameExtension.id} in Frame to continue`,
+ code: 4001
+ }
+
+ return res({ id: rawPayload.id, jsonrpc: rawPayload.jsonrpc, error })
+ }
+
// Request from extension, swap origin
if (rawPayload.__frameOrigin) {
requestOrigin = rawPayload.__frameOrigin
@@ -75,7 +91,7 @@ const handler = (socket: FrameWebSocket, req: IncomingMessage) => {
if (logTraffic)
log.info(
- `req -> | ${socket.isFrameExtension ? 'ext' : 'ws'} | ${origin} | ${rawPayload.method} | -> | ${
+ `req -> | ${socket.frameExtension ? 'ext' : 'ws'} | ${origin} | ${rawPayload.method} | -> | ${
rawPayload.params
}`
)
@@ -114,7 +130,7 @@ const handler = (socket: FrameWebSocket, req: IncomingMessage) => {
}
if (logTraffic)
log.info(
- `<- res | ${socket.isFrameExtension ? 'ext' : 'ws'} | ${origin} | ${
+ `<- res | ${socket.frameExtension ? 'ext' : 'ws'} | ${origin} | ${
payload.method
} | <- | ${JSON.stringify(response.result || response.error)}`
)
diff --git a/main/rpc/index.js b/main/rpc/index.js
index eddcd7bf6..0f251ad16 100644
--- a/main/rpc/index.js
+++ b/main/rpc/index.js
@@ -134,6 +134,9 @@ const rpc = {
confirmRequestApproval(req, approvalType, approvalData, cb) {
accounts.confirmRequestApproval(req.handlerId, approvalType, approvalData)
},
+ respondToExtensionRequest(id, approved, cb) {
+ callbackWhenDone(() => store.trustExtension(id, approved), cb)
+ },
updateRequest(reqId, actionId, data, cb = () => {}) {
accounts.updateRequest(reqId, actionId, data)
},
diff --git a/main/store/actions/index.js b/main/store/actions/index.js
index 0c1e2235d..1d6d5c3b1 100644
--- a/main/store/actions/index.js
+++ b/main/store/actions/index.js
@@ -501,6 +501,9 @@ module.exports = {
return origins
})
},
+ trustExtension: (u, extensionId, trusted) => {
+ u('main.knownExtensions', (extensions = {}) => ({ ...extensions, [extensionId]: trusted }))
+ },
setBlockHeight: (u, chainId, blockHeight) => {
u('main.networksMeta.ethereum', (chainsMeta) => {
if (chainsMeta[chainId]) {
diff --git a/main/store/state/index.js b/main/store/state/index.js
index f3eec2091..b90041100 100644
--- a/main/store/state/index.js
+++ b/main/store/state/index.js
@@ -214,6 +214,7 @@ const initial = {
derivation: main('trezor.derivation', 'standard')
},
origins: main('origins', {}),
+ knownExtensions: main('knownExtensions', {}),
privacy: {
errorReporting: main('privacy.errorReporting', true)
},
@@ -729,6 +730,10 @@ initial.main.origins = Object.entries(initial.main.origins).reduce((origins, [id
return origins
}, {})
+initial.main.knownExtensions = Object.fromEntries(
+ Object.entries(initial.main.knownExtensions).filter(([id, allowed]) => allowed)
+)
+
// ---
module.exports = () => migrations.apply(initial)
diff --git a/resources/Components/Cluster/style/index.styl b/resources/Components/Cluster/style/index.styl
index 65ea2b3ff..a6f1db4de 100644
--- a/resources/Components/Cluster/style/index.styl
+++ b/resources/Components/Cluster/style/index.styl
@@ -57,7 +57,6 @@
cursor pointer
margin-bottom 0px
position relative
- z-index 3
.clusterValueInteractable
*
diff --git a/test/main/api/origins.test.js b/test/main/api/origins.test.js
index 12fd988c1..dfd30d393 100644
--- a/test/main/api/origins.test.js
+++ b/test/main/api/origins.test.js
@@ -1,7 +1,13 @@
import { v5 as uuidv5 } from 'uuid'
import log from 'electron-log'
-import { parseOrigin, updateOrigin, isTrusted } from '../../../main/api/origins'
+import {
+ parseOrigin,
+ updateOrigin,
+ isTrusted,
+ parseFrameExtension,
+ isKnownExtension
+} from '../../../main/api/origins'
import accounts from '../../../main/accounts'
import store from '../../../main/store'
@@ -152,6 +158,150 @@ describe('#updateOrigin', () => {
})
})
+describe('#parseFrameExtension', () => {
+ it('correctly identifies the Chrome extension', () => {
+ const origin = 'chrome-extension://ldcoohedfbjoobcadoglnnmmfbdlmmhf'
+ const req = { headers: { origin } }
+
+ expect(parseFrameExtension(req)).toStrictEqual({
+ browser: 'chrome',
+ id: 'ldcoohedfbjoobcadoglnnmmfbdlmmhf'
+ })
+ })
+
+ it('does not recognize a Chrome extension with the wrong id', () => {
+ const origin = 'chrome-extension://somebogusid'
+ const req = { headers: { origin } }
+
+ expect(parseFrameExtension(req)).toBeUndefined()
+ })
+
+ it('correctly identifies the Firefox extension', () => {
+ const origin = 'moz-extension://4be0643f-1d98-573b-97cd-ca98a65347dd'
+ const req = { headers: { origin }, url: '/?identity=frame-extension' }
+
+ expect(parseFrameExtension(req)).toStrictEqual({
+ browser: 'firefox',
+ id: '4be0643f-1d98-573b-97cd-ca98a65347dd'
+ })
+ })
+
+ it('does not recognize the Firefox extension without the identity query parameter', () => {
+ const origin = 'moz-extension://4be0643f-1d98-573b-97cd-ca98a65347dd'
+ const req = { headers: { origin }, url: '/' }
+
+ expect(parseFrameExtension(req)).toBeUndefined()
+ })
+
+ it('correctly identifies the Safari extension', async () => {
+ return withEnvironment({ NODE_ENV: 'development' }, async () => {
+ const origin = 'safari-web-extension://4be0643f-1d98-573b-97cd-ca98a65347dd'
+ const req = { headers: { origin }, url: '/?identity=frame-extension' }
+
+ const { parseFrameExtension } = await import('../../../main/api/origins')
+
+ expect(parseFrameExtension(req)).toStrictEqual({
+ browser: 'safari',
+ id: expect.any(String)
+ })
+ })
+ })
+
+ it('does not recognize a Safari extension in production', () => {
+ return withEnvironment({ NODE_ENV: 'production' }, async () => {
+ const origin = 'safari-web-extension://4be0643f-1d98-573b-97cd-ca98a65347dd'
+ const req = { headers: { origin }, url: '/?identity=frame-extension' }
+
+ const { parseFrameExtension } = await import('../../../main/api/origins')
+
+ expect(parseFrameExtension(req)).toBeUndefined()
+ })
+ })
+
+ it('does not recognize the Safari extension without the identity query parameter', () => {
+ return withEnvironment({ NODE_ENV: 'development' }, async () => {
+ const origin = 'safari-web-extension://4be0643f-1d98-573b-97cd-ca98a65347dd'
+ const req = { headers: { origin }, url: '/' }
+
+ const { parseFrameExtension } = await import('../../../main/api/origins')
+
+ expect(parseFrameExtension(req)).toBeUndefined()
+ })
+ })
+
+ it('does not recognize an extension from an unsupported browser', () => {
+ const origin = 'brave-extension://4be0643f-1d98-573b-97cd-ca98a65347dd'
+ const req = { headers: { origin } }
+
+ expect(parseFrameExtension(req)).toBeUndefined()
+ })
+})
+
+describe('#isKnownExtension', () => {
+ beforeEach(() => {
+ store.set('main.knownExtensions', {})
+ store.notify = jest.fn()
+ })
+
+ it('always knows the single Chrome extension', async () => {
+ const extension = { browser: 'chrome', id: 'ldcoohedfbjoobcadoglnnmmfbdlmmhf' }
+ return expect(isKnownExtension(extension)).resolves.toBe(true)
+ })
+
+ it('always knows the single Safari extension', async () => {
+ const extension = { browser: 'safari', id: 'test-frame' }
+ return expect(isKnownExtension(extension)).resolves.toBe(true)
+ })
+
+ it('knows a previously trusted Firefox extension', async () => {
+ const extension = { browser: 'firefox', id: '4be0643f-1d98-573b-97cd-ca98a65347dd' }
+
+ store.set('main.knownExtensions', { [extension.id]: true })
+
+ return expect(isKnownExtension(extension)).resolves.toBe(true)
+ })
+
+ it('rejects a previously rejected Firefox extension', async () => {
+ const extension = { browser: 'firefox', id: '4be0643f-1d98-573b-97cd-ca98a65347dd' }
+
+ store.set('main.knownExtensions', { [extension.id]: false })
+
+ return expect(isKnownExtension(extension)).resolves.toBe(false)
+ })
+
+ it('prompts the user to trust a Firefox extension', async () => {
+ const extension = { browser: 'firefox', id: '4be0643f-1d98-573b-97cd-ca98a65347dd' }
+
+ isKnownExtension(extension)
+
+ expect(store.notify).toHaveBeenCalledWith('extensionConnect', extension)
+ })
+
+ it('allows a user to trust a Firefox extension', async () => {
+ const extension = { browser: 'firefox', id: '4ae0643f-1d98-573b-97cd-ca98a65347dd' }
+
+ store.notify.mockImplementationOnce(() => {
+ // simulate user accepting the request
+ store.set('main.knownExtensions', { [extension.id]: true })
+ store.getObserver('origins:requestExtension').fire()
+ })
+
+ return expect(isKnownExtension(extension)).resolves.toBe(true)
+ })
+
+ it('allows a user to reject a connection from a Firefox extension', async () => {
+ const extension = { browser: 'firefox', id: '4ce0643f-1d98-573b-97cd-ca98a65347dd' }
+
+ store.notify.mockImplementationOnce(() => {
+ // simulate user accepting the request
+ store.set('main.knownExtensions', { [extension.id]: false })
+ store.getObserver('origins:requestExtension').fire()
+ })
+
+ return expect(isKnownExtension(extension)).resolves.toBe(false)
+ })
+})
+
describe('#isTrusted', () => {
const frameTestOriginId = 'bf93061b-3575-40c5-b526-4932b02e1f3f'
@@ -262,3 +412,16 @@ describe('#isTrusted', () => {
})
})
})
+
+// helper functions
+async function withEnvironment(env, test) {
+ const oldEnv = { ...process.env }
+
+ jest.resetModules()
+ process.env = env
+
+ await test()
+
+ process.env = oldEnv
+ jest.resetModules()
+}