Skip to content

Commit

Permalink
[BugFix] - Replace python-jose by PyJWT (#6407)
Browse files Browse the repository at this point in the history
* fix: replace python-jose by PyJWT

* add unit test

* add unit test

---------

Co-authored-by: Theodore Aptekarev <aptekarev@gmail.com>
  • Loading branch information
montezdesousa and piiq committed May 14, 2024
1 parent 88cdd75 commit 0769378
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 73 deletions.
10 changes: 4 additions & 6 deletions openbb_platform/core/openbb_core/app/service/hub_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from warnings import warn

from fastapi import HTTPException
from jose import JWTError
from jose.exceptions import ExpiredSignatureError
from jose.jwt import decode, get_unverified_header
from jwt import ExpiredSignatureError, PyJWTError, decode, get_unverified_header
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.app.model.credentials import Credentials
from openbb_core.app.model.hub.hub_session import HubSession
Expand Down Expand Up @@ -139,7 +137,7 @@ def _get_session_from_platform_token(self, token: str) -> HubSession:
if not token:
raise OpenBBError("Platform personal access token not found.")

self.check_token_expiration(token)
self._check_token_expiration(token)

response = post(
url=self._base_url + "/sdk/login",
Expand Down Expand Up @@ -259,7 +257,7 @@ def platform2hub(self, credentials: Credentials) -> HubUserSettings:
return settings

@staticmethod
def check_token_expiration(token: str) -> None:
def _check_token_expiration(token: str) -> None:
"""Check token expiration, raises exception if expired."""
try:
header_data = get_unverified_header(token)
Expand All @@ -271,5 +269,5 @@ def check_token_expiration(token: str) -> None:
)
except ExpiredSignatureError as e:
raise OpenBBError("Platform personal access token expired.") from e
except JWTError as e:
except PyJWTError as e:
raise OpenBBError("Failed to decode Platform token.") from e
84 changes: 18 additions & 66 deletions openbb_platform/core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion openbb_platform/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ websockets = "^12.0"
pandas = ">=1.5.3"
html5lib = "^1.1"
fastapi = "^0.104.1"
python-jose = "^3.3.0"
uuid7 = "^0.1.0"
posthog = "^3.3.1"
python-multipart = "^0.0.7"
Expand All @@ -23,6 +22,7 @@ importlib-metadata = "^6.8.0"
python-dotenv = "^1.0.0"
aiohttp = "^3.9.5"
ruff = "^0.1.6"
pyjwt = "^2.8.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.0.0"
Expand Down
34 changes: 34 additions & 0 deletions openbb_platform/core/tests/app/service/test_hub_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@


from pathlib import Path
from time import time
from unittest.mock import MagicMock, patch

import pytest
from jwt import encode
from openbb_core.app.service.hub_service import (
Credentials,
HubService,
Expand Down Expand Up @@ -308,3 +310,35 @@ def test_platform2hub():
assert user_settings.features_keys["API_FRED_KEY"] == "fred"
assert user_settings.features_keys["benzinga_api_key"] == "benzinga"
assert "some_api_key" not in user_settings.features_keys


@pytest.mark.parametrize(
"token, message",
[
# valid
(
encode(
{"some": "payload", "exp": int(time()) + 100},
"secret",
algorithm="HS256",
),
None,
),
# expired
(
encode(
{"some": "payload", "exp": int(time())}, "secret", algorithm="HS256"
),
"Platform personal access token expired.",
),
# invalid
("invalid_token", "Failed to decode Platform token."),
],
)
def test__check_token_expiration(token, message):
"""Test check token expiration function."""
if message:
with pytest.raises(OpenBBError, match=message):
HubService._check_token_expiration(token)
else:
HubService._check_token_expiration(token)

0 comments on commit 0769378

Please sign in to comment.