Skip to content

Commit

Permalink
Merge pull request #244 from forta-network/mock-jwt-in-local-mode
Browse files Browse the repository at this point in the history
WIP: Return mock value when fetching jwt in local mode
  • Loading branch information
haseebrabbani authored Jan 18, 2023
2 parents cc8bd29 + 03554e5 commit c1740e1
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 111 deletions.
3 changes: 2 additions & 1 deletion python-sdk/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest==6.2.5
pytest-env==0.6.2
coverage==5.5
coverage==5.5
responses==0.17.0
3 changes: 2 additions & 1 deletion python-sdk/src/forta_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .trace import Trace, TraceAction, TraceResult
from .event_type import EventType
from .network import Network
from .utils import get_json_rpc_url, create_block_event, create_transaction_event, create_alert_event, get_web3_provider, keccak256, get_transaction_receipt, get_alerts, fetch_jwt, decode_jwt, verify_jwt
from .utils import get_json_rpc_url, create_block_event, create_transaction_event, create_alert_event, get_web3_provider, keccak256, get_transaction_receipt, get_alerts
from .jwt import fetch_jwt, decode_jwt, verify_jwt, MOCK_JWT
from web3 import Web3

web3Provider = Web3(Web3.HTTPProvider(get_json_rpc_url()))
112 changes: 112 additions & 0 deletions python-sdk/src/forta_agent/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import requests
import datetime
import json
import base64
import logging
import time
from web3.auto import w3
from web3 import Web3

MOCK_JWT = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJib3QtaWQiOiIweDEzazM4N2IzNzc2OWNlMjQyMzZjNDAzZTc2ZmMzMGYwMWZhNzc0MTc2ZTE0MTZjODYxeWZlNmMwN2RmZWY3MWYiLCJleHAiOjE2NjAxMTk0NDMsImlhdCI6MTY2MDExOTQxMywianRpIjoicWtkNWNmYWQtMTg4NC0xMWVkLWE1YzktMDI0MjBhNjM5MzA4IiwibmJmIjoxNjYwMTE5MzgzLCJzdWIiOiIweDU1NmY4QkU0MmY3NmMwMUY5NjBmMzJDQjE5MzZEMmUwZTBFYjNGNEQifQ.9v5OiiYhDoEbhZ-abbiSXa5y-nQXa104YCN_2mK7SP0'


def fetch_jwt(claims={}, expiresAt=None) -> str:
if(os.environ['NODE_ENV'] != 'production'):
return MOCK_JWT

host_name = 'forta-jwt-provider'
port = 8515
path = '/create'
uri = 'http://{host_name}:{port}{path}'.format(
host_name=host_name, port=port, path=path)

if((expiresAt != None) and (isinstance(expiresAt, datetime.datetime) == False)):
raise Exception("expireAt must be of type datetime")

if expiresAt is not None:
exp_in_sec = expiresAt.timestamp()
claims["exp"] = exp_in_sec

response = requests.request("POST", uri, json={'claims': claims})

if response.status_code == 200:
data = response.json()
return data.get('token')
else:
raise Exception(
"Error occured with response fetching jwt token.", response)


def verify_jwt(token: str, polygonUrl: str = 'https://polygon-rpc.com') -> bool:
splitJwt = token.split('.')
rawHeader = splitJwt[0]
rawPayload = splitJwt[1]

header = json.loads(base64.urlsafe_b64decode(
rawHeader + '==').decode('utf-8'))
payload = json.loads(base64.urlsafe_b64decode(
rawPayload + '==').decode('utf-8'))

alg = header['alg']
botId = payload['bot-id']
expiresAt = payload['exp']
signerAddress = payload['sub']

if (signerAddress is None) or (botId is None):
logging.warning('Invalid claim')
return False

if alg != 'ETH':
logging.warning('Unexpected signing method: {alg}'.format(alg=alg))
return False

currentUnixTime = time.mktime(datetime.datetime.utcnow().utctimetuple())

if expiresAt < currentUnixTime:
logging.warning('Jwt expired')
return False

msg = '{header}.{payload}'.format(header=rawHeader, payload=rawPayload)
msgHash = w3.keccak(text=msg)
b64signature = splitJwt[2]
signature = base64.urlsafe_b64decode(f'{b64signature}=').hex()
recoveredSignerAddress = w3.eth.account.recoverHash(
msgHash, signature=signature)

if recoveredSignerAddress != signerAddress:
logging.warn('Signature invalid: expected={signerAddress}, got={recoveredSignerAddress}'.format(
signerAddress=signerAddress, recoveredSignerAddress=recoveredSignerAddress))
return False

