From 9b1c6875e7227f0ad8e434b84047ed3985a33e66 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 18:19:03 +0300 Subject: [PATCH 01/12] Replace rollback with scope --- datumaro/util/__init__.py | 102 +---------------------- datumaro/util/scope.py | 171 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 101 deletions(-) create mode 100644 datumaro/util/scope.py diff --git a/datumaro/util/__init__.py b/datumaro/util/__init__.py index 7c8a52e7b1..6a2328b7e1 100644 --- a/datumaro/util/__init__.py +++ b/datumaro/util/__init__.py @@ -2,15 +2,11 @@ # # SPDX-License-Identifier: MIT -from contextlib import ExitStack, contextmanager -from functools import partial, wraps +from functools import wraps from inspect import isclass from itertools import islice from typing import Iterable, Tuple import distutils.util -import threading - -import attr NOTSET = object() @@ -126,99 +122,3 @@ def real_decorator(decoratee): return real_decorator return wrapped_decorator - -class Rollback: - _thread_locals = threading.local() - - @attr.attrs - class Handler: - callback = attr.attrib() - enabled = attr.attrib(default=True) - ignore_errors = attr.attrib(default=False) - - def __call__(self): - if self.enabled: - try: - self.callback() - except: # pylint: disable=bare-except - if not self.ignore_errors: - raise - - def __init__(self): - self._handlers = {} - self._stack = ExitStack() - self.enabled = True - - def add(self, callback, *args, - name=None, enabled=True, ignore_errors=False, - fwd_kwargs=None, **kwargs): - if args or kwargs or fwd_kwargs: - if fwd_kwargs: - kwargs.update(fwd_kwargs) - callback = partial(callback, *args, **kwargs) - name = name or hash(callback) - assert name not in self._handlers - handler = self.Handler(callback, - enabled=enabled, ignore_errors=ignore_errors) - self._handlers[name] = handler - self._stack.callback(handler) - return name - - do = add # readability alias - - def enable(self, name=None): - if name: - self._handlers[name].enabled = True - else: - self.enabled = True - - def disable(self, name=None): - if name: - self._handlers[name].enabled = False - else: - self.enabled = False - - def clean(self): - self.__exit__(None, None, None) - - def __enter__(self): - return self - - def __exit__(self, type=None, value=None, \ - traceback=None): # pylint: disable=redefined-builtin - if type is None: - return - if not self.enabled: - return - self._stack.__exit__(type, value, traceback) - - @classmethod - def current(cls) -> "Rollback": - return cls._thread_locals.current - - @contextmanager - def as_current(self): - previous = getattr(self._thread_locals, 'current', None) - self._thread_locals.current = self - try: - yield - finally: - self._thread_locals.current = previous - -# shorthand for common cases -def on_error_do(callback, *args, ignore_errors=False): - Rollback.current().do(callback, *args, ignore_errors=ignore_errors) - -@optional_arg_decorator -def error_rollback(func, arg_name=None): - @wraps(func) - def wrapped_func(*args, **kwargs): - with Rollback() as manager: - if arg_name is None: - with manager.as_current(): - ret_val = func(*args, **kwargs) - else: - kwargs[arg_name] = manager - ret_val = func(*args, **kwargs) - return ret_val - return wrapped_func diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py new file mode 100644 index 0000000000..7fff715166 --- /dev/null +++ b/datumaro/util/scope.py @@ -0,0 +1,171 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from contextlib import ExitStack, contextmanager +from functools import partial, wraps +from typing import Any, ContextManager, Dict, Optional +import threading + +from attr import attrs + +from datumaro.util import optional_arg_decorator + + +class Scope: + """ + A context manager that allows to register error and exit callbacks. + """ + + _thread_locals = threading.local() + + @attrs(auto_attribs=True) + class Handler: + callback: Any + enabled: bool = True + ignore_errors: bool = False + + def __call__(self): + if self.enabled: + try: + self.callback() + except: # pylint: disable=bare-except + if not self.ignore_errors: + raise + + def __init__(self): + self._handlers = {} + self._error_stack = ExitStack() + self._exit_stack = ExitStack() + self.enabled = True + + def on_error_do(self, callback, *args, name: Optional[str] = None, + enabled: bool = True, ignore_errors: bool = False, + fwd_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> str: + """ + Registers a function to be called on scope exit because of an error. + Equivalent to the "except" block of "try-except". + """ + + if args or kwargs or fwd_kwargs: + if fwd_kwargs: + kwargs.update(fwd_kwargs) + callback = partial(callback, *args, **kwargs) + + name = name or hash(callback) + assert name not in self._handlers, "Callback is already registered" + + handler = self.Handler(callback, + enabled=enabled, ignore_errors=ignore_errors) + self._handlers[name] = handler + self._error_stack.callback(handler) + return name + + def on_exit_do(self, callback, *args, name: Optional[str] = None, + enabled: bool = True, ignore_errors: bool = False, + fwd_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> str: + """ + Registers a function to be called on scope exit unconditionally. + Equivalent to the "finally" block of "try-except". + """ + + if args or kwargs or fwd_kwargs: + if fwd_kwargs: + kwargs.update(fwd_kwargs) + callback = partial(callback, *args, **kwargs) + + name = name or hash(callback) + assert name not in self._handlers, "Callback is already registered" + + handler = self.Handler(callback, + enabled=enabled, ignore_errors=ignore_errors) + self._handlers[name] = handler + self._exit_stack.callback(handler) + return name + + def add(self, cm: ContextManager) -> Any: + """ + Enters a context manager and adds it to the exit stack. + """ + + return self._exit_stack.enter_context(cm) + + def enable(self, name=None): + if name: + self._handlers[name].enabled = True + else: + self.enabled = True + + def disable(self, name=None): + if name: + self._handlers[name].enabled = False + else: + self.enabled = False + + def clean(self): + self.__exit__() + + def __enter__(self): + return self + + def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + if not self.enabled: + return + + try: + if exc_type: + self._error_stack.__exit__(exc_type, exc_value, exc_traceback) + finally: + self._exit_stack.__exit__(exc_type, exc_value, exc_traceback) + + @classmethod + def current(cls) -> 'Scope': + return cls._thread_locals.current + + @contextmanager + def as_current(self): + previous = getattr(self._thread_locals, 'current', None) + self._thread_locals.current = self + try: + yield + finally: + self._thread_locals.current = previous + +@optional_arg_decorator +def scoped(func, arg_name=None): + """ + A function decorator, which allows to do actions with the current scope, + such as registering error and exit callbacks and context managers. + """ + + @wraps(func) + def wrapped_func(*args, **kwargs): + with Scope() as scope: + if arg_name is None: + with scope.as_current(): + ret_val = func(*args, **kwargs) + else: + kwargs[arg_name] = scope + ret_val = func(*args, **kwargs) + return ret_val + + return wrapped_func + +# Shorthands for common cases +def on_error_do(callback, *args, ignore_errors=False): + return Scope.current().on_error_do(callback, *args, + ignore_errors=ignore_errors) +on_error_do.__doc__ = Scope.on_error_do.__doc__ + +def on_exit_do(callback, *args, ignore_errors=False): + return Scope.current().on_exit_do(callback, *args, + ignore_errors=ignore_errors) +on_exit_do.__doc__ = Scope.on_exit_do.__doc__ + +def add(cm: ContextManager): + return Scope.current().add(cm) +add.__doc__ = Scope.add.__doc__ + +def current(): + return Scope.current() +current.__doc__ = Scope.current.__doc__ From 0f823cf92002eea439ee375746e88031faa80bae Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 18:19:15 +0300 Subject: [PATCH 02/12] Update tests --- tests/test_util.py | 186 ++++++++++++++++++++++++--------------------- 1 file changed, 101 insertions(+), 85 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index f124e7cdb2..e900b57ee4 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,127 +1,143 @@ +from contextlib import suppress from unittest import TestCase, mock import os import os.path as osp -from datumaro.util import ( - Rollback, error_rollback, is_method_redefined, on_error_do, -) +from datumaro.util import is_method_redefined from datumaro.util.os_util import walk +from datumaro.util.scope import Scope, on_error_do, scoped, on_exit_do from datumaro.util.test_utils import TestDir from .requirements import Requirements, mark_requirement -class TestRollback(TestCase): +class TestException(Exception): + pass + +class TestScope(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + def test_calls_on_no_error(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - with Rollback() as on_error: - on_error.do(cb) + with suppress(TestException), Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True - - try: - with Rollback() as on_error: - on_error.do(cb) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + def test_calls_both_stacks_on_error(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) + raise TestException('err') + + error_cb.assert_called_once() + exit_cb.assert_called_once() + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_cant_add_single_callback_in_both_stacks(self): + cb = mock.MagicMock() + + with self.assertRaisesRegex(AssertionError, "already registered"): + with Scope() as scope: + scope.on_error_do(cb) + scope.on_exit_do(cb) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_adds_cm(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock(return_value=42) + cm.__exit__ = mock.MagicMock() + + with Scope() as scope: + retval = scope.add(cm) + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() + self.assertEqual(42, retval) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_calls_cm_on_error(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock() + cm.__exit__ = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.add(cm) + raise TestException() + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True + cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) - raise Exception('err') + @scoped('scope') + def foo(scope=None): + scope.on_error_do(cb) + raise TestException('err') - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) + @scoped('scope') + def foo(scope=None): + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) foo() - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_supports_implicit_form(self): - success = False - def cb(): - nonlocal success - success = True + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback + @scoped def foo(): - on_error_do(cb) - raise Exception('err') + on_error_do(error_cb) + on_exit_do(exit_cb) + raise TestException('err') - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + error_cb.assert_called_once() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_fowrard_args(self): - success1 = False - def cb1(a1, a2=None, ignore_errors=None): - nonlocal success1 - if a1 == 5 and a2 == 2 and ignore_errors == None: - success1 = True - - success2 = False - def cb2(a1, a2=None, ignore_errors=None): - nonlocal success2 - if a1 == 5 and a2 == 2 and ignore_errors == 4: - success2 = True - - try: - with Rollback() as on_error: - on_error.do(cb1, 5, a2=2, ignore_errors=True) - on_error.do(cb2, 5, a2=2, ignore_errors=True, - fwd_kwargs={'ignore_errors': 4}) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success1) - self.assertTrue(success2) + cb1 = mock.MagicMock() + cb2 = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.on_error_do(cb1, 5, a2=2, ignore_errors=True) + scope.on_error_do(cb2, 5, a2=2, ignore_errors=True, + fwd_kwargs={'ignore_errors': 4}) + raise TestException('err') + + cb1.assert_called_once_with(5, a2=2) + cb2.assert_called_once_with(5, a2=2, ignore_errors=4) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_implicit_form(self): - @error_rollback + @scoped def f(): return 42 @@ -131,8 +147,8 @@ def f(): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_explicit_form(self): - @error_rollback('on_error') - def f(on_error=None): + @scoped('scope') + def f(scope=None): return 42 retval = f() From 655505ad10e22487fd88b0c6189598459f8c6724 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 18:19:39 +0300 Subject: [PATCH 03/12] Replace rollback uses --- datumaro/cli/contexts/model.py | 4 ++-- datumaro/cli/contexts/project/__init__.py | 4 ++-- datumaro/components/converter.py | 4 ++-- datumaro/components/dataset.py | 5 +++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index 378af1ff76..cebcd124fc 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -9,7 +9,7 @@ import shutil from datumaro.components.project import Environment -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import scoped, on_error_do from ..util import CliException, MultilineFormatter, add_subparser from ..util.project import ( @@ -46,7 +46,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def add_command(args): project = load_project(args.project_dir) diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index c12ddaf132..f975852c6f 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -21,7 +21,7 @@ from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project from datumaro.components.validator import TaskType -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import scoped, on_error_do from datumaro.util.os_util import make_file_name from ...util import CliException, MultilineFormatter, add_subparser @@ -526,7 +526,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index ec8f9c2de5..45f7c3258a 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -12,7 +12,7 @@ from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import DatasetPatch from datumaro.components.extractor import DatasetItem -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import scoped, on_error_do from datumaro.util.image import Image @@ -36,7 +36,7 @@ def convert(cls, extractor, save_dir, **options): return converter.apply() @classmethod - @error_rollback + @scoped def patch(cls, dataset, patch, save_dir, **options): # This solution is not any better in performance than just # writing a dataset, but in case of patching (i.e. writing diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index ecd5db46c0..b186e2b3fd 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -27,7 +27,8 @@ DEFAULT_SUBSET_NAME, CategoriesInfo, DatasetItem, Extractor, IExtractor, ItemTransform, Transform, ) -from datumaro.util import error_rollback, is_method_redefined, on_error_do +from datumaro.util import is_method_redefined +from datumaro.util.scope import scoped, on_error_do from datumaro.util.log_utils import logging_disabled DEFAULT_FORMAT = 'datumaro' @@ -790,7 +791,7 @@ def bind(self, path: str, format: Optional[str] = None, *, def flush_changes(self): self._data.flush_changes() - @error_rollback + @scoped def export(self, save_dir: str, format, **kwargs): inplace = (save_dir == self._source_path and format == self._format) From 73d7b3fde2e8b0d2e5fa822b08ac410ef04b8105 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 19:28:56 +0300 Subject: [PATCH 04/12] Fix imports --- datumaro/cli/contexts/model.py | 2 +- datumaro/cli/contexts/project/__init__.py | 2 +- datumaro/components/converter.py | 2 +- datumaro/components/dataset.py | 2 +- tests/test_util.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index cebcd124fc..178b389deb 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -9,7 +9,7 @@ import shutil from datumaro.components.project import Environment -from datumaro.util.scope import scoped, on_error_do +from datumaro.util.scope import on_error_do, scoped from ..util import CliException, MultilineFormatter, add_subparser from ..util.project import ( diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index f975852c6f..18d68c5a7a 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -21,8 +21,8 @@ from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project from datumaro.components.validator import TaskType -from datumaro.util.scope import scoped, on_error_do from datumaro.util.os_util import make_file_name +from datumaro.util.scope import on_error_do, scoped from ...util import CliException, MultilineFormatter, add_subparser from ...util.project import generate_next_file_name, load_project diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index 45f7c3258a..21c0410cad 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -12,8 +12,8 @@ from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import DatasetPatch from datumaro.components.extractor import DatasetItem -from datumaro.util.scope import scoped, on_error_do from datumaro.util.image import Image +from datumaro.util.scope import on_error_do, scoped class Converter(CliPlugin): diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index b186e2b3fd..06e52dd8eb 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -28,8 +28,8 @@ ItemTransform, Transform, ) from datumaro.util import is_method_redefined -from datumaro.util.scope import scoped, on_error_do from datumaro.util.log_utils import logging_disabled +from datumaro.util.scope import on_error_do, scoped DEFAULT_FORMAT = 'datumaro' diff --git a/tests/test_util.py b/tests/test_util.py index e900b57ee4..4ddc45cb2e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -5,7 +5,7 @@ from datumaro.util import is_method_redefined from datumaro.util.os_util import walk -from datumaro.util.scope import Scope, on_error_do, scoped, on_exit_do +from datumaro.util.scope import Scope, on_error_do, on_exit_do, scoped from datumaro.util.test_utils import TestDir from .requirements import Requirements, mark_requirement From e75c7abdec72a1712b5f2ae20e0fe6b37174b2c5 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 19:30:14 +0300 Subject: [PATCH 05/12] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4102485ca9..5e52f9a812 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Annotation-related classes were moved into a new module, `datumaro.components.annotation` () +- Rollback utilities replaced with Scope utilities + () ### Deprecated - TBD From afc658a6bd6e13324deb5b68458f76158347ff57 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 14:19:21 +0300 Subject: [PATCH 06/12] Fix add() type anns in Scope --- datumaro/util/scope.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py index 7fff715166..1a86732e28 100644 --- a/datumaro/util/scope.py +++ b/datumaro/util/scope.py @@ -4,7 +4,7 @@ from contextlib import ExitStack, contextmanager from functools import partial, wraps -from typing import Any, ContextManager, Dict, Optional +from typing import Any, ContextManager, Dict, Optional, TypeVar import threading from attr import attrs @@ -12,6 +12,8 @@ from datumaro.util import optional_arg_decorator +T = TypeVar('T') + class Scope: """ A context manager that allows to register error and exit callbacks. @@ -83,9 +85,11 @@ def on_exit_do(self, callback, *args, name: Optional[str] = None, self._exit_stack.callback(handler) return name - def add(self, cm: ContextManager) -> Any: + def add(self, cm: ContextManager[T]) -> T: """ Enters a context manager and adds it to the exit stack. + + Returns: cm.__enter__() result """ return self._exit_stack.enter_context(cm) @@ -152,17 +156,19 @@ def wrapped_func(*args, **kwargs): return wrapped_func # Shorthands for common cases -def on_error_do(callback, *args, ignore_errors=False): +def on_error_do(callback, *args, ignore_errors=False, + fwd_kwargs=None, **kwargs): return Scope.current().on_error_do(callback, *args, - ignore_errors=ignore_errors) + ignore_errors=ignore_errors, fwd_kwargs=fwd_kwargs, **kwargs) on_error_do.__doc__ = Scope.on_error_do.__doc__ -def on_exit_do(callback, *args, ignore_errors=False): +def on_exit_do(callback, *args, ignore_errors=False, + fwd_kwargs=None, **kwargs): return Scope.current().on_exit_do(callback, *args, - ignore_errors=ignore_errors) + ignore_errors=ignore_errors, fwd_kwargs=fwd_kwargs, **kwargs) on_exit_do.__doc__ = Scope.on_exit_do.__doc__ -def add(cm: ContextManager): +def add(cm: ContextManager[T]) -> T: return Scope.current().add(cm) add.__doc__ = Scope.add.__doc__ From 75e80c55a224906bcd8872b8f8a166f68836f084 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 15:36:36 +0300 Subject: [PATCH 07/12] Remove current() --- datumaro/util/scope.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py index 1a86732e28..5839c3522b 100644 --- a/datumaro/util/scope.py +++ b/datumaro/util/scope.py @@ -171,7 +171,3 @@ def on_exit_do(callback, *args, ignore_errors=False, def add(cm: ContextManager[T]) -> T: return Scope.current().add(cm) add.__doc__ = Scope.add.__doc__ - -def current(): - return Scope.current() -current.__doc__ = Scope.current.__doc__ From 1a21d483a3841be8d5f45d38420fbaf409ef42c6 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 16:22:35 +0300 Subject: [PATCH 08/12] Refactor scope --- datumaro/util/scope.py | 128 +++++++++++++++++------------------------ 1 file changed, 53 insertions(+), 75 deletions(-) diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py index 5839c3522b..793a68562b 100644 --- a/datumaro/util/scope.py +++ b/datumaro/util/scope.py @@ -4,7 +4,7 @@ from contextlib import ExitStack, contextmanager from functools import partial, wraps -from typing import Any, ContextManager, Dict, Optional, TypeVar +from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, TypeVar import threading from attr import attrs @@ -22,68 +22,61 @@ class Scope: _thread_locals = threading.local() @attrs(auto_attribs=True) - class Handler: - callback: Any - enabled: bool = True - ignore_errors: bool = False - - def __call__(self): - if self.enabled: - try: - self.callback() - except: # pylint: disable=bare-except - if not self.ignore_errors: - raise + class ExitHandler: + callback: Callable[[], Any] + ignore_errors: bool = True + + def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + try: + self.callback() + except Exception: + if not self.ignore_errors: + raise + + @attrs + class ErrorHandler(ExitHandler): + def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + if exc_type: + return super().__exit__(exc_type=exc_type, exc_value=exc_value, + exc_traceback=exc_traceback) + def __init__(self): - self._handlers = {} - self._error_stack = ExitStack() - self._exit_stack = ExitStack() + self._stack = ExitStack() self.enabled = True - def on_error_do(self, callback, *args, name: Optional[str] = None, - enabled: bool = True, ignore_errors: bool = False, - fwd_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> str: + def on_error_do(self, callback: Callable, + *args, kwargs: Optional[Dict[str, Any]] = None, + ignore_errors: bool = False): """ Registers a function to be called on scope exit because of an error. - Equivalent to the "except" block of "try-except". - """ - if args or kwargs or fwd_kwargs: - if fwd_kwargs: - kwargs.update(fwd_kwargs) - callback = partial(callback, *args, **kwargs) - - name = name or hash(callback) - assert name not in self._handlers, "Callback is already registered" + If ignore_errors is True, the errors from this function call + will be ignored. + """ - handler = self.Handler(callback, - enabled=enabled, ignore_errors=ignore_errors) - self._handlers[name] = handler - self._error_stack.callback(handler) - return name + return self._register_callback(self.ErrorHandler, + ignore_errors=ignore_errors, + callback=callback, args=args, kwargs=kwargs) - def on_exit_do(self, callback, *args, name: Optional[str] = None, - enabled: bool = True, ignore_errors: bool = False, - fwd_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> str: + def on_exit_do(self, callback: Callable, + *args, kwargs: Optional[Dict[str, Any]] = None, + ignore_errors: bool = False): """ - Registers a function to be called on scope exit unconditionally. - Equivalent to the "finally" block of "try-except". + Registers a function to be called on scope exit. """ - if args or kwargs or fwd_kwargs: - if fwd_kwargs: - kwargs.update(fwd_kwargs) - callback = partial(callback, *args, **kwargs) + return self._register_callback(self.ExitHandler, + ignore_errors=ignore_errors, + callback=callback, args=args, kwargs=kwargs) - name = name or hash(callback) - assert name not in self._handlers, "Callback is already registered" + def _register_callback(self, handler_type, callback: Callable, + args: Tuple[Any] = None, kwargs: Dict[str, Any] = None, + ignore_errors: bool = False): + if args or kwargs: + callback = partial(callback, *args, **(kwargs or {})) - handler = self.Handler(callback, - enabled=enabled, ignore_errors=ignore_errors) - self._handlers[name] = handler - self._exit_stack.callback(handler) - return name + self._stack.push(handler_type(callback, ignore_errors=ignore_errors)) def add(self, cm: ContextManager[T]) -> T: """ @@ -92,21 +85,15 @@ def add(self, cm: ContextManager[T]) -> T: Returns: cm.__enter__() result """ - return self._exit_stack.enter_context(cm) + return self._stack.enter_context(cm) - def enable(self, name=None): - if name: - self._handlers[name].enabled = True - else: - self.enabled = True + def enable(self): + self.enabled = True - def disable(self, name=None): - if name: - self._handlers[name].enabled = False - else: - self.enabled = False + def disable(self): + self.enabled = False - def clean(self): + def close(self): self.__exit__() def __enter__(self): @@ -116,11 +103,7 @@ def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): if not self.enabled: return - try: - if exc_type: - self._error_stack.__exit__(exc_type, exc_value, exc_traceback) - finally: - self._exit_stack.__exit__(exc_type, exc_value, exc_traceback) + self._stack.__exit__(exc_type, exc_value, exc_traceback) @classmethod def current(cls) -> 'Scope': @@ -156,18 +139,13 @@ def wrapped_func(*args, **kwargs): return wrapped_func # Shorthands for common cases -def on_error_do(callback, *args, ignore_errors=False, - fwd_kwargs=None, **kwargs): +def on_error_do(callback, *args, ignore_errors=False, kwargs=None): return Scope.current().on_error_do(callback, *args, - ignore_errors=ignore_errors, fwd_kwargs=fwd_kwargs, **kwargs) -on_error_do.__doc__ = Scope.on_error_do.__doc__ + ignore_errors=ignore_errors, kwargs=kwargs) -def on_exit_do(callback, *args, ignore_errors=False, - fwd_kwargs=None, **kwargs): +def on_exit_do(callback, *args, ignore_errors=False, kwargs=None): return Scope.current().on_exit_do(callback, *args, - ignore_errors=ignore_errors, fwd_kwargs=fwd_kwargs, **kwargs) -on_exit_do.__doc__ = Scope.on_exit_do.__doc__ + ignore_errors=ignore_errors, kwargs=kwargs) def add(cm: ContextManager[T]) -> T: return Scope.current().add(cm) -add.__doc__ = Scope.add.__doc__ From 350308d661c69e52cf4b4e8723ea0fe2a16f8805 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 16:33:14 +0300 Subject: [PATCH 09/12] Update tests --- tests/test_util.py | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 4ddc45cb2e..66a791df76 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -14,13 +14,13 @@ class TestException(Exception): pass -class TestScope(TestCase): +class ScopeTest(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_calls_on_no_error(self): + def test_calls_only_exit_callback_on_exit(self): error_cb = mock.MagicMock() exit_cb = mock.MagicMock() - with suppress(TestException), Scope() as scope: + with Scope() as scope: scope.on_error_do(error_cb) scope.on_exit_do(exit_cb) @@ -28,27 +28,18 @@ def test_calls_on_no_error(self): exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_calls_both_stacks_on_error(self): + def test_calls_both_callbacks_on_error(self): error_cb = mock.MagicMock() exit_cb = mock.MagicMock() - with suppress(TestException), Scope() as scope: + with self.assertRaises(TestException), Scope() as scope: scope.on_error_do(error_cb) scope.on_exit_do(exit_cb) - raise TestException('err') + raise TestException() error_cb.assert_called_once() exit_cb.assert_called_once() - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_cant_add_single_callback_in_both_stacks(self): - cb = mock.MagicMock() - - with self.assertRaisesRegex(AssertionError, "already registered"): - with Scope() as scope: - scope.on_error_do(cb) - scope.on_exit_do(cb) - @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_adds_cm(self): cm = mock.Mock() @@ -82,7 +73,7 @@ def test_decorator_calls_on_error(self): @scoped('scope') def foo(scope=None): scope.on_error_do(cb) - raise TestException('err') + raise TestException() with suppress(TestException): foo() @@ -113,7 +104,7 @@ def test_decorator_supports_implicit_form(self): def foo(): on_error_do(error_cb) on_exit_do(exit_cb) - raise TestException('err') + raise TestException() with suppress(TestException): foo() @@ -123,17 +114,13 @@ def foo(): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_fowrard_args(self): - cb1 = mock.MagicMock() - cb2 = mock.MagicMock() + cb = mock.MagicMock() with suppress(TestException), Scope() as scope: - scope.on_error_do(cb1, 5, a2=2, ignore_errors=True) - scope.on_error_do(cb2, 5, a2=2, ignore_errors=True, - fwd_kwargs={'ignore_errors': 4}) - raise TestException('err') + scope.on_error_do(cb, 5, ignore_errors=True, kwargs={'a2': 2}) + raise TestException() - cb1.assert_called_once_with(5, a2=2) - cb2.assert_called_once_with(5, a2=2, ignore_errors=4) + cb.assert_called_once_with(5, a2=2) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_implicit_form(self): From 4582adc908c5e07c7ce0cd5c9959e00b898bd0e8 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 16:36:46 +0300 Subject: [PATCH 10/12] Rename add to scope_add --- datumaro/util/scope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py index 793a68562b..76de253f1b 100644 --- a/datumaro/util/scope.py +++ b/datumaro/util/scope.py @@ -147,5 +147,5 @@ def on_exit_do(callback, *args, ignore_errors=False, kwargs=None): return Scope.current().on_exit_do(callback, *args, ignore_errors=ignore_errors, kwargs=kwargs) -def add(cm: ContextManager[T]) -> T: +def scope_add(cm: ContextManager[T]) -> T: return Scope.current().add(cm) From 49329516bf3235a9f61798959189ca9884e5fd55 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 16:42:28 +0300 Subject: [PATCH 11/12] fix imports --- datumaro/util/scope.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py index 76de253f1b..c15188d39c 100644 --- a/datumaro/util/scope.py +++ b/datumaro/util/scope.py @@ -11,7 +11,6 @@ from datumaro.util import optional_arg_decorator - T = TypeVar('T') class Scope: From c6d7dbf9b9c5eee7ee0127d9893bbfe2114dff94 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 20:40:33 +0300 Subject: [PATCH 12/12] Code cleanings --- datumaro/util/scope.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py index c15188d39c..a857882ca5 100644 --- a/datumaro/util/scope.py +++ b/datumaro/util/scope.py @@ -25,7 +25,7 @@ class ExitHandler: callback: Callable[[], Any] ignore_errors: bool = True - def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + def __exit__(self, exc_type, exc_value, exc_traceback): try: self.callback() except Exception: @@ -34,7 +34,7 @@ def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): @attrs class ErrorHandler(ExitHandler): - def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + def __exit__(self, exc_type, exc_value, exc_traceback): if exc_type: return super().__exit__(exc_type=exc_type, exc_value=exc_value, exc_traceback=exc_traceback) @@ -54,7 +54,7 @@ def on_error_do(self, callback: Callable, will be ignored. """ - return self._register_callback(self.ErrorHandler, + self._register_callback(self.ErrorHandler, ignore_errors=ignore_errors, callback=callback, args=args, kwargs=kwargs) @@ -65,7 +65,7 @@ def on_exit_do(self, callback: Callable, Registers a function to be called on scope exit. """ - return self._register_callback(self.ExitHandler, + self._register_callback(self.ExitHandler, ignore_errors=ignore_errors, callback=callback, args=args, kwargs=kwargs) @@ -93,16 +93,17 @@ def disable(self): self.enabled = False def close(self): - self.__exit__() + self.__exit__(None, None, None) - def __enter__(self): + def __enter__(self) -> 'Scope': return self - def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + def __exit__(self, exc_type, exc_value, exc_traceback): if not self.enabled: return self._stack.__exit__(exc_type, exc_value, exc_traceback) + self._stack.pop_all() # prevent issues on repetitive calls @classmethod def current(cls) -> 'Scope':