Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Rollback with Scope #444

Merged
merged 12 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Annotation-related classes were moved into a new module,
`datumaro.components.annotation`
(<https://github.com/openvinotoolkit/datumaro/pull/439>)
- Rollback utilities replaced with Scope utilities
(<https://github.com/openvinotoolkit/datumaro/pull/444>)

### Deprecated
- TBD
Expand Down
4 changes: 2 additions & 2 deletions datumaro/cli/contexts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil

from datumaro.components.project import Environment
from datumaro.util import error_rollback, on_error_do
from datumaro.util.scope import on_error_do, scoped

from ..util import CliException, MultilineFormatter, add_subparser
from ..util.project import (
Expand Down Expand Up @@ -46,7 +46,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser):

return parser

@error_rollback
@scoped
def add_command(args):
project = load_project(args.project_dir)

Expand Down
4 changes: 2 additions & 2 deletions datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG
from datumaro.components.project import Environment, Project
from datumaro.components.validator import TaskType
from datumaro.util import error_rollback, on_error_do
from datumaro.util.os_util import make_file_name
from datumaro.util.scope import on_error_do, scoped

from ...util import CliException, MultilineFormatter, add_subparser
from ...util.project import generate_next_file_name, load_project
Expand Down Expand Up @@ -526,7 +526,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser):

return parser

@error_rollback
@scoped
def diff_command(args):
first_project = load_project(args.project_dir)
second_project = load_project(args.other_project_dir)
Expand Down
4 changes: 2 additions & 2 deletions datumaro/components/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset import DatasetPatch
from datumaro.components.extractor import DatasetItem
from datumaro.util import error_rollback, on_error_do
from datumaro.util.image import Image
from datumaro.util.scope import on_error_do, scoped


class Converter(CliPlugin):
Expand All @@ -36,7 +36,7 @@ def convert(cls, extractor, save_dir, **options):
return converter.apply()

@classmethod
@error_rollback
@scoped
def patch(cls, dataset, patch, save_dir, **options):
# This solution is not any better in performance than just
# writing a dataset, but in case of patching (i.e. writing
Expand Down
5 changes: 3 additions & 2 deletions datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
DEFAULT_SUBSET_NAME, CategoriesInfo, DatasetItem, Extractor, IExtractor,
ItemTransform, Transform,
)
from datumaro.util import error_rollback, is_method_redefined, on_error_do
from datumaro.util import is_method_redefined
from datumaro.util.log_utils import logging_disabled
from datumaro.util.scope import on_error_do, scoped

DEFAULT_FORMAT = 'datumaro'

Expand Down Expand Up @@ -790,7 +791,7 @@ def bind(self, path: str, format: Optional[str] = None, *,
def flush_changes(self):
self._data.flush_changes()

@error_rollback
@scoped
def export(self, save_dir: str, format, **kwargs):
inplace = (save_dir == self._source_path and format == self._format)

Expand Down
102 changes: 1 addition & 101 deletions datumaro/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
#
# SPDX-License-Identifier: MIT

from contextlib import ExitStack, contextmanager
from functools import partial, wraps
from functools import wraps
from inspect import isclass
from itertools import islice
from typing import Iterable, Tuple
import distutils.util
import threading

import attr

NOTSET = object()

Expand Down Expand Up @@ -126,99 +122,3 @@ def real_decorator(decoratee):
return real_decorator

return wrapped_decorator

class Rollback:
_thread_locals = threading.local()

@attr.attrs
class Handler:
callback = attr.attrib()
enabled = attr.attrib(default=True)
ignore_errors = attr.attrib(default=False)

def __call__(self):
if self.enabled:
try:
self.callback()
except: # pylint: disable=bare-except
if not self.ignore_errors:
raise

def __init__(self):
self._handlers = {}
self._stack = ExitStack()
self.enabled = True

def add(self, callback, *args,
name=None, enabled=True, ignore_errors=False,
fwd_kwargs=None, **kwargs):
if args or kwargs or fwd_kwargs:
if fwd_kwargs:
kwargs.update(fwd_kwargs)
callback = partial(callback, *args, **kwargs)
name = name or hash(callback)
assert name not in self._handlers
handler = self.Handler(callback,
enabled=enabled, ignore_errors=ignore_errors)
self._handlers[name] = handler
self._stack.callback(handler)
return name

do = add # readability alias

def enable(self, name=None):
if name:
self._handlers[name].enabled = True
else:
self.enabled = True

def disable(self, name=None):
if name:
self._handlers[name].enabled = False
else:
self.enabled = False

def clean(self):
self.__exit__(None, None, None)

def __enter__(self):
return self

def __exit__(self, type=None, value=None, \
traceback=None): # pylint: disable=redefined-builtin
if type is None:
return
if not self.enabled:
return
self._stack.__exit__(type, value, traceback)

@classmethod
def current(cls) -> "Rollback":
return cls._thread_locals.current

@contextmanager
def as_current(self):
previous = getattr(self._thread_locals, 'current', None)
self._thread_locals.current = self
try:
yield
finally:
self._thread_locals.current = previous

# shorthand for common cases
def on_error_do(callback, *args, ignore_errors=False):
Rollback.current().do(callback, *args, ignore_errors=ignore_errors)

@optional_arg_decorator
def error_rollback(func, arg_name=None):
@wraps(func)
def wrapped_func(*args, **kwargs):
with Rollback() as manager:
if arg_name is None:
with manager.as_current():
ret_val = func(*args, **kwargs)
else:
kwargs[arg_name] = manager
ret_val = func(*args, **kwargs)
return ret_val
return wrapped_func
151 changes: 151 additions & 0 deletions datumaro/util/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from contextlib import ExitStack, contextmanager
from functools import partial, wraps
from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, TypeVar
import threading

from attr import attrs

from datumaro.util import optional_arg_decorator

T = TypeVar('T')

class Scope:
"""
A context manager that allows to register error and exit callbacks.
"""

_thread_locals = threading.local()

@attrs(auto_attribs=True)
class ExitHandler:
callback: Callable[[], Any]
ignore_errors: bool = True

def __exit__(self, exc_type, exc_value, exc_traceback):
try:
self.callback()
except Exception:
if not self.ignore_errors:
raise

@attrs
class ErrorHandler(ExitHandler):
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_type:
return super().__exit__(exc_type=exc_type, exc_value=exc_value,
exc_traceback=exc_traceback)


def __init__(self):
self._stack = ExitStack()
self.enabled = True

def on_error_do(self, callback: Callable,
*args, kwargs: Optional[Dict[str, Any]] = None,
ignore_errors: bool = False):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ignore_errors: bool = False):
ignore_errors: bool = False) -> None:

