From cb50094fe56b507393d53bf83d2d189a2cfe9b73 Mon Sep 17 00:00:00 2001 From: Haseeb Rabbani Date: Tue, 23 Aug 2022 17:32:22 +0400 Subject: [PATCH 1/2] dont check network id in production --- cli/commands/run/index.spec.ts | 17 ++++++++++------- cli/commands/run/index.ts | 9 ++++++--- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/cli/commands/run/index.spec.ts b/cli/commands/run/index.spec.ts index ec5fa3e..fffda74 100644 --- a/cli/commands/run/index.spec.ts +++ b/cli/commands/run/index.spec.ts @@ -13,6 +13,7 @@ describe("run", () => { getNetwork: jest.fn() } as any const mockExit = jest.spyOn(process, 'exit').mockImplementation(); + let mockIsProduction = true let consoleSpy = jest.spyOn(console, 'warn'); const resetMocks = () => { @@ -21,6 +22,7 @@ describe("run", () => { mockCache.save.mockReset() mockExit.mockReset() consoleSpy.mockReset() + mockIsProduction = true } const defaultChainIds = [1]; @@ -34,11 +36,12 @@ describe("run", () => { it("logs a warning if detected chainId is not in list of configured chainIds", async () => { resetMocks(); (mockEthersProvider.getNetwork as jest.Mock).mockReturnValueOnce({chainId: 234, name: "test"}) + mockIsProduction = false const mockCliArgs = {tx: '0x123'} const mockRunTransaction = jest.fn() mockContainer.resolve.mockReturnValueOnce(mockRunTransaction) - run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockCache, mockCliArgs) + run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockIsProduction, mockCache, mockCliArgs) await run() expect(consoleSpy).toBeCalledTimes(1) }) @@ -48,7 +51,7 @@ describe("run", () => { const mockRunTransaction = jest.fn() mockContainer.resolve.mockReturnValueOnce(mockRunTransaction) - run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockCache, mockCliArgs) + run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockIsProduction, mockCache, mockCliArgs) await run() expect(mockContainer.resolve).toHaveBeenCalledTimes(1) @@ -65,7 +68,7 @@ describe("run", () => { const mockRunBlock = jest.fn() mockContainer.resolve.mockReturnValueOnce(mockRunBlock) - run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockCache, mockCliArgs) + run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockIsProduction, mockCache, mockCliArgs) await run() expect(mockContainer.resolve).toHaveBeenCalledTimes(1) @@ -82,7 +85,7 @@ describe("run", () => { const mockRunBlockRange = jest.fn() mockContainer.resolve.mockReturnValueOnce(mockRunBlockRange) - run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockCache, mockCliArgs) + run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockIsProduction, mockCache, mockCliArgs) await run() expect(mockContainer.resolve).toHaveBeenCalledTimes(1) @@ -99,7 +102,7 @@ describe("run", () => { const mockRunFile = jest.fn() mockContainer.resolve.mockReturnValueOnce(mockRunFile) - run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockCache, mockCliArgs) + run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockIsProduction, mockCache, mockCliArgs) await run() expect(mockContainer.resolve).toHaveBeenCalledTimes(1) @@ -116,7 +119,7 @@ describe("run", () => { const mockRunProdServer = jest.fn() mockContainer.resolve.mockReturnValueOnce(mockRunProdServer) - run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockCache, mockCliArgs) + run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockIsProduction, mockCache, mockCliArgs) await run() expect(mockContainer.resolve).toHaveBeenCalledTimes(1) @@ -133,7 +136,7 @@ describe("run", () => { const mockRunLive = jest.fn() mockContainer.resolve.mockReturnValueOnce(mockRunLive) - run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockCache, mockCliArgs) + run = provideRun(mockContainer, mockEthersProvider, defaultChainIds, testRpcUrl, mockIsProduction, mockCache, mockCliArgs) await run() expect(mockContainer.resolve).toHaveBeenCalledTimes(1) diff --git a/cli/commands/run/index.ts b/cli/commands/run/index.ts index 5e1ddd6..7c551f8 100644 --- a/cli/commands/run/index.ts +++ b/cli/commands/run/index.ts @@ -15,6 +15,7 @@ export default function provideRun( ethersProvider: providers.JsonRpcProvider, chainIds: number[], jsonRpcUrl: string, + isProduction: boolean, cache: Cache, args: any ): CommandHandler { @@ -27,9 +28,11 @@ export default function provideRun( return async function run(runtimeArgs: any = {}) { args = { ...args, ...runtimeArgs } - const network = await ethersProvider.getNetwork(); - - if(!network || !chainIds.includes(network.chainId)) console.warn(`Warning: Detected chainId mismatch between ${jsonRpcUrl} [chainId: ${network.chainId}] and package.json [chainIds: ${chainIds}]. \n`) + // only check network id during local development + if (!isProduction) { + const network = await ethersProvider.getNetwork(); + if(!network || !chainIds.includes(network.chainId)) console.warn(`Warning: Detected chainId mismatch between ${jsonRpcUrl} [chainId: ${network.chainId}] and package.json [chainIds: ${chainIds}]. \n`) + } // we manually inject the run functions here (instead of through the provide function above) so that // we get RUNTIME errors if certain configuration is missing for that run function e.g. jsonRpcUrl From dd5a8ffb9c17da9ef78200ff65a67193d130dc12 Mon Sep 17 00:00:00 2001 From: Robert Leonard <40375385+Robert-H-Leonard@users.noreply.github.com> Date: Wed, 7 Sep 2022 12:06:20 -0700 Subject: [PATCH 2/2] Adding js/ts util method to verify JWT's generated by a scanner (#212) * Adding js/ts util method to verify JWT's generated by a scanner * Updating variable names and adding polygon rpc * Renaming to remove redundant name * Making sure to export correct method names * Adding param to input polygon url + checking exp of JWT * Adding verify method to python sdk * updating import Co-authored-by: Robert Leonard --- python-sdk/src/forta_agent/__init__.py | 2 +- python-sdk/src/forta_agent/utils.py | 68 ++++++++++++++++++++++- sdk/index.ts | 10 ++-- sdk/utils.ts | 77 +++++++++++++++++++++++++- 4 files changed, 146 insertions(+), 11 deletions(-) diff --git a/python-sdk/src/forta_agent/__init__.py b/python-sdk/src/forta_agent/__init__.py index 95a751b..b3b4efd 100644 --- a/python-sdk/src/forta_agent/__init__.py +++ b/python-sdk/src/forta_agent/__init__.py @@ -7,7 +7,7 @@ from .trace import Trace, TraceAction, TraceResult from .event_type import EventType from .network import Network -from .utils import get_json_rpc_url, create_block_event, create_transaction_event, get_web3_provider, keccak256, get_transaction_receipt, get_alerts, fetch_Jwt_token, decode_Jwt_token +from .utils import get_json_rpc_url, create_block_event, create_transaction_event, get_web3_provider, keccak256, get_transaction_receipt, get_alerts, fetch_jwt, decode_jwt, verify_jwt from web3 import Web3 web3Provider = Web3(Web3.HTTPProvider(get_json_rpc_url())) diff --git a/python-sdk/src/forta_agent/utils.py b/python-sdk/src/forta_agent/utils.py index b861f99..5af65e3 100644 --- a/python-sdk/src/forta_agent/utils.py +++ b/python-sdk/src/forta_agent/utils.py @@ -5,9 +5,16 @@ import sha3 import requests import datetime +import time +from web3.auto import w3 +from web3 import Web3 +import json +import logging from .forta_graphql import AlertsResponse +DISPTACHER_ABI = [{"inputs":[{"internalType":"uint256","name":"agentId","type":"uint256"},{"internalType":"uint256","name":"scannerId","type":"uint256"}],"name":"areTheyLinked","outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"view","type":"function"}] +DISPATCH_CONTRACT = "0xd46832F3f8EA8bDEFe5316696c0364F01b31a573"; # Source: https://docs.forta.network/en/latest/smart-contracts/ def get_web3_provider(): from . import web3Provider @@ -130,7 +137,7 @@ def keccak256(val): hash.update(bytes(val, encoding='utf-8')) return f'0x{hash.hexdigest()}' -def fetch_Jwt_token(claims, expiresAt=None) -> str: +def fetch_jwt(claims, expiresAt=None) -> str: host_name = 'forta-jwt-provider' port = 8515 path = '/create' @@ -162,7 +169,62 @@ def fetch_Jwt_token(claims, expiresAt=None) -> str: else: raise err +def verify_jwt(token: str, polygonUrl: str ='https://polygon-rpc.com') -> bool: + splitJwt = token.split('.') + rawHeader = splitJwt[0] + rawPayload = splitJwt[1] -def decode_Jwt_token(token): + header = json.loads(base64.urlsafe_b64decode(rawHeader + '==').decode('utf-8')) + payload = json.loads(base64.urlsafe_b64decode(rawPayload + '==').decode('utf-8')) + + alg = header['alg'] + botId = payload['bot-id'] + expiresAt = payload['exp'] + signerAddress = payload['sub'] + + if (signerAddress is None) or (botId is None): + logging.warning('Invalid claim') + return False + + if alg != 'ETH': + logging.warning('Unexpected signing method: {alg}'.format(alg=alg)) + return False + + currentUnixTime = time.mktime(datetime.datetime.utcnow().utctimetuple()) + + if expiresAt < currentUnixTime: + logging.warning('Jwt expired') + return False + + msg = '{header}.{payload}'.format(header=rawHeader, payload=rawPayload) + msgHash = w3.keccak(text=msg) + b64signature = splitJwt[2] + signature = base64.urlsafe_b64decode(f'{b64signature}=').hex() + recoveredSignerAddress = w3.eth.account.recoverHash(msgHash, signature=signature) + + if recoveredSignerAddress != signerAddress: + logging.warn('Signature invalid: expected={signerAddress}, got={recoveredSignerAddress}'.format(signerAddress=signerAddress, recoveredSignerAddress=recoveredSignerAddress)) + return False + + w3Client = Web3(Web3.HTTPProvider(polygonUrl)) + contract = w3Client.eth.contract(address=DISPATCH_CONTRACT,abi=DISPTACHER_ABI) + + areTheyLinked = contract.functions.areTheyLinked(int(botId,0), int(recoveredSignerAddress,0)).call() + + return areTheyLinked + +class DecodedJwt: + def __init__(self, dict): + self.header = dict.get('header') + self.payload = dict.get('payload') + + +def decode_jwt(token): # Adding need 4 byte for pythons b64decode - return base64.b64decode(token.split('.')[1] + '==').decode('utf-8') \ No newline at end of file + header = base64.urlsafe_b64decode(token.split('.')[0] + '==').decode('utf-8') + payload = base64.urlsafe_b64decode(token.split('.')[1] + '==').decode('utf-8') + + return DecodedJwt({ + "header": header, + "payload": payload + }) \ No newline at end of file diff --git a/sdk/index.ts b/sdk/index.ts index e25a5d0..57df328 100644 --- a/sdk/index.ts +++ b/sdk/index.ts @@ -17,8 +17,9 @@ import { isPrivateFindings, getTransactionReceipt, getAlerts, - fetchJwtToken, - decodeJwtToken + fetchJwt, + decodeJwt, + verifyJwt } from "./utils" import awilixConfigureContainer from '../cli/di.container'; @@ -103,6 +104,7 @@ export { configureContainer, getTransactionReceipt, getAlerts, - fetchJwtToken, - decodeJwtToken + fetchJwt, + decodeJwt, + verifyJwt } \ No newline at end of file diff --git a/sdk/utils.ts b/sdk/utils.ts index 0e02d36..7140398 100644 --- a/sdk/utils.ts +++ b/sdk/utils.ts @@ -10,6 +10,7 @@ import { Log, Receipt } from './receipt' import { TxEventBlock } from './transaction.event' import { Block } from './block' import { ethers } from '.' +import { toUtf8Bytes } from "@ethersproject/strings" import { AlertQueryOptions, AlertsResponse, FORTA_GRAPHQL_URL, getQueryFromAlertOptions, RawGraphqlAlertResponse } from './graphql/forta' import axios from 'axios' @@ -154,7 +155,7 @@ export const getAlerts = async (query: AlertQueryOptions): Promise => { +export const fetchJwt = async (claims: {}, expiresAt?: Date): Promise<{token: string} | null> => { const hostname = 'forta-jwt-provider' const port = 8515 const path = '/create' @@ -188,6 +189,76 @@ export const fetchJwtToken = async (claims: {}, expiresAt?: Date): Promise<{toke } } -export const decodeJwtToken = (token: string) => { - return JSON.parse(Buffer.from((token as string).split('.')[1], 'base64').toString()) +interface DecodedJwt { + header: any, + payload: any +} + +export const decodeJwt = (token: string): DecodedJwt => { + + const splitJwt = (token).split('.'); + const header = JSON.parse(Buffer.from(splitJwt[0], 'base64').toString()) + const payload = JSON.parse(Buffer.from(splitJwt[1], 'base64').toString()) + + return { + header, + payload + } +} + +const DISPTACHER_ARE_THEY_LINKED = "function areTheyLinked(uint256 agentId, uint256 scannerId) external view returns(bool)"; +const DISPATCH_CONTRACT = "0xd46832F3f8EA8bDEFe5316696c0364F01b31a573"; // Source: https://docs.forta.network/en/latest/smart-contracts/ + +export const verifyJwt = async (token: string, polygonRpcUrl: string = "https://polygon-rpc.com"): Promise => { + const splitJwt = (token).split('.') + const rawHeader = splitJwt[0] + const rawPayload = splitJwt[1] + + const header = JSON.parse(Buffer.from(rawHeader, 'base64').toString()) + const payload = JSON.parse(Buffer.from(rawPayload, 'base64').toString()) + + const botId = payload["bot-id"] as string + const expiresAt = payload["exp"] as number + const algorithm = header?.alg; + + if(algorithm !== "ETH") { + console.warn(`Unexpected signing method: ${algorithm}`) + return false + } + + if(!botId) { + console.warn(`Invalid claim`) + return false + } + + const signerAddress = payload?.sub as string | undefined // public key should be contract address that signed the JWT + + if(!signerAddress) { + console.warn(`Invalid claim`) + return false + } + + const currentUnixTime = Math.floor((Date.now() / 1000)) + + if(expiresAt < currentUnixTime) { + console.warn(`Jwt is expired`) + return false + } + + const digest = ethers.utils.keccak256(toUtf8Bytes(`${rawHeader}.${rawPayload}`)) + const signature = `0x${ Buffer.from(splitJwt[2], 'base64').toString('hex')}` + + const recoveredSignerAddress = ethers.utils.recoverAddress(digest, signature) // Contract address that signed message + + if(recoveredSignerAddress !== signerAddress) { + console.warn(`Signature invalid: expected=${signerAddress}, got=${recoveredSignerAddress}`) + return false + } + + const polygonProvider = new ethers.providers.JsonRpcProvider(polygonRpcUrl) + + const dispatchContract = new ethers.Contract(DISPATCH_CONTRACT, [DISPTACHER_ARE_THEY_LINKED], polygonProvider) + const areTheyLinked = await dispatchContract.areTheyLinked(botId, recoveredSignerAddress) + + return areTheyLinked } \ No newline at end of file