From 23b9456145f5e83444d9fd21a42e0f28f1a905e7 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 20 Jun 2020 21:49:27 -0700 Subject: [PATCH] Don't keep fetching snitun token when sub expired --- hass_nabucasa/remote.py | 14 ++++++++++++-- tests/conftest.py | 1 + tests/test_remote.py | 10 +++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/hass_nabucasa/remote.py b/hass_nabucasa/remote.py index e94e51722..51b3fc7eb 100644 --- a/hass_nabucasa/remote.py +++ b/hass_nabucasa/remote.py @@ -33,6 +33,10 @@ class RemoteNotConnected(RemoteError): """Raise if a request need connection and we are not ready.""" +class SubscriptionExpired(RemoteError): + """Raise if we cannot connect because subscription expired.""" + + @attr.s class SniTunToken: """Handle snitun token.""" @@ -241,13 +245,17 @@ async def _refresh_snitun_token(self) -> None: _LOGGER.debug("Don't need refresh snitun token") return + if self.cloud.subscription_expired: + raise SubscriptionExpired() + # Generate session token aes_key, aes_iv = generate_aes_keyset() try: async with async_timeout.timeout(30): resp = await cloud_api.async_remote_token(self.cloud, aes_key, aes_iv) - assert resp.status == 200 - except (asyncio.TimeoutError, AssertionError): + if resp.status != 200: + raise RemoteBackendError() + except asyncio.TimeoutError: raise RemoteBackendError() from None data = await resp.json() @@ -283,6 +291,8 @@ async def connect(self) -> None: _LOGGER.error("Connection problem to snitun server") except RemoteBackendError: _LOGGER.error("Can't refresh the snitun token") + except SubscriptionExpired: + pass except AttributeError: pass # Ignore because HA shutdown on snitun token refresh finally: diff --git a/tests/conftest.py b/tests/conftest.py index 49a782f21..5e81cfd83 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,6 +43,7 @@ def _executor(call, *args): def auth_cloud_mock(cloud_mock): """Return an authenticated cloud instance.""" cloud_mock.auth.async_check_token.side_effect = mock_coro + cloud_mock.subscription_expired = False return cloud_mock diff --git a/tests/test_remote.py b/tests/test_remote.py index 26ef98172..20af67684 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -11,7 +11,7 @@ DISPATCH_REMOTE_CONNECT, DISPATCH_REMOTE_DISCONNECT, ) -from hass_nabucasa.remote import RemoteUI +from hass_nabucasa.remote import RemoteUI, SubscriptionExpired from hass_nabucasa.utils import utcnow from .common import MockAcme, MockSnitun, mock_coro @@ -468,3 +468,11 @@ async def test_certificate_task_renew_cert( await remote.load_backend() await asyncio.sleep(0.1) assert acme_mock.call_issue + + +async def test_refresh_token_no_sub(auth_cloud_mock): + """Test that we rais SubscriptionExpired if expired sub.""" + auth_cloud_mock.subscription_expired = True + + with pytest.raises(SubscriptionExpired): + await RemoteUI(auth_cloud_mock)._refresh_snitun_token()