From 3a96776fad43bf7eefa2781d233886061900ae60 Mon Sep 17 00:00:00 2001 From: Bodong Yang <86948717+Bodong-Yang@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:02:10 +0900 Subject: [PATCH] refactor: ota_proxy: use simple-sqlite3-orm instead of vendoring an orm inside otaproxy package (#363) This PR introduces to use simple-sqlite3-orm package(which is also implemented by me) instead. This package is feature-rich, well-implemented and well-tested replacement for the vendored sqlite3 ORM inside the otaproxy package. Also this PR fixes a potential race condition, which the database entry being committed before the cache file being finalized, resulting small chance of cache file not found on database lookup hit. Other major changes * split cache_streaming related logic intocache_streaming module. * split lru_cache_helper from ota_cache module into lru_cache_helper module. * in lru_cache_helper, now we limit the max steps to walk down the bucket list when rotating the cache. * in db, now we use RETURNING statement when rotating cache on supported platform. * in ota_cache, now we also check db integrity to determine whether to force init. --- pyproject.toml | 6 +- src/ota_proxy/__init__.py | 2 +- src/ota_proxy/_consts.py | 2 +- ...che_control.py => cache_control_header.py} | 0 src/ota_proxy/cache_streaming.py | 457 +++++++++++++ src/ota_proxy/config.py | 2 + src/ota_proxy/db.py | 455 ++++--------- src/ota_proxy/lru_cache_helper.py | 124 ++++ src/ota_proxy/orm.py | 243 ------- src/ota_proxy/ota_cache.py | 610 ++---------------- tests/test_ota_proxy/test_cachedb.py | 263 -------- tests/test_ota_proxy/test_ota_cache.py | 183 ++++-- tests/test_ota_proxy/test_ota_proxy_server.py | 30 +- 13 files changed, 942 insertions(+), 1435 deletions(-) rename src/ota_proxy/{cache_control.py => cache_control_header.py} (100%) create mode 100644 src/ota_proxy/cache_streaming.py create mode 100644 src/ota_proxy/lru_cache_helper.py delete mode 100644 src/ota_proxy/orm.py delete mode 100644 tests/test_ota_proxy/test_cachedb.py diff --git a/pyproject.toml b/pyproject.toml index e1b905f95..4dd1bc528 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "pyopenssl<25,>=24.1", "pyyaml<7,>=6.0.1", "requests<2.33,>=2.32", + "simple-sqlite3-orm @ https://github.com/pga2rn/simple-sqlite3-orm/releases/download/v0.2.0/simple_sqlite3_orm-0.2.0-py3-none-any.whl", "typing-extensions>=4.6.3", "urllib3<2.3,>=2.2.2", "uvicorn[standard]<0.31,>=0.30", @@ -45,7 +46,7 @@ optional-dependencies.dev = [ "flake8", "isort", "pytest==7.1.2", - "pytest-asyncio==0.21", + "pytest-asyncio==0.23.8", "pytest-mock==3.14", "requests-mock", ] @@ -54,6 +55,9 @@ urls.Source = "https://github.com/tier4/ota-client" [tool.hatch.version] source = "vcs" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.hooks.vcs] version-file = "src/_otaclient_version.py" diff --git a/src/ota_proxy/__init__.py b/src/ota_proxy/__init__.py index f0eb18cbd..7e87bbf70 100644 --- a/src/ota_proxy/__init__.py +++ b/src/ota_proxy/__init__.py @@ -24,7 +24,7 @@ from typing_extensions import ParamSpec, Self -from .cache_control import OTAFileCacheControl +from .cache_control_header import OTAFileCacheControl from .config import config from .ota_cache import OTACache from .server_app import App diff --git a/src/ota_proxy/_consts.py b/src/ota_proxy/_consts.py index 4fad83696..d07769f03 100644 --- a/src/ota_proxy/_consts.py +++ b/src/ota_proxy/_consts.py @@ -1,6 +1,6 @@ from multidict import istr -from .cache_control import OTAFileCacheControl +from .cache_control_header import OTAFileCacheControl # uvicorn REQ_TYPE_LIFESPAN = "lifespan" diff --git a/src/ota_proxy/cache_control.py b/src/ota_proxy/cache_control_header.py similarity index 100% rename from src/ota_proxy/cache_control.py rename to src/ota_proxy/cache_control_header.py diff --git a/src/ota_proxy/cache_streaming.py b/src/ota_proxy/cache_streaming.py new file mode 100644 index 000000000..7a519326e --- /dev/null +++ b/src/ota_proxy/cache_streaming.py @@ -0,0 +1,457 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of cache streaming.""" + + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import os +import threading +import weakref +from concurrent.futures import Executor +from pathlib import Path +from typing import ( + AsyncGenerator, + AsyncIterator, + Callable, + Coroutine, + Generic, + List, + MutableMapping, + Optional, + Tuple, + TypeVar, + Union, +) + +import aiofiles + +from .config import config as cfg +from .db import CacheMeta +from .errors import ( + CacheMultiStreamingFailed, + CacheStreamingFailed, + CacheStreamingInterrupt, + StorageReachHardLimit, +) +from .utils import wait_with_backoff + +logger = logging.getLogger(__name__) + +_WEAKREF = TypeVar("_WEAKREF") + +# cache tracker + + +class CacheTracker(Generic[_WEAKREF]): + """A tracker for an ongoing cache entry. + + This tracker represents an ongoing cache entry under the , + and takes care of the life cycle of this temp cache entry. + It implements the provider/subscriber model for cache writing and cache streaming to + multiple clients. + + This entry will disappear automatically, ensured by gc and weakref + when no strong reference to this tracker(provider finished and no subscribers + attached to this tracker). + When this tracker is garbage collected, the corresponding temp cache entry will + also be removed automatically(ensure by registered finalizer). + + Attributes: + fpath: the path to the temporary cache file. + save_path: the path to finished cache file. + meta: an inst of CacheMeta for the remote OTA file that being cached. + writer_ready: a property indicates whether the provider is + ready to write and streaming data chunks. + writer_finished: a property indicates whether the provider finished the caching. + writer_failed: a property indicates whether provider fails to + finish the caching. + """ + + READER_SUBSCRIBE_WAIT_PROVIDER_TIMEOUT = 2 + FNAME_PART_SEPARATOR = "_" + + @classmethod + def _tmp_file_naming(cls, cache_identifier: str) -> str: + """Create fname for tmp caching entry. + + naming scheme: tmp__ + + NOTE: append 4bytes hex to identify cache entry for the same OTA file between + different trackers. + """ + return ( + f"{cfg.TMP_FILE_PREFIX}{cls.FNAME_PART_SEPARATOR}" + f"{cache_identifier}{cls.FNAME_PART_SEPARATOR}{(os.urandom(4)).hex()}" + ) + + def __init__( + self, + cache_identifier: str, + ref_holder: _WEAKREF, + *, + base_dir: Union[str, Path], + executor: Executor, + callback: _CACHE_ENTRY_REGISTER_CALLBACK, + below_hard_limit_event: threading.Event, + ): + self.fpath = Path(base_dir) / self._tmp_file_naming(cache_identifier) + self.meta: CacheMeta | None = None + self.cache_identifier = cache_identifier + self.save_path = Path(base_dir) / cache_identifier + self._writer_ready = asyncio.Event() + self._writer_finished = asyncio.Event() + self._writer_failed = asyncio.Event() + self._ref = ref_holder + self._subscriber_ref_holder: List[_WEAKREF] = [] + self._executor = executor + self._cache_commit_cb = callback + self._space_availability_event = below_hard_limit_event + + self._bytes_written = 0 + self._cache_write_gen: Optional[AsyncGenerator[int, bytes]] = None + + # self-register the finalizer to this tracker + weakref.finalize( + self, + self.finalizer, + fpath=self.fpath, + ) + + async def _finalize_caching(self): + """Commit cache entry to db and rename tmp cached file with sha256hash + to fialize the caching.""" + # if the file with the same sha256has is already presented, skip the hardlink + # NOTE: no need to clean the tmp file, it will be done by the cache tracker. + assert self.meta is not None + await self._cache_commit_cb(self.meta) + if not self.save_path.is_file(): + self.fpath.link_to(self.save_path) + + @staticmethod + def finalizer(*, fpath: Union[str, Path]): + """Finalizer that cleans up the tmp file when this tracker is gced.""" + Path(fpath).unlink(missing_ok=True) + + @property + def writer_failed(self) -> bool: + return self._writer_failed.is_set() + + @property + def writer_finished(self) -> bool: + return self._writer_finished.is_set() + + @property + def is_cache_valid(self) -> bool: + """Indicates whether the temp cache entry for this tracker is valid.""" + return ( + self.meta is not None + and self._writer_finished.is_set() + and not self._writer_failed.is_set() + ) + + def get_cache_write_gen(self) -> Optional[AsyncGenerator[int, bytes]]: + if self._cache_write_gen: + return self._cache_write_gen + logger.warning(f"tracker for {self.meta} is not ready, skipped") + + async def _provider_write_cache(self) -> AsyncGenerator[int, bytes]: + """Provider writes data chunks from upper caller to tmp cache file. + + If cache writing failed, this method will exit and tracker.writer_failed and + tracker.writer_finished will be set. + """ + logger.debug(f"start to cache for {self.meta=}...") + try: + if not self.meta: + raise ValueError("called before provider tracker is ready, abort") + + async with aiofiles.open(self.fpath, "wb", executor=self._executor) as f: + # let writer become ready when file is open successfully + self._writer_ready.set() + + _written = 0 + while _data := (yield _written): + if not self._space_availability_event.is_set(): + logger.warning( + f"abort writing cache for {self.meta=}: {StorageReachHardLimit.__name__}" + ) + self._writer_failed.set() + return + + _written = await f.write(_data) + self._bytes_written += _written + + logger.debug( + "cache write finished, total bytes written" + f"({self._bytes_written}) for {self.meta=}" + ) + self.meta.cache_size = self._bytes_written + + # NOTE: no need to track the cache commit, + # as writer_finish is meant to set on cache file created. + asyncio.create_task(self._finalize_caching()) + except Exception as e: + logger.exception(f"failed to write cache for {self.meta=}: {e!r}") + self._writer_failed.set() + finally: + self._writer_finished.set() + self._ref = None + + async def _subscribe_cache_streaming(self) -> AsyncIterator[bytes]: + """Subscriber keeps polling chunks from ongoing tmp cache file. + + Subscriber will keep polling until the provider fails or + provider finished and subscriber has read bytes. + + Raises: + CacheMultipleStreamingFailed if provider failed or timeout reading + data chunk from tmp cache file(might be caused by a dead provider). + """ + try: + err_count, _bytes_read = 0, 0 + async with aiofiles.open(self.fpath, "rb", executor=self._executor) as f: + while not (self.writer_finished and _bytes_read == self._bytes_written): + if self.writer_failed: + raise CacheMultiStreamingFailed( + f"provider aborted for {self.meta}" + ) + _bytes_read += len(_chunk := await f.read(cfg.CHUNK_SIZE)) + if _chunk: + err_count = 0 + yield _chunk + continue + + err_count += 1 + if not await wait_with_backoff( + err_count, + _backoff_factor=cfg.STREAMING_BACKOFF_FACTOR, + _backoff_max=cfg.STREAMING_BACKOFF_MAX, + ): + # abort caching due to potential dead streaming coro + _err_msg = f"failed to stream({self.meta=}): timeout getting data, partial read might happen" + logger.error(_err_msg) + # signal streamer to stop streaming + raise CacheMultiStreamingFailed(_err_msg) + finally: + # unsubscribe on finish + self._subscriber_ref_holder.pop() + + async def _read_cache(self) -> AsyncIterator[bytes]: + """Directly open the tmp cache entry and yield data chunks from it. + + Raises: + CacheMultipleStreamingFailed if fails to read from the + cached file, this might indicate a partial written cache file. + """ + _bytes_read, _retry_count = 0, 0 + async with aiofiles.open(self.fpath, "rb", executor=self._executor) as f: + while _bytes_read < self._bytes_written: + if _data := await f.read(cfg.CHUNK_SIZE): + _retry_count = 0 + _bytes_read += len(_data) + yield _data + continue + + # no data is read from the cache entry, + # retry sometimes to ensure all data is acquired + _retry_count += 1 + if not await wait_with_backoff( + _retry_count, + _backoff_factor=cfg.STREAMING_BACKOFF_FACTOR, + _backoff_max=cfg.STREAMING_CACHED_TMP_TIMEOUT, + ): + # abort caching due to potential dead streaming coro + _err_msg = ( + f"open_cached_tmp failed for ({self.meta=}): " + "timeout getting more data, partial written cache file detected" + ) + logger.debug(_err_msg) + # signal streamer to stop streaming + raise CacheMultiStreamingFailed(_err_msg) + + # exposed API + + async def provider_start(self, meta: CacheMeta): + """Register meta to the Tracker, create tmp cache entry and get ready. + + Check _provider_write_cache for more details. + + Args: + meta: inst of CacheMeta for the requested file tracked by this tracker. + This meta is created by open_remote() method. + """ + self.meta = meta + self._cache_write_gen = self._provider_write_cache() + await self._cache_write_gen.asend(None) # type: ignore + + async def provider_on_finished(self): + if not self.writer_finished and self._cache_write_gen: + with contextlib.suppress(StopAsyncIteration): + await self._cache_write_gen.asend(b"") + self._writer_finished.set() + self._ref = None + + async def provider_on_failed(self): + """Manually fail and stop the caching.""" + if not self.writer_finished and self._cache_write_gen: + logger.warning(f"interrupt writer coroutine for {self.meta=}") + with contextlib.suppress(StopAsyncIteration, CacheStreamingInterrupt): + await self._cache_write_gen.athrow(CacheStreamingInterrupt) + + self._writer_failed.set() + self._writer_finished.set() + self._ref = None + + async def subscriber_subscribe_tracker(self) -> Optional[AsyncIterator[bytes]]: + """Reader subscribe this tracker and get a file descriptor to get data chunks.""" + _wait_count = 0 + while not self._writer_ready.is_set(): + _wait_count += 1 + if self.writer_failed or not await wait_with_backoff( + _wait_count, + _backoff_factor=cfg.STREAMING_BACKOFF_FACTOR, + _backoff_max=self.READER_SUBSCRIBE_WAIT_PROVIDER_TIMEOUT, + ): + return # timeout waiting for provider to become ready + + # subscribe on an ongoing cache + if not self.writer_finished and isinstance(self._ref, _Weakref): + self._subscriber_ref_holder.append(self._ref) + return self._subscribe_cache_streaming() + # caching just finished, try to directly read the finished cache entry + elif self.is_cache_valid: + return self._read_cache() + + +# a callback that register the cache entry indicates by input CacheMeta inst to the cache_db +_CACHE_ENTRY_REGISTER_CALLBACK = Callable[[CacheMeta], Coroutine[None, None, None]] + + +class _Weakref: + pass + + +class CachingRegister: + """A tracker register that manages cache trackers. + + For each ongoing caching for unique OTA file, there will be only one unique identifier for it. + + This first caller that requests with the identifier will become the provider and create + a new tracker for this identifier. + The later comes callers will become the subscriber to this tracker. + """ + + def __init__(self, base_dir: Union[str, Path]): + self._base_dir = Path(base_dir) + self._id_ref_dict: MutableMapping[str, _Weakref] = weakref.WeakValueDictionary() + self._ref_tracker_dict: MutableMapping[_Weakref, CacheTracker] = ( + weakref.WeakKeyDictionary() + ) + + async def get_tracker( + self, + cache_identifier: str, + *, + executor: Executor, + callback: _CACHE_ENTRY_REGISTER_CALLBACK, + below_hard_limit_event: threading.Event, + ) -> Tuple[CacheTracker, bool]: + """Get an inst of CacheTracker for the cache_identifier. + + Returns: + An inst of tracker, and a bool indicates the caller is provider(True), + or subscriber(False). + """ + _new_ref = _Weakref() + _ref = self._id_ref_dict.setdefault(cache_identifier, _new_ref) + + # subscriber + if ( + _tracker := self._ref_tracker_dict.get(_ref) + ) and not _tracker.writer_failed: + return _tracker, False + + # provider, or override a failed provider + if _ref is not _new_ref: # override a failed tracker + self._id_ref_dict[cache_identifier] = _new_ref + _ref = _new_ref + + _tracker = CacheTracker( + cache_identifier, + _ref, + base_dir=self._base_dir, + executor=executor, + callback=callback, + below_hard_limit_event=below_hard_limit_event, + ) + self._ref_tracker_dict[_ref] = _tracker + return _tracker, True + + +async def cache_streaming( + fd: AsyncIterator[bytes], + meta: CacheMeta, + tracker: CacheTracker, +) -> AsyncIterator[bytes]: + """A cache streamer that get data chunk from and tees to multiple destination. + + Data chunk yielded from will be teed to: + 1. upper uvicorn otaproxy APP to send back to client, + 2. cache_tracker cache_write_gen for caching. + + Args: + fd: opened connection to a remote file. + meta: meta data of the requested resource. + tracker: an inst of ongoing cache tracker bound to this request. + + Returns: + A bytes async iterator to yield data chunk from, for upper otaproxy uvicorn APP. + + Raises: + CacheStreamingFailed if any exception happens during retrieving. + """ + + async def _inner(): + _cache_write_gen = tracker.get_cache_write_gen() + try: + # tee the incoming chunk to two destinations + async for chunk in fd: + # NOTE: for aiohttp, when HTTP chunk encoding is enabled, + # an empty chunk will be sent to indicate the EOF of stream, + # we MUST handle this empty chunk. + if not chunk: # skip if empty chunk is read from remote + continue + # to caching generator + if _cache_write_gen and not tracker.writer_finished: + try: + await _cache_write_gen.asend(chunk) + except Exception as e: + await tracker.provider_on_failed() # signal tracker + logger.error( + f"cache write coroutine failed for {meta=}, abort caching: {e!r}" + ) + # to uvicorn thread + yield chunk + await tracker.provider_on_finished() + except Exception as e: + logger.exception(f"cache tee failed for {meta=}") + await tracker.provider_on_failed() + raise CacheStreamingFailed from e + + await tracker.provider_start(meta) + return _inner() diff --git a/src/ota_proxy/config.py b/src/ota_proxy/config.py index 6696a77b3..61ab3a2a4 100644 --- a/src/ota_proxy/config.py +++ b/src/ota_proxy/config.py @@ -38,6 +38,8 @@ class Config: 32 * (1024**2): 2, # [32MiB, ~), will not be rotated } DB_FILE = f"{BASE_DIR}/cache_db" + DB_THREADS = 3 + DB_THREAD_WAIT_TIMEOUT = 30 # seconds # DB configuration/setup # ota-cache table diff --git a/src/ota_proxy/db.py b/src/ota_proxy/db.py index 3858bee82..239c9bd51 100644 --- a/src/ota_proxy/db.py +++ b/src/ota_proxy/db.py @@ -15,24 +15,32 @@ from __future__ import annotations -import asyncio import logging import sqlite3 -import threading -import time -from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Optional + +from pydantic import SkipValidation +from simple_sqlite3_orm import ( + ConstrainRepr, + ORMBase, + TableSpec, + TypeAffinityRepr, + utils, +) +from simple_sqlite3_orm._orm import AsyncORMThreadPoolBase +from typing_extensions import Annotated + +from otaclient_common.typing import StrOrPath from ._consts import HEADER_CONTENT_ENCODING, HEADER_OTA_FILE_CACHE_CONTROL -from .cache_control import OTAFileCacheControl +from .cache_control_header import OTAFileCacheControl from .config import config as cfg -from .orm import FV, ColumnDescriptor, ORMBase logger = logging.getLogger(__name__) -class CacheMeta(ORMBase): +class CacheMeta(TableSpec): """revision 4 url: unquoted URL from the request of this cache entry. @@ -47,25 +55,51 @@ class CacheMeta(ORMBase): content_encoding: the content_encoding header string comes with resp from remote server. """ - file_sha256: ColumnDescriptor[str] = ColumnDescriptor( - str, "TEXT", "UNIQUE", "NOT NULL", "PRIMARY KEY" - ) - url: ColumnDescriptor[str] = ColumnDescriptor(str, "TEXT", "NOT NULL") - bucket_idx: ColumnDescriptor[int] = ColumnDescriptor( - int, "INTEGER", "NOT NULL", type_guard=True - ) - last_access: ColumnDescriptor[int] = ColumnDescriptor( - int, "INTEGER", "NOT NULL", type_guard=(int, float) - ) - cache_size: ColumnDescriptor[int] = ColumnDescriptor( - int, "INTEGER", "NOT NULL", type_guard=True - ) - file_compression_alg: ColumnDescriptor[str] = ColumnDescriptor( - str, "TEXT", "NOT NULL" - ) - content_encoding: ColumnDescriptor[str] = ColumnDescriptor(str, "TEXT", "NOT NULL") - - def export_headers_to_client(self) -> Dict[str, str]: + file_sha256: Annotated[ + str, + TypeAffinityRepr(str), + ConstrainRepr("PRIMARY KEY"), + SkipValidation, + ] + url: Annotated[ + str, + TypeAffinityRepr(str), + ConstrainRepr("NOT NULL"), + SkipValidation, + ] + bucket_idx: Annotated[ + int, + TypeAffinityRepr(int), + ConstrainRepr("NOT NULL"), + SkipValidation, + ] = 0 + last_access: Annotated[ + int, + TypeAffinityRepr(int), + ConstrainRepr("NOT NULL"), + SkipValidation, + ] = 0 + cache_size: Annotated[ + int, + TypeAffinityRepr(int), + ConstrainRepr("NOT NULL"), + SkipValidation, + ] = 0 + file_compression_alg: Annotated[ + Optional[str], + TypeAffinityRepr(str), + SkipValidation, + ] = None + content_encoding: Annotated[ + Optional[str], + TypeAffinityRepr(str), + SkipValidation, + ] = None + + def __hash__(self) -> int: + return hash(tuple(getattr(self, attrn) for attrn in self.model_fields)) + + def export_headers_to_client(self) -> dict[str, str]: """Export required headers for client. Currently includes content-type, content-encoding and ota-file-cache-control headers. @@ -81,304 +115,99 @@ def export_headers_to_client(self) -> Dict[str, str]: res[HEADER_OTA_FILE_CACHE_CONTROL] = ( OTAFileCacheControl.export_kwargs_as_header( file_sha256=self.file_sha256, - file_compression_alg=self.file_compression_alg, + file_compression_alg=self.file_compression_alg or "", ) ) return res -class OTACacheDB: - TABLE_NAME: str = cfg.TABLE_NAME - OTA_CACHE_IDX: List[str] = [ - ( - "CREATE INDEX IF NOT EXISTS " - f"bucket_last_access_idx_{TABLE_NAME} " - f"ON {TABLE_NAME}({CacheMeta.bucket_idx.name}, {CacheMeta.last_access.name})" - ), - ] - - @classmethod - def check_db_file(cls, db_file: Union[str, Path]) -> bool: - if not Path(db_file).is_file(): - return False - try: - with sqlite3.connect(db_file) as con: - con.execute("PRAGMA integrity_check;") - # check whether expected table is in it or not - cur = con.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", - (cls.TABLE_NAME,), - ) - if cur.fetchone() is None: - logger.warning( - f"{cls.TABLE_NAME} not found, this db file should be initialized" - ) - return False - return True - except sqlite3.DatabaseError as e: - logger.warning(f"{db_file} is corrupted: {e!r}") - return False - - @classmethod - def init_db_file(cls, db_file: Union[str, Path]): - """ - Purge the old db file and init the db files, creating table in it. - """ - # remove the db file first - Path(db_file).unlink(missing_ok=True) - try: - with sqlite3.connect(db_file) as con: - # create ota_cache table - con.execute(CacheMeta.get_create_table_stmt(cls.TABLE_NAME), ()) - # create indices - for idx in cls.OTA_CACHE_IDX: - con.execute(idx, ()) - - ### db performance tunning - # enable WAL mode - con.execute("PRAGMA journal_mode = WAL;") - # set temp_store to memory - con.execute("PRAGMA temp_store = memory;") - # enable mmap (size in bytes) - mmap_size = 16 * 1024 * 1024 # 16MiB - con.execute(f"PRAGMA mmap_size = {mmap_size};") - except sqlite3.Error as e: - logger.debug(f"init db failed: {e!r}") - raise e - - def __init__(self, db_file: Union[str, Path]): - """Connects to OTA cache database. - - Args: - db_file: target db file to connect to. +class CacheMetaORM(ORMBase[CacheMeta]): ... - Raises: - ValueError on invalid ota_cache db file, - sqlite3.Error if optimization settings applied failed. - """ - self._con = sqlite3.connect( - db_file, - check_same_thread=True, # one thread per connection in the threadpool - # isolation_level=None, # enable autocommit mode - ) - self._con.row_factory = sqlite3.Row - if not self.check_db_file(db_file): - raise ValueError(f"invalid db file: {db_file}") - - # db performance tunning, enable optimization - with self._con as con: - # enable WAL mode - con.execute("PRAGMA journal_mode = WAL;") - # set temp_store to memory - con.execute("PRAGMA temp_store = memory;") - # enable mmap (size in bytes) - mmap_size = 16 * 1024 * 1024 # 16MiB - con.execute(f"PRAGMA mmap_size = {mmap_size};") - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - return False - def close(self): - self._con.close() +class AsyncCacheMetaORM(AsyncORMThreadPoolBase[CacheMeta]): - def remove_entries(self, fd: ColumnDescriptor[FV], *_inputs: FV) -> int: - """Remove entri(es) indicated by field(s). + async def rotate_cache( + self, bucket_idx: int, num: int + ) -> Optional[list[CacheMeta]]: + bucket_fn, last_access_fn = "bucket_idx", "last_access" - Args: - fd: field descriptor of the column - *_inputs: a list of field value(s) - - Returns: - returns affected rows count. - """ - if not _inputs: - return 0 - if CacheMeta.contains_field(fd): + def _inner(): with self._con as con: - _regulated_input = [(i,) for i in _inputs] - return con.executemany( - f"DELETE FROM {self.TABLE_NAME} WHERE {fd._field_name}=?", - _regulated_input, - ).rowcount - else: - logger.debug(f"invalid inputs detected: {_inputs=}") - return 0 - - def lookup_entry( - self, - fd: ColumnDescriptor[FV], - value: FV, - ) -> Optional[CacheMeta]: - """Lookup entry by field value. - - NOTE: lookup via this method will trigger update to field, - NOTE 2: field is updated by searching key. - - Args: - fd: field descriptor of the column. - value: field value to lookup. - - Returns: - An instance of CacheMeta representing the cache entry, or None if lookup failed. - """ - if not CacheMeta.contains_field(fd): - return - - fd_name = fd.name - with self._con as con: # put the lookup and update into one session - if row := con.execute( - f"SELECT * FROM {self.TABLE_NAME} WHERE {fd_name}=?", - (value,), - ).fetchone(): - # warm up the cache(update last_access timestamp) here - res = CacheMeta.row_to_meta(row) - con.execute( - ( - f"UPDATE {self.TABLE_NAME} SET {CacheMeta.last_access.name}=? " - f"WHERE {fd_name}=?" - ), - (int(time.time()), value), + # check if we have enough entries to rotate + select_stmt = self.orm_table_spec.table_select_stmt( + select_from=self.orm_table_name, + select_cols="*", + function="count", + where_cols=(bucket_fn,), + order_by=(last_access_fn,), + limit=num, ) - return res - - def insert_entry(self, *cache_meta: CacheMeta) -> int: - """Insert an entry into the ota cache table. - - Args: - cache_meta: a list of CacheMeta instances to be inserted - - Returns: - Returns inserted rows count. - """ - if not cache_meta: - return 0 - with self._con as con: - return con.executemany( - f"INSERT OR REPLACE INTO {self.TABLE_NAME} VALUES ({CacheMeta.get_shape()})", - [m.astuple() for m in cache_meta], - ).rowcount - - def lookup_all(self) -> List[CacheMeta]: - """Lookup all entries in the ota cache table. - - Returns: - A list of CacheMeta instances representing each entry. - """ - with self._con as con: - cur = con.execute(f"SELECT * FROM {self.TABLE_NAME}", ()) - return [CacheMeta.row_to_meta(row) for row in cur.fetchall()] - - def rotate_cache(self, bucket_idx: int, num: int) -> Optional[List[str]]: - """Rotate cache entries in LRU flavour. - - Args: - bucket_idx: which bucket for space reserving - num: num of entries needed to be deleted in this bucket - - Return: - A list of OTA file's hashes that needed to be deleted for space reserving, - or None if no enough entries for space reserving. - """ - bucket_fn, last_access_fn = ( - CacheMeta.bucket_idx.name, - CacheMeta.last_access.name, - ) - # first, check whether we have required number of entries in the bucket - with self._con as con: - cur = con.execute( - ( - f"SELECT COUNT(*) FROM {self.TABLE_NAME} WHERE {bucket_fn}=? " - f"ORDER BY {last_access_fn} LIMIT ?" - ), - (bucket_idx, num), - ) - if not (_raw_res := cur.fetchone()): - return - - # NOTE: if we can upgrade to sqlite3 >= 3.35, - # use RETURNING clause instead of using 2 queries as below - - # if we have enough entries for space reserving - if _raw_res[0] >= num: - # first select those entries - _rows = con.execute( - ( - f"SELECT * FROM {self.TABLE_NAME} " - f"WHERE {bucket_fn}=? " - f"ORDER BY {last_access_fn} " - "LIMIT ?" - ), - (bucket_idx, num), - ).fetchall() - # and then delete those entries with same conditions - con.execute( - ( - f"DELETE FROM {self.TABLE_NAME} " - f"WHERE {bucket_fn}=? " - f"ORDER BY {last_access_fn} " - "LIMIT ?" - ), - (bucket_idx, num), - ) - return [row[CacheMeta.file_sha256.name] for row in _rows] - - -class _ProxyBase: - """A proxy class base for OTACacheDB that dispatches all requests into a threadpool.""" - - DB_THREAD_POOL_SIZE = 1 - - def _thread_initializer(self, db_f): - """Init a db connection for each thread worker""" - # NOTE: set init to False always as we only operate db when using proxy - self._thread_local.db = OTACacheDB(db_f) - - def __init__(self, db_f: Union[str, Path]): - """Init the database connecting thread pool.""" - self._thread_local = threading.local() - # set thread_pool_size to 1 to make the db access - # to make it totally concurrent. - self._executor = ThreadPoolExecutor( - max_workers=self.DB_THREAD_POOL_SIZE, - initializer=self._thread_initializer, - initargs=(db_f,), - ) - - def close(self): - self._executor.shutdown(wait=True) - - -class AIO_OTACacheDBProxy(_ProxyBase): - async def insert_entry(self, *cache_meta: CacheMeta) -> int: - def _inner(): - _db: OTACacheDB = self._thread_local.db - return _db.insert_entry(*cache_meta) - - return await asyncio.get_running_loop().run_in_executor(self._executor, _inner) - - async def lookup_entry( - self, fd: ColumnDescriptor, _input: Any - ) -> Optional[CacheMeta]: - def _inner(): - _db: OTACacheDB = self._thread_local.db - return _db.lookup_entry(fd, _input) - - return await asyncio.get_running_loop().run_in_executor(self._executor, _inner) + cur = con.execute(select_stmt, {bucket_fn: bucket_idx}) + # we don't have enough entries to delete + if not (_raw_res := cur.fetchone()) or _raw_res[0] < num: + return + + # RETURNING statement is available only after sqlite3 v3.35.0 + if sqlite3.sqlite_version_info < (3, 35, 0): + # first select entries met the requirements + select_to_delete_stmt = self.orm_table_spec.table_select_stmt( + select_from=self.orm_table_name, + where_cols=(bucket_fn,), + order_by=(last_access_fn,), + limit=num, + ) + cur = con.execute(select_to_delete_stmt, {bucket_fn: bucket_idx}) + rows_to_remove = list(cur) + + # delete the target entries + delete_stmt = self.orm_table_spec.table_delete_stmt( + delete_from=self.orm_table_name, + where_cols=(bucket_fn,), + order_by=(last_access_fn,), + limit=num, + ) + con.execute(delete_stmt, {bucket_fn: bucket_idx}) + + return rows_to_remove + else: + rotate_stmt = self.orm_table_spec.table_delete_stmt( + delete_from=self.orm_table_name, + where_cols=(bucket_fn,), + order_by=(last_access_fn,), + limit=num, + returning_cols="*", + ) + cur = con.execute(rotate_stmt, {bucket_fn: bucket_idx}) + return list(cur) - async def remove_entries(self, fd: ColumnDescriptor, *_inputs: Any) -> int: - def _inner(): - _db: OTACacheDB = self._thread_local.db - return _db.remove_entries(fd, *_inputs) + return await self._run_in_pool(_inner) - return await asyncio.get_running_loop().run_in_executor(self._executor, _inner) - async def rotate_cache(self, bucket_idx: int, num: int) -> Optional[List[str]]: - def _inner(): - _db: OTACacheDB = self._thread_local.db - return _db.rotate_cache(bucket_idx, num) +def check_db(db_f: StrOrPath, table_name: str) -> bool: + """Check whether specific db is normal or not.""" + if not Path(db_f).is_file(): + logger.warning(f"{db_f} not found") + return False - return await asyncio.get_running_loop().run_in_executor(self._executor, _inner) + con = sqlite3.connect(f"file:{db_f}?mode=ro", uri=True) + try: + if not utils.check_db_integrity(con): + logger.warning(f"{db_f} fails integrity check") + return False + if not utils.lookup_table(con, table_name): + logger.warning(f"{table_name} not found in {db_f}") + return False + finally: + con.close() + return True + + +def init_db(db_f: StrOrPath, table_name: str) -> None: + """Init the database.""" + con = sqlite3.connect(db_f) + orm = CacheMetaORM(con, table_name) + try: + orm.orm_create_table(without_rowid=True) + utils.enable_wal_mode(con, relax_sync_mode=True) + finally: + con.close() diff --git a/src/ota_proxy/lru_cache_helper.py b/src/ota_proxy/lru_cache_helper.py new file mode 100644 index 000000000..5c1df1c8c --- /dev/null +++ b/src/ota_proxy/lru_cache_helper.py @@ -0,0 +1,124 @@ +# Copyright 2022 TIER IV, INC. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of OTA cache control.""" + + +from __future__ import annotations + +import bisect +import logging +import sqlite3 +import time +from pathlib import Path +from typing import Optional, Union + +from simple_sqlite3_orm import utils + +from .db import AsyncCacheMetaORM, CacheMeta + +logger = logging.getLogger(__name__) + + +class LRUCacheHelper: + """A helper class that provides API for accessing/managing cache entries in ota cachedb. + + Serveral buckets are created according to predefined file size threshould. + Each bucket will maintain the cache entries of that bucket's size definition, + LRU is applied on per-bucket scale. + + NOTE: currently entry in first bucket and last bucket will skip LRU rotate. + """ + + def __init__( + self, + db_f: Union[str, Path], + *, + bsize_dict: dict[int, int], + table_name: str, + thread_nums: int, + thread_wait_timeout: int, + ): + self.bsize_list = list(bsize_dict) + self.bsize_dict = bsize_dict.copy() + + def _con_factory(): + con = sqlite3.connect( + db_f, check_same_thread=False, timeout=thread_wait_timeout + ) + + utils.enable_wal_mode(con, relax_sync_mode=True) + utils.enable_mmap(con) + utils.enable_tmp_store_at_memory(con) + return con + + self._async_db = AsyncCacheMetaORM( + table_name=table_name, + con_factory=_con_factory, + number_of_cons=thread_nums, + ) + + self._closed = False + + def close(self): + self._async_db.orm_pool_shutdown(wait=True, close_connections=True) + self._closed = True + + async def commit_entry(self, entry: CacheMeta) -> bool: + """Commit cache entry meta to the database.""" + # populate bucket and last_access column + entry.bucket_idx = bisect.bisect_right(self.bsize_list, entry.cache_size) - 1 + entry.last_access = int(time.time()) + + if (await self._async_db.orm_insert_entry(entry, or_option="replace")) != 1: + logger.error(f"db: failed to add {entry=}") + return False + return True + + async def lookup_entry(self, file_sha256: str) -> Optional[CacheMeta]: + return await self._async_db.orm_select_entry(file_sha256=file_sha256) + + async def remove_entry(self, file_sha256: str) -> bool: + return bool(await self._async_db.orm_delete_entries(file_sha256=file_sha256)) + + async def rotate_cache(self, size: int) -> Optional[list[str]]: + """Wrapper method for calling the database LRU cache rotating method. + + Args: + size int: the size of file that we want to reserve space for + + Returns: + A list of hashes that needed to be cleaned, or empty list if rotation + is not required, or None if cache rotation cannot be executed. + """ + # NOTE: currently item size smaller than 1st bucket and larger than latest bucket + # will be saved without cache rotating. + if size >= self.bsize_list[-1] or size < self.bsize_list[1]: + return [] + + _cur_bucket_idx = bisect.bisect_right(self.bsize_list, size) - 1 + _cur_bucket_size = self.bsize_list[_cur_bucket_idx] + + # first: check the upper bucket, remove 1 item from any of the + # upper bucket is enough. + # NOTE(20240802): limit the max steps we can go to avoid remove too large file + max_steps = min(len(self.bsize_list), _cur_bucket_idx + 3) + for _bucket_idx in range(_cur_bucket_idx + 1, max_steps): + if res := await self._async_db.rotate_cache(_bucket_idx, 1): + return [entry.file_sha256 for entry in res] + + # second: if cannot find one entry at any upper bucket, check current bucket + if res := await self._async_db.rotate_cache( + _cur_bucket_idx, self.bsize_dict[_cur_bucket_size] + ): + return [entry.file_sha256 for entry in res] diff --git a/src/ota_proxy/orm.py b/src/ota_proxy/orm.py deleted file mode 100644 index 7141ccffd..000000000 --- a/src/ota_proxy/orm.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -from dataclasses import asdict, astuple, dataclass, fields -from io import StringIO -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generic, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - overload, -) - -from typing_extensions import Self, dataclass_transform - -if TYPE_CHECKING: - import sqlite3 - - -class NULL_TYPE: - """Singleton for NULL type.""" - - def __new__(cls, *args, **kwargs) -> None: - return None - - -SQLITE_DATATYPES = Union[ - int, # INTEGER - str, # TEXT - float, # REAL - bytes, # BLOB - bool, # INTEGER 0, 1 - NULL_TYPE, # NULL -] -SQLITE_DATATYPES_SET = {int, str, float, bytes, bool, NULL_TYPE} -FV = TypeVar("FV", bound=SQLITE_DATATYPES) # field value type -TYPE_CHECKER = Callable[[Any], bool] - - -class ColumnDescriptor(Generic[FV]): - """ColumnDescriptor represents a column in a sqlite3 table, - implemented the python descriptor protocol. - - When accessed as attribute of TableCls(subclass of ORMBase) instance, - it will return the value of the column/field. - When accessed as attribute of the TableCls class, - it will return the ColumnDescriptor itself. - """ - - def __init__( - self, - field_type: Type[FV], - *constrains: str, - type_guard: Union[Tuple[Type, ...], TYPE_CHECKER, bool] = False, - default: Optional[FV] = None, - ) -> None: - # whether this field should be included in table def or not - self._skipped = False - self.constrains = " ".join(constrains) # TODO: constrains validation - - self.field_type = field_type - self.default = field_type() if default is None else default - - # init type checker callable - # default to check over the specific field type - self.type_guard_enabled = bool(type_guard) - self.type_checker = lambda x: isinstance(x, field_type) - if isinstance(type_guard, tuple): # check over a list of types - self.type_checker = lambda x: isinstance(x, type_guard) - elif callable(type_guard): # custom type guard function - self.type_checker = type_guard - - @overload - def __get__(self, obj: None, objtype: type) -> Self: - """Descriptor accessed via class.""" - ... - - @overload - def __get__(self, obj, objtype: type) -> FV: - """Descriptor accessed via bound instance.""" - ... - - def __get__(self, obj, objtype=None) -> Union[FV, Self]: - # try accessing bound instance attribute - if not self._skipped and obj is not None: - if isinstance(obj, type): - return self # bound inst is type(class access mode), return descriptor itself - return getattr(obj, self._private_name) # access via instance - return self # return descriptor instance by default - - def __set__(self, obj, value: Any) -> None: - if self._skipped: - return # ignore value setting on skipped field - - # set default value and ignore input value - # 1. field_type is NULL_TYPE - # 2. value is None - # 3. value is descriptor instance itself - # dataclass will retrieve default value with class access form, - # but we return descriptor instance in class access form to expose - # descriptor helper methods(check_type, etc.), so apply special - # treatment here. - if self.field_type is NULL_TYPE or value is None or value is self: - return setattr(obj, self._private_name, self.default) - - # handle normal value setting - if self.type_guard_enabled and not self.type_checker(value): - raise TypeError(f"type_guard: expect {self.field_type}, get {type(value)}") - # if value's type is not field_type but subclass or compatible types that can pass type_checker, - # convert the value to the field_type first before assigning - _input_value = ( - value if type(value) is self.field_type else self.field_type(value) - ) - setattr(obj, self._private_name, _input_value) - - def __set_name__(self, owner: type, name: str) -> None: - self.owner = owner - try: - self._index = list(owner.__annotations__).index(name) - except (AttributeError, ValueError): - self._skipped = True # skipped due to annotation missing - self._field_name = name - self._private_name = f"_{owner.__name__}_{name}" - - @property - def name(self) -> str: - return self._field_name - - @property - def index(self) -> int: - return self._index - - def check_type(self, value: Any) -> bool: - return self.type_checker(value) - - -@dataclass_transform() -class ORMeta(type): - """This metaclass is for generating customized .""" - - def __new__(cls, cls_name: str, bases: Tuple[type, ...], classdict: Dict[str, Any]): - new_cls: type = super().__new__(cls, cls_name, bases, classdict) - if len(new_cls.__mro__) > 2: # , ORMBase, object - # we will define our own eq and hash logics, disable dataclass' - # eq method and hash method generation - return dataclass(eq=False, unsafe_hash=False)(new_cls) - else: # ORMBase, object - return new_cls - - -class ORMBase(metaclass=ORMeta): - """Base class for defining a sqlite3 table programatically. - - Subclass of this base class is also a subclass of dataclass. - """ - - @classmethod - def row_to_meta(cls, row: "Union[sqlite3.Row, Dict[str, Any], Tuple[Any]]") -> Self: - parsed = {} - for field in fields(cls): - try: - col: ColumnDescriptor = getattr(cls, field.name) - if isinstance(row, tuple): - parsed[col.name] = row[col.index] - else: - parsed[col.name] = row[col.name] - except (IndexError, KeyError): - continue # silently ignore unknonw input fields - return cls(**parsed) - - @classmethod - def get_create_table_stmt(cls, table_name: str) -> str: - """Generate the sqlite query statement to create the defined table in database. - - Args: - table_name: the name of table to be created - - Returns: - query statement to create the table defined by this class. - """ - _col_descriptors: List[ColumnDescriptor] = [ - getattr(cls, field.name) for field in fields(cls) - ] - with StringIO() as buffer: - buffer.write(f"CREATE TABLE {table_name}") - buffer.write("(") - buffer.write( - ", ".join([f"{col.name} {col.constrains}" for col in _col_descriptors]) - ) - buffer.write(")") - return buffer.getvalue() - - @classmethod - def contains_field(cls, _input: Union[str, ColumnDescriptor]) -> bool: - """Check if this table contains field indicated by <_input>.""" - if isinstance(_input, ColumnDescriptor): - return _input.owner.__name__ == cls.__name__ - return isinstance(getattr(cls, _input), ColumnDescriptor) - - @classmethod - def get_shape(cls) -> str: - """Used by insert row query.""" - return ",".join(["?"] * len(fields(cls))) - - def __hash__(self) -> int: - """compute the hash with all stored fields' value.""" - return hash(astuple(self)) - - def __eq__(self, __o: object) -> bool: - if not isinstance(__o, self.__class__): - return False - for field in fields(self): - field_name = field.name - if getattr(self, field_name) != getattr(__o, field_name): - return False - return True - - def astuple(self) -> Tuple[SQLITE_DATATYPES, ...]: - return astuple(self) - - def asdict(self) -> Dict[str, SQLITE_DATATYPES]: - return asdict(self) diff --git a/src/ota_proxy/ota_cache.py b/src/ota_proxy/ota_cache.py index 7a4920eb1..ad7fded61 100644 --- a/src/ota_proxy/ota_cache.py +++ b/src/ota_proxy/ota_cache.py @@ -16,56 +16,32 @@ from __future__ import annotations import asyncio -import bisect -import contextlib import logging -import os import shutil import threading import time -import weakref -from concurrent.futures import Executor, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import ( - AsyncGenerator, - AsyncIterator, - Callable, - Coroutine, - Dict, - Generic, - List, - Mapping, - MutableMapping, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import AsyncIterator, Dict, List, Mapping, Optional, Tuple from urllib.parse import SplitResult, quote, urlsplit -import aiofiles import aiohttp from multidict import CIMultiDictProxy +from otaclient_common.typing import StrOrPath + from ._consts import HEADER_CONTENT_ENCODING, HEADER_OTA_FILE_CACHE_CONTROL -from .cache_control import OTAFileCacheControl +from .cache_control_header import OTAFileCacheControl +from .cache_streaming import CachingRegister, cache_streaming from .config import config as cfg -from .db import AIO_OTACacheDBProxy, CacheMeta, OTACacheDB -from .errors import ( - BaseOTACacheError, - CacheMultiStreamingFailed, - CacheStreamingFailed, - CacheStreamingInterrupt, - StorageReachHardLimit, -) -from .utils import read_file, url_based_hash, wait_with_backoff +from .db import CacheMeta, check_db, init_db +from .errors import BaseOTACacheError +from .lru_cache_helper import LRUCacheHelper +from .utils import read_file, url_based_hash logger = logging.getLogger(__name__) -_WEAKREF = TypeVar("_WEAKREF") - - # helper functions @@ -86,490 +62,22 @@ def create_cachemeta_for_request( compression_alg: pre-collected information from caller resp_headers_from_upper """ - cache_meta = CacheMeta( - url=raw_url, - content_encoding=resp_headers_from_upper.get(HEADER_CONTENT_ENCODING, ""), - ) - _upper_cache_policy = OTAFileCacheControl.parse_header( resp_headers_from_upper.get(HEADER_OTA_FILE_CACHE_CONTROL, "") ) if _upper_cache_policy.file_sha256: - cache_meta.file_sha256 = _upper_cache_policy.file_sha256 - cache_meta.file_compression_alg = _upper_cache_policy.file_compression_alg + file_sha256 = _upper_cache_policy.file_sha256 + file_compression_alg = _upper_cache_policy.file_compression_alg or None else: - cache_meta.file_sha256 = cache_identifier - cache_meta.file_compression_alg = compression_alg - return cache_meta - - -# cache tracker - - -class CacheTracker(Generic[_WEAKREF]): - """A tracker for an ongoing cache entry. - - This tracker represents an ongoing cache entry under the , - and takes care of the life cycle of this temp cache entry. - It implements the provider/subscriber model for cache writing and cache streaming to - multiple clients. - - This entry will disappear automatically, ensured by gc and weakref - when no strong reference to this tracker(provider finished and no subscribers - attached to this tracker). - When this tracker is garbage collected, the corresponding temp cache entry will - also be removed automatically(ensure by registered finalizer). - - Attributes: - fpath: the path to the temporary cache file. - save_path: the path to finished cache file. - meta: an inst of CacheMeta for the remote OTA file that being cached. - writer_ready: a property indicates whether the provider is - ready to write and streaming data chunks. - writer_finished: a property indicates whether the provider finished the caching. - writer_failed: a property indicates whether provider fails to - finish the caching. - """ - - READER_SUBSCRIBE_WAIT_PROVIDER_TIMEOUT = 2 - FNAME_PART_SEPARATOR = "_" - - @classmethod - def _tmp_file_naming(cls, cache_identifier: str) -> str: - """Create fname for tmp caching entry. - - naming scheme: tmp__ - - NOTE: append 4bytes hex to identify cache entry for the same OTA file between - different trackers. - """ - return ( - f"{cfg.TMP_FILE_PREFIX}{cls.FNAME_PART_SEPARATOR}" - f"{cache_identifier}{cls.FNAME_PART_SEPARATOR}{(os.urandom(4)).hex()}" - ) - - def __init__( - self, - cache_identifier: str, - ref_holder: _WEAKREF, - *, - base_dir: Union[str, Path], - executor: Executor, - callback: _CACHE_ENTRY_REGISTER_CALLBACK, - below_hard_limit_event: threading.Event, - ): - self.fpath = Path(base_dir) / self._tmp_file_naming(cache_identifier) - self.meta: CacheMeta = None # type: ignore[assignment] - self.cache_identifier = cache_identifier - self.save_path = Path(base_dir) / cache_identifier - self._writer_ready = asyncio.Event() - self._writer_finished = asyncio.Event() - self._writer_failed = asyncio.Event() - self._ref = ref_holder - self._subscriber_ref_holder: List[_WEAKREF] = [] - self._executor = executor - self._cache_commit_cb = callback - self._space_availability_event = below_hard_limit_event - - self._bytes_written = 0 - self._cache_write_gen: Optional[AsyncGenerator[int, bytes]] = None - - # self-register the finalizer to this tracker - weakref.finalize( - self, - self.finalizer, - fpath=self.fpath, - ) - - async def _finalize_caching(self): - """Commit cache entry to db and rename tmp cached file with sha256hash - to fialize the caching.""" - # if the file with the same sha256has is already presented, skip the hardlink - # NOTE: no need to clean the tmp file, it will be done by the cache tracker. - await self._cache_commit_cb(self.meta) - if not self.save_path.is_file(): - self.fpath.link_to(self.save_path) - - @staticmethod - def finalizer(*, fpath: Union[str, Path]): - """Finalizer that cleans up the tmp file when this tracker is gced.""" - Path(fpath).unlink(missing_ok=True) - - @property - def writer_failed(self) -> bool: - return self._writer_failed.is_set() - - @property - def writer_finished(self) -> bool: - return self._writer_finished.is_set() - - @property - def is_cache_valid(self) -> bool: - """Indicates whether the temp cache entry for this tracker is valid.""" - return ( - self.meta is not None - and self._writer_finished.is_set() - and not self._writer_failed.is_set() - ) - - def get_cache_write_gen(self) -> Optional[AsyncGenerator[int, bytes]]: - if self._cache_write_gen: - return self._cache_write_gen - logger.warning(f"tracker for {self.meta} is not ready, skipped") - - async def _provider_write_cache(self) -> AsyncGenerator[int, bytes]: - """Provider writes data chunks from upper caller to tmp cache file. - - If cache writing failed, this method will exit and tracker.writer_failed and - tracker.writer_finished will be set. - """ - logger.debug(f"start to cache for {self.meta=}...") - try: - if not self.meta: - raise ValueError("called before provider tracker is ready, abort") - - async with aiofiles.open(self.fpath, "wb", executor=self._executor) as f: - # let writer become ready when file is open successfully - self._writer_ready.set() - - _written = 0 - while _data := (yield _written): - if not self._space_availability_event.is_set(): - logger.warning( - f"abort writing cache for {self.meta=}: {StorageReachHardLimit.__name__}" - ) - self._writer_failed.set() - return - - _written = await f.write(_data) - self._bytes_written += _written - - logger.debug( - "cache write finished, total bytes written" - f"({self._bytes_written}) for {self.meta=}" - ) - self.meta.cache_size = self._bytes_written - - # NOTE: no need to track the cache commit, - # as writer_finish is meant to set on cache file created. - asyncio.create_task(self._finalize_caching()) - except Exception as e: - logger.exception(f"failed to write cache for {self.meta=}: {e!r}") - self._writer_failed.set() - finally: - self._writer_finished.set() - self._ref = None - - async def _subscribe_cache_streaming(self) -> AsyncIterator[bytes]: - """Subscriber keeps polling chunks from ongoing tmp cache file. - - Subscriber will keep polling until the provider fails or - provider finished and subscriber has read bytes. - - Raises: - CacheMultipleStreamingFailed if provider failed or timeout reading - data chunk from tmp cache file(might be caused by a dead provider). - """ - try: - err_count, _bytes_read = 0, 0 - async with aiofiles.open(self.fpath, "rb", executor=self._executor) as f: - while not (self.writer_finished and _bytes_read == self._bytes_written): - if self.writer_failed: - raise CacheMultiStreamingFailed( - f"provider aborted for {self.meta}" - ) - _bytes_read += len(_chunk := await f.read(cfg.CHUNK_SIZE)) - if _chunk: - err_count = 0 - yield _chunk - continue - - err_count += 1 - if not await wait_with_backoff( - err_count, - _backoff_factor=cfg.STREAMING_BACKOFF_FACTOR, - _backoff_max=cfg.STREAMING_BACKOFF_MAX, - ): - # abort caching due to potential dead streaming coro - _err_msg = f"failed to stream({self.meta=}): timeout getting data, partial read might happen" - logger.error(_err_msg) - # signal streamer to stop streaming - raise CacheMultiStreamingFailed(_err_msg) - finally: - # unsubscribe on finish - self._subscriber_ref_holder.pop() - - async def _read_cache(self) -> AsyncIterator[bytes]: - """Directly open the tmp cache entry and yield data chunks from it. - - Raises: - CacheMultipleStreamingFailed if fails to read from the - cached file, this might indicate a partial written cache file. - """ - _bytes_read, _retry_count = 0, 0 - async with aiofiles.open(self.fpath, "rb", executor=self._executor) as f: - while _bytes_read < self._bytes_written: - if _data := await f.read(cfg.CHUNK_SIZE): - _retry_count = 0 - _bytes_read += len(_data) - yield _data - continue - - # no data is read from the cache entry, - # retry sometimes to ensure all data is acquired - _retry_count += 1 - if not await wait_with_backoff( - _retry_count, - _backoff_factor=cfg.STREAMING_BACKOFF_FACTOR, - _backoff_max=cfg.STREAMING_CACHED_TMP_TIMEOUT, - ): - # abort caching due to potential dead streaming coro - _err_msg = ( - f"open_cached_tmp failed for ({self.meta=}): " - "timeout getting more data, partial written cache file detected" - ) - logger.debug(_err_msg) - # signal streamer to stop streaming - raise CacheMultiStreamingFailed(_err_msg) - - # exposed API - - async def provider_start(self, meta: CacheMeta): - """Register meta to the Tracker, create tmp cache entry and get ready. - - Check _provider_write_cache for more details. - - Args: - meta: inst of CacheMeta for the requested file tracked by this tracker. - This meta is created by open_remote() method. - """ - self.meta = meta - self._cache_write_gen = self._provider_write_cache() - await self._cache_write_gen.asend(None) # type: ignore - - async def provider_on_finished(self): - if not self.writer_finished and self._cache_write_gen: - with contextlib.suppress(StopAsyncIteration): - await self._cache_write_gen.asend(b"") - self._writer_finished.set() - self._ref = None - - async def provider_on_failed(self): - """Manually fail and stop the caching.""" - if not self.writer_finished and self._cache_write_gen: - logger.warning(f"interrupt writer coroutine for {self.meta=}") - with contextlib.suppress(StopAsyncIteration, CacheStreamingInterrupt): - await self._cache_write_gen.athrow(CacheStreamingInterrupt) - - self._writer_failed.set() - self._writer_finished.set() - self._ref = None - - async def subscriber_subscribe_tracker(self) -> Optional[AsyncIterator[bytes]]: - """Reader subscribe this tracker and get a file descriptor to get data chunks.""" - _wait_count = 0 - while not self._writer_ready.is_set(): - _wait_count += 1 - if self.writer_failed or not await wait_with_backoff( - _wait_count, - _backoff_factor=cfg.STREAMING_BACKOFF_FACTOR, - _backoff_max=self.READER_SUBSCRIBE_WAIT_PROVIDER_TIMEOUT, - ): - return # timeout waiting for provider to become ready - - # subscribe on an ongoing cache - if not self.writer_finished and isinstance(self._ref, _Weakref): - self._subscriber_ref_holder.append(self._ref) - return self._subscribe_cache_streaming() - # caching just finished, try to directly read the finished cache entry - elif self.is_cache_valid: - return self._read_cache() - - -# a callback that register the cache entry indicates by input CacheMeta inst to the cache_db -_CACHE_ENTRY_REGISTER_CALLBACK = Callable[[CacheMeta], Coroutine[None, None, None]] - - -class _Weakref: - pass - - -class CachingRegister: - """A tracker register that manages cache trackers. - - For each ongoing caching for unique OTA file, there will be only one unique identifier for it. - - This first caller that requests with the identifier will become the provider and create - a new tracker for this identifier. - The later comes callers will become the subscriber to this tracker. - """ - - def __init__(self, base_dir: Union[str, Path]): - self._base_dir = Path(base_dir) - self._id_ref_dict: MutableMapping[str, _Weakref] = weakref.WeakValueDictionary() - self._ref_tracker_dict: MutableMapping[_Weakref, CacheTracker] = ( - weakref.WeakKeyDictionary() - ) - - async def get_tracker( - self, - cache_identifier: str, - *, - executor: Executor, - callback: _CACHE_ENTRY_REGISTER_CALLBACK, - below_hard_limit_event: threading.Event, - ) -> Tuple[CacheTracker, bool]: - """Get an inst of CacheTracker for the cache_identifier. - - Returns: - An inst of tracker, and a bool indicates the caller is provider(True), - or subscriber(False). - """ - _new_ref = _Weakref() - _ref = self._id_ref_dict.setdefault(cache_identifier, _new_ref) - - # subscriber - if ( - _tracker := self._ref_tracker_dict.get(_ref) - ) and not _tracker.writer_failed: - return _tracker, False - - # provider, or override a failed provider - if _ref is not _new_ref: # override a failed tracker - self._id_ref_dict[cache_identifier] = _new_ref - _ref = _new_ref - - _tracker = CacheTracker( - cache_identifier, - _ref, - base_dir=self._base_dir, - executor=executor, - callback=callback, - below_hard_limit_event=below_hard_limit_event, - ) - self._ref_tracker_dict[_ref] = _tracker - return _tracker, True - - -class LRUCacheHelper: - """A helper class that provides API for accessing/managing cache entries in ota cachedb. - - Serveral buckets are created according to predefined file size threshould. - Each bucket will maintain the cache entries of that bucket's size definition, - LRU is applied on per-bucket scale. - - NOTE: currently entry in first bucket and last bucket will skip LRU rotate. - """ - - BSIZE_LIST = list(cfg.BUCKET_FILE_SIZE_DICT.keys()) - BSIZE_DICT = cfg.BUCKET_FILE_SIZE_DICT - - def __init__(self, db_f: Union[str, Path]): - self._db = AIO_OTACacheDBProxy(db_f) - self._closed = False - - def close(self): - if not self._closed: - self._db.close() - - async def commit_entry(self, entry: CacheMeta) -> bool: - """Commit cache entry meta to the database.""" - # populate bucket and last_access column - entry.bucket_idx = bisect.bisect_right(self.BSIZE_LIST, entry.cache_size) - 1 - entry.last_access = int(time.time()) - - if (await self._db.insert_entry(entry)) != 1: - logger.error(f"db: failed to add {entry=}") - return False - return True - - async def lookup_entry(self, file_sha256: str) -> Optional[CacheMeta]: - return await self._db.lookup_entry(CacheMeta.file_sha256, file_sha256) - - async def remove_entry(self, file_sha256: str) -> bool: - return (await self._db.remove_entries(CacheMeta.file_sha256, file_sha256)) > 0 - - async def rotate_cache(self, size: int) -> Optional[List[str]]: - """Wrapper method for calling the database LRU cache rotating method. - - Args: - size int: the size of file that we want to reserve space for - - Returns: - A list of hashes that needed to be cleaned, or empty list if rotation - is not required, or None if cache rotation cannot be executed. - """ - # NOTE: currently item size smaller than 1st bucket and larger than latest bucket - # will be saved without cache rotating. - if size >= self.BSIZE_LIST[-1] or size < self.BSIZE_LIST[1]: - return [] - - _cur_bucket_idx = bisect.bisect_right(self.BSIZE_LIST, size) - 1 - _cur_bucket_size = self.BSIZE_LIST[_cur_bucket_idx] - - # first check the upper bucket, remove 1 item from any of the - # upper bucket is enough. - for _bucket_idx in range(_cur_bucket_idx + 1, len(self.BSIZE_LIST)): - if res := await self._db.rotate_cache(_bucket_idx, 1): - return res - # if cannot find one entry at any upper bucket, check current bucket - return await self._db.rotate_cache( - _cur_bucket_idx, self.BSIZE_DICT[_cur_bucket_size] - ) - - -async def cache_streaming( - fd: AsyncIterator[bytes], - meta: CacheMeta, - tracker: CacheTracker, -) -> AsyncIterator[bytes]: - """A cache streamer that get data chunk from and tees to multiple destination. - - Data chunk yielded from will be teed to: - 1. upper uvicorn otaproxy APP to send back to client, - 2. cache_tracker cache_write_gen for caching. - - Args: - fd: opened connection to a remote file. - meta: meta data of the requested resource. - tracker: an inst of ongoing cache tracker bound to this request. - - Returns: - A bytes async iterator to yield data chunk from, for upper otaproxy uvicorn APP. - - Raises: - CacheStreamingFailed if any exception happens during retrieving. - """ + file_sha256 = cache_identifier + file_compression_alg = compression_alg - async def _inner(): - _cache_write_gen = tracker.get_cache_write_gen() - try: - # tee the incoming chunk to two destinations - async for chunk in fd: - # NOTE: for aiohttp, when HTTP chunk encoding is enabled, - # an empty chunk will be sent to indicate the EOF of stream, - # we MUST handle this empty chunk. - if not chunk: # skip if empty chunk is read from remote - continue - # to caching generator - if _cache_write_gen and not tracker.writer_finished: - try: - await _cache_write_gen.asend(chunk) - except Exception as e: - await tracker.provider_on_failed() # signal tracker - logger.error( - f"cache write coroutine failed for {meta=}, abort caching: {e!r}" - ) - # to uvicorn thread - yield chunk - await tracker.provider_on_finished() - except Exception as e: - logger.exception(f"cache tee failed for {meta=}") - await tracker.provider_on_failed() - raise CacheStreamingFailed from e - - await tracker.provider_start(meta) - return _inner() + return CacheMeta( + file_sha256=file_sha256, + file_compression_alg=file_compression_alg, + url=raw_url, + content_encoding=resp_headers_from_upper.get(HEADER_CONTENT_ENCODING), + ) class OTACache: @@ -595,8 +103,8 @@ def __init__( *, cache_enabled: bool, init_cache: bool, - base_dir: Optional[Union[str, Path]] = None, - db_file: Optional[Union[str, Path]] = None, + base_dir: Optional[StrOrPath] = None, + db_file: Optional[StrOrPath] = None, upper_proxy: str = "", enable_https: bool = False, external_cache: Optional[str] = None, @@ -606,14 +114,26 @@ def __init__( f"init ota_cache({cache_enabled=}, {init_cache=}, {upper_proxy=}, {enable_https=})" ) self._closed = True + self._shutdown_lock = asyncio.Lock() + self.table_name = table_name = cfg.TABLE_NAME self._chunk_size = cfg.CHUNK_SIZE - self._base_dir = Path(base_dir) if base_dir else Path(cfg.BASE_DIR) - self._db_file = Path(db_file) if db_file else Path(cfg.DB_FILE) self._cache_enabled = cache_enabled self._init_cache = init_cache self._enable_https = enable_https - self._executor = ThreadPoolExecutor(thread_name_prefix="ota_cache_executor") + + self._base_dir = Path(base_dir) if base_dir else Path(cfg.BASE_DIR) + self._db_file = db_f = Path(db_file) if db_file else Path(cfg.DB_FILE) + + self._base_dir.mkdir(parents=True, exist_ok=True) + if not check_db(self._db_file, table_name): + logger.info(f"db file is broken, force init db file at {db_f}") + db_f.unlink(missing_ok=True) + self._init_cache = True # force init cache on db file cleanup + + self._executor = ThreadPoolExecutor( + thread_name_prefix="ota_cache_fileio_executor" + ) if external_cache and cache_enabled: logger.info(f"external cache source is enabled at: {external_cache}") @@ -623,14 +143,6 @@ def __init__( self._storage_below_soft_limit_event = threading.Event() self._upper_proxy = upper_proxy - def _check_cache_db(self) -> bool: - """Check ota_cache can be reused or not.""" - return ( - self._base_dir.is_dir() - and self._db_file.is_file() - and OTACacheDB.check_db_file(self._db_file) - ) - async def start(self): """Start the ota_cache instance.""" # silently ignore multi launching of ota_cache @@ -657,15 +169,15 @@ async def start(self): if self._cache_enabled: # purge cache dir if requested(init_cache=True) or ota_cache invalid, # and then recreate the cache folder and cache db file. - db_f_valid = self._check_cache_db() - if self._init_cache or not db_f_valid: - logger.warning( - f"purge and init ota_cache: {self._init_cache=}, {db_f_valid}" - ) + if self._init_cache: + logger.warning("purge and init ota_cache") shutil.rmtree(str(self._base_dir), ignore_errors=True) self._base_dir.mkdir(exist_ok=True, parents=True) # init db file with table created - OTACacheDB.init_db_file(self._db_file) + self._db_file.unlink(missing_ok=True) + + init_db(self._db_file, cfg.TABLE_NAME) + # reuse the previously left ota_cache else: # cleanup unfinished tmp files for tmp_f in self._base_dir.glob(f"{cfg.TMP_FILE_PREFIX}*"): @@ -675,7 +187,13 @@ async def start(self): self._executor.submit(self._background_check_free_space) # init cache helper(and connect to ota_cache db) - self._lru_helper = LRUCacheHelper(self._db_file) + self._lru_helper = LRUCacheHelper( + self._db_file, + bsize_dict=cfg.BUCKET_FILE_SIZE_DICT, + table_name=cfg.TABLE_NAME, + thread_nums=cfg.DB_THREADS, + thread_wait_timeout=cfg.DB_THREAD_WAIT_TIMEOUT, + ) self._on_going_caching = CachingRegister(self._base_dir) if self._upper_proxy: @@ -691,11 +209,14 @@ async def close(self): performed by the OTACache. """ logger.debug("shutdown ota-cache...") - if self._cache_enabled and not self._closed: - self._closed = True - await self._session.close() - self._lru_helper.close() - self._executor.shutdown(wait=True) + async with self._shutdown_lock: + if not self._closed: + self._closed = True + await self._session.close() + self._executor.shutdown(wait=True) + + if self._cache_enabled: + self._lru_helper.close() logger.info("shutdown ota-cache completed") @@ -894,6 +415,15 @@ async def _retrieve_file_by_cache( return # check if cache file exists + # NOTE(20240729): there is an edge condition that the finished cached file is not yet renamed, + # but the database entry has already been inserted. Here we wait for 3 rounds for + # cache_commit_callback to rename the tmp file. + _retry_count_max, _factor = 3, 0.01 + for _retry_count in range(_retry_count_max): + if cache_file.is_file(): + break + await asyncio.sleep(_factor * _retry_count) # give away the event loop + if not cache_file.is_file(): logger.warning( f"dangling cache entry found, remove db entry: {meta_db_entry}" @@ -999,9 +529,9 @@ async def retrieve_file( # pre-calculated cache_identifier and corresponding compression_alg cache_identifier = cache_policy.file_sha256 compression_alg = cache_policy.file_compression_alg - if ( - not cache_identifier - ): # fallback to use URL based hash, and clear compression_alg + + # fallback to use URL based hash, and clear compression_alg + if not cache_identifier: cache_identifier = url_based_hash(raw_url) compression_alg = "" diff --git a/tests/test_ota_proxy/test_cachedb.py b/tests/test_ota_proxy/test_cachedb.py deleted file mode 100644 index 4f5992ee4..000000000 --- a/tests/test_ota_proxy/test_cachedb.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2022 TIER IV, INC. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import sqlite3 -from dataclasses import dataclass -from os import urandom -from pathlib import Path -from typing import Any, Dict, Tuple - -import pytest - -from ota_proxy import config as cfg -from ota_proxy.orm import NULL_TYPE -from ota_proxy.ota_cache import CacheMeta, OTACacheDB -from ota_proxy.utils import url_based_hash - -logger = logging.getLogger(__name__) - - -class TestORM: - @pytest.fixture(autouse=True) - def create_table_defs(self): - from ota_proxy.orm import NULL_TYPE, ColumnDescriptor, ORMBase - - @dataclass - class TableCls(ORMBase): - str_field: ColumnDescriptor[str] = ColumnDescriptor( - str, - "TEXT", - "UNIQUE", - "NOT NULL", - "PRIMARY KEY", - default="invalid_url", - ) - int_field: ColumnDescriptor[int] = ColumnDescriptor( - int, "INTEGER", "NOT NULL", type_guard=(int, float) - ) - float_field: ColumnDescriptor[float] = ColumnDescriptor( - float, "INTEGER", "NOT NULL", type_guard=(int, float) - ) - op_str_field: ColumnDescriptor[str] = ColumnDescriptor(str, "TEXT") - op_int_field: ColumnDescriptor[int] = ColumnDescriptor( - int, "INTEGER", type_guard=(int, float) - ) - null_field: ColumnDescriptor[NULL_TYPE] = ColumnDescriptor( - NULL_TYPE, "NULL" - ) - - self.table_cls = TableCls - - @pytest.mark.parametrize( - "raw_row, as_dict, as_tuple", - ( - ( - { - "str_field": "unique_str", - "int_field": 123.123, # expect to be converted into int - "float_field": 456.123, - "null_field": "should_not_be_set", - }, - { - "str_field": "unique_str", - "int_field": 123, - "float_field": 456.123, - "op_str_field": "", - "op_int_field": 0, - "null_field": NULL_TYPE(), - }, - ("unique_str", 123, 456.123, "", 0, NULL_TYPE()), - ), - ), - ) - def test_parse_and_export( - self, raw_row, as_dict: Dict[str, Any], as_tuple: Tuple[Any] - ): - table_cls = self.table_cls - parsed = table_cls.row_to_meta(raw_row) - assert parsed.asdict() == as_dict - assert parsed.astuple() == as_tuple - assert table_cls.row_to_meta(parsed.astuple()).asdict() == as_dict - - @pytest.mark.parametrize( - "table_name, expected", - ( - ( - "table_name", - ( - "CREATE TABLE table_name(" - "str_field TEXT UNIQUE NOT NULL PRIMARY KEY, " - "int_field INTEGER NOT NULL, " - "float_field INTEGER NOT NULL, " - "op_str_field TEXT, " - "op_int_field INTEGER, " - "null_field NULL)" - ), - ), - ), - ) - def test_get_create_table_stmt(self, table_name: str, expected: str): - assert self.table_cls.get_create_table_stmt(table_name) == expected - - @pytest.mark.parametrize( - "name", - ( - "str_field", - "int_field", - "float_field", - "op_str_field", - "op_int_field", - "null_field", - ), - ) - def test_contains_field(self, name: str): - table_cls = self.table_cls - assert (col_descriptor := getattr(table_cls, name)) - assert table_cls.contains_field(name) - assert col_descriptor is table_cls.__dict__[name] - assert col_descriptor and table_cls.contains_field(col_descriptor) - - def test_type_check(self): - inst = self.table_cls() - # field float_field is type_checked - inst.float_field = 123.456 - with pytest.raises(TypeError): - inst.float_field = "str_type" - - @pytest.mark.parametrize( - "row_dict", - ( - { - "str_field": "unique_str", - "int_field": 123.9, - "float_field": 456, - }, - ), - ) - def test_with_actual_db(self, row_dict: Dict[str, Any]): - """Setup a new conn to a in-memory otacache_db.""" - table_cls, table_name = self.table_cls, "test_table" - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - with conn: # test create table with table_cls - conn.execute(table_cls.get_create_table_stmt(table_name)) - - row_inst = table_cls.row_to_meta(row_dict) - logger.info(row_inst) - with conn: # insert one entry - assert ( - conn.execute( - f"INSERT INTO {table_name} VALUES ({table_cls.get_shape()})", - row_inst.astuple(), - ).rowcount - == 1 - ) - with conn: # get the entry back - cur = conn.execute(f"SELECT * from {table_name}", ()) - row_parsed = table_cls.row_to_meta(cur.fetchone()) - assert row_parsed == row_inst - assert row_parsed.asdict() == row_inst.asdict() - - conn.close() - - -class TestOTACacheDB: - @pytest.fixture(autouse=True) - def prepare_db(self, tmp_path: Path): - self.db_f = db_f = tmp_path / "db_f" - self.entries = entries = [] - - # prepare data - for target_size, rotate_num in cfg.BUCKET_FILE_SIZE_DICT.items(): - for _i in range(rotate_num): - mocked_url = f"{target_size}#{_i}" - entries.append( - CacheMeta( - file_sha256=url_based_hash(mocked_url), - url=mocked_url, - bucket_idx=target_size, - cache_size=target_size, - ) - ) - # insert entry into db - OTACacheDB.init_db_file(db_f) - with OTACacheDB(db_f) as db_conn: - db_conn.insert_entry(*self.entries) - - # prepare connection - try: - self.conn = OTACacheDB(self.db_f) - yield - finally: - self.conn.close() - - def test_insert(self, tmp_path: Path): - db_f = tmp_path / "db_f2" - # first insert - OTACacheDB.init_db_file(db_f) - with OTACacheDB(db_f) as db_conn: - assert db_conn.insert_entry(*self.entries) == len(self.entries) - # check the insertion with another connection - with OTACacheDB(db_f) as db_conn: - _all = db_conn.lookup_all() - # should have the same order as insertion - assert _all == self.entries - - def test_db_corruption(self, tmp_path: Path): - test_db_f = tmp_path / "corrupted_db_f" - # intensionally create corrupted db file - with open(self.db_f, "rb") as src, open(test_db_f, "wb") as dst: - count = 0 - while data := src.read(32): - count += 1 - dst.write(data) - if count % 2 == 0: - dst.write(urandom(8)) - - assert not OTACacheDB.check_db_file(test_db_f) - - def test_lookup_all(self): - """ - NOTE: the timestamp update is only executed at lookup method - """ - entries_set = set(self.entries) - checked_entries = self.conn.lookup_all() - assert entries_set == set(checked_entries) - - def test_lookup(self): - """ - lookup the one entry in the database, and ensure the timestamp is updated - """ - target = self.entries[-1] - # lookup once to update last_acess - checked_entry = self.conn.lookup_entry(CacheMeta.url, target.url) - assert checked_entry == target - checked_entry = self.conn.lookup_entry(CacheMeta.url, target.url) - assert checked_entry and checked_entry.last_access > target.last_access - - def test_delete(self): - """ - delete the whole 8MiB bucket - """ - bucket_size = 8 * (1024**2) - assert ( - self.conn.remove_entries(CacheMeta.bucket_idx, bucket_size) - == cfg.BUCKET_FILE_SIZE_DICT[bucket_size] - ) - assert ( - len(self.conn.lookup_all()) - == len(self.entries) - cfg.BUCKET_FILE_SIZE_DICT[bucket_size] - ) diff --git a/tests/test_ota_proxy/test_ota_cache.py b/tests/test_ota_proxy/test_ota_cache.py index 26dec048c..734911533 100644 --- a/tests/test_ota_proxy/test_ota_cache.py +++ b/tests/test_ota_proxy/test_ota_cache.py @@ -13,93 +13,141 @@ # limitations under the License. +from __future__ import annotations + import asyncio +import bisect import logging import random +import sqlite3 from pathlib import Path -from typing import Coroutine, Dict, List, Optional, Tuple +from typing import Coroutine, Optional import pytest +import pytest_asyncio +from simple_sqlite3_orm import ORMBase from ota_proxy import config as cfg -from ota_proxy.db import CacheMeta, OTACacheDB +from ota_proxy.db import CacheMeta from ota_proxy.ota_cache import CachingRegister, LRUCacheHelper from ota_proxy.utils import url_based_hash logger = logging.getLogger(__name__) +TEST_DATA_SET_SIZE = 4096 +TEST_LOOKUP_ENTRIES = 1200 +TEST_DELETE_ENTRIES = 512 -class TestLRUCacheHelper: - @pytest.fixture(scope="class") - def prepare_entries(self): - entries: Dict[str, CacheMeta] = {} - for target_size, rotate_num in cfg.BUCKET_FILE_SIZE_DICT.items(): - for _i in range(rotate_num): - mocked_url = f"{target_size}#{_i}" - entries[mocked_url] = CacheMeta( - url=mocked_url, - bucket_idx=target_size, - cache_size=target_size, - file_sha256=url_based_hash(mocked_url), - ) - return entries +class OTACacheDB(ORMBase[CacheMeta]): + pass + + +@pytest.fixture(autouse=True, scope="module") +def setup_testdata() -> dict[str, CacheMeta]: + size_list = list(cfg.BUCKET_FILE_SIZE_DICT) + + entries: dict[str, CacheMeta] = {} + for idx in range(TEST_DATA_SET_SIZE): + target_size = random.choice(size_list) + mocked_url = f"#{idx}/w/targetsize={target_size}" + file_sha256 = url_based_hash(mocked_url) + + entries[file_sha256] = CacheMeta( + url=mocked_url, + # see lru_cache_helper module for more details + bucket_idx=bisect.bisect_right(size_list, target_size), + cache_size=target_size, + file_sha256=file_sha256, + ) + return entries + + +@pytest.fixture(autouse=True, scope="module") +def entries_to_lookup(setup_testdata: dict[str, CacheMeta]) -> list[CacheMeta]: + return random.sample( + list(setup_testdata.values()), + k=TEST_LOOKUP_ENTRIES, + ) + - @pytest.fixture(scope="class") - def launch_lru_helper(self, tmp_path_factory: pytest.TempPathFactory): - # init db +@pytest.fixture(autouse=True, scope="module") +def entries_to_remove(setup_testdata: dict[str, CacheMeta]) -> list[CacheMeta]: + return random.sample( + list(setup_testdata.values()), + k=TEST_DELETE_ENTRIES, + ) + + +@pytest.mark.asyncio(scope="class") +class TestLRUCacheHelper: + + @pytest_asyncio.fixture(autouse=True, scope="class") + async def lru_helper(self, tmp_path_factory: pytest.TempPathFactory): ota_cache_folder = tmp_path_factory.mktemp("ota-cache") - self._db_f = ota_cache_folder / "db_f" - OTACacheDB.init_db_file(self._db_f) + db_f = ota_cache_folder / "db_f" + + # init table + conn = sqlite3.connect(db_f) + orm = OTACacheDB(conn, cfg.TABLE_NAME) + orm.orm_create_table(without_rowid=True) + conn.close() - lru_cache_helper = LRUCacheHelper(self._db_f) + lru_cache_helper = LRUCacheHelper( + db_f, + bsize_dict=cfg.BUCKET_FILE_SIZE_DICT, + table_name=cfg.TABLE_NAME, + thread_nums=cfg.DB_THREADS, + thread_wait_timeout=cfg.DB_THREAD_WAIT_TIMEOUT, + ) try: yield lru_cache_helper finally: lru_cache_helper.close() - @pytest.fixture(autouse=True) - def setup_test(self, launch_lru_helper, prepare_entries): - self.entries: Dict[str, CacheMeta] = prepare_entries - self.cache_helper: LRUCacheHelper = launch_lru_helper + async def test_commit_entry( + self, lru_helper: LRUCacheHelper, setup_testdata: dict[str, CacheMeta] + ): + for _, entry in setup_testdata.items(): + # deliberately clear the bucket_idx, this should be set by commit_entry method + _copy = entry.model_copy() + _copy.bucket_idx = 0 + assert await lru_helper.commit_entry(entry) - async def test_commit_entry(self): - for _, entry in self.entries.items(): - assert await self.cache_helper.commit_entry(entry) + async def test_lookup_entry( + self, + lru_helper: LRUCacheHelper, + entries_to_lookup: list[CacheMeta], + setup_testdata: dict[str, CacheMeta], + ): + for entry in entries_to_lookup: + assert ( + await lru_helper.lookup_entry(entry.file_sha256) + == setup_testdata[entry.file_sha256] + ) - async def test_lookup_entry(self): - target_size, idx = 8 * (1024**2), 6 - target_url = f"{target_size}#{idx}" - file_sha256 = url_based_hash(target_url) + async def test_remove_entry( + self, lru_helper: LRUCacheHelper, entries_to_remove: list[CacheMeta] + ): + for entry in entries_to_remove: + assert await lru_helper.remove_entry(entry.file_sha256) - assert ( - await self.cache_helper.lookup_entry(file_sha256) - == self.entries[target_url] - ) + async def test_rotate_cache(self, lru_helper: LRUCacheHelper): + """Ensure the LRUHelper properly rotates the cache entries. - async def test_remove_entry(self): - target_size, idx = 8 * (1024**2), 6 - target_file_sha256 = url_based_hash(f"{target_size}#{idx}") - assert await self.cache_helper.remove_entry(target_file_sha256) - - async def test_rotate_cache(self): - """Ensure the LRUHelper properly rotates the cache entries.""" - # test 1: reserve space for 32 * (1024**2) bucket - # the 32MB bucket is the last bucket and will not be rotated. - target_bucket = 32 * (1024**2) - entries_to_be_removed = await self.cache_helper.rotate_cache(target_bucket) - assert entries_to_be_removed is not None and len(entries_to_be_removed) == 0 - - # test 2: reserve space for 8 * 1024 bucket - # the next bucket is not empty, so we expecte to remove one entry from the next bucket - target_bucket = 8 * 1024 - assert ( - entries_to_be_removed := await self.cache_helper.rotate_cache(target_bucket) - ) and len(entries_to_be_removed) == 1 + We should file enough entries into the database, so each rotate should be successful. + """ + # NOTE that the first bucket and last bucket will not be rotated, + # see lru_cache_helper module for more details. + for target_bucket in list(cfg.BUCKET_FILE_SIZE_DICT)[1:-1]: + entries_to_be_removed = await lru_helper.rotate_cache(target_bucket) + assert entries_to_be_removed is not None and len(entries_to_be_removed) != 0 class TestOngoingCachingRegister: """ + Testing multiple access to a single resource at the same time with ongoing_cache control. + NOTE; currently this test only testing the weakref implementation part, the file descriptor management part is tested in test_ota_proxy_server """ @@ -108,7 +156,7 @@ class TestOngoingCachingRegister: WORKS_NUM = 128 @pytest.fixture(autouse=True) - def setup_test(self, tmp_path: Path): + async def setup_test(self, tmp_path: Path): self.base_dir = tmp_path / "base_dir" self.base_dir.mkdir(parents=True, exist_ok=True) self.register = CachingRegister(self.base_dir) @@ -128,7 +176,7 @@ async def _wait_for_registeration_finish(self): async def _worker( self, idx: int, - ) -> Tuple[bool, Optional[CacheMeta]]: + ) -> tuple[bool, Optional[CacheMeta]]: """ Returns tuple of bool indicates whether the worker is writter, and CacheMeta from tracker. @@ -149,7 +197,11 @@ async def _worker( logger.info(f"#{idx} is provider") # NOTE: use last_access field to store worker index # NOTE 2: bypass provider_start method, directly set tracker property - _tracker.meta = CacheMeta(last_access=idx) + _tracker.meta = CacheMeta( + last_access=idx, + url="some_url", + file_sha256="some_filesha256_value", + ) _tracker._writer_ready.set() # simulate waiting for writer finished downloading await self.writer_done_event.wait() @@ -167,15 +219,19 @@ async def _worker( _tracker._subscriber_ref_holder.pop() return False, _tracker.meta - async def test_OngoingCachingRegister(self): - coros: List[Coroutine] = [] + async def test_ongoing_cache_register(self): + """ + Test multiple access to single resource with ongoing_cache control mechanism. + """ + coros: list[Coroutine] = [] for idx in range(self.WORKS_NUM): coros.append(self._worker(idx)) + random.shuffle(coros) # shuffle the corotines to simulate unordered access tasks = [asyncio.create_task(c) for c in coros] logger.info(f"{self.WORKS_NUM} workers have been dispatched") - # start all the worker + # start all the worker, all the workers will now access the same resouce. self.sync_event.set() logger.info("all workers start to subscribe to the register") await self._wait_for_registeration_finish() # wait for all workers finish subscribing @@ -183,7 +239,8 @@ async def test_OngoingCachingRegister(self): ###### check the test result ###### meta_set, writer_meta = set(), None - for is_writer, meta in await asyncio.gather(*tasks): + for _fut in asyncio.as_completed(tasks): + is_writer, meta = await _fut if meta is None: logger.warning( "encount edge condition that subscriber subscribes " diff --git a/tests/test_ota_proxy/test_ota_proxy_server.py b/tests/test_ota_proxy/test_ota_proxy_server.py index 361ef9868..cfc65b496 100644 --- a/tests/test_ota_proxy/test_ota_proxy_server.py +++ b/tests/test_ota_proxy/test_ota_proxy_server.py @@ -13,6 +13,8 @@ # limitations under the License. +from __future__ import annotations + import asyncio import logging import random @@ -20,7 +22,6 @@ import time from hashlib import sha256 from pathlib import Path -from typing import List from urllib.parse import quote, unquote, urljoin import aiohttp @@ -158,8 +159,8 @@ async def launch_ota_proxy_server(self, setup_ota_proxy_server): pass # ignore exp on shutting down @pytest.fixture(scope="class") - def parse_regulars(self): - regular_entries: List[RegularInf] = [] + def parse_regulars(self) -> list[RegularInf]: + regular_entries: list[RegularInf] = [] with open(self.REGULARS_TXT_PATH, "r") as f: for _line in f: _entry = parse_regulars_from_txt(_line) @@ -221,7 +222,9 @@ async def test_download_file_with_special_fname(self): elif self.space_availability == "exceed_hard_limit": pass - async def ota_image_downloader(self, regular_entries, sync_event: asyncio.Event): + async def ota_image_downloader( + self, regular_entries: list[RegularInf], sync_event: asyncio.Event + ): """Test single client download the whole ota image.""" async with aiohttp.ClientSession() as session: await sync_event.wait() @@ -258,16 +261,21 @@ async def ota_image_downloader(self, regular_entries, sync_event: asyncio.Event) continue raise - async def test_multiple_clients_download_ota_image(self, parse_regulars): + async def test_multiple_clients_download_ota_image( + self, parse_regulars: list[RegularInf] + ): """Test multiple client download the whole ota image simultaneously.""" # ------ dispatch many clients to download from otaproxy simultaneously ------ # # --- execution --- # sync_event = asyncio.Event() - tasks: List[asyncio.Task] = [] + tasks: list[asyncio.Task] = [] for _ in range(self.CLIENTS_NUM): tasks.append( asyncio.create_task( - self.ota_image_downloader(parse_regulars, sync_event) + self.ota_image_downloader( + parse_regulars, + sync_event, + ) ) ) logger.info( @@ -277,7 +285,9 @@ async def test_multiple_clients_download_ota_image(self, parse_regulars): # --- assertions --- # # 1. ensure all clients finished the downloading successfully - await asyncio.gather(*tasks, return_exceptions=False) + for _fut in asyncio.as_completed(tasks): + await _fut + await self.otaproxy_inst.shutdown() # 2. check there is no tmp files left in the ota_cache dir # ensure that the gc for multi-cache-streaming works @@ -338,7 +348,7 @@ async def setup_ota_proxy_server(self, tmp_path: Path): @pytest.fixture(scope="class") def parse_regulars(self): - regular_entries: List[RegularInf] = [] + regular_entries: list[RegularInf] = [] with open(self.REGULARS_TXT_PATH, "r") as f: for _line in f: _entry = parse_regulars_from_txt(_line) @@ -369,7 +379,7 @@ async def test_multiple_clients_download_ota_image(self, parse_regulars): # ------ dispatch many clients to download from otaproxy simultaneously ------ # # --- execution --- # sync_event = asyncio.Event() - tasks: List[asyncio.Task] = [] + tasks: list[asyncio.Task] = [] for _ in range(self.CLIENTS_NUM): tasks.append( asyncio.create_task(