Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for fetching ICE servers from service handlers #717

Merged
merged 24 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8943960
Add function for fetching ICE servers from service handlers
klejejs Oct 14, 2024
4fac79c
Refactor code for fetching webrtc servers
klejejs Oct 21, 2024
77e85ac
Fix types issues
klejejs Oct 21, 2024
334c0d0
Merge branch 'master' into feat/add-function-for-getting-ice-servers
klejejs Oct 21, 2024
e545e7f
Remove old code
klejejs Oct 22, 2024
431c39a
Migrate to support only one listener
klejejs Oct 22, 2024
90695f9
Merge branch 'master' into feat/add-function-for-getting-ice-servers
klejejs Oct 22, 2024
ceccd80
PR suggestion fixes
klejejs Oct 22, 2024
eeefdae
Minor code improvements, add tests
klejejs Oct 24, 2024
303b4ae
Merge branch 'master' into feat/add-function-for-getting-ice-servers
klejejs Oct 24, 2024
c4b2c7e
Merge branch 'master' into feat/add-function-for-getting-ice-servers
klejejs Oct 24, 2024
a27348f
Add webrtc-models dependency
klejejs Oct 24, 2024
c42daf5
Add minimum refresh time constraint, improve tests
klejejs Oct 24, 2024
168fbeb
Improve sleep time timestamps check
klejejs Oct 24, 2024
1385434
Add support for listener that supports list of ICE servers
klejejs Oct 24, 2024
ed7ca6d
Move listener unregister clearance in condition
klejejs Oct 24, 2024
0a422e9
Fix test None check
klejejs Oct 24, 2024
25802ba
Improve tests based on PR reviews
klejejs Oct 25, 2024
71f799a
Update webrtc-models to a stable release
klejejs Oct 28, 2024
014cbd6
Merge branch 'master' into feat/add-function-for-getting-ice-servers
klejejs Oct 28, 2024
fa60b99
Code flow improvements
klejejs Oct 29, 2024
554a16a
Change webrtc models dependency version requirements
klejejs Oct 29, 2024
29591d8
Merge branch 'master' into feat/add-function-for-getting-ice-servers
klejejs Oct 29, 2024
a231f3f
Changed retry to be between 1 and 12 hours
klejejs Oct 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
STATE_CONNECTED,
)
from .google_report_state import GoogleReportState
from .ice_servers import IceServers
from .iot import CloudIoT
from .remote import RemoteUI
from .utils import UTC, gather_callbacks, parse_date, utcnow
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.remote = RemoteUI(self)
self.auth = CognitoAuth(self)
self.voice = Voice(self)
self.ice_servers = IceServers(self)

self._init_task: asyncio.Task | None = None

