From 04eac2ad70bf4354f9fcfedd6cc885c9e0678180 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Thu, 7 Dec 2023 16:36:43 +0000 Subject: [PATCH] Improve reliability of reload (#631) --- aiohttp_devtools/runserver/serve.py | 7 ++++- aiohttp_devtools/runserver/watch.py | 17 ++++++++++-- tests/test_runserver_serve.py | 6 ++++- tests/test_runserver_watch.py | 40 ++++++++++++++++++++++++++--- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/aiohttp_devtools/runserver/serve.py b/aiohttp_devtools/runserver/serve.py index 1c9a5134..541d9037 100644 --- a/aiohttp_devtools/runserver/serve.py +++ b/aiohttp_devtools/runserver/serve.py @@ -3,10 +3,11 @@ import json import mimetypes import sys +import time import warnings from errno import EADDRINUSE from pathlib import Path -from typing import Any, Iterator, NoReturn, Optional, Set, Tuple +from typing import Any, Iterator, List, NoReturn, Optional, Set, Tuple from aiohttp import WSMsgType, web from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH @@ -33,6 +34,7 @@ LIVE_RELOAD_LOCAL_SNIPPET = b'\n\n' HOST = '0.0.0.0' +LAST_RELOAD = web.AppKey("LAST_RELOAD", List[float]) LIVERELOAD_SCRIPT = web.AppKey("LIVERELOAD_SCRIPT", bytes) STATIC_PATH = web.AppKey("STATIC_PATH", str) STATIC_URL = web.AppKey("STATIC_URL", str) @@ -240,6 +242,8 @@ async def src_reload(app: web.Application, path: Optional[str] = None) -> int: else: reloads += 1 + app[LAST_RELOAD][0] = len(app[WS]) + app[LAST_RELOAD][1] = time.time() if reloads: s = '' if reloads == 1 else 's' aux_logger.info('prompted reload of %s on %d client%s', path or 'page', reloads, s) @@ -256,6 +260,7 @@ def create_auxiliary_app( browser_cache: bool = False) -> web.Application: app = web.Application() ws: Set[Tuple[web.WebSocketResponse, str]] = set() + app[LAST_RELOAD] = [0, 0.] app[STATIC_PATH] = static_path or "" app[STATIC_URL] = static_url app[WS] = ws diff --git a/aiohttp_devtools/runserver/watch.py b/aiohttp_devtools/runserver/watch.py index 55698bbd..2dc6516c 100644 --- a/aiohttp_devtools/runserver/watch.py +++ b/aiohttp_devtools/runserver/watch.py @@ -2,6 +2,7 @@ import os import signal import sys +import time from contextlib import suppress from multiprocessing import Process from pathlib import Path @@ -14,7 +15,7 @@ from ..exceptions import AiohttpDevException from ..logs import rs_dft_logger as logger from .config import Config -from .serve import STATIC_PATH, WS, serve_main_app, src_reload +from .serve import LAST_RELOAD, STATIC_PATH, WS, serve_main_app, src_reload class WatchTask: @@ -27,7 +28,7 @@ def __init__(self, path: Union[Path, str]): async def start(self, app: web.Application) -> None: self._app = app self.stopper = asyncio.Event() - self._awatch = awatch(self._path, stop_event=self.stopper) + self._awatch = awatch(self._path, stop_event=self.stopper, step=250) self._task = asyncio.create_task(self._run()) async def _run(self) -> None: @@ -71,8 +72,20 @@ def is_static(changes: Iterable[Tuple[object, str]]) -> bool: async for changes in self._awatch: self._reloads += 1 + logger.debug("file changes: %s", changes) if any(f.endswith('.py') for _, f in changes): logger.debug('%d changes, restarting server', len(changes)) + + count, t = self._app[LAST_RELOAD] + if len(self._app[WS]) < count: + wait_delay = max(t + 5 - time.time(), 0) + logger.debug("waiting upto %s seconds before restarting", wait_delay) + + for i in range(int(wait_delay / 0.1)): + await asyncio.sleep(0.1) + if len(self._app[WS]) >= count: + break + await self._stop_dev_server() self._start_dev_server() await self._src_reload_when_live(live_checks) diff --git a/tests/test_runserver_serve.py b/tests/test_runserver_serve.py index 63e02bfa..6a319b3c 100644 --- a/tests/test_runserver_serve.py +++ b/tests/test_runserver_serve.py @@ -13,7 +13,7 @@ from aiohttp_devtools.runserver.config import Config from aiohttp_devtools.runserver.log_handlers import fmt_size from aiohttp_devtools.runserver.serve import ( - STATIC_PATH, STATIC_URL, WS, check_port_open, cleanup_aux_app, + LAST_RELOAD, STATIC_PATH, STATIC_URL, WS, check_port_open, cleanup_aux_app, modify_main_app, src_reload) from .conftest import SIMPLE_APP, create_future @@ -36,6 +36,7 @@ async def test_aux_reload(smart_caplog): aux_app = Application() ws = MagicMock() ws.send_str = MagicMock(return_value=create_future()) + aux_app[LAST_RELOAD] = [0, 0.] aux_app[STATIC_PATH] = "/path/to/static_files/" aux_app[STATIC_URL] = "/static/" aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc] @@ -56,6 +57,7 @@ async def test_aux_reload_no_path(): aux_app = Application() ws = MagicMock() ws.send_str = MagicMock(return_value=create_future()) + aux_app[LAST_RELOAD] = [0, 0.] aux_app[STATIC_PATH] = "/path/to/static_files/" aux_app[STATIC_URL] = "/static/" aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc] @@ -74,6 +76,7 @@ async def test_aux_reload_html_different(): aux_app = Application() ws = MagicMock() ws.send_str = MagicMock(return_value=create_future()) + aux_app[LAST_RELOAD] = [0, 0.] aux_app[STATIC_PATH] = "/path/to/static_files/" aux_app[STATIC_URL] = "/static/" aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc] @@ -86,6 +89,7 @@ async def test_aux_reload_runtime_error(smart_caplog): ws = MagicMock() ws.send_str = MagicMock(return_value=create_future()) ws.send_str = MagicMock(side_effect=RuntimeError('foobar')) + aux_app[LAST_RELOAD] = [0, 0.] aux_app[STATIC_PATH] = "/path/to/static_files/" aux_app[STATIC_URL] = "/static/" aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc] diff --git a/tests/test_runserver_watch.py b/tests/test_runserver_watch.py index 02edf5b2..56769ecb 100644 --- a/tests/test_runserver_watch.py +++ b/tests/test_runserver_watch.py @@ -1,12 +1,13 @@ import asyncio +import time from functools import partial -from typing import Set, Tuple -from unittest.mock import MagicMock, call +from typing import Any, Set, Tuple +from unittest.mock import AsyncMock, MagicMock, call from aiohttp import ClientSession from aiohttp.web import Application, WebSocketResponse -from aiohttp_devtools.runserver.serve import STATIC_PATH, WS +from aiohttp_devtools.runserver.serve import LAST_RELOAD, STATIC_PATH, WS from aiohttp_devtools.runserver.watch import AppTask, LiveReloadTask from .conftest import create_future @@ -81,6 +82,7 @@ async def test_python_no_server(event_loop, mocker): stop_mock = mocker.patch.object(app_task, "_stop_dev_server", autospec=True) mocker.patch.object(app_task, "_run", partial(app_task._run, live_checks=2)) app = Application() + app[LAST_RELOAD] = [0, 0.] app[STATIC_PATH] = "/path/to/" app.src_reload = MagicMock() mock_ws = MagicMock() @@ -192,3 +194,35 @@ async def test_stop_process_dirty(mocker): await app_task._stop_dev_server() assert mock_kill.call_args_list == [call(321, 2)] assert process_mock.kill.called_once() + + +async def test_restart_after_connection_loss(mocker): + mocked_awatch = mocker.patch("aiohttp_devtools.runserver.watch.awatch", autospec=True, spec_set=True) + mocked_awatch.side_effect = create_awatch_mock({("x", "/path/to/file.py")}) + app_task = AppTask(MagicMock()) + start_mock = mocker.patch.object(app_task, "_start_dev_server", autospec=True, spec_set=True) + mock_reload = mocker.patch.object(app_task, "_src_reload_when_live", autospec=True, spec_set=True) + mocker.patch.object(app_task, "_stop_dev_server", autospec=True, spec_set=True) + + app = mocker.create_autospec(Application, spec_set=True, instance=True) + # Simulate connection lost from recent restart. + ws: Set[Any] = set() + d = {WS: ws, LAST_RELOAD: [1, time.time()]} + app.__getitem__.side_effect = lambda k: d.get(k, MagicMock()) + + def update_ws(i): + ws.add(MagicMock(spec_set=())) + return AsyncMock() + + sleep_mock = mocker.patch("asyncio.sleep", autospec=True, spec_set=True) + sleep_mock.side_effect = update_ws + + await app_task.start(app) + assert app_task._task is not None + await app_task._task + assert sleep_mock.call_count < 5 + assert call(0.1) in sleep_mock.call_args_list + mock_reload.assert_called_once() + assert start_mock.call_count == 2 + assert app_task._session is not None + await app_task._session.close()