(and ditto for the other None-returning functions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point in this?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To help typecheckers find errors.

Copy link
Contributor Author

@zhiltsov-max zhiltsov-max Sep 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you suggest to add it to all functions without return value?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say yes (if there are annotations on the function's arguments). Otherwise, the return type defaults to Any.

Copy link
Contributor Author

@zhiltsov-max zhiltsov-max Sep 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced it is worthy investment of efforts in most cases. I'm open for a PR on this, but I don't think it is needed for the functions, which names don't not suppose any return value.

"""
Registers a function to be called on scope exit because of an error.

If ignore_errors is True, the errors from this function call
will be ignored.
"""

self._register_callback(self.ErrorHandler,
ignore_errors=ignore_errors,
callback=callback, args=args, kwargs=kwargs)

def on_exit_do(self, callback: Callable,
*args, kwargs: Optional[Dict[str, Any]] = None,
ignore_errors: bool = False):
"""
Registers a function to be called on scope exit.
"""

self._register_callback(self.ExitHandler,
ignore_errors=ignore_errors,
callback=callback, args=args, kwargs=kwargs)

def _register_callback(self, handler_type, callback: Callable,
args: Tuple[Any] = None, kwargs: Dict[str, Any] = None,
ignore_errors: bool = False):
if args or kwargs:
callback = partial(callback, *args, **(kwargs or {}))

self._stack.push(handler_type(callback, ignore_errors=ignore_errors))

def add(self, cm: ContextManager[T]) -> T:
"""
Enters a context manager and adds it to the exit stack.

Returns: cm.__enter__() result
"""

return self._stack.enter_context(cm)

def enable(self):
self.enabled = True

def disable(self):
self.enabled = False

def close(self):
self.__exit__(None, None, None)

def __enter__(self) -> 'Scope':
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.enabled:
return

self._stack.__exit__(exc_type, exc_value, exc_traceback)
self._stack.pop_all() # prevent issues on repetitive calls

@classmethod
def current(cls) -> 'Scope':
return cls._thread_locals.current

@contextmanager
def as_current(self):
previous = getattr(self._thread_locals, 'current', None)
self._thread_locals.current = self
try:
yield
finally:
self._thread_locals.current = previous

@optional_arg_decorator
def scoped(func, arg_name=None):
"""
A function decorator, which allows to do actions with the current scope,
such as registering error and exit callbacks and context managers.
"""

@wraps(func)
def wrapped_func(*args, **kwargs):
with Scope() as scope:
if arg_name is None:
with scope.as_current():
ret_val = func(*args, **kwargs)
else:
kwargs[arg_name] = scope
ret_val = func(*args, **kwargs)
return ret_val

return wrapped_func

# Shorthands for common cases
def on_error_do(callback, *args, ignore_errors=False, kwargs=None):
return Scope.current().on_error_do(callback, *args,
ignore_errors=ignore_errors, kwargs=kwargs)

def on_exit_do(callback, *args, ignore_errors=False, kwargs=None):
return Scope.current().on_exit_do(callback, *args,
ignore_errors=ignore_errors, kwargs=kwargs)

def scope_add(cm: ContextManager[T]) -> T:
return Scope.current().add(cm)
Loading