Skip to content

Commit

Permalink
Add connection check feature to authentication process (#797)
Browse files Browse the repository at this point in the history
* Add connection check feature to authentication process

* Fix docstring

* Extract and make it more spesific

* add to base methods
  • Loading branch information
ludeeus authored Feb 13, 2025
1 parent bfea2fb commit cb14a2c
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 9 deletions.
43 changes: 39 additions & 4 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from .files import Files
from .google_report_state import GoogleReportState
from .ice_servers import IceServers
from .instance_api import InstanceApi
from .instance_api import (
InstanceApi,
InstanceConnectionDetails,
)
from .iot import CloudIoT
from .remote import RemoteUI
from .utils import UTC, gather_callbacks, parse_date, utcnow
Expand All @@ -41,6 +44,15 @@
_LOGGER = logging.getLogger(__name__)


class AlreadyConnectedError(CloudError):
"""Raised when a connection is already established."""

def __init__(self, *, details: InstanceConnectionDetails) -> None:
"""Initialize an already connected error."""
super().__init__("instance_already_connected")
self.details = details


class Cloud(Generic[_ClientT]):
"""Store the configuration of the cloud connection."""

Expand Down Expand Up @@ -164,6 +176,23 @@ def user_info_path(self) -> Path:
"""Get path to the stored auth."""
return self.path(f"{self.mode}_auth.json")

async def ensure_not_connected(
self,
*,
access_token: str,
) -> None:
"""Raise AlreadyConnectedError if already connected."""
try:
connection = await self.instance.connection(
skip_token_check=True,
access_token=access_token,
)
except CloudError:
return

if connection["connected"]:
raise AlreadyConnectedError(details=connection["details"])

async def update_token(
self,
id_token: str,
Expand Down Expand Up @@ -223,18 +252,24 @@ def run_executor(self, callback: Callable, *args: Any) -> asyncio.Future:
"""
return self.client.loop.run_in_executor(None, callback, *args)

async def login(self, email: str, password: str) -> None:
async def login(
self, email: str, password: str, *, check_connection: bool = False
) -> None:
"""Log a user in."""
await self.auth.async_login(email, password)
await self.auth.async_login(email, password, check_connection=check_connection)

async def login_verify_totp(
self,
email: str,
code: str,
mfa_tokens: dict[str, Any],
*,
check_connection: bool = False,
) -> None:
"""Verify TOTP code during login."""
await self.auth.async_login_verify_totp(email, code, mfa_tokens)
await self.auth.async_login_verify_totp(
email, code, mfa_tokens, check_connection=check_connection
)

async def logout(self) -> None:
"""Close connection and remove all credentials."""
Expand Down
4 changes: 3 additions & 1 deletion hass_nabucasa/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ async def _call_cloud_api(
client_timeout: ClientTimeout | None = None,
jsondata: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
skip_token_check: bool = False,
) -> Any:
"""Call cloud API."""
data: dict[str, Any] | list[Any] | str | None = None
await self._cloud.auth.async_check_token()
if not skip_token_check:
await self._cloud.auth.async_check_token()
if TYPE_CHECKING:
assert self._cloud.id_token is not None

Expand Down
20 changes: 19 additions & 1 deletion hass_nabucasa/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,13 @@ async def async_forgot_password(self, email: str) -> None:
except BotoCoreError as err:
raise UnknownError from err

async def async_login(self, email: str, password: str) -> None:
async def async_login(
self,
email: str,
password: str,
*,
check_connection: bool = False,
) -> None:
"""Log user in and fetch certificate."""
try:
async with self._request_lock:
Expand All @@ -199,6 +205,11 @@ async def async_login(self, email: str, password: str) -> None:
partial(cognito.authenticate, password=password),
)

if check_connection:
await self.cloud.ensure_not_connected(
access_token=cognito.access_token
)

task = await self.cloud.update_token(
cognito.id_token,
cognito.access_token,
Expand All @@ -225,6 +236,8 @@ async def async_login_verify_totp(
email: str,
code: str,
mfa_tokens: dict[str, Any],
*,
check_connection: bool = False,
) -> None:
"""Log user in and fetch certificate if MFA is required."""
try:
Expand All @@ -246,6 +259,11 @@ async def async_login_verify_totp(
),
)

if check_connection:
await self.cloud.ensure_not_connected(
access_token=cognito.access_token
)

task = await self.cloud.update_token(
cognito.id_token,
cognito.access_token,
Expand Down
10 changes: 8 additions & 2 deletions hass_nabucasa/instance_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,21 @@ def hostname(self) -> str:
assert self._cloud.servicehandlers_server is not None
return self._cloud.servicehandlers_server

async def connection(self) -> InstanceConnection:
async def connection(
self,
*,
access_token: str | None = None,
skip_token_check: bool = False,
) -> InstanceConnection:
"""Get the connection details."""
_LOGGER.debug("Getting instance connection details")
try:
details: InstanceConnection = await self._call_cloud_api(
path="/instance/connection",
headers={
hdrs.AUTHORIZATION: self._cloud.access_token,
hdrs.AUTHORIZATION: access_token or self._cloud.access_token,
},
skip_token_check=skip_token_check,
)
except CloudApiError as err:
raise InstanceApiError(err, orig_exc=err) from err
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ def _executor(call, *args):
cloud.websession = aioclient_mock.create_session(loop)
cloud.client = MockClient(tmp_path, loop, cloud.websession)

async def update_token(id_token, access_token, refresh_token=None):
async def update_token(
id_token,
access_token,
refresh_token=None,
):
cloud.id_token = id_token
cloud.access_token = access_token
if refresh_token is not None:
cloud.refresh_token = refresh_token

cloud.update_token = MagicMock(side_effect=update_token)
cloud.ensure_not_connected = AsyncMock()

yield cloud

Expand Down
17 changes: 17 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ async def test_login(mock_cognito, mock_cloud):
)


async def test_login_with_check_connection(mock_cognito, mock_cloud):
"""Test login with connection check."""
auth = auth_api.CognitoAuth(mock_cloud)
mock_cognito.id_token = "test_id_token"
mock_cognito.access_token = "test_access_token"
mock_cognito.refresh_token = "test_refresh_token"

await auth.async_login("user", "pass", check_connection=True)

assert len(mock_cognito.authenticate.mock_calls) == 1
mock_cloud.update_token.assert_called_once_with(
"test_id_token",
"test_access_token",
"test_refresh_token",
)


async def test_register(mock_cognito, cloud_mock):
"""Test registering an account."""
auth = auth_api.CognitoAuth(cloud_mock)
Expand Down

0 comments on commit cb14a2c

Please sign in to comment.