Skip to content

Commit

Permalink
Add function for fetching ICE servers from service handlers (#717)
Browse files Browse the repository at this point in the history
* Add function for fetching ICE servers from service handlers

* Refactor code for fetching webrtc servers

* Fix types issues

* Remove old code

* Migrate to support only one listener

* PR suggestion fixes

* Minor code improvements, add tests

* Add webrtc-models dependency

* Add minimum refresh time constraint, improve tests

* Improve sleep time timestamps check

* Add support for listener that supports list of ICE servers

* Move listener unregister clearance in condition

* Fix test None check

* Improve tests based on PR reviews

* Update webrtc-models to a stable release

* Code flow improvements

* Change webrtc models dependency version requirements

* Changed retry to be between 1 and 12 hours
  • Loading branch information
klejejs authored Oct 29, 2024
1 parent 0a9095d commit 29d64c6
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 0 deletions.
2 changes: 2 additions & 0 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
STATE_CONNECTED,
)
from .google_report_state import GoogleReportState
from .ice_servers import IceServers
from .iot import CloudIoT
from .remote import RemoteUI
from .utils import UTC, gather_callbacks, parse_date, utcnow
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.remote = RemoteUI(self)
self.auth = CognitoAuth(self)
self.voice = Voice(self)
self.ice_servers = IceServers(self)

self._init_task: asyncio.Task | None = None

Expand Down
143 changes: 143 additions & 0 deletions hass_nabucasa/ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Manage ICE servers."""

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
import logging
import random
import time
from typing import TYPE_CHECKING

from aiohttp import ClientResponseError
from aiohttp.hdrs import AUTHORIZATION, USER_AGENT
from webrtc_models import RTCIceServer

if TYPE_CHECKING:
from . import Cloud, _ClientT


_LOGGER = logging.getLogger(__name__)


class IceServers:
"""Class to manage ICE servers."""

def __init__(self, cloud: Cloud[_ClientT]) -> None:
"""Initialize ICE Servers."""
self.cloud = cloud
self._refresh_task: asyncio.Task | None = None
self._ice_servers: list[RTCIceServer] = []
self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None
self._ice_servers_listener_unregister: Callable[[], None] | None = None

async def _async_fetch_ice_servers(self) -> list[RTCIceServer]:
"""Fetch ICE servers."""
if TYPE_CHECKING:
assert self.cloud.id_token is not None

async with self.cloud.websession.get(
f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers",
headers={
AUTHORIZATION: self.cloud.id_token,
USER_AGENT: self.cloud.client.client_name,
},
) as resp:
resp.raise_for_status()

return [
RTCIceServer(
urls=item["urls"],
username=item["username"],
credential=item["credential"],
)
for item in await resp.json()
]

def _get_refresh_sleep_time(self) -> int:
"""Get the sleep time for refreshing ICE servers."""
timestamps = [
int(server.username.split(":")[0])
for server in self._ice_servers
if server.username is not None and ":" in server.username
]

if not timestamps:
return random.randint(3600, 3600 * 12) # 1-12 hours

if (expiration := min(timestamps) - int(time.time()) - 3600) < 0:
return random.randint(100, 300)

# 1 hour before the earliest expiration
return expiration

async def _async_refresh_ice_servers(self) -> None:
"""Handle ICE server refresh."""
while True:
try:
self._ice_servers = await self._async_fetch_ice_servers()

if self._ice_servers_listener is not None:
await self._ice_servers_listener()
except ClientResponseError as err:
_LOGGER.error("Can't refresh ICE servers: %s", err)
except asyncio.CancelledError:
# Task is canceled, stop it.
break

sleep_time = self._get_refresh_sleep_time()
await asyncio.sleep(sleep_time)

def _on_add_listener(self) -> None:
"""When the instance is connected."""
self._refresh_task = asyncio.create_task(self._async_refresh_ice_servers())

def _on_remove_listener(self) -> None:
"""When the instance is disconnected."""
if self._refresh_task is not None:
self._refresh_task.cancel()
self._refresh_task = None

async def async_register_ice_servers_listener(
self,
register_ice_server_fn: Callable[
[list[RTCIceServer]],
Awaitable[Callable[[], None]],
],
) -> Callable[[], None]:
"""Register a listener for ICE servers and return unregister function."""
_LOGGER.debug("Registering ICE servers listener")

async def perform_ice_server_update() -> None:
"""Perform ICE server update by unregistering and registering servers."""
_LOGGER.debug("Updating ICE servers")

if self._ice_servers_listener_unregister is not None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

if not self._ice_servers:
return

self._ice_servers_listener_unregister = await register_ice_server_fn(
self._ice_servers,
)

_LOGGER.debug("ICE servers updated")

def remove_listener() -> None:
"""Remove listener."""
if self._ice_servers_listener_unregister is not None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

self._ice_servers = []
self._ice_servers_listener = None

self._on_remove_listener()

self._ice_servers_listener = perform_ice_server_update

self._on_add_listener()

return remove_listener
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pycognito==2024.5.1",
"PyJWT>=2.8.0",
"snitun==0.39.1",
"webrtc-models<1.0.0",
]
description = "Home Assistant cloud integration by Nabu Casa, Inc."
license = {text = "GPL v3"}
Expand Down
128 changes: 128 additions & 0 deletions tests/test_ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Test the ICE servers module."""

import asyncio
import time

import pytest
from webrtc_models import RTCIceServer

from hass_nabucasa import ice_servers


@pytest.fixture
def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers:
"""ICE servers API fixture."""
auth_cloud_mock.servicehandlers_server = "example.com/test"
auth_cloud_mock.id_token = "mock-id-token"
return ice_servers.IceServers(auth_cloud_mock)


@pytest.fixture(autouse=True)
def mock_ice_servers(aioclient_mock):
"""Mock ICE servers."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
json=[
{
"urls": "turn:example.com:80",
"username": "12345678:test-user",
"credential": "secret-value",
},
],
)


async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update(
ice_servers_api: ice_servers.IceServers,
):
"""Test that registering an ICE servers listener triggers a periodic update."""
times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# These asserts will silently fail and variable will not be incremented
assert len(ice_servers) == 1
assert ice_servers[0].urls == "turn:example.com:80"
assert ice_servers[0].username == "12345678:test-user"
assert ice_servers[0].credential == "secret-value"

times_register_called_successfully += 1

def unregister():
pass

return unregister

unregister = await ice_servers_api.async_register_ice_servers_listener(
register_ice_servers,
)

# Let the periodic update run once
await asyncio.sleep(0)
# Let the periodic update run again
await asyncio.sleep(0)

assert times_register_called_successfully == 2

unregister()

# The periodic update should not run again
await asyncio.sleep(0)

assert times_register_called_successfully == 2

assert ice_servers_api._refresh_task is None
assert ice_servers_api._ice_servers == []
assert ice_servers_api._ice_servers_listener is None
assert ice_servers_api._ice_servers_listener_unregister is None


def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers):
"""Test get refresh sleep time."""
min_timestamp = 8888888888

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="9999999999:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

assert (
ice_servers_api._get_refresh_sleep_time()
== min_timestamp - int(time.time()) - 3600
)


def test_get_refresh_sleep_time_no_turn_servers(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
refresh_time = ice_servers_api._get_refresh_sleep_time()

assert refresh_time >= 3600
assert refresh_time <= 43200


def test_get_refresh_sleep_time_expiration_less_than_one_hour(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
min_timestamp = 10

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="12345678:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

refresh_time = ice_servers_api._get_refresh_sleep_time()

assert refresh_time >= 100
assert refresh_time <= 300

0 comments on commit 29d64c6

Please sign in to comment.