-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add function for fetching ICE servers from service handlers (#717)
* Add function for fetching ICE servers from service handlers * Refactor code for fetching webrtc servers * Fix types issues * Remove old code * Migrate to support only one listener * PR suggestion fixes * Minor code improvements, add tests * Add webrtc-models dependency * Add minimum refresh time constraint, improve tests * Improve sleep time timestamps check * Add support for listener that supports list of ICE servers * Move listener unregister clearance in condition * Fix test None check * Improve tests based on PR reviews * Update webrtc-models to a stable release * Code flow improvements * Change webrtc models dependency version requirements * Changed retry to be between 1 and 12 hours
- Loading branch information
Showing
4 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
"""Manage ICE servers.""" | ||
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
from collections.abc import Awaitable, Callable | ||
import logging | ||
import random | ||
import time | ||
from typing import TYPE_CHECKING | ||
|
||
from aiohttp import ClientResponseError | ||
from aiohttp.hdrs import AUTHORIZATION, USER_AGENT | ||
from webrtc_models import RTCIceServer | ||
|
||
if TYPE_CHECKING: | ||
from . import Cloud, _ClientT | ||
|
||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class IceServers: | ||
"""Class to manage ICE servers.""" | ||
|
||
def __init__(self, cloud: Cloud[_ClientT]) -> None: | ||
"""Initialize ICE Servers.""" | ||
self.cloud = cloud | ||
self._refresh_task: asyncio.Task | None = None | ||
self._ice_servers: list[RTCIceServer] = [] | ||
self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None | ||
self._ice_servers_listener_unregister: Callable[[], None] | None = None | ||
|
||
async def _async_fetch_ice_servers(self) -> list[RTCIceServer]: | ||
"""Fetch ICE servers.""" | ||
if TYPE_CHECKING: | ||
assert self.cloud.id_token is not None | ||
|
||
async with self.cloud.websession.get( | ||
f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers", | ||
headers={ | ||
AUTHORIZATION: self.cloud.id_token, | ||
USER_AGENT: self.cloud.client.client_name, | ||
}, | ||
) as resp: | ||
resp.raise_for_status() | ||
|
||
return [ | ||
RTCIceServer( | ||
urls=item["urls"], | ||
username=item["username"], | ||
credential=item["credential"], | ||
) | ||
for item in await resp.json() | ||
] | ||
|
||
def _get_refresh_sleep_time(self) -> int: | ||
"""Get the sleep time for refreshing ICE servers.""" | ||
timestamps = [ | ||
int(server.username.split(":")[0]) | ||
for server in self._ice_servers | ||
if server.username is not None and ":" in server.username | ||
] | ||
|
||
if not timestamps: | ||
return random.randint(3600, 3600 * 12) # 1-12 hours | ||
|
||
if (expiration := min(timestamps) - int(time.time()) - 3600) < 0: | ||
return random.randint(100, 300) | ||
|
||
# 1 hour before the earliest expiration | ||
return expiration | ||
|
||
async def _async_refresh_ice_servers(self) -> None: | ||
"""Handle ICE server refresh.""" | ||
while True: | ||
try: | ||
self._ice_servers = await self._async_fetch_ice_servers() | ||
|
||
if self._ice_servers_listener is not None: | ||
await self._ice_servers_listener() | ||
except ClientResponseError as err: | ||
_LOGGER.error("Can't refresh ICE servers: %s", err) | ||
except asyncio.CancelledError: | ||
# Task is canceled, stop it. | ||
break | ||
|
||
sleep_time = self._get_refresh_sleep_time() | ||
await asyncio.sleep(sleep_time) | ||
|
||
def _on_add_listener(self) -> None: | ||
"""When the instance is connected.""" | ||
self._refresh_task = asyncio.create_task(self._async_refresh_ice_servers()) | ||
|
||
def _on_remove_listener(self) -> None: | ||
"""When the instance is disconnected.""" | ||
if self._refresh_task is not None: | ||
self._refresh_task.cancel() | ||
self._refresh_task = None | ||
|
||
async def async_register_ice_servers_listener( | ||
self, | ||
register_ice_server_fn: Callable[ | ||
[list[RTCIceServer]], | ||
Awaitable[Callable[[], None]], | ||
], | ||
) -> Callable[[], None]: | ||
"""Register a listener for ICE servers and return unregister function.""" | ||
_LOGGER.debug("Registering ICE servers listener") | ||
|
||
async def perform_ice_server_update() -> None: | ||
"""Perform ICE server update by unregistering and registering servers.""" | ||
_LOGGER.debug("Updating ICE servers") | ||
|
||
if self._ice_servers_listener_unregister is not None: | ||
self._ice_servers_listener_unregister() | ||
self._ice_servers_listener_unregister = None | ||
|
||
if not self._ice_servers: | ||
return | ||
|
||
self._ice_servers_listener_unregister = await register_ice_server_fn( | ||
self._ice_servers, | ||
) | ||
|
||
_LOGGER.debug("ICE servers updated") | ||
|
||
def remove_listener() -> None: | ||
"""Remove listener.""" | ||
if self._ice_servers_listener_unregister is not None: | ||
self._ice_servers_listener_unregister() | ||
self._ice_servers_listener_unregister = None | ||
|
||
self._ice_servers = [] | ||
self._ice_servers_listener = None | ||
|
||
self._on_remove_listener() | ||
|
||
self._ice_servers_listener = perform_ice_server_update | ||
|
||
self._on_add_listener() | ||
|
||
return remove_listener |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
"""Test the ICE servers module.""" | ||
|
||
import asyncio | ||
import time | ||
|
||
import pytest | ||
from webrtc_models import RTCIceServer | ||
|
||
from hass_nabucasa import ice_servers | ||
|
||
|
||
@pytest.fixture | ||
def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers: | ||
"""ICE servers API fixture.""" | ||
auth_cloud_mock.servicehandlers_server = "example.com/test" | ||
auth_cloud_mock.id_token = "mock-id-token" | ||
return ice_servers.IceServers(auth_cloud_mock) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_ice_servers(aioclient_mock): | ||
"""Mock ICE servers.""" | ||
aioclient_mock.get( | ||
"https://example.com/test/webrtc/ice_servers", | ||
json=[ | ||
{ | ||
"urls": "turn:example.com:80", | ||
"username": "12345678:test-user", | ||
"credential": "secret-value", | ||
}, | ||
], | ||
) | ||
|
||
|
||
async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update( | ||
ice_servers_api: ice_servers.IceServers, | ||
): | ||
"""Test that registering an ICE servers listener triggers a periodic update.""" | ||
times_register_called_successfully = 0 | ||
|
||
ice_servers_api._get_refresh_sleep_time = lambda: 0 | ||
|
||
async def register_ice_servers(ice_servers: list[RTCIceServer]): | ||
nonlocal times_register_called_successfully | ||
|
||
# These asserts will silently fail and variable will not be incremented | ||
assert len(ice_servers) == 1 | ||
assert ice_servers[0].urls == "turn:example.com:80" | ||
assert ice_servers[0].username == "12345678:test-user" | ||
assert ice_servers[0].credential == "secret-value" | ||
|
||
times_register_called_successfully += 1 | ||
|
||
def unregister(): | ||
pass | ||
|
||
return unregister | ||
|
||
unregister = await ice_servers_api.async_register_ice_servers_listener( | ||
register_ice_servers, | ||
) | ||
|
||
# Let the periodic update run once | ||
await asyncio.sleep(0) | ||
# Let the periodic update run again | ||
await asyncio.sleep(0) | ||
|
||
assert times_register_called_successfully == 2 | ||
|
||
unregister() | ||
|
||
# The periodic update should not run again | ||
await asyncio.sleep(0) | ||
|
||
assert times_register_called_successfully == 2 | ||
|
||
assert ice_servers_api._refresh_task is None | ||
assert ice_servers_api._ice_servers == [] | ||
assert ice_servers_api._ice_servers_listener is None | ||
assert ice_servers_api._ice_servers_listener_unregister is None | ||
|
||
|
||
def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers): | ||
"""Test get refresh sleep time.""" | ||
min_timestamp = 8888888888 | ||
|
||
ice_servers_api._ice_servers = [ | ||
RTCIceServer(urls="turn:example.com:80", username="9999999999:test-user"), | ||
RTCIceServer( | ||
urls="turn:example.com:80", | ||
username=f"{min_timestamp!s}:test-user", | ||
), | ||
] | ||
|
||
assert ( | ||
ice_servers_api._get_refresh_sleep_time() | ||
== min_timestamp - int(time.time()) - 3600 | ||
) | ||
|
||
|
||
def test_get_refresh_sleep_time_no_turn_servers( | ||
ice_servers_api: ice_servers.IceServers, | ||
): | ||
"""Test get refresh sleep time.""" | ||
refresh_time = ice_servers_api._get_refresh_sleep_time() | ||
|
||
assert refresh_time >= 3600 | ||
assert refresh_time <= 43200 | ||
|
||
|
||
def test_get_refresh_sleep_time_expiration_less_than_one_hour( | ||
ice_servers_api: ice_servers.IceServers, | ||
): | ||
"""Test get refresh sleep time.""" | ||
min_timestamp = 10 | ||
|
||
ice_servers_api._ice_servers = [ | ||
RTCIceServer(urls="turn:example.com:80", username="12345678:test-user"), | ||
RTCIceServer( | ||
urls="turn:example.com:80", | ||
username=f"{min_timestamp!s}:test-user", | ||
), | ||
] | ||
|
||
refresh_time = ice_servers_api._get_refresh_sleep_time() | ||
|
||
assert refresh_time >= 100 | ||
assert refresh_time <= 300 |