Skip to content

Commit

Permalink
Replace Rollback with Scope (#444)
Browse files Browse the repository at this point in the history
* Replace rollback with scope

* Add scope_add() function

* Update tests

* Replace rollback uses

* Update changelog
  • Loading branch information
Maxim Zhiltsov authored Sep 6, 2021
1 parent df4a0d6 commit 7f2ca57
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 194 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Annotation-related classes were moved into a new module,
`datumaro.components.annotation`
(<https://github.com/openvinotoolkit/datumaro/pull/439>)
- Rollback utilities replaced with Scope utilities
(<https://github.com/openvinotoolkit/datumaro/pull/444>)

### Deprecated
- TBD
Expand Down
4 changes: 2 additions & 2 deletions datumaro/cli/contexts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil

from datumaro.components.project import Environment
from datumaro.util import error_rollback, on_error_do
from datumaro.util.scope import on_error_do, scoped

from ..util import CliException, MultilineFormatter, add_subparser
from ..util.project import (
Expand Down Expand Up @@ -46,7 +46,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser):

return parser

@error_rollback
@scoped
def add_command(args):
project = load_project(args.project_dir)

Expand Down
4 changes: 2 additions & 2 deletions datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG
from datumaro.components.project import Environment, Project
from datumaro.components.validator import TaskType
from datumaro.util import error_rollback, on_error_do
from datumaro.util.os_util import make_file_name
from datumaro.util.scope import on_error_do, scoped

from ...util import CliException, MultilineFormatter, add_subparser
from ...util.project import generate_next_file_name, load_project
Expand Down Expand Up @@ -526,7 +526,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser):

return parser

@error_rollback
@scoped
def diff_command(args):
first_project = load_project(args.project_dir)
second_project = load_project(args.other_project_dir)
Expand Down
4 changes: 2 additions & 2 deletions datumaro/components/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset import DatasetPatch
from datumaro.components.extractor import DatasetItem
from datumaro.util import error_rollback, on_error_do
from datumaro.util.image import Image
from datumaro.util.scope import on_error_do, scoped


class Converter(CliPlugin):
Expand All @@ -36,7 +36,7 @@ def convert(cls, extractor, save_dir, **options):
return converter.apply()

@classmethod
@error_rollback
@scoped
def patch(cls, dataset, patch, save_dir, **options):
# This solution is not any better in performance than just
# writing a dataset, but in case of patching (i.e. writing
Expand Down
5 changes: 3 additions & 2 deletions datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
DEFAULT_SUBSET_NAME, CategoriesInfo, DatasetItem, Extractor, IExtractor,
ItemTransform, Transform,
)
from datumaro.util import error_rollback, is_method_redefined, on_error_do
from datumaro.util import is_method_redefined
from datumaro.util.log_utils import logging_disabled
from datumaro.util.scope import on_error_do, scoped

DEFAULT_FORMAT = 'datumaro'

Expand Down Expand Up @@ -790,7 +791,7 @@ def bind(self, path: str, format: Optional[str] = None, *,
def flush_changes(self):
self._data.flush_changes()

@error_rollback
@scoped
def export(self, save_dir: str, format, **kwargs):
inplace = (save_dir == self._source_path and format == self._format)

Expand Down
102 changes: 1 addition & 101 deletions datumaro/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
#
# SPDX-License-Identifier: MIT

from contextlib import ExitStack, contextmanager
from functools import partial, wraps
from functools import wraps
from inspect import isclass
from itertools import islice
from typing import Iterable, Tuple
import distutils.util
import threading

import attr

NOTSET = object()

Expand Down Expand Up @@ -126,99 +122,3 @@ def real_decorator(decoratee):
return real_decorator

return wrapped_decorator

class Rollback:
_thread_locals = threading.local()

@attr.attrs
class Handler:
callback = attr.attrib()
enabled = attr.attrib(default=True)
ignore_errors = attr.attrib(default=False)

def __call__(self):
if self.enabled:
try:
self.callback()
except: # pylint: disable=bare-except
if not self.ignore_errors:
raise

def __init__(self):
self._handlers = {}
self._stack = ExitStack()
self.enabled = True

def add(self, callback, *args,
name=None, enabled=True, ignore_errors=False,
fwd_kwargs=None, **kwargs):
if args or kwargs or fwd_kwargs:
if fwd_kwargs:
kwargs.update(fwd_kwargs)
callback = partial(callback, *args, **kwargs)
name = name or hash(callback)
assert name not in self._handlers
handler = self.Handler(callback,
enabled=enabled, ignore_errors=ignore_errors)
self._handlers[name] = handler
self._stack.callback(handler)
return name

do = add # readability alias

def enable(self, name=None):
if name:
self._handlers[name].enabled = True
else:
self.enabled = True

def disable(self, name=None):
if name:
self._handlers[name].enabled = False
else:
self.enabled = False

def clean(self):
self.__exit__(None, None, None)

def __enter__(self):
return self

def __exit__(self, type=None, value=None, \
traceback=None): # pylint: disable=redefined-builtin
if type is None:
return
if not self.enabled:
return
self._stack.__exit__(type, value, traceback)

@classmethod
def current(cls) -> "Rollback":
return cls._thread_locals.current

@contextmanager
def as_current(self):
previous = getattr(self._thread_locals, 'current', None)
self._thread_locals.current = self
try:
yield
finally:
self._thread_locals.current = previous

# shorthand for common cases
def on_error_do(callback, *args, ignore_errors=False):
Rollback.current().do(callback, *args, ignore_errors=ignore_errors)

@optional_arg_decorator
def error_rollback(func, arg_name=None):
@wraps(func)
def wrapped_func(*args, **kwargs):
with Rollback() as manager:
if arg_name is None:
with manager.as_current():
ret_val = func(*args, **kwargs)
else:
kwargs[arg_name] = manager
ret_val = func(*args, **kwargs)
return ret_val
return wrapped_func
151 changes: 151 additions & 0 deletions datumaro/util/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from contextlib import ExitStack, contextmanager
from functools import partial, wraps
from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, TypeVar
import threading

from attr import attrs

from datumaro.util import optional_arg_decorator

T = TypeVar('T')

class Scope:
"""
A context manager that allows to register error and exit callbacks.
"""

_thread_locals = threading.local()

@attrs(auto_attribs=True)
class ExitHandler:
callback: Callable[[], Any]
ignore_errors: bool = True

def __exit__(self, exc_type, exc_value, exc_traceback):
try:
self.callback()
except Exception:
if not self.ignore_errors:
raise

@attrs
class ErrorHandler(ExitHandler):
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_type:
return super().__exit__(exc_type=exc_type, exc_value=exc_value,
exc_traceback=exc_traceback)


def __init__(self):
self._stack = ExitStack()
self.enabled = True

def on_error_do(self, callback: Callable,
*args, kwargs: Optional[Dict[str, Any]] = None,
ignore_errors: bool = False):
"""
Registers a function to be called on scope exit because of an error.
If ignore_errors is True, the errors from this function call
will be ignored.
"""

self._register_callback(self.ErrorHandler,
ignore_errors=ignore_errors,
callback=callback, args=args, kwargs=kwargs)

def on_exit_do(self, callback: Callable,
*args, kwargs: Optional[Dict[str, Any]] = None,
ignore_errors: bool = False):
"""
Registers a function to be called on scope exit.
"""

self._register_callback(self.ExitHandler,
ignore_errors=ignore_errors,
callback=callback, args=args, kwargs=kwargs)

def _register_callback(self, handler_type, callback: Callable,
args: Tuple[Any] = None, kwargs: Dict[str, Any] = None,
ignore_errors: bool = False):
if args or kwargs:
callback = partial(callback, *args, **(kwargs or {}))

self._stack.push(handler_type(callback, ignore_errors=ignore_errors))

def add(self, cm: ContextManager[T]) -> T:
"""
Enters a context manager and adds it to the exit stack.
Returns: cm.__enter__() result
"""

return self._stack.enter_context(cm)

def enable(self):
self.enabled = True

def disable(self):
self.enabled = False

def close(self):
self.__exit__(None, None, None)

def __enter__(self) -> 'Scope':
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.enabled:
return

self._stack.__exit__(exc_type, exc_value, exc_traceback)
self._stack.pop_all() # prevent issues on repetitive calls

@classmethod
def current(cls) -> 'Scope':
return cls._thread_locals.current

@contextmanager
def as_current(self):
previous = getattr(self._thread_locals, 'current', None)
self._thread_locals.current = self
try:
yield
finally:
self._thread_locals.current = previous

@optional_arg_decorator
def scoped(func, arg_name=None):
"""
A function decorator, which allows to do actions with the current scope,
such as registering error and exit callbacks and context managers.
"""

@wraps(func)
def wrapped_func(*args, **kwargs):
with Scope() as scope:
if arg_name is None:
with scope.as_current():
ret_val = func(*args, **kwargs)
else:
kwargs[arg_name] = scope
ret_val = func(*args, **kwargs)
return ret_val

return wrapped_func

# Shorthands for common cases
def on_error_do(callback, *args, ignore_errors=False, kwargs=None):
return Scope.current().on_error_do(callback, *args,
ignore_errors=ignore_errors, kwargs=kwargs)

def on_exit_do(callback, *args, ignore_errors=False, kwargs=None):
return Scope.current().on_exit_do(callback, *args,
ignore_errors=ignore_errors, kwargs=kwargs)

def scope_add(cm: ContextManager[T]) -> T:
return Scope.current().add(cm)
Loading

0 comments on commit 7f2ca57

Please sign in to comment.