Skip to content

Commit

Permalink
Separate login and MFA verification flows
Browse files Browse the repository at this point in the history
  • Loading branch information
klejejs committed Dec 9, 2024
1 parent 2a64387 commit 2d103ad
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 31 deletions.
13 changes: 6 additions & 7 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,13 @@ def run_executor(self, callback: Callable, *args: Any) -> asyncio.Future:
"""
return self.client.loop.run_in_executor(None, callback, *args)

async def login(
self,
email: str,
password: str,
totp_code: str | None = None,
) -> None:
async def login(self, email: str, password: str) -> None:
"""Log a user in."""
await self.auth.async_login(email, password, totp_code)
await self.auth.async_login(email, password)

async def login_verify_totp(self, code: str) -> None:
"""Verify TOTP code during login."""
await self.auth.async_login_verify_totp(code)

async def logout(self) -> None:
"""Close connection and remove all credentials."""
Expand Down
115 changes: 91 additions & 24 deletions hass_nabucasa/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import lru_cache, partial
import logging
import random
import time
from typing import TYPE_CHECKING, Any

import async_timeout
Expand All @@ -15,7 +16,7 @@
import pycognito
from pycognito.exceptions import ForceChangePasswordException, MFAChallengeException

from .const import MESSAGE_AUTH_FAIL
from .const import LOGIN_MFA_CHALLENGE_EXPIRATION, MESSAGE_AUTH_FAIL

if TYPE_CHECKING:
from . import Cloud, _ClientT
Expand All @@ -31,6 +32,14 @@ class Unauthenticated(CloudError):
"""Raised when authentication failed."""


class LoginNotStarted(CloudError):
"""Raised when login has not been started but MFA was provided."""


class LoginExpired(CloudError):
"""Raised when MFA challenge for login has expired."""


class MFARequired(CloudError):
"""Raised when MFA is required."""

Expand Down Expand Up @@ -65,6 +74,44 @@ class UnknownError(CloudError):
"""Raised when an unknown error occurs."""


class MFAChallenge:
"""Store MFA challenge data."""

def __init__(self) -> None:
"""Initialize MFA challenge."""
self._email: str | None = None
self._mfa_tokens: dict[str, str] | None = None
self._timestamp: str | None = None

def _clear_challenge(self) -> None:
"""Clear MFA challenge data."""
self._email = None
self._mfa_tokens = None
self._timestamp = None

def store_challenge(self, email: str, mfa_tokens: dict[str, str]) -> None:
"""Store new MFA challenge data."""
self._email = email
self._mfa_tokens = mfa_tokens
self._timestamp = str(int(time.time()))

def get_active_challenge_data(self) -> tuple[str, dict[str, str]]:
"""Return active challenge data."""
if self._email is None or self._mfa_tokens is None:
raise LoginNotStarted

if int(time.time()) - int(self._timestamp) > LOGIN_MFA_CHALLENGE_EXPIRATION:
self._clear_challenge()
raise LoginExpired

email = self._email
mfa_tokens = self._mfa_tokens

self._clear_challenge()

return email, mfa_tokens


AWS_EXCEPTIONS: dict[str, type[CloudError]] = {
"CodeMismatchException": InvalidTotpCode,
"UserNotFoundException": UserNotFound,
Expand All @@ -84,6 +131,7 @@ def __init__(self, cloud: Cloud[_ClientT]) -> None:
self._refresh_task: asyncio.Task | None = None
self._session: boto3.Session | None = None
self._request_lock = asyncio.Lock()
self._mfa_challenge = MFAChallenge()

cloud.iot.register_on_connect(self.on_connect)
cloud.iot.register_on_disconnect(self.on_disconnect)
Expand Down Expand Up @@ -172,12 +220,7 @@ async def async_forgot_password(self, email: str) -> None:
except BotoCoreError as err:
raise UnknownError from err

async def async_login(
self,
email: str,
password: str,
totp_code: str | None = None,
) -> None:
async def async_login(self, email: str, password: str) -> None:
"""Log user in and fetch certificate."""
try:
async with self._request_lock:
Expand All @@ -187,23 +230,10 @@ async def async_login(
partial(self._create_cognito_client, username=email),
)

try:
async with async_timeout.timeout(30):
await self.cloud.run_executor(
partial(cognito.authenticate, password=password),
)
except MFAChallengeException as err:
if totp_code is None:
raise

async with async_timeout.timeout(30):
await self.cloud.run_executor(
partial(
cognito.respond_to_software_token_mfa_challenge,
code=totp_code,
mfa_tokens=err.get_tokens(),
),
)
async with async_timeout.timeout(30):
await self.cloud.run_executor(
partial(cognito.authenticate, password=password),
)

task = await self.cloud.update_token(
cognito.id_token,
Expand All @@ -215,6 +245,7 @@ async def async_login(
await task

except MFAChallengeException as err:
self._mfa_challenge.store_challenge(email, err.get_tokens())
raise MFARequired from err

except ForceChangePasswordException as err:
Expand All @@ -226,6 +257,42 @@ async def async_login(
except BotoCoreError as err:
raise UnknownError from err

async def async_login_verify_totp(self, code: str) -> None:
"""Log user in and fetch certificate if MFA is required."""
try:
async with self._request_lock:
assert not self.cloud.is_logged_in, "Cannot login if already logged in."

email, mfa_tokens = self._mfa_challenge.get_active_challenge_data()

cognito: pycognito.Cognito = await self.cloud.run_executor(
partial(self._create_cognito_client, username=email),
)

async with async_timeout.timeout(30):
await self.cloud.run_executor(
partial(
cognito.respond_to_software_token_mfa_challenge,
code=code,
mfa_tokens=mfa_tokens,
),
)

task = await self.cloud.update_token(
cognito.id_token,
cognito.access_token,
cognito.refresh_token,
)

if task:
await task

except ClientError as err:
raise _map_aws_exception(err) from err

except BotoCoreError as err:
raise UnknownError from err

async def async_check_token(self) -> None:
"""Check that the token is valid and renew if necessary."""
async with self._request_lock:
Expand Down
2 changes: 2 additions & 0 deletions hass_nabucasa/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

REQUEST_TIMEOUT = 10

LOGIN_MFA_CHALLENGE_EXPIRATION = 60

MODE_PROD = "production"
MODE_DEV = "development"

Expand Down

0 comments on commit 2d103ad

Please sign in to comment.