Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Cloud ICE server fetch update and reset logic #756

Merged
merged 4 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions hass_nabucasa/ice_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ async def _async_fetch_ice_servers(self) -> list[RTCIceServer]:
if TYPE_CHECKING:
assert self.cloud.id_token is not None

if self.cloud.subscription_expired:
return []

async with self.cloud.websession.get(
f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers",
headers={
Expand Down Expand Up @@ -77,14 +80,21 @@ async def _async_refresh_ice_servers(self) -> None:
try:
self._ice_servers = await self._async_fetch_ice_servers()

if self._ice_servers_listener is not None:
await self._ice_servers_listener()
except ClientResponseError as err:
_LOGGER.error("Can't refresh ICE servers: %s", err)
_LOGGER.error("Can't refresh ICE servers: %s", err.message)

# We should not keep the existing ICE servers with old timestamps
# as that will retrigger a refresh almost immediately.
if err.status in (401, 403):
self._ice_servers = []

except asyncio.CancelledError:
# Task is canceled, stop it.
break

if self._ice_servers_listener is not None:
await self._ice_servers_listener()

sleep_time = self._get_refresh_sleep_time()
await asyncio.sleep(sleep_time)

Expand Down Expand Up @@ -116,9 +126,6 @@ async def perform_ice_server_update() -> None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

if not self._ice_servers:
return

self._ice_servers_listener_unregister = await register_ice_server_fn(
self._ice_servers,
)
Expand Down
145 changes: 143 additions & 2 deletions tests/test_ice_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from webrtc_models import RTCIceServer

from hass_nabucasa import ice_servers
from tests.utils.aiohttp import AiohttpClientMocker


@pytest.fixture
Expand All @@ -17,8 +18,8 @@ def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers:
return ice_servers.IceServers(auth_cloud_mock)


@pytest.fixture(autouse=True)
def mock_ice_servers(aioclient_mock):
@pytest.fixture
def mock_ice_servers(aioclient_mock: AiohttpClientMocker):
"""Mock ICE servers."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
Expand All @@ -34,6 +35,7 @@ def mock_ice_servers(aioclient_mock):

async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update(
ice_servers_api: ice_servers.IceServers,
mock_ice_servers,
):
"""Test that registering an ICE servers listener triggers a periodic update."""
times_register_called_successfully = 0
Expand Down Expand Up @@ -80,6 +82,145 @@ def unregister():
assert ice_servers_api._ice_servers_listener_unregister is None


async def test_ice_server_refresh_sets_ice_server_list_empty_on_expired_subscription(
ice_servers_api: ice_servers.IceServers,
aioclient_mock: AiohttpClientMocker,
):
"""Test that the ICE server list is set to empty when the subscription expires."""
times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

ice_servers_api.cloud.subscription_expired = True

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# This assert will silently fail and variable will not be incremented
assert len(ice_servers) == 0

times_register_called_successfully += 1

def unregister():
pass

return unregister

await ice_servers_api.async_register_ice_servers_listener(register_ice_servers)

# Let the periodic update run once
await asyncio.sleep(0)

assert ice_servers_api._ice_servers == []

assert len(aioclient_mock.mock_calls) == 0
assert times_register_called_successfully == 1
assert ice_servers_api._refresh_task is not None
assert ice_servers_api._ice_servers_listener is not None
assert ice_servers_api._ice_servers_listener_unregister is not None


async def test_ice_server_refresh_sets_ice_server_list_empty_on_401_403_client_error(
ice_servers_api: ice_servers.IceServers,
aioclient_mock: AiohttpClientMocker,
):
"""Test that ICE server list is empty when server returns 401 or 403 errors."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
status=403,
json={"message": "Boom!"},
)

times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

ice_servers_api._ice_servers = [
RTCIceServer(
urls="turn:example.com:80",
username="12345678:test-user",
credential="secret-value",
),
]

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# This assert will silently fail and variable will not be incremented
assert len(ice_servers) == 0

times_register_called_successfully += 1

def unregister():
pass

return unregister

await ice_servers_api.async_register_ice_servers_listener(register_ice_servers)

# Let the periodic update run once
await asyncio.sleep(0)

assert ice_servers_api._ice_servers == []

assert times_register_called_successfully == 1
assert ice_servers_api._refresh_task is not None
assert ice_servers_api._ice_servers_listener is not None
assert ice_servers_api._ice_servers_listener_unregister is not None


async def test_ice_server_refresh_keeps_ice_server_list_on_other_client_errors(
ice_servers_api: ice_servers.IceServers,
aioclient_mock,
):
"""Test that ICE server list is not set to empty when server returns an error."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
status=500,
json={"message": "Boom!"},
)

times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

ice_servers_api._ice_servers = [
RTCIceServer(
urls="turn:example.com:80",
username="12345678:test-user",
credential="secret-value",
),
]

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# These asserts will silently fail and variable will not be incremented
assert len(ice_servers) == 1
assert ice_servers[0].urls == "turn:example.com:80"
assert ice_servers[0].username == "12345678:test-user"
assert ice_servers[0].credential == "secret-value"

times_register_called_successfully += 1

def unregister():
pass

return unregister

await ice_servers_api.async_register_ice_servers_listener(register_ice_servers)

# Let the periodic update run once
await asyncio.sleep(0)

assert ice_servers_api._ice_servers != []

assert times_register_called_successfully == 1
assert ice_servers_api._refresh_task is not None
assert ice_servers_api._ice_servers_listener is not None
assert ice_servers_api._ice_servers_listener_unregister is not None


def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers):
"""Test get refresh sleep time."""
min_timestamp = 8888888888
Expand Down
Loading