Skip to content

Commit

Permalink
Update Cloud ICE server fetch update and reset logic
Browse files Browse the repository at this point in the history
  • Loading branch information
klejejs committed Dec 3, 2024
1 parent 90ceb37 commit b33cc6d
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 6 deletions.
17 changes: 13 additions & 4 deletions hass_nabucasa/ice_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ async def _async_fetch_ice_servers(self) -> list[RTCIceServer]:
if TYPE_CHECKING:
assert self.cloud.id_token is not None

if self.cloud.subscription_expired:
return []

async with self.cloud.websession.get(
f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers",
headers={
Expand Down Expand Up @@ -80,7 +83,16 @@ async def _async_refresh_ice_servers(self) -> None:
if self._ice_servers_listener is not None:
await self._ice_servers_listener()
except ClientResponseError as err:
_LOGGER.error("Can't refresh ICE servers: %s", err)
_LOGGER.error("Can't refresh ICE servers: %s", err.message)

# We should not keep the existing ICE servers with old timestamps
# as that will retrigger a refresh almost immediately.
if err.status in (401, 403):
self._ice_servers = []

# we should update the listener as the ICE servers might have changed
if self._ice_servers_listener is not None:
await self._ice_servers_listener()
except asyncio.CancelledError:
# Task is canceled, stop it.
break
Expand Down Expand Up @@ -116,9 +128,6 @@ async def perform_ice_server_update() -> None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

if not self._ice_servers:
return

self._ice_servers_listener_unregister = await register_ice_server_fn(
self._ice_servers,
)
Expand Down
145 changes: 143 additions & 2 deletions tests/test_ice_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from webrtc_models import RTCIceServer

from hass_nabucasa import ice_servers
from tests.utils.aiohttp import AiohttpClientMocker


@pytest.fixture
Expand All @@ -17,8 +18,8 @@ def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers:
return ice_servers.IceServers(auth_cloud_mock)


@pytest.fixture(autouse=True)
def mock_ice_servers(aioclient_mock):
@pytest.fixture
def mock_ice_servers(aioclient_mock: AiohttpClientMocker):
"""Mock ICE servers."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
Expand All @@ -34,6 +35,7 @@ def mock_ice_servers(aioclient_mock):

async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update(
ice_servers_api: ice_servers.IceServers,
mock_ice_servers,
):
"""Test that registering an ICE servers listener triggers a periodic update."""
times_register_called_successfully = 0
Expand Down Expand Up @@ -80,6 +82,145 @@ def unregister():
assert ice_servers_api._ice_servers_listener_unregister is None


async def test_ice_server_refresh_sets_ice_server_list_empty_on_expired_subscription(
ice_servers_api: ice_servers.IceServers,
aioclient_mock: AiohttpClientMocker,
):
"""Test that the ICE server list is set to empty when the subscription expires."""
times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

ice_servers_api.cloud.subscription_expired = True

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# This assert will silently fail and variable will not be incremented
assert len(ice_servers) == 0

times_register_called_successfully += 1

def unregister():
pass

return unregister

await ice_servers_api.async_register_ice_servers_listener(register_ice_servers)

# Let the periodic update run once
await asyncio.sleep(0)

assert ice_servers_api._ice_servers == []

assert len(aioclient_mock.mock_calls) == 0
assert times_register_called_successfully == 1
assert ice_servers_api._refresh_task is not None
assert ice_servers_api._ice_servers_listener is not None
assert ice_servers_api._ice_servers_listener_unregister is not None


async def test_ice_server_refresh_sets_ice_server_list_empty_on_401_403_client_error(
ice_servers_api: ice_servers.IceServers,
aioclient_mock: AiohttpClientMocker,
):
"""Test that ICE server list is empty when server returns 401 or 403 errors."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
status=403,
json={"message": "Boom!"},
)

times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

ice_servers_api._ice_servers = [
RTCIceServer(
urls="turn:example.com:80",
username="12345678:test-user",
credential="secret-value",
),
]

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# This assert will silently fail and variable will not be incremented
assert len(ice_servers) == 0

times_register_called_successfully += 1

def unregister():
pass

return unregister

await ice_servers_api.async_register_ice_servers_listener(register_ice_servers)

# Let the periodic update run once
await asyncio.sleep(0)

assert ice_servers_api._ice_servers == []

assert times_register_called_successfully == 1
assert ice_servers_api._refresh_task is not None
assert ice_servers_api._ice_servers_listener is not None
assert ice_servers_api._ice_servers_listener_unregister is not None


async def test_ice_server_refresh_keeps_ice_server_list_on_other_client_errors(
ice_servers_api: ice_servers.IceServers,
aioclient_mock,
):
"""Test that ICE server list is not set to empty when server returns an error."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
status=500,
json={"message": "Boom!"},
)

times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

ice_servers_api._ice_servers = [
RTCIceServer(
urls="turn:example.com:80",
username="12345678:test-user",
credential="secret-value",
),
]

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# These asserts will silently fail and variable will not be incremented
assert len(ice_servers) == 1
assert ice_servers[0].urls == "turn:example.com:80"
assert ice_servers[0].username == "12345678:test-user"
assert ice_servers[0].credential == "secret-value"

times_register_called_successfully += 1

def unregister():
pass

return unregister

await ice_servers_api.async_register_ice_servers_listener(register_ice_servers)

# Let the periodic update run once
await asyncio.sleep(0)

assert ice_servers_api._ice_servers != []

assert times_register_called_successfully == 1
assert ice_servers_api._refresh_task is not None
assert ice_servers_api._ice_servers_listener is not None
assert ice_servers_api._ice_servers_listener_unregister is not None


def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers):
"""Test get refresh sleep time."""
min_timestamp = 8888888888
Expand Down

0 comments on commit b33cc6d

Please sign in to comment.