diff --git a/CHANGELOG.md b/CHANGELOG.md index 4102485ca9..5e52f9a812 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Annotation-related classes were moved into a new module, `datumaro.components.annotation` () +- Rollback utilities replaced with Scope utilities + () ### Deprecated - TBD diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index 378af1ff76..178b389deb 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -9,7 +9,7 @@ import shutil from datumaro.components.project import Environment -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import on_error_do, scoped from ..util import CliException, MultilineFormatter, add_subparser from ..util.project import ( @@ -46,7 +46,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def add_command(args): project = load_project(args.project_dir) diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index c12ddaf132..18d68c5a7a 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -21,8 +21,8 @@ from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project from datumaro.components.validator import TaskType -from datumaro.util import error_rollback, on_error_do from datumaro.util.os_util import make_file_name +from datumaro.util.scope import on_error_do, scoped from ...util import CliException, MultilineFormatter, add_subparser from ...util.project import generate_next_file_name, load_project @@ -526,7 +526,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index ec8f9c2de5..21c0410cad 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -12,8 +12,8 @@ from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import DatasetPatch from datumaro.components.extractor import DatasetItem -from datumaro.util import error_rollback, on_error_do from datumaro.util.image import Image +from datumaro.util.scope import on_error_do, scoped class Converter(CliPlugin): @@ -36,7 +36,7 @@ def convert(cls, extractor, save_dir, **options): return converter.apply() @classmethod - @error_rollback + @scoped def patch(cls, dataset, patch, save_dir, **options): # This solution is not any better in performance than just # writing a dataset, but in case of patching (i.e. writing diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index ecd5db46c0..06e52dd8eb 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -27,8 +27,9 @@ DEFAULT_SUBSET_NAME, CategoriesInfo, DatasetItem, Extractor, IExtractor, ItemTransform, Transform, ) -from datumaro.util import error_rollback, is_method_redefined, on_error_do +from datumaro.util import is_method_redefined from datumaro.util.log_utils import logging_disabled +from datumaro.util.scope import on_error_do, scoped DEFAULT_FORMAT = 'datumaro' @@ -790,7 +791,7 @@ def bind(self, path: str, format: Optional[str] = None, *, def flush_changes(self): self._data.flush_changes() - @error_rollback + @scoped def export(self, save_dir: str, format, **kwargs): inplace = (save_dir == self._source_path and format == self._format) diff --git a/datumaro/util/__init__.py b/datumaro/util/__init__.py index 7c8a52e7b1..6a2328b7e1 100644 --- a/datumaro/util/__init__.py +++ b/datumaro/util/__init__.py @@ -2,15 +2,11 @@ # # SPDX-License-Identifier: MIT -from contextlib import ExitStack, contextmanager -from functools import partial, wraps +from functools import wraps from inspect import isclass from itertools import islice from typing import Iterable, Tuple import distutils.util -import threading - -import attr NOTSET = object() @@ -126,99 +122,3 @@ def real_decorator(decoratee): return real_decorator return wrapped_decorator - -class Rollback: - _thread_locals = threading.local() - - @attr.attrs - class Handler: - callback = attr.attrib() - enabled = attr.attrib(default=True) - ignore_errors = attr.attrib(default=False) - - def __call__(self): - if self.enabled: - try: - self.callback() - except: # pylint: disable=bare-except - if not self.ignore_errors: - raise - - def __init__(self): - self._handlers = {} - self._stack = ExitStack() - self.enabled = True - - def add(self, callback, *args, - name=None, enabled=True, ignore_errors=False, - fwd_kwargs=None, **kwargs): - if args or kwargs or fwd_kwargs: - if fwd_kwargs: - kwargs.update(fwd_kwargs) - callback = partial(callback, *args, **kwargs) - name = name or hash(callback) - assert name not in self._handlers - handler = self.Handler(callback, - enabled=enabled, ignore_errors=ignore_errors) - self._handlers[name] = handler - self._stack.callback(handler) - return name - - do = add # readability alias - - def enable(self, name=None): - if name: - self._handlers[name].enabled = True - else: - self.enabled = True - - def disable(self, name=None): - if name: - self._handlers[name].enabled = False - else: - self.enabled = False - - def clean(self): - self.__exit__(None, None, None) - - def __enter__(self): - return self - - def __exit__(self, type=None, value=None, \ - traceback=None): # pylint: disable=redefined-builtin - if type is None: - return - if not self.enabled: - return - self._stack.__exit__(type, value, traceback) - - @classmethod - def current(cls) -> "Rollback": - return cls._thread_locals.current - - @contextmanager - def as_current(self): - previous = getattr(self._thread_locals, 'current', None) - self._thread_locals.current = self - try: - yield - finally: - self._thread_locals.current = previous - -# shorthand for common cases -def on_error_do(callback, *args, ignore_errors=False): - Rollback.current().do(callback, *args, ignore_errors=ignore_errors) - -@optional_arg_decorator -def error_rollback(func, arg_name=None): - @wraps(func) - def wrapped_func(*args, **kwargs): - with Rollback() as manager: - if arg_name is None: - with manager.as_current(): - ret_val = func(*args, **kwargs) - else: - kwargs[arg_name] = manager - ret_val = func(*args, **kwargs) - return ret_val - return wrapped_func diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py new file mode 100644 index 0000000000..a857882ca5 --- /dev/null +++ b/datumaro/util/scope.py @@ -0,0 +1,151 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from contextlib import ExitStack, contextmanager +from functools import partial, wraps +from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, TypeVar +import threading + +from attr import attrs + +from datumaro.util import optional_arg_decorator + +T = TypeVar('T') + +class Scope: + """ + A context manager that allows to register error and exit callbacks. + """ + + _thread_locals = threading.local() + + @attrs(auto_attribs=True) + class ExitHandler: + callback: Callable[[], Any] + ignore_errors: bool = True + + def __exit__(self, exc_type, exc_value, exc_traceback): + try: + self.callback() + except Exception: + if not self.ignore_errors: + raise + + @attrs + class ErrorHandler(ExitHandler): + def __exit__(self, exc_type, exc_value, exc_traceback): + if exc_type: + return super().__exit__(exc_type=exc_type, exc_value=exc_value, + exc_traceback=exc_traceback) + + + def __init__(self): + self._stack = ExitStack() + self.enabled = True + + def on_error_do(self, callback: Callable, + *args, kwargs: Optional[Dict[str, Any]] = None, + ignore_errors: bool = False): + """ + Registers a function to be called on scope exit because of an error. + + If ignore_errors is True, the errors from this function call + will be ignored. + """ + + self._register_callback(self.ErrorHandler, + ignore_errors=ignore_errors, + callback=callback, args=args, kwargs=kwargs) + + def on_exit_do(self, callback: Callable, + *args, kwargs: Optional[Dict[str, Any]] = None, + ignore_errors: bool = False): + """ + Registers a function to be called on scope exit. + """ + + self._register_callback(self.ExitHandler, + ignore_errors=ignore_errors, + callback=callback, args=args, kwargs=kwargs) + + def _register_callback(self, handler_type, callback: Callable, + args: Tuple[Any] = None, kwargs: Dict[str, Any] = None, + ignore_errors: bool = False): + if args or kwargs: + callback = partial(callback, *args, **(kwargs or {})) + + self._stack.push(handler_type(callback, ignore_errors=ignore_errors)) + + def add(self, cm: ContextManager[T]) -> T: + """ + Enters a context manager and adds it to the exit stack. + + Returns: cm.__enter__() result + """ + + return self._stack.enter_context(cm) + + def enable(self): + self.enabled = True + + def disable(self): + self.enabled = False + + def close(self): + self.__exit__(None, None, None) + + def __enter__(self) -> 'Scope': + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + if not self.enabled: + return + + self._stack.__exit__(exc_type, exc_value, exc_traceback) + self._stack.pop_all() # prevent issues on repetitive calls + + @classmethod + def current(cls) -> 'Scope': + return cls._thread_locals.current + + @contextmanager + def as_current(self): + previous = getattr(self._thread_locals, 'current', None) + self._thread_locals.current = self + try: + yield + finally: + self._thread_locals.current = previous + +@optional_arg_decorator +def scoped(func, arg_name=None): + """ + A function decorator, which allows to do actions with the current scope, + such as registering error and exit callbacks and context managers. + """ + + @wraps(func) + def wrapped_func(*args, **kwargs): + with Scope() as scope: + if arg_name is None: + with scope.as_current(): + ret_val = func(*args, **kwargs) + else: + kwargs[arg_name] = scope + ret_val = func(*args, **kwargs) + return ret_val + + return wrapped_func + +# Shorthands for common cases +def on_error_do(callback, *args, ignore_errors=False, kwargs=None): + return Scope.current().on_error_do(callback, *args, + ignore_errors=ignore_errors, kwargs=kwargs) + +def on_exit_do(callback, *args, ignore_errors=False, kwargs=None): + return Scope.current().on_exit_do(callback, *args, + ignore_errors=ignore_errors, kwargs=kwargs) + +def scope_add(cm: ContextManager[T]) -> T: + return Scope.current().add(cm) diff --git a/tests/test_util.py b/tests/test_util.py index f124e7cdb2..66a791df76 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,127 +1,130 @@ +from contextlib import suppress from unittest import TestCase, mock import os import os.path as osp -from datumaro.util import ( - Rollback, error_rollback, is_method_redefined, on_error_do, -) +from datumaro.util import is_method_redefined from datumaro.util.os_util import walk +from datumaro.util.scope import Scope, on_error_do, on_exit_do, scoped from datumaro.util.test_utils import TestDir from .requirements import Requirements, mark_requirement -class TestRollback(TestCase): +class TestException(Exception): + pass + +class ScopeTest(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + def test_calls_only_exit_callback_on_exit(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - with Rollback() as on_error: - on_error.do(cb) + with Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True - - try: - with Rollback() as on_error: - on_error.do(cb) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + def test_calls_both_callbacks_on_error(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() + + with self.assertRaises(TestException), Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) + raise TestException() + + error_cb.assert_called_once() + exit_cb.assert_called_once() + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_adds_cm(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock(return_value=42) + cm.__exit__ = mock.MagicMock() + + with Scope() as scope: + retval = scope.add(cm) + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() + self.assertEqual(42, retval) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_calls_cm_on_error(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock() + cm.__exit__ = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.add(cm) + raise TestException() + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True + cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) - raise Exception('err') + @scoped('scope') + def foo(scope=None): + scope.on_error_do(cb) + raise TestException() - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) + @scoped('scope') + def foo(scope=None): + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) foo() - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_supports_implicit_form(self): - success = False - def cb(): - nonlocal success - success = True + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback + @scoped def foo(): - on_error_do(cb) - raise Exception('err') + on_error_do(error_cb) + on_exit_do(exit_cb) + raise TestException() - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + error_cb.assert_called_once() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_fowrard_args(self): - success1 = False - def cb1(a1, a2=None, ignore_errors=None): - nonlocal success1 - if a1 == 5 and a2 == 2 and ignore_errors == None: - success1 = True - - success2 = False - def cb2(a1, a2=None, ignore_errors=None): - nonlocal success2 - if a1 == 5 and a2 == 2 and ignore_errors == 4: - success2 = True - - try: - with Rollback() as on_error: - on_error.do(cb1, 5, a2=2, ignore_errors=True) - on_error.do(cb2, 5, a2=2, ignore_errors=True, - fwd_kwargs={'ignore_errors': 4}) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success1) - self.assertTrue(success2) + cb = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.on_error_do(cb, 5, ignore_errors=True, kwargs={'a2': 2}) + raise TestException() + + cb.assert_called_once_with(5, a2=2) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_implicit_form(self): - @error_rollback + @scoped def f(): return 42 @@ -131,8 +134,8 @@ def f(): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_explicit_form(self): - @error_rollback('on_error') - def f(on_error=None): + @scoped('scope') + def f(scope=None): return 42 retval = f()