Expand Down
143 changes: 143 additions & 0 deletions hass_nabucasa/ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Manage ICE servers."""

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
import logging
import random
import time
from typing import TYPE_CHECKING

from aiohttp import ClientResponseError
from aiohttp.hdrs import AUTHORIZATION, USER_AGENT
from webrtc_models import RTCIceServer

if TYPE_CHECKING:
from . import Cloud, _ClientT


_LOGGER = logging.getLogger(__name__)


class IceServers:
"""Class to manage ICE servers."""

def __init__(self, cloud: Cloud[_ClientT]) -> None:
"""Initialize ICE Servers."""
self.cloud = cloud
self._refresh_task: asyncio.Task | None = None
self._ice_servers: list[RTCIceServer] = []
self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None
self._ice_servers_listener_unregister: Callable[[], None] | None = None

async def _async_fetch_ice_servers(self) -> list[RTCIceServer]:
"""Fetch ICE servers."""
if TYPE_CHECKING:
assert self.cloud.id_token is not None

async with self.cloud.websession.get(
f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers",
headers={
AUTHORIZATION: self.cloud.id_token,
USER_AGENT: self.cloud.client.client_name,
},
) as resp:
resp.raise_for_status()

return [
RTCIceServer(
urls=item["urls"],
username=item["username"],
credential=item["credential"],
)
for item in await resp.json()
]

def _get_refresh_sleep_time(self) -> int:
"""Get the sleep time for refreshing ICE servers."""
timestamps = [
int(server.username.split(":")[0])
for server in self._ice_servers
if server.username is not None and ":" in server.username
]

if not timestamps:
return random.randint(2400, 3600) # 40-60 minutes

if (expiration := min(timestamps) - int(time.time()) - 3600) < 0:
return random.randint(100, 300)

# 1 hour before the earliest expiration
return expiration

async def _async_refresh_ice_servers(self) -> None:
"""Handle ICE server refresh."""
while True:
try:
self._ice_servers = await self._async_fetch_ice_servers()

if self._ice_servers_listener is not None:
await self._ice_servers_listener()
except ClientResponseError as err:
_LOGGER.error("Can't refresh ICE servers: %s", err)
except asyncio.CancelledError:
# Task is canceled, stop it.
break

sleep_time = self._get_refresh_sleep_time()
await asyncio.sleep(sleep_time)

def _on_add_listener(self) -> None:
"""When the instance is connected."""
self._refresh_task = asyncio.create_task(self._async_refresh_ice_servers())

def _on_remove_listener(self) -> None:
"""When the instance is disconnected."""
if self._refresh_task is not None:
self._refresh_task.cancel()
self._refresh_task = None

async def async_register_ice_servers_listener(
self,
register_ice_server_fn: Callable[
[list[RTCIceServer]],
Awaitable[Callable[[], None]],
],
) -> Callable[[], None]:
"""Register a listener for ICE servers and return unregister function."""
_LOGGER.debug("Registering ICE servers listener")

async def perform_ice_server_update() -> None:
"""Perform ICE server update by unregistering and registering servers."""
_LOGGER.debug("Updating ICE servers")

if self._ice_servers_listener_unregister is not None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

if not self._ice_servers:
return

self._ice_servers_listener_unregister = await register_ice_server_fn(
self._ice_servers,
)

_LOGGER.debug("ICE servers updated")

def remove_listener() -> None:
"""Remove listener."""
if self._ice_servers_listener_unregister is not None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

self._ice_servers = []
self._ice_servers_listener = None

self._on_remove_listener()

self._ice_servers_listener = perform_ice_server_update

self._on_add_listener()

return remove_listener
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pycognito==2024.5.1",
"PyJWT>=2.8.0",
"snitun==0.39.1",
"webrtc-models<1.0.0",
]
description = "Home Assistant cloud integration by Nabu Casa, Inc."
license = {text = "GPL v3"}
Expand Down
128 changes: 128 additions & 0 deletions tests/test_ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Test the ICE servers module."""

import asyncio
import time

import pytest
from webrtc_models import RTCIceServer

from hass_nabucasa import ice_servers


@pytest.fixture
def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers:
"""ICE servers API fixture."""
auth_cloud_mock.servicehandlers_server = "example.com/test"
auth_cloud_mock.id_token = "mock-id-token"
return ice_servers.IceServers(auth_cloud_mock)


@pytest.fixture(autouse=True)
def mock_ice_servers(aioclient_mock):
"""Mock ICE servers."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
json=[
{
"urls": "turn:example.com:80",
"username": "12345678:test-user",
"credential": "secret-value",
},
],
)


async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update(
ice_servers_api: ice_servers.IceServers,
):
"""Test that registering an ICE servers listener triggers a periodic update."""
times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# These asserts will silently fail and variable will not be incremented
assert len(ice_servers) == 1
assert ice_servers[0].urls == "turn:example.com:80"
assert ice_servers[0].username == "12345678:test-user"
assert ice_servers[0].credential == "secret-value"

times_register_called_successfully += 1

def unregister():
pass

return unregister

unregister = await ice_servers_api.async_register_ice_servers_listener(
register_ice_servers,
)

# Let the periodic update run once
await asyncio.sleep(0)
# Let the periodic update run again
await asyncio.sleep(0)

assert times_register_called_successfully == 2

unregister()

# The periodic update should not run again
await asyncio.sleep(0)

assert times_register_called_successfully == 2

assert ice_servers_api._refresh_task is None
assert ice_servers_api._ice_servers == []
assert ice_servers_api._ice_servers_listener is None
assert ice_servers_api._ice_servers_listener_unregister is None


def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers):
"""Test get refresh sleep time."""
min_timestamp = 8888888888

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="9999999999:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

assert (
ice_servers_api._get_refresh_sleep_time()
== min_timestamp - int(time.time()) - 3600
)


def test_get_refresh_sleep_time_no_turn_servers(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
refresh_time = ice_servers_api._get_refresh_sleep_time()

assert refresh_time >= 2400
assert refresh_time <= 3600


def test_get_refresh_sleep_time_expiration_less_than_one_hour(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
min_timestamp = 10

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="12345678:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

refresh_time = ice_servers_api._get_refresh_sleep_time()

assert refresh_time >= 100
assert refresh_time <= 300