Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Only update credentials on login #6437

Merged
merged 2 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions openbb_platform/core/openbb_core/app/model/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,7 @@ def show(self):
[f"{k}: {v}" for k, v in sorted(self.model_dump(mode="json").items())]
)
)

def update(self, incoming: "Credentials"):
"""Update current credentials."""
self.__dict__.update(incoming.model_dump(exclude_none=True))
9 changes: 0 additions & 9 deletions openbb_platform/core/openbb_core/app/service/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ def write_default_user_settings(
)
path.write_text(user_settings_json, encoding="utf-8")

@classmethod
def update_default(cls, user_settings: UserSettings) -> UserSettings:
"""Update default user settings."""
d1 = cls.read_default_user_settings().model_dump()
d2 = user_settings.model_dump() if user_settings else {}
updated = cls._merge_dicts([d1, d2])

return UserSettings.model_validate(updated)

@staticmethod
def _merge_dicts(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Merge a list of dictionaries."""
Expand Down
12 changes: 5 additions & 7 deletions openbb_platform/core/openbb_core/app/static/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from openbb_core.app.static.app_factory import BaseApp


class Account:
class Account: # noqa: D205, D400
"""/account
login
logout
Expand Down Expand Up @@ -123,8 +123,8 @@ def login(
"""
self._hub_service = self._create_hub_service(email, password, pat)
incoming = self._hub_service.pull()
updated: UserSettings = UserService.update_default(incoming)
self._base_app._command_runner.user_settings = updated
self._base_app.user.profile = incoming.profile
self._base_app.user.credentials.update(incoming.credentials)
if remember_me:
Path(self._openbb_directory).mkdir(parents=False, exist_ok=True)
session_file = Path(self._openbb_directory, self.SESSION_FILE)
Expand Down Expand Up @@ -185,10 +185,8 @@ def refresh(self, return_settings: bool = False) -> Optional[UserSettings]:
)
else:
incoming = self._hub_service.pull()
updated: UserSettings = UserService.update_default(incoming)
updated.id = self._base_app._command_runner.user_settings.id
self._base_app._command_runner.user_settings = updated

self._base_app.user.profile = incoming.profile
self._base_app.user.credentials.update(incoming.credentials)
if return_settings:
return self._base_app._command_runner.user_settings
return None
Expand Down
17 changes: 7 additions & 10 deletions openbb_platform/core/openbb_core/app/static/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@
from openbb_core.app.static.app_factory import BaseApp


class Coverage:
"""Coverage class.

/coverage

providers
commands
command_model
command_schemas
reference
class Coverage: # noqa: D205, D400
"""/coverage
providers
commands
command_model
command_schemas
reference
"""

def __init__(self, app: "BaseApp"):
Expand Down
15 changes: 0 additions & 15 deletions openbb_platform/core/tests/app/service/test_user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import tempfile
from pathlib import Path

from openbb_core.app.model.defaults import Defaults
from openbb_core.app.service.user_service import (
UserService,
UserSettings,
Expand Down Expand Up @@ -47,20 +46,6 @@ def test_write_default_user_settings():
temp_path.unlink()


def test_update_default():
"""Test update default user settings."""

# Some settings
defaults_test = Defaults(routes={"test": {"test": "test"}})
other_settings = UserSettings(defaults=defaults_test)

# Update the default settings
updated_settings = UserService.update_default(other_settings)

assert "test" in updated_settings.defaults.model_dump()["routes"]
assert updated_settings.defaults.model_dump()["routes"]["test"] == {"test": "test"}


def test_merge_dicts():
"""Test merge dicts."""
result = UserService._merge_dicts( # pylint: disable=protected-access
Expand Down
Loading