w3Client = Web3(Web3.HTTPProvider(polygonUrl))
DISPTACHER_ABI = [{"inputs": [{"internalType": "uint256", "name": "agentId", "type": "uint256"}, {"internalType": "uint256", "name": "scannerId", "type": "uint256"}],
"name": "areTheyLinked", "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], "stateMutability": "view", "type": "function"}]
# Source: https://docs.forta.network/en/latest/smart-contracts/
DISPATCH_CONTRACT = "0xd46832F3f8EA8bDEFe5316696c0364F01b31a573"
contract = w3Client.eth.contract(
address=DISPATCH_CONTRACT, abi=DISPTACHER_ABI)

areTheyLinked = contract.functions.areTheyLinked(
int(botId, 0), int(recoveredSignerAddress, 0)).call()

return areTheyLinked


class DecodedJwt:
def __init__(self, dict):
self.header = dict.get('header')
self.payload = dict.get('payload')


def decode_jwt(token):
# Adding need 4 byte for pythons b64decode
header = base64.urlsafe_b64decode(
token.split('.')[0] + '==').decode('utf-8')
payload = base64.urlsafe_b64decode(
token.split('.')[1] + '==').decode('utf-8')

return DecodedJwt({
"header": header,
"payload": payload
})
59 changes: 59 additions & 0 deletions python-sdk/src/forta_agent/jwt_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from .jwt import MOCK_JWT, fetch_jwt
import responses
import os
import pytest


@responses.activate
def test_return_mock_jwt():
# Register error response because we should not make a real call
rsp1 = responses.Response(
url="/",
method="POST",
json={"error": "not found"},
status=404,
)
responses.add(rsp1)

os.environ['NODE_ENV'] = 'dev'

token = fetch_jwt()

assert token == MOCK_JWT


@responses.activate
def test_return_valid_JWT():
testJWT = "testJWT"
# Register response because we should not make a real call
rsp1 = responses.Response(
url="http://forta-jwt-provider:8515/create",
method="POST",
json={"token": testJWT},
status=200,
)
responses.add(rsp1)

os.environ['NODE_ENV'] = 'production'

token = fetch_jwt()

assert token == testJWT


@responses.activate
def test_JWT_should_throw_exception():
with pytest.raises(Exception) as e_info:

# Register error response because we should not make a real call
rsp1 = responses.Response(
url="http://forta-jwt-provider:8515/create",
method="POST",
json={"message": 'Bad request'},
status=400,
)
responses.add(rsp1)

os.environ['NODE_ENV'] = 'production'

token = fetch_jwt()
102 changes: 2 additions & 100 deletions python-sdk/src/forta_agent/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import base64
import sys
import os
from jsonc_parser.parser import JsoncParser
import requests
import datetime
import time
from web3.auto import w3
from web3 import Web3
import json
import logging

from .forta_graphql import AlertsResponse

DISPTACHER_ABI = [{"inputs":[{"internalType":"uint256","name":"agentId","type":"uint256"},{"internalType":"uint256","name":"scannerId","type":"uint256"}],"name":"areTheyLinked","outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"view","type":"function"}]
DISPATCH_CONTRACT = "0xd46832F3f8EA8bDEFe5316696c0364F01b31a573"; # Source: https://docs.forta.network/en/latest/smart-contracts/

def get_web3_provider():
from . import web3Provider
Expand Down Expand Up @@ -104,7 +96,8 @@ def get_alerts(dict):
query_options = AlertQueryOptions(dict)
payload = query_options.get_query()

response = requests.request("POST", forta_api, json=payload, headers=headers)
response = requests.request(
"POST", forta_api, json=payload, headers=headers)

if response.status_code == 200:
data = response.json().get('data')
Expand Down Expand Up @@ -139,94 +132,3 @@ def hex_to_int(strVal):

def keccak256(val):
return Web3.keccak(text=val).hex()

def fetch_jwt(claims, expiresAt=None) -> str:
host_name = 'forta-jwt-provider'
port = 8515
path = '/create'

uri = 'http://{host_name}:{port}{path}'.format(host_name=host_name, port=port, path=path)

if( (expiresAt != None) and (isinstance(expiresAt, datetime.datetime) == False)):
raise Exception("expireAt must be of type datetime")

if(expiresAt is not None):
exp_in_sec = expiresAt.timestamp()
claims["exp"] = exp_in_sec

try:
response = requests.request("POST", uri, json={'claims': claims})

if response.status_code == 200:
data = response.json()
return data.get('token')

else:
raise Exception("Error occured with response fetching jwt token.")

except requests.exceptions.RequestException as err:
if("Name does not resolve" in str(err)):
print("Could not resolve host 'forta-jwt-provider'. This url host can only be resolved inside of a running scan node")
raise err
else:
raise err

def verify_jwt(token: str, polygonUrl: str ='https://polygon-rpc.com') -> bool:
splitJwt = token.split('.')
rawHeader = splitJwt[0]
rawPayload = splitJwt[1]

header = json.loads(base64.urlsafe_b64decode(rawHeader + '==').decode('utf-8'))
payload = json.loads(base64.urlsafe_b64decode(rawPayload + '==').decode('utf-8'))

alg = header['alg']
botId = payload['bot-id']
expiresAt = payload['exp']
signerAddress = payload['sub']

if (signerAddress is None) or (botId is None):
logging.warning('Invalid claim')
return False

if alg != 'ETH':
logging.warning('Unexpected signing method: {alg}'.format(alg=alg))
return False

currentUnixTime = time.mktime(datetime.datetime.utcnow().utctimetuple())

if expiresAt < currentUnixTime:
logging.warning('Jwt expired')
return False

msg = '{header}.{payload}'.format(header=rawHeader, payload=rawPayload)
msgHash = w3.keccak(text=msg)
b64signature = splitJwt[2]
signature = base64.urlsafe_b64decode(f'{b64signature}=').hex()
recoveredSignerAddress = w3.eth.account.recoverHash(msgHash, signature=signature)

if recoveredSignerAddress != signerAddress:
logging.warn('Signature invalid: expected={signerAddress}, got={recoveredSignerAddress}'.format(signerAddress=signerAddress, recoveredSignerAddress=recoveredSignerAddress))
return False

w3Client = Web3(Web3.HTTPProvider(polygonUrl))
contract = w3Client.eth.contract(address=DISPATCH_CONTRACT,abi=DISPTACHER_ABI)

areTheyLinked = contract.functions.areTheyLinked(int(botId,0), int(recoveredSignerAddress,0)).call()

return areTheyLinked

class DecodedJwt:
def __init__(self, dict):
self.header = dict.get('header')
self.payload = dict.get('payload')


def decode_jwt(token):
# Adding need 4 byte for pythons b64decode
header = base64.urlsafe_b64decode(token.split('.')[0] + '==').decode('utf-8')
payload = base64.urlsafe_b64decode(token.split('.')[1] + '==').decode('utf-8')

return DecodedJwt({
"header": header,
"payload": payload
})
6 changes: 4 additions & 2 deletions sdk/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import {
import {
fetchJwt,
decodeJwt,
verifyJwt
verifyJwt,
MOCK_JWT
} from "./jwt"
import awilixConfigureContainer from '../cli/di.container';
import {InitializeResponse} from "./initialize.response";
Expand Down Expand Up @@ -120,5 +121,6 @@ export {
getAlerts,
fetchJwt,
decodeJwt,
verifyJwt
verifyJwt,
MOCK_JWT
}
22 changes: 15 additions & 7 deletions sdk/jwt.spec.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import { FetchJwt, provideFetchJwt } from "./jwt";
import { FetchJwt, provideFetchJwt, MOCK_JWT } from "./jwt";

describe("JWT methods", () => {
const mockJWT =
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";

describe("fetchJWT", () => {
let fetchJwt: FetchJwt;
const mockAxios = {
Expand All @@ -20,18 +17,20 @@ describe("JWT methods", () => {

beforeEach(() => resetMocks());

it("should return a JWT string", async () => {
it("should return a JWT string from scan node when in production mode", async () => {
process.env.NODE_ENV = "production";
const jwt = "someJwt";
mockAxios.post.mockReturnValueOnce({
data: {
token: mockJWT,
token: jwt,
},
});
const claims = { some: "claim" };
const expiresAt = new Date();

const token = await fetchJwt(claims, expiresAt);

expect(token).toEqual(mockJWT);
expect(token).toEqual(jwt);
expect(mockAxios.post).toHaveBeenCalledTimes(1);
expect(mockAxios.post).toHaveBeenCalledWith(
`http://forta-jwt-provider:8515/create`,
Expand All @@ -43,5 +42,14 @@ describe("JWT methods", () => {
}
);
});

it("should return a mock JWT when not in production mode", async () => {
process.env.NODE_ENV = "development";

const token = await fetchJwt({});

expect(token).toEqual(MOCK_JWT);
expect(mockAxios.post).toHaveBeenCalledTimes(0);
});
});
});
Loading

0 comments on commit c1740e1

Please sign in to comment.