From 2d103ad69873793df94f568e4e8e49f812972613 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Mon, 9 Dec 2024 10:53:54 +0200 Subject: [PATCH] Separate login and MFA verification flows --- hass_nabucasa/__init__.py | 13 ++--- hass_nabucasa/auth.py | 115 ++++++++++++++++++++++++++++++-------- hass_nabucasa/const.py | 2 + 3 files changed, 99 insertions(+), 31 deletions(-) diff --git a/hass_nabucasa/__init__.py b/hass_nabucasa/__init__.py index a2d190517..2f54ca5a7 100644 --- a/hass_nabucasa/__init__.py +++ b/hass_nabucasa/__init__.py @@ -220,14 +220,13 @@ def run_executor(self, callback: Callable, *args: Any) -> asyncio.Future: """ return self.client.loop.run_in_executor(None, callback, *args) - async def login( - self, - email: str, - password: str, - totp_code: str | None = None, - ) -> None: + async def login(self, email: str, password: str) -> None: """Log a user in.""" - await self.auth.async_login(email, password, totp_code) + await self.auth.async_login(email, password) + + async def login_verify_totp(self, code: str) -> None: + """Verify TOTP code during login.""" + await self.auth.async_login_verify_totp(code) async def logout(self) -> None: """Close connection and remove all credentials.""" diff --git a/hass_nabucasa/auth.py b/hass_nabucasa/auth.py index a2cb10ec7..14d670fdd 100644 --- a/hass_nabucasa/auth.py +++ b/hass_nabucasa/auth.py @@ -6,6 +6,7 @@ from functools import lru_cache, partial import logging import random +import time from typing import TYPE_CHECKING, Any import async_timeout @@ -15,7 +16,7 @@ import pycognito from pycognito.exceptions import ForceChangePasswordException, MFAChallengeException -from .const import MESSAGE_AUTH_FAIL +from .const import LOGIN_MFA_CHALLENGE_EXPIRATION, MESSAGE_AUTH_FAIL if TYPE_CHECKING: from . import Cloud, _ClientT @@ -31,6 +32,14 @@ class Unauthenticated(CloudError): """Raised when authentication failed.""" +class LoginNotStarted(CloudError): + """Raised when login has not been started but MFA was provided.""" + + +class LoginExpired(CloudError): + """Raised when MFA challenge for login has expired.""" + + class MFARequired(CloudError): """Raised when MFA is required.""" @@ -65,6 +74,44 @@ class UnknownError(CloudError): """Raised when an unknown error occurs.""" +class MFAChallenge: + """Store MFA challenge data.""" + + def __init__(self) -> None: + """Initialize MFA challenge.""" + self._email: str | None = None + self._mfa_tokens: dict[str, str] | None = None + self._timestamp: str | None = None + + def _clear_challenge(self) -> None: + """Clear MFA challenge data.""" + self._email = None + self._mfa_tokens = None + self._timestamp = None + + def store_challenge(self, email: str, mfa_tokens: dict[str, str]) -> None: + """Store new MFA challenge data.""" + self._email = email + self._mfa_tokens = mfa_tokens + self._timestamp = str(int(time.time())) + + def get_active_challenge_data(self) -> tuple[str, dict[str, str]]: + """Return active challenge data.""" + if self._email is None or self._mfa_tokens is None: + raise LoginNotStarted + + if int(time.time()) - int(self._timestamp) > LOGIN_MFA_CHALLENGE_EXPIRATION: + self._clear_challenge() + raise LoginExpired + + email = self._email + mfa_tokens = self._mfa_tokens + + self._clear_challenge() + + return email, mfa_tokens + + AWS_EXCEPTIONS: dict[str, type[CloudError]] = { "CodeMismatchException": InvalidTotpCode, "UserNotFoundException": UserNotFound, @@ -84,6 +131,7 @@ def __init__(self, cloud: Cloud[_ClientT]) -> None: self._refresh_task: asyncio.Task | None = None self._session: boto3.Session | None = None self._request_lock = asyncio.Lock() + self._mfa_challenge = MFAChallenge() cloud.iot.register_on_connect(self.on_connect) cloud.iot.register_on_disconnect(self.on_disconnect) @@ -172,12 +220,7 @@ async def async_forgot_password(self, email: str) -> None: except BotoCoreError as err: raise UnknownError from err - async def async_login( - self, - email: str, - password: str, - totp_code: str | None = None, - ) -> None: + async def async_login(self, email: str, password: str) -> None: """Log user in and fetch certificate.""" try: async with self._request_lock: @@ -187,23 +230,10 @@ async def async_login( partial(self._create_cognito_client, username=email), ) - try: - async with async_timeout.timeout(30): - await self.cloud.run_executor( - partial(cognito.authenticate, password=password), - ) - except MFAChallengeException as err: - if totp_code is None: - raise - - async with async_timeout.timeout(30): - await self.cloud.run_executor( - partial( - cognito.respond_to_software_token_mfa_challenge, - code=totp_code, - mfa_tokens=err.get_tokens(), - ), - ) + async with async_timeout.timeout(30): + await self.cloud.run_executor( + partial(cognito.authenticate, password=password), + ) task = await self.cloud.update_token( cognito.id_token, @@ -215,6 +245,7 @@ async def async_login( await task except MFAChallengeException as err: + self._mfa_challenge.store_challenge(email, err.get_tokens()) raise MFARequired from err except ForceChangePasswordException as err: @@ -226,6 +257,42 @@ async def async_login( except BotoCoreError as err: raise UnknownError from err + async def async_login_verify_totp(self, code: str) -> None: + """Log user in and fetch certificate if MFA is required.""" + try: + async with self._request_lock: + assert not self.cloud.is_logged_in, "Cannot login if already logged in." + + email, mfa_tokens = self._mfa_challenge.get_active_challenge_data() + + cognito: pycognito.Cognito = await self.cloud.run_executor( + partial(self._create_cognito_client, username=email), + ) + + async with async_timeout.timeout(30): + await self.cloud.run_executor( + partial( + cognito.respond_to_software_token_mfa_challenge, + code=code, + mfa_tokens=mfa_tokens, + ), + ) + + task = await self.cloud.update_token( + cognito.id_token, + cognito.access_token, + cognito.refresh_token, + ) + + if task: + await task + + except ClientError as err: + raise _map_aws_exception(err) from err + + except BotoCoreError as err: + raise UnknownError from err + async def async_check_token(self) -> None: """Check that the token is valid and renew if necessary.""" async with self._request_lock: diff --git a/hass_nabucasa/const.py b/hass_nabucasa/const.py index 573160f55..32f5c2e7d 100644 --- a/hass_nabucasa/const.py +++ b/hass_nabucasa/const.py @@ -6,6 +6,8 @@ REQUEST_TIMEOUT = 10 +LOGIN_MFA_CHALLENGE_EXPIRATION = 60 + MODE_PROD = "production" MODE_DEV = "development"