From 9b1c6875e7227f0ad8e434b84047ed3985a33e66 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 18:19:03 +0300 Subject: [PATCH 01/19] Replace rollback with scope --- datumaro/util/__init__.py | 102 +---------------------- datumaro/util/scope.py | 171 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 101 deletions(-) create mode 100644 datumaro/util/scope.py diff --git a/datumaro/util/__init__.py b/datumaro/util/__init__.py index 7c8a52e7b1..6a2328b7e1 100644 --- a/datumaro/util/__init__.py +++ b/datumaro/util/__init__.py @@ -2,15 +2,11 @@ # # SPDX-License-Identifier: MIT -from contextlib import ExitStack, contextmanager -from functools import partial, wraps +from functools import wraps from inspect import isclass from itertools import islice from typing import Iterable, Tuple import distutils.util -import threading - -import attr NOTSET = object() @@ -126,99 +122,3 @@ def real_decorator(decoratee): return real_decorator return wrapped_decorator - -class Rollback: - _thread_locals = threading.local() - - @attr.attrs - class Handler: - callback = attr.attrib() - enabled = attr.attrib(default=True) - ignore_errors = attr.attrib(default=False) - - def __call__(self): - if self.enabled: - try: - self.callback() - except: # pylint: disable=bare-except - if not self.ignore_errors: - raise - - def __init__(self): - self._handlers = {} - self._stack = ExitStack() - self.enabled = True - - def add(self, callback, *args, - name=None, enabled=True, ignore_errors=False, - fwd_kwargs=None, **kwargs): - if args or kwargs or fwd_kwargs: - if fwd_kwargs: - kwargs.update(fwd_kwargs) - callback = partial(callback, *args, **kwargs) - name = name or hash(callback) - assert name not in self._handlers - handler = self.Handler(callback, - enabled=enabled, ignore_errors=ignore_errors) - self._handlers[name] = handler - self._stack.callback(handler) - return name - - do = add # readability alias - - def enable(self, name=None): - if name: - self._handlers[name].enabled = True - else: - self.enabled = True - - def disable(self, name=None): - if name: - self._handlers[name].enabled = False - else: - self.enabled = False - - def clean(self): - self.__exit__(None, None, None) - - def __enter__(self): - return self - - def __exit__(self, type=None, value=None, \ - traceback=None): # pylint: disable=redefined-builtin - if type is None: - return - if not self.enabled: - return - self._stack.__exit__(type, value, traceback) - - @classmethod - def current(cls) -> "Rollback": - return cls._thread_locals.current - - @contextmanager - def as_current(self): - previous = getattr(self._thread_locals, 'current', None) - self._thread_locals.current = self - try: - yield - finally: - self._thread_locals.current = previous - -# shorthand for common cases -def on_error_do(callback, *args, ignore_errors=False): - Rollback.current().do(callback, *args, ignore_errors=ignore_errors) - -@optional_arg_decorator -def error_rollback(func, arg_name=None): - @wraps(func) - def wrapped_func(*args, **kwargs): - with Rollback() as manager: - if arg_name is None: - with manager.as_current(): - ret_val = func(*args, **kwargs) - else: - kwargs[arg_name] = manager - ret_val = func(*args, **kwargs) - return ret_val - return wrapped_func diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py new file mode 100644 index 0000000000..7fff715166 --- /dev/null +++ b/datumaro/util/scope.py @@ -0,0 +1,171 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from contextlib import ExitStack, contextmanager +from functools import partial, wraps +from typing import Any, ContextManager, Dict, Optional +import threading + +from attr import attrs + +from datumaro.util import optional_arg_decorator + + +class Scope: + """ + A context manager that allows to register error and exit callbacks. + """ + + _thread_locals = threading.local() + + @attrs(auto_attribs=True) + class Handler: + callback: Any + enabled: bool = True + ignore_errors: bool = False + + def __call__(self): + if self.enabled: + try: + self.callback() + except: # pylint: disable=bare-except + if not self.ignore_errors: + raise + + def __init__(self): + self._handlers = {} + self._error_stack = ExitStack() + self._exit_stack = ExitStack() + self.enabled = True + + def on_error_do(self, callback, *args, name: Optional[str] = None, + enabled: bool = True, ignore_errors: bool = False, + fwd_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> str: + """ + Registers a function to be called on scope exit because of an error. + Equivalent to the "except" block of "try-except". + """ + + if args or kwargs or fwd_kwargs: + if fwd_kwargs: + kwargs.update(fwd_kwargs) + callback = partial(callback, *args, **kwargs) + + name = name or hash(callback) + assert name not in self._handlers, "Callback is already registered" + + handler = self.Handler(callback, + enabled=enabled, ignore_errors=ignore_errors) + self._handlers[name] = handler + self._error_stack.callback(handler) + return name + + def on_exit_do(self, callback, *args, name: Optional[str] = None, + enabled: bool = True, ignore_errors: bool = False, + fwd_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> str: + """ + Registers a function to be called on scope exit unconditionally. + Equivalent to the "finally" block of "try-except". + """ + + if args or kwargs or fwd_kwargs: + if fwd_kwargs: + kwargs.update(fwd_kwargs) + callback = partial(callback, *args, **kwargs) + + name = name or hash(callback) + assert name not in self._handlers, "Callback is already registered" + + handler = self.Handler(callback, + enabled=enabled, ignore_errors=ignore_errors) + self._handlers[name] = handler + self._exit_stack.callback(handler) + return name + + def add(self, cm: ContextManager) -> Any: + """ + Enters a context manager and adds it to the exit stack. + """ + + return self._exit_stack.enter_context(cm) + + def enable(self, name=None): + if name: + self._handlers[name].enabled = True + else: + self.enabled = True + + def disable(self, name=None): + if name: + self._handlers[name].enabled = False + else: + self.enabled = False + + def clean(self): + self.__exit__() + + def __enter__(self): + return self + + def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + if not self.enabled: + return + + try: + if exc_type: + self._error_stack.__exit__(exc_type, exc_value, exc_traceback) + finally: + self._exit_stack.__exit__(exc_type, exc_value, exc_traceback) + + @classmethod + def current(cls) -> 'Scope': + return cls._thread_locals.current + + @contextmanager + def as_current(self): + previous = getattr(self._thread_locals, 'current', None) + self._thread_locals.current = self + try: + yield + finally: + self._thread_locals.current = previous + +@optional_arg_decorator +def scoped(func, arg_name=None): + """ + A function decorator, which allows to do actions with the current scope, + such as registering error and exit callbacks and context managers. + """ + + @wraps(func) + def wrapped_func(*args, **kwargs): + with Scope() as scope: + if arg_name is None: + with scope.as_current(): + ret_val = func(*args, **kwargs) + else: + kwargs[arg_name] = scope + ret_val = func(*args, **kwargs) + return ret_val + + return wrapped_func + +# Shorthands for common cases +def on_error_do(callback, *args, ignore_errors=False): + return Scope.current().on_error_do(callback, *args, + ignore_errors=ignore_errors) +on_error_do.__doc__ = Scope.on_error_do.__doc__ + +def on_exit_do(callback, *args, ignore_errors=False): + return Scope.current().on_exit_do(callback, *args, + ignore_errors=ignore_errors) +on_exit_do.__doc__ = Scope.on_exit_do.__doc__ + +def add(cm: ContextManager): + return Scope.current().add(cm) +add.__doc__ = Scope.add.__doc__ + +def current(): + return Scope.current() +current.__doc__ = Scope.current.__doc__ From 0f823cf92002eea439ee375746e88031faa80bae Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 18:19:15 +0300 Subject: [PATCH 02/19] Update tests --- tests/test_util.py | 186 ++++++++++++++++++++++++--------------------- 1 file changed, 101 insertions(+), 85 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index f124e7cdb2..e900b57ee4 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,127 +1,143 @@ +from contextlib import suppress from unittest import TestCase, mock import os import os.path as osp -from datumaro.util import ( - Rollback, error_rollback, is_method_redefined, on_error_do, -) +from datumaro.util import is_method_redefined from datumaro.util.os_util import walk +from datumaro.util.scope import Scope, on_error_do, scoped, on_exit_do from datumaro.util.test_utils import TestDir from .requirements import Requirements, mark_requirement -class TestRollback(TestCase): +class TestException(Exception): + pass + +class TestScope(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + def test_calls_on_no_error(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - with Rollback() as on_error: - on_error.do(cb) + with suppress(TestException), Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True - - try: - with Rollback() as on_error: - on_error.do(cb) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + def test_calls_both_stacks_on_error(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) + raise TestException('err') + + error_cb.assert_called_once() + exit_cb.assert_called_once() + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_cant_add_single_callback_in_both_stacks(self): + cb = mock.MagicMock() + + with self.assertRaisesRegex(AssertionError, "already registered"): + with Scope() as scope: + scope.on_error_do(cb) + scope.on_exit_do(cb) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_adds_cm(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock(return_value=42) + cm.__exit__ = mock.MagicMock() + + with Scope() as scope: + retval = scope.add(cm) + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() + self.assertEqual(42, retval) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_calls_cm_on_error(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock() + cm.__exit__ = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.add(cm) + raise TestException() + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True + cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) - raise Exception('err') + @scoped('scope') + def foo(scope=None): + scope.on_error_do(cb) + raise TestException('err') - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) + @scoped('scope') + def foo(scope=None): + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) foo() - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_supports_implicit_form(self): - success = False - def cb(): - nonlocal success - success = True + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback + @scoped def foo(): - on_error_do(cb) - raise Exception('err') + on_error_do(error_cb) + on_exit_do(exit_cb) + raise TestException('err') - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + error_cb.assert_called_once() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_fowrard_args(self): - success1 = False - def cb1(a1, a2=None, ignore_errors=None): - nonlocal success1 - if a1 == 5 and a2 == 2 and ignore_errors == None: - success1 = True - - success2 = False - def cb2(a1, a2=None, ignore_errors=None): - nonlocal success2 - if a1 == 5 and a2 == 2 and ignore_errors == 4: - success2 = True - - try: - with Rollback() as on_error: - on_error.do(cb1, 5, a2=2, ignore_errors=True) - on_error.do(cb2, 5, a2=2, ignore_errors=True, - fwd_kwargs={'ignore_errors': 4}) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success1) - self.assertTrue(success2) + cb1 = mock.MagicMock() + cb2 = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.on_error_do(cb1, 5, a2=2, ignore_errors=True) + scope.on_error_do(cb2, 5, a2=2, ignore_errors=True, + fwd_kwargs={'ignore_errors': 4}) + raise TestException('err') + + cb1.assert_called_once_with(5, a2=2) + cb2.assert_called_once_with(5, a2=2, ignore_errors=4) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_implicit_form(self): - @error_rollback + @scoped def f(): return 42 @@ -131,8 +147,8 @@ def f(): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_explicit_form(self): - @error_rollback('on_error') - def f(on_error=None): + @scoped('scope') + def f(scope=None): return 42 retval = f() From 655505ad10e22487fd88b0c6189598459f8c6724 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 18:19:39 +0300 Subject: [PATCH 03/19] Replace rollback uses --- datumaro/cli/contexts/model.py | 4 ++-- datumaro/cli/contexts/project/__init__.py | 4 ++-- datumaro/components/converter.py | 4 ++-- datumaro/components/dataset.py | 5 +++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index 378af1ff76..cebcd124fc 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -9,7 +9,7 @@ import shutil from datumaro.components.project import Environment -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import scoped, on_error_do from ..util import CliException, MultilineFormatter, add_subparser from ..util.project import ( @@ -46,7 +46,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def add_command(args): project = load_project(args.project_dir) diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index c12ddaf132..f975852c6f 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -21,7 +21,7 @@ from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project from datumaro.components.validator import TaskType -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import scoped, on_error_do from datumaro.util.os_util import make_file_name from ...util import CliException, MultilineFormatter, add_subparser @@ -526,7 +526,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index ec8f9c2de5..45f7c3258a 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -12,7 +12,7 @@ from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import DatasetPatch from datumaro.components.extractor import DatasetItem -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import scoped, on_error_do from datumaro.util.image import Image @@ -36,7 +36,7 @@ def convert(cls, extractor, save_dir, **options): return converter.apply() @classmethod - @error_rollback + @scoped def patch(cls, dataset, patch, save_dir, **options): # This solution is not any better in performance than just # writing a dataset, but in case of patching (i.e. writing diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index ecd5db46c0..b186e2b3fd 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -27,7 +27,8 @@ DEFAULT_SUBSET_NAME, CategoriesInfo, DatasetItem, Extractor, IExtractor, ItemTransform, Transform, ) -from datumaro.util import error_rollback, is_method_redefined, on_error_do +from datumaro.util import is_method_redefined +from datumaro.util.scope import scoped, on_error_do from datumaro.util.log_utils import logging_disabled DEFAULT_FORMAT = 'datumaro' @@ -790,7 +791,7 @@ def bind(self, path: str, format: Optional[str] = None, *, def flush_changes(self): self._data.flush_changes() - @error_rollback + @scoped def export(self, save_dir: str, format, **kwargs): inplace = (save_dir == self._source_path and format == self._format) From 73d7b3fde2e8b0d2e5fa822b08ac410ef04b8105 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 19:28:56 +0300 Subject: [PATCH 04/19] Fix imports --- datumaro/cli/contexts/model.py | 2 +- datumaro/cli/contexts/project/__init__.py | 2 +- datumaro/components/converter.py | 2 +- datumaro/components/dataset.py | 2 +- tests/test_util.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index cebcd124fc..178b389deb 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -9,7 +9,7 @@ import shutil from datumaro.components.project import Environment -from datumaro.util.scope import scoped, on_error_do +from datumaro.util.scope import on_error_do, scoped from ..util import CliException, MultilineFormatter, add_subparser from ..util.project import ( diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index f975852c6f..18d68c5a7a 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -21,8 +21,8 @@ from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project from datumaro.components.validator import TaskType -from datumaro.util.scope import scoped, on_error_do from datumaro.util.os_util import make_file_name +from datumaro.util.scope import on_error_do, scoped from ...util import CliException, MultilineFormatter, add_subparser from ...util.project import generate_next_file_name, load_project diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index 45f7c3258a..21c0410cad 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -12,8 +12,8 @@ from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import DatasetPatch from datumaro.components.extractor import DatasetItem -from datumaro.util.scope import scoped, on_error_do from datumaro.util.image import Image +from datumaro.util.scope import on_error_do, scoped class Converter(CliPlugin): diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index b186e2b3fd..06e52dd8eb 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -28,8 +28,8 @@ ItemTransform, Transform, ) from datumaro.util import is_method_redefined -from datumaro.util.scope import scoped, on_error_do from datumaro.util.log_utils import logging_disabled +from datumaro.util.scope import on_error_do, scoped DEFAULT_FORMAT = 'datumaro' diff --git a/tests/test_util.py b/tests/test_util.py index e900b57ee4..4ddc45cb2e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -5,7 +5,7 @@ from datumaro.util import is_method_redefined from datumaro.util.os_util import walk -from datumaro.util.scope import Scope, on_error_do, scoped, on_exit_do +from datumaro.util.scope import Scope, on_error_do, on_exit_do, scoped from datumaro.util.test_utils import TestDir from .requirements import Requirements, mark_requirement From e75c7abdec72a1712b5f2ae20e0fe6b37174b2c5 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 2 Sep 2021 19:30:14 +0300 Subject: [PATCH 05/19] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4102485ca9..5e52f9a812 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Annotation-related classes were moved into a new module, `datumaro.components.annotation` () +- Rollback utilities replaced with Scope utilities + () ### Deprecated - TBD From 564957d79e1aade5c013fd922bf3f9f460bfe914 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 14:19:21 +0300 Subject: [PATCH 06/19] Fix add() type anns in Scope --- datumaro/util/scope.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py index 7fff715166..1a86732e28 100644 --- a/datumaro/util/scope.py +++ b/datumaro/util/scope.py @@ -4,7 +4,7 @@ from contextlib import ExitStack, contextmanager from functools import partial, wraps -from typing import Any, ContextManager, Dict, Optional +from typing import Any, ContextManager, Dict, Optional, TypeVar import threading from attr import attrs @@ -12,6 +12,8 @@ from datumaro.util import optional_arg_decorator +T = TypeVar('T') + class Scope: """ A context manager that allows to register error and exit callbacks. @@ -83,9 +85,11 @@ def on_exit_do(self, callback, *args, name: Optional[str] = None, self._exit_stack.callback(handler) return name - def add(self, cm: ContextManager) -> Any: + def add(self, cm: ContextManager[T]) -> T: """ Enters a context manager and adds it to the exit stack. + + Returns: cm.__enter__() result """ return self._exit_stack.enter_context(cm) @@ -152,17 +156,19 @@ def wrapped_func(*args, **kwargs): return wrapped_func # Shorthands for common cases -def on_error_do(callback, *args, ignore_errors=False): +def on_error_do(callback, *args, ignore_errors=False, + fwd_kwargs=None, **kwargs): return Scope.current().on_error_do(callback, *args, - ignore_errors=ignore_errors) + ignore_errors=ignore_errors, fwd_kwargs=fwd_kwargs, **kwargs) on_error_do.__doc__ = Scope.on_error_do.__doc__ -def on_exit_do(callback, *args, ignore_errors=False): +def on_exit_do(callback, *args, ignore_errors=False, + fwd_kwargs=None, **kwargs): return Scope.current().on_exit_do(callback, *args, - ignore_errors=ignore_errors) + ignore_errors=ignore_errors, fwd_kwargs=fwd_kwargs, **kwargs) on_exit_do.__doc__ = Scope.on_exit_do.__doc__ -def add(cm: ContextManager): +def add(cm: ContextManager[T]) -> T: return Scope.current().add(cm) add.__doc__ = Scope.add.__doc__ From 51d871e0af6d6cddf877c81bb3283bb8b4de27f3 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 14:37:49 +0300 Subject: [PATCH 07/19] Replace rollback with scoped in the project --- datumaro/cli/commands/checkout.py | 6 ++++- datumaro/cli/commands/commit.py | 6 ++++- datumaro/cli/commands/diff.py | 24 ++++++++++++----- datumaro/cli/commands/explain.py | 11 ++++++-- datumaro/cli/commands/info.py | 9 +++++-- datumaro/cli/commands/log.py | 6 ++++- datumaro/cli/commands/merge.py | 10 +++++-- datumaro/cli/commands/status.py | 5 +++- datumaro/cli/contexts/model.py | 18 +++++++++---- datumaro/cli/contexts/project/__init__.py | 30 ++++++++++++++------- datumaro/cli/contexts/source.py | 13 +++++---- datumaro/cli/util/project.py | 33 ++++++++++++++++------- datumaro/components/project.py | 13 +++++---- 13 files changed, 131 insertions(+), 53 deletions(-) diff --git a/datumaro/cli/commands/checkout.py b/datumaro/cli/commands/checkout.py index 6739cdad77..e8323cbf1c 100644 --- a/datumaro/cli/commands/checkout.py +++ b/datumaro/cli/commands/checkout.py @@ -4,6 +4,9 @@ import argparse +from datumaro.util.scope import scoped +import datumaro.util.scope as scope + from ..util import MultilineFormatter from ..util.project import load_project @@ -45,6 +48,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def checkout_command(args): has_sep = '--' in args._positionals if has_sep: @@ -60,7 +64,7 @@ def checkout_command(args): raise argparse.ArgumentError('sources', message="When '--' is used, " "at least 1 source name must be specified") - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) project.checkout(rev=args.rev, sources=args.sources, force=args.force) diff --git a/datumaro/cli/commands/commit.py b/datumaro/cli/commands/commit.py index 5e3f0131b0..67978ce23c 100644 --- a/datumaro/cli/commands/commit.py +++ b/datumaro/cli/commands/commit.py @@ -4,6 +4,9 @@ import argparse +from datumaro.util.scope import scoped +import datumaro.util.scope as scope + from ..util import MultilineFormatter from ..util.project import load_project @@ -32,8 +35,9 @@ def build_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def commit_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) old_tree = project.head diff --git a/datumaro/cli/commands/diff.py b/datumaro/cli/commands/diff.py index 26dc087cee..43839697fe 100644 --- a/datumaro/cli/commands/diff.py +++ b/datumaro/cli/commands/diff.py @@ -11,8 +11,9 @@ from datumaro.components.errors import ProjectNotFoundError from datumaro.components.operations import DistanceComparator, ExactComparator -from datumaro.util import error_rollback, on_error_do from datumaro.util.os_util import rmtree +from datumaro.util.scope import on_error_do, scoped +import datumaro.util.scope as scope from ..contexts.project.diff import DiffVisualizer from ..util import MultilineFormatter @@ -132,7 +133,7 @@ def _parse_comparison_method(s): return parser -@error_rollback +@scoped def diff_command(args): dst_dir = args.dst_dir if dst_dir: @@ -149,7 +150,7 @@ def diff_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if args.project_dir: raise @@ -157,14 +158,23 @@ def diff_command(args): try: if not args.second_target: first_dataset = project.working_tree.make_dataset() - second_dataset = parse_full_revpath(args.first_target, project) + second_dataset, target_project = \ + parse_full_revpath(args.first_target, project) + if target_project: + scope.add(target_project) else: - first_dataset = parse_full_revpath(args.first_target, project) - second_dataset = parse_full_revpath(args.second_target, project) + first_dataset, target_project = \ + parse_full_revpath(args.first_target, project) + if target_project: + scope.add(target_project) + + second_dataset, target_project = \ + parse_full_revpath(args.second_target, project) + if target_project: + scope.add(target_project) except Exception as e: raise CliException(str(e)) - if args.method is ComparisonMethod.equality: if args.ignore_field: args.ignore_field = eq_default_if diff --git a/datumaro/cli/commands/explain.py b/datumaro/cli/commands/explain.py index eaa7500bde..45a925a505 100644 --- a/datumaro/cli/commands/explain.py +++ b/datumaro/cli/commands/explain.py @@ -8,6 +8,8 @@ import os.path as osp from datumaro.util.image import is_image, load_image, save_image +from datumaro.util.scope import scoped +import datumaro.util.scope as scope from ..util import MultilineFormatter from ..util.project import load_project, parse_full_revpath @@ -110,11 +112,12 @@ def build_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def explain_command(args): from matplotlib import cm import cv2 - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) model = project.working_tree.models.make_executable_model(args.model) @@ -168,7 +171,11 @@ def explain_command(args): cv2.waitKey(0) else: - dataset = parse_full_revpath(args.target or 'project', project) + dataset, target_project = \ + parse_full_revpath(args.target or 'project', project) + if target_project: + scope.add(target_project) + log.info("Running inference explanation for '%s'" % args.target) for item in dataset: diff --git a/datumaro/cli/commands/info.py b/datumaro/cli/commands/info.py index fa55598498..56f2537102 100644 --- a/datumaro/cli/commands/info.py +++ b/datumaro/cli/commands/info.py @@ -8,6 +8,8 @@ DatasetMergeError, MissingObjectError, ProjectNotFoundError, ) from datumaro.components.extractor import AnnotationType +from datumaro.util.scope import scoped +import datumaro.util.scope as scope from ..util import MultilineFormatter from ..util.project import load_project, parse_full_revpath @@ -55,17 +57,20 @@ def build_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def info_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if args.project_dir: raise try: # TODO: avoid computing working tree hashes - dataset = parse_full_revpath(args.target, project) + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope.add(target_project) except DatasetMergeError as e: dataset = None dataset_problem = "Can't merge project sources automatically: %s " \ diff --git a/datumaro/cli/commands/log.py b/datumaro/cli/commands/log.py index f08f946986..31afbe0ed8 100644 --- a/datumaro/cli/commands/log.py +++ b/datumaro/cli/commands/log.py @@ -4,6 +4,9 @@ import argparse +from datumaro.util.scope import scoped +import datumaro.util.scope as scope + from ..util.project import load_project @@ -18,8 +21,9 @@ def build_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def log_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) revisions = project.history(args.max_count) if revisions: diff --git a/datumaro/cli/commands/merge.py b/datumaro/cli/commands/merge.py index 13cfd1953f..b34670dd9b 100644 --- a/datumaro/cli/commands/merge.py +++ b/datumaro/cli/commands/merge.py @@ -14,6 +14,8 @@ DatasetMergeError, DatasetQualityError, ProjectNotFoundError, ) from datumaro.components.operations import IntersectMerge +from datumaro.util.scope import scoped +import datumaro.util.scope as scope from ..util import MultilineFormatter from ..util.errors import CliException @@ -102,6 +104,7 @@ def _group(s): return parser +@scoped def merge_command(args): dst_dir = args.dst_dir if dst_dir: @@ -114,7 +117,7 @@ def merge_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if args.project_dir: raise @@ -125,7 +128,10 @@ def merge_command(args): source_datasets.append(project.working_tree.make_dataset()) for t in args.targets: - source_datasets.append(parse_full_revpath(t, project)) + target_dataset, target_project = parse_full_revpath(t, project) + if target_project: + scope.add(target_project) + source_datasets.append(target_dataset) except Exception as e: raise CliException(str(e)) diff --git a/datumaro/cli/commands/status.py b/datumaro/cli/commands/status.py index 9ea1683239..a31e352bfe 100644 --- a/datumaro/cli/commands/status.py +++ b/datumaro/cli/commands/status.py @@ -5,6 +5,8 @@ import argparse from datumaro.cli.util import MultilineFormatter +from datumaro.util.scope import scoped +import datumaro.util.scope as scope from ..util.project import load_project @@ -24,8 +26,9 @@ def build_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def status_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) statuses = project.status() diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index a6e636b199..30ba405c81 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -11,6 +11,7 @@ from datumaro.components.project import Environment from datumaro.util.os_util import rmtree from datumaro.util.scope import on_error_do, scoped +import datumaro.util.scope as scope from ..util import MultilineFormatter, add_subparser from ..util.errors import CliException @@ -59,7 +60,7 @@ def add_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if not show_plugin_help and args.project_dir: raise @@ -125,8 +126,9 @@ def build_remove_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def remove_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) project.remove_model(args.name) project.save() @@ -165,6 +167,7 @@ def build_run_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def run_command(args): dst_dir = args.dst_dir if dst_dir: @@ -175,8 +178,12 @@ def run_command(args): dst_dir = generate_next_file_name('%s-inference' % args.model_name) dst_dir = osp.abspath(dst_dir) - project = load_project(args.project_dir) - dataset = parse_full_revpath(args.target, project) + project = scope.add(load_project(args.project_dir)) + + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope.add(target_project) + model = project.make_model(args.model_name) inference = dataset.run_model(model) inference.save(dst_dir) @@ -198,8 +205,9 @@ def build_info_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def info_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) if args.name: print(project.models[args.name]) diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index 81a955e6a7..a7a6426a38 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -21,7 +21,8 @@ from datumaro.components.validator import TaskType from datumaro.util import str_to_bool from datumaro.util.os_util import make_file_name -from datumaro.util.scope import on_exit_do, scoped +from datumaro.util.scope import scoped +import datumaro.util.scope as scope from ...util import MultilineFormatter, add_subparser from ...util.errors import CliException @@ -136,6 +137,7 @@ def build_export_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def export_command(args): has_sep = '--' in args._positionals if has_sep: @@ -153,7 +155,7 @@ def export_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if not show_plugin_help and args.project_dir: raise @@ -280,8 +282,9 @@ def build_filter_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def filter_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) # TODO: check if we can accept a dataset revpath here if not args.dry_run and args.stage and \ @@ -409,6 +412,7 @@ def build_transform_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def transform_command(args): has_sep = '--' in args._positionals if has_sep: @@ -426,7 +430,7 @@ def transform_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if not show_plugin_help and args.project_dir: raise @@ -524,15 +528,18 @@ def build_stats_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def stats_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if args.project_dir: raise - dataset = parse_full_revpath(args.target, project) + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope.add(target_project) stats = {} stats.update(compute_image_statistics(dataset)) @@ -566,8 +573,9 @@ def build_info_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def info_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) rev = project.get_rev(args.revision) env = rev.env @@ -654,6 +662,7 @@ def _parse_task_type(s): return parser +@scoped def validate_command(args): has_sep = '--' in args._positionals if has_sep: @@ -670,7 +679,7 @@ def validate_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if not show_plugin_help and args.project_dir: raise @@ -687,8 +696,11 @@ def validate_command(args): extra_args = validator_type.parse_cmdline(args.extra_args) + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope.add(target_project) + dst_file_name = f'validation-report' - dataset = parse_full_revpath(args.target, project) if args.subset_name is not None: dataset = dataset.get_subset(args.subset_name) dst_file_name += f'-{args.subset_name}' diff --git a/datumaro/cli/contexts/source.py b/datumaro/cli/contexts/source.py index 1dddfd6348..a26b4210c8 100644 --- a/datumaro/cli/contexts/source.py +++ b/datumaro/cli/contexts/source.py @@ -8,7 +8,8 @@ from datumaro.components.errors import ProjectNotFoundError from datumaro.components.project import Environment -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import on_error_do, scoped +import datumaro.util.scope as scope from ..util import MultilineFormatter, add_subparser, join_cli_args from ..util.errors import CliException @@ -80,7 +81,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def add_command(args): # Workaround. Required positionals consume positionals from the end args._positionals += join_cli_args(args, 'url', 'extra_args') @@ -97,7 +98,7 @@ def add_command(args): project = None try: - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) except ProjectNotFoundError: if not show_plugin_help and args.project_dir: raise @@ -160,8 +161,9 @@ def build_remove_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def remove_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) if not args.names: raise CliException("Expected source name") @@ -188,8 +190,9 @@ def build_info_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def info_command(args): - project = load_project(args.project_dir) + project = scope.add(load_project(args.project_dir)) if args.name: source = project.working_tree.sources[args.name] diff --git a/datumaro/cli/util/project.py b/datumaro/cli/util/project.py index 19a50233aa..11117d5c4f 100644 --- a/datumaro/cli/util/project.py +++ b/datumaro/cli/util/project.py @@ -12,6 +12,7 @@ from datumaro.components.errors import DatumaroError, ProjectNotFoundError from datumaro.components.project import Project, Revision from datumaro.util.os_util import generate_next_name +from datumaro.util.scope import on_error_do, scoped def load_project(project_dir, readonly=False): @@ -48,7 +49,9 @@ def parse_dataset_pathspec(s: str, format = match["format"] return Dataset.import_from(path, format, env=env) -def parse_revspec(s: str, ctx_project: Optional[Project] = None) -> Dataset: +@scoped +def parse_revspec(s: str, ctx_project: Optional[Project] = None) \ + -> Tuple[Dataset, Project]: """ Parses Revision paths. The syntax is: - [ @ ] [ : ] @@ -56,7 +59,8 @@ def parse_revspec(s: str, ctx_project: Optional[Project] = None) -> Dataset: - The second and the third forms assume an existing "current" project. - Returns: a dataset from the parsed path + Returns: the dataset and the project from the parsed path. + The project is only returned when specified in the revpath. """ match = re.fullmatch(r""" @@ -72,32 +76,41 @@ def parse_revspec(s: str, ctx_project: Optional[Project] = None) -> Dataset: rev = match["rev"] source = match["source"] + target_project = None + assert proj_path if rev: - project = load_project(proj_path, readonly=True) - - # proj_path is either proj_path or rev or source name + target_project = load_project(proj_path, readonly=True) + project = target_project + # proj_path is either proj_path or rev or source name elif Project.find_project_dir(proj_path): - project = load_project(proj_path, readonly=True) + target_project = load_project(proj_path, readonly=True) + project = target_project elif ctx_project: project = ctx_project if project.is_ref(proj_path): rev = proj_path elif not source: source = proj_path + else: raise ProjectNotFoundError("Failed to find project at '%s'. " \ "Specify project path with '-p/--project' or in the " "target pathspec." % proj_path) + if target_project: + on_error_do(Project.close, target_project, ignore_errors=True) + tree = project.get_rev(rev) - return tree.make_dataset(source) + return tree.make_dataset(source), target_project -def parse_full_revpath(s: str, ctx_project: Optional[Project]) -> Dataset: +def parse_full_revpath(s: str, ctx_project: Optional[Project] = None) \ + -> Tuple[Dataset, Optional[Project]]: """ revpath - either a Dataset path or a Revision path. - Returns: a dataset from the parsed path + Returns: the dataset and the project from the parsed path + The project is only returned when specified in the revpath. """ if ctx_project: @@ -107,7 +120,7 @@ def parse_full_revpath(s: str, ctx_project: Optional[Project]) -> Dataset: errors = [] try: - return parse_dataset_pathspec(s, env=env) + return parse_dataset_pathspec(s, env=env), None except (DatumaroError, OSError) as e: errors.append(e) diff --git a/datumaro/components/project.py b/datumaro/components/project.py index 082e309241..2185aa8839 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -35,13 +35,12 @@ UnsavedChangesError, VcsError, ) from datumaro.components.launcher import Launcher -from datumaro.util import ( - error_rollback, find, on_error_do, parse_str_enum_value, -) +from datumaro.util import find, parse_str_enum_value from datumaro.util.log_utils import catch_logs, logging_disabled from datumaro.util.os_util import ( copytree, generate_next_name, is_subpath, make_file_name, rmtree, ) +from datumaro.util.scope import on_error_do, scoped class ProjectSourceDataset(IDataset): @@ -1521,8 +1520,8 @@ def _init_vcs(self): osp.join(self._root_dir, '.dvc', 'plots'), r=True) @classmethod - @error_rollback - def init(cls, path): + @scoped + def init(cls, path) -> 'Project': existing_project = cls.find_project_dir(path) if existing_project: raise ProjectAlreadyExists(path) @@ -1792,7 +1791,7 @@ def validate_source_name(self, name: str): if name.lower() in reserved_names: raise ValueError("Source name is reserved for internal use") - @error_rollback + @scoped def _download_source(self, url: str, dst_dir: str, no_cache: bool = False): assert url assert dst_dir @@ -1888,7 +1887,7 @@ def _materialize_obj(self, obj_hash: ObjectId) -> str: self._dvc.write_obj(obj_hash, dst_dir, allow_links=True) return dst_dir - @error_rollback + @scoped def import_source(self, name: str, url: Optional[str], format: str, options: Optional[Dict] = None, no_cache: bool = False, rpath: Optional[str] = None) -> Source: From 6b5034c10df4b996e1e06b3af7cf7dae774d899e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 3 Sep 2021 14:38:11 +0300 Subject: [PATCH 08/19] Update revpath tests --- datumaro/util/test_utils.py | 5 +- tests/cli/test_revpath.py | 189 ++++++++++++++++++++---------------- 2 files changed, 110 insertions(+), 84 deletions(-) diff --git a/datumaro/util/test_utils.py b/datumaro/util/test_utils.py index e07d926cc2..869f8964b0 100644 --- a/datumaro/util/test_utils.py +++ b/datumaro/util/test_utils.py @@ -50,14 +50,15 @@ class TestDir(FileRemover): ... """ - def __init__(self, path=None): + def __init__(self, path=None, frame_id=2): super().__init__(path, is_dir=True) + self._frame_id = frame_id def __enter__(self): path = self.path if path is None: - path = osp.abspath('temp_%s-' % current_function_name(2)) + path = f'temp_{current_function_name(self._frame_id)}-' path = tempfile.mkdtemp(dir=os.getcwd(), prefix=path) self.path = path else: diff --git a/tests/cli/test_revpath.py b/tests/cli/test_revpath.py index 2a0eafaf6c..2a0c17d713 100644 --- a/tests/cli/test_revpath.py +++ b/tests/cli/test_revpath.py @@ -10,96 +10,121 @@ ) from datumaro.components.extractor import DatasetItem from datumaro.components.project import Project +from datumaro.util.scope import scoped from datumaro.util.test_utils import TestDir +import datumaro.util.scope as scope from ..requirements import Requirements, mark_requirement class TestRevpath(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_parse(self): - with TestDir() as test_dir: - dataset_url = osp.join(test_dir, 'source') - dataset = Dataset.from_iterable([DatasetItem(1)]) - dataset.save(dataset_url) - - proj_dir = osp.join(test_dir, 'proj') - proj = Project.init(proj_dir) - proj.import_source('source-1', dataset_url, format=DEFAULT_FORMAT) - ref = proj.commit("second commit", allow_empty=True) - - - with self.subTest("project"): - self.assertTrue(isinstance(parse_full_revpath(proj_dir, None), - IDataset)) - - with self.subTest("project ref"): - self.assertTrue(isinstance( - parse_full_revpath(proj_dir + "@" + ref, None), - IDataset)) - - with self.subTest("project ref source"): - self.assertTrue(isinstance( - parse_full_revpath( - proj_dir + "@" + ref + ":source-1", None), - IDataset)) - - with self.subTest("project ref source stage"): - self.assertTrue(isinstance( - parse_full_revpath( - proj_dir + "@" + ref + ":source-1.root", None), - IDataset)) - - with self.subTest("ref"): - self.assertTrue(isinstance( - parse_full_revpath(ref, proj), - IDataset)) - - with self.subTest("ref source"): - self.assertTrue(isinstance( - parse_full_revpath(ref + ":source-1", proj), - IDataset)) - - with self.subTest("ref source stage"): - self.assertTrue(isinstance( - parse_full_revpath(ref + ":source-1.root", proj), - IDataset)) - - with self.subTest("source"): - self.assertTrue(isinstance( - parse_full_revpath("source-1", proj), - IDataset)) - - with self.subTest("source stage"): - self.assertTrue(isinstance( - parse_full_revpath("source-1.root", proj), - IDataset)) - - with self.subTest("dataset (in context)"): - with self.assertRaises(WrongRevpathError) as cm: - parse_full_revpath(dataset_url, proj) - self.assertEqual( - {UnknownTargetError, MultipleFormatsMatchError}, - set(type(e) for e in cm.exception.problems) - ) - - with self.subTest("dataset format (in context)"): - self.assertTrue(isinstance( - parse_full_revpath(dataset_url + ":datumaro", proj), - IDataset)) - - with self.subTest("dataset (no context)"): - with self.assertRaises(WrongRevpathError) as cm: - parse_full_revpath(dataset_url, None) - self.assertEqual( - {ProjectNotFoundError, MultipleFormatsMatchError}, - set(type(e) for e in cm.exception.problems) - ) - - with self.subTest("dataset format (no context)"): - self.assertTrue(isinstance( - parse_full_revpath(dataset_url + ":datumaro", None), - IDataset)) + test_dir = scope.add(TestDir(frame_id=5)) + + dataset_url = osp.join(test_dir, 'source') + Dataset.from_iterable([DatasetItem(1)]).save(dataset_url) + + proj_dir = osp.join(test_dir, 'proj') + proj = scope.add(Project.init(proj_dir)) + proj.import_source('source-1', dataset_url, format=DEFAULT_FORMAT) + ref = proj.commit("second commit", allow_empty=True) + + with self.subTest("project"): + dataset, project = parse_full_revpath(proj_dir) + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("project ref"): + dataset, project = parse_full_revpath(f"{proj_dir}@{ref}") + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("project ref source"): + dataset, project = parse_full_revpath(f"{proj_dir}@{ref}:source-1") + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("project ref source stage"): + dataset, project = parse_full_revpath( + f"{proj_dir}@{ref}:source-1.root") + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("ref"): + dataset, project = parse_full_revpath(ref, proj) + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("ref source"): + dataset, project = parse_full_revpath(f"{ref}:source-1", proj) + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("ref source stage"): + dataset, project = parse_full_revpath(f"{ref}:source-1.root", proj) + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("source"): + dataset, project = parse_full_revpath("source-1", proj) + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("source stage"): + dataset, project = parse_full_revpath("source-1.root", proj) + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("dataset (in context)"): + with self.assertRaises(WrongRevpathError) as cm: + parse_full_revpath(dataset_url, proj) + self.assertEqual( + {UnknownTargetError, MultipleFormatsMatchError}, + set(type(e) for e in cm.exception.problems) + ) + + with self.subTest("dataset format (in context)"): + dataset, project = parse_full_revpath( + f"{dataset_url}:datumaro", proj) + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("dataset (no context)"): + with self.assertRaises(WrongRevpathError) as cm: + parse_full_revpath(dataset_url) + self.assertEqual( + {ProjectNotFoundError, MultipleFormatsMatchError}, + set(type(e) for e in cm.exception.problems) + ) + + with self.subTest("dataset format (no context)"): + dataset, project = parse_full_revpath(f"{dataset_url}:datumaro") + if project: + scope.add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_split_local_revpath(self): From 7f2ca570f78bf4fbc74b433798341c790ade5484 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 6 Sep 2021 12:17:46 +0300 Subject: [PATCH 09/19] Replace Rollback with Scope (#444) * Replace rollback with scope * Add scope_add() function * Update tests * Replace rollback uses * Update changelog --- CHANGELOG.md | 2 + datumaro/cli/contexts/model.py | 4 +- datumaro/cli/contexts/project/__init__.py | 4 +- datumaro/components/converter.py | 4 +- datumaro/components/dataset.py | 5 +- datumaro/util/__init__.py | 102 +------------ datumaro/util/scope.py | 151 +++++++++++++++++++ tests/test_util.py | 173 +++++++++++----------- 8 files changed, 251 insertions(+), 194 deletions(-) create mode 100644 datumaro/util/scope.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fa0576db4..660cbf3d8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Annotation-related classes were moved into a new module, `datumaro.components.annotation` () +- Rollback utilities replaced with Scope utilities + () ### Deprecated - TBD diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index 378af1ff76..178b389deb 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -9,7 +9,7 @@ import shutil from datumaro.components.project import Environment -from datumaro.util import error_rollback, on_error_do +from datumaro.util.scope import on_error_do, scoped from ..util import CliException, MultilineFormatter, add_subparser from ..util.project import ( @@ -46,7 +46,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def add_command(args): project = load_project(args.project_dir) diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index c12ddaf132..18d68c5a7a 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -21,8 +21,8 @@ from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project from datumaro.components.validator import TaskType -from datumaro.util import error_rollback, on_error_do from datumaro.util.os_util import make_file_name +from datumaro.util.scope import on_error_do, scoped from ...util import CliException, MultilineFormatter, add_subparser from ...util.project import generate_next_file_name, load_project @@ -526,7 +526,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): return parser -@error_rollback +@scoped def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index ec8f9c2de5..21c0410cad 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -12,8 +12,8 @@ from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import DatasetPatch from datumaro.components.extractor import DatasetItem -from datumaro.util import error_rollback, on_error_do from datumaro.util.image import Image +from datumaro.util.scope import on_error_do, scoped class Converter(CliPlugin): @@ -36,7 +36,7 @@ def convert(cls, extractor, save_dir, **options): return converter.apply() @classmethod - @error_rollback + @scoped def patch(cls, dataset, patch, save_dir, **options): # This solution is not any better in performance than just # writing a dataset, but in case of patching (i.e. writing diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index ecd5db46c0..06e52dd8eb 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -27,8 +27,9 @@ DEFAULT_SUBSET_NAME, CategoriesInfo, DatasetItem, Extractor, IExtractor, ItemTransform, Transform, ) -from datumaro.util import error_rollback, is_method_redefined, on_error_do +from datumaro.util import is_method_redefined from datumaro.util.log_utils import logging_disabled +from datumaro.util.scope import on_error_do, scoped DEFAULT_FORMAT = 'datumaro' @@ -790,7 +791,7 @@ def bind(self, path: str, format: Optional[str] = None, *, def flush_changes(self): self._data.flush_changes() - @error_rollback + @scoped def export(self, save_dir: str, format, **kwargs): inplace = (save_dir == self._source_path and format == self._format) diff --git a/datumaro/util/__init__.py b/datumaro/util/__init__.py index 7c8a52e7b1..6a2328b7e1 100644 --- a/datumaro/util/__init__.py +++ b/datumaro/util/__init__.py @@ -2,15 +2,11 @@ # # SPDX-License-Identifier: MIT -from contextlib import ExitStack, contextmanager -from functools import partial, wraps +from functools import wraps from inspect import isclass from itertools import islice from typing import Iterable, Tuple import distutils.util -import threading - -import attr NOTSET = object() @@ -126,99 +122,3 @@ def real_decorator(decoratee): return real_decorator return wrapped_decorator - -class Rollback: - _thread_locals = threading.local() - - @attr.attrs - class Handler: - callback = attr.attrib() - enabled = attr.attrib(default=True) - ignore_errors = attr.attrib(default=False) - - def __call__(self): - if self.enabled: - try: - self.callback() - except: # pylint: disable=bare-except - if not self.ignore_errors: - raise - - def __init__(self): - self._handlers = {} - self._stack = ExitStack() - self.enabled = True - - def add(self, callback, *args, - name=None, enabled=True, ignore_errors=False, - fwd_kwargs=None, **kwargs): - if args or kwargs or fwd_kwargs: - if fwd_kwargs: - kwargs.update(fwd_kwargs) - callback = partial(callback, *args, **kwargs) - name = name or hash(callback) - assert name not in self._handlers - handler = self.Handler(callback, - enabled=enabled, ignore_errors=ignore_errors) - self._handlers[name] = handler - self._stack.callback(handler) - return name - - do = add # readability alias - - def enable(self, name=None): - if name: - self._handlers[name].enabled = True - else: - self.enabled = True - - def disable(self, name=None): - if name: - self._handlers[name].enabled = False - else: - self.enabled = False - - def clean(self): - self.__exit__(None, None, None) - - def __enter__(self): - return self - - def __exit__(self, type=None, value=None, \ - traceback=None): # pylint: disable=redefined-builtin - if type is None: - return - if not self.enabled: - return - self._stack.__exit__(type, value, traceback) - - @classmethod - def current(cls) -> "Rollback": - return cls._thread_locals.current - - @contextmanager - def as_current(self): - previous = getattr(self._thread_locals, 'current', None) - self._thread_locals.current = self - try: - yield - finally: - self._thread_locals.current = previous - -# shorthand for common cases -def on_error_do(callback, *args, ignore_errors=False): - Rollback.current().do(callback, *args, ignore_errors=ignore_errors) - -@optional_arg_decorator -def error_rollback(func, arg_name=None): - @wraps(func) - def wrapped_func(*args, **kwargs): - with Rollback() as manager: - if arg_name is None: - with manager.as_current(): - ret_val = func(*args, **kwargs) - else: - kwargs[arg_name] = manager - ret_val = func(*args, **kwargs) - return ret_val - return wrapped_func diff --git a/datumaro/util/scope.py b/datumaro/util/scope.py new file mode 100644 index 0000000000..a857882ca5 --- /dev/null +++ b/datumaro/util/scope.py @@ -0,0 +1,151 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from contextlib import ExitStack, contextmanager +from functools import partial, wraps +from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, TypeVar +import threading + +from attr import attrs + +from datumaro.util import optional_arg_decorator + +T = TypeVar('T') + +class Scope: + """ + A context manager that allows to register error and exit callbacks. + """ + + _thread_locals = threading.local() + + @attrs(auto_attribs=True) + class ExitHandler: + callback: Callable[[], Any] + ignore_errors: bool = True + + def __exit__(self, exc_type, exc_value, exc_traceback): + try: + self.callback() + except Exception: + if not self.ignore_errors: + raise + + @attrs + class ErrorHandler(ExitHandler): + def __exit__(self, exc_type, exc_value, exc_traceback): + if exc_type: + return super().__exit__(exc_type=exc_type, exc_value=exc_value, + exc_traceback=exc_traceback) + + + def __init__(self): + self._stack = ExitStack() + self.enabled = True + + def on_error_do(self, callback: Callable, + *args, kwargs: Optional[Dict[str, Any]] = None, + ignore_errors: bool = False): + """ + Registers a function to be called on scope exit because of an error. + + If ignore_errors is True, the errors from this function call + will be ignored. + """ + + self._register_callback(self.ErrorHandler, + ignore_errors=ignore_errors, + callback=callback, args=args, kwargs=kwargs) + + def on_exit_do(self, callback: Callable, + *args, kwargs: Optional[Dict[str, Any]] = None, + ignore_errors: bool = False): + """ + Registers a function to be called on scope exit. + """ + + self._register_callback(self.ExitHandler, + ignore_errors=ignore_errors, + callback=callback, args=args, kwargs=kwargs) + + def _register_callback(self, handler_type, callback: Callable, + args: Tuple[Any] = None, kwargs: Dict[str, Any] = None, + ignore_errors: bool = False): + if args or kwargs: + callback = partial(callback, *args, **(kwargs or {})) + + self._stack.push(handler_type(callback, ignore_errors=ignore_errors)) + + def add(self, cm: ContextManager[T]) -> T: + """ + Enters a context manager and adds it to the exit stack. + + Returns: cm.__enter__() result + """ + + return self._stack.enter_context(cm) + + def enable(self): + self.enabled = True + + def disable(self): + self.enabled = False + + def close(self): + self.__exit__(None, None, None) + + def __enter__(self) -> 'Scope': + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + if not self.enabled: + return + + self._stack.__exit__(exc_type, exc_value, exc_traceback) + self._stack.pop_all() # prevent issues on repetitive calls + + @classmethod + def current(cls) -> 'Scope': + return cls._thread_locals.current + + @contextmanager + def as_current(self): + previous = getattr(self._thread_locals, 'current', None) + self._thread_locals.current = self + try: + yield + finally: + self._thread_locals.current = previous + +@optional_arg_decorator +def scoped(func, arg_name=None): + """ + A function decorator, which allows to do actions with the current scope, + such as registering error and exit callbacks and context managers. + """ + + @wraps(func) + def wrapped_func(*args, **kwargs): + with Scope() as scope: + if arg_name is None: + with scope.as_current(): + ret_val = func(*args, **kwargs) + else: + kwargs[arg_name] = scope + ret_val = func(*args, **kwargs) + return ret_val + + return wrapped_func + +# Shorthands for common cases +def on_error_do(callback, *args, ignore_errors=False, kwargs=None): + return Scope.current().on_error_do(callback, *args, + ignore_errors=ignore_errors, kwargs=kwargs) + +def on_exit_do(callback, *args, ignore_errors=False, kwargs=None): + return Scope.current().on_exit_do(callback, *args, + ignore_errors=ignore_errors, kwargs=kwargs) + +def scope_add(cm: ContextManager[T]) -> T: + return Scope.current().add(cm) diff --git a/tests/test_util.py b/tests/test_util.py index f124e7cdb2..66a791df76 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,127 +1,130 @@ +from contextlib import suppress from unittest import TestCase, mock import os import os.path as osp -from datumaro.util import ( - Rollback, error_rollback, is_method_redefined, on_error_do, -) +from datumaro.util import is_method_redefined from datumaro.util.os_util import walk +from datumaro.util.scope import Scope, on_error_do, on_exit_do, scoped from datumaro.util.test_utils import TestDir from .requirements import Requirements, mark_requirement -class TestRollback(TestCase): +class TestException(Exception): + pass + +class ScopeTest(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + def test_calls_only_exit_callback_on_exit(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - with Rollback() as on_error: - on_error.do(cb) + with Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True - - try: - with Rollback() as on_error: - on_error.do(cb) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + def test_calls_both_callbacks_on_error(self): + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() + + with self.assertRaises(TestException), Scope() as scope: + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) + raise TestException() + + error_cb.assert_called_once() + exit_cb.assert_called_once() + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_adds_cm(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock(return_value=42) + cm.__exit__ = mock.MagicMock() + + with Scope() as scope: + retval = scope.add(cm) + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() + self.assertEqual(42, retval) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_calls_cm_on_error(self): + cm = mock.Mock() + cm.__enter__ = mock.MagicMock() + cm.__exit__ = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.add(cm) + raise TestException() + + cm.__enter__.assert_called_once() + cm.__exit__.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_calls_on_error(self): - success = False - def cb(): - nonlocal success - success = True + cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) - raise Exception('err') + @scoped('scope') + def foo(scope=None): + scope.on_error_do(cb) + raise TestException() - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_does_not_call_on_no_error(self): - success = True - def cb(): - nonlocal success - success = False + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback('on_error') - def foo(on_error=None): - on_error.do(cb) + @scoped('scope') + def foo(scope=None): + scope.on_error_do(error_cb) + scope.on_exit_do(exit_cb) foo() - self.assertTrue(success) + error_cb.assert_not_called() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_supports_implicit_form(self): - success = False - def cb(): - nonlocal success - success = True + error_cb = mock.MagicMock() + exit_cb = mock.MagicMock() - @error_rollback + @scoped def foo(): - on_error_do(cb) - raise Exception('err') + on_error_do(error_cb) + on_exit_do(exit_cb) + raise TestException() - try: + with suppress(TestException): foo() - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success) + + error_cb.assert_called_once() + exit_cb.assert_called_once() @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_fowrard_args(self): - success1 = False - def cb1(a1, a2=None, ignore_errors=None): - nonlocal success1 - if a1 == 5 and a2 == 2 and ignore_errors == None: - success1 = True - - success2 = False - def cb2(a1, a2=None, ignore_errors=None): - nonlocal success2 - if a1 == 5 and a2 == 2 and ignore_errors == 4: - success2 = True - - try: - with Rollback() as on_error: - on_error.do(cb1, 5, a2=2, ignore_errors=True) - on_error.do(cb2, 5, a2=2, ignore_errors=True, - fwd_kwargs={'ignore_errors': 4}) - raise Exception('err') - except Exception: # nosec - disable B110:try_except_pass check - pass - finally: - self.assertTrue(success1) - self.assertTrue(success2) + cb = mock.MagicMock() + + with suppress(TestException), Scope() as scope: + scope.on_error_do(cb, 5, ignore_errors=True, kwargs={'a2': 2}) + raise TestException() + + cb.assert_called_once_with(5, a2=2) @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_implicit_form(self): - @error_rollback + @scoped def f(): return 42 @@ -131,8 +134,8 @@ def f(): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_decorator_can_return_on_success_in_explicit_form(self): - @error_rollback('on_error') - def f(on_error=None): + @scoped('scope') + def f(scope=None): return 42 retval = f() From ca1c60db964da611ea9cbd40230685574abc8527 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 6 Sep 2021 12:49:51 +0300 Subject: [PATCH 10/19] Update imports --- datumaro/cli/commands/checkout.py | 5 ++--- datumaro/cli/commands/commit.py | 5 ++--- datumaro/cli/commands/diff.py | 11 +++++------ datumaro/cli/commands/explain.py | 7 +++---- datumaro/cli/commands/info.py | 7 +++---- datumaro/cli/commands/log.py | 5 ++--- datumaro/cli/commands/merge.py | 7 +++---- datumaro/cli/commands/status.py | 5 ++--- datumaro/cli/contexts/source.py | 9 ++++----- tests/cli/test_revpath.py | 29 ++++++++++++++--------------- 10 files changed, 40 insertions(+), 50 deletions(-) diff --git a/datumaro/cli/commands/checkout.py b/datumaro/cli/commands/checkout.py index e8323cbf1c..ce84d8b6d4 100644 --- a/datumaro/cli/commands/checkout.py +++ b/datumaro/cli/commands/checkout.py @@ -4,8 +4,7 @@ import argparse -from datumaro.util.scope import scoped -import datumaro.util.scope as scope +from datumaro.util.scope import scope_add, scoped from ..util import MultilineFormatter from ..util.project import load_project @@ -64,7 +63,7 @@ def checkout_command(args): raise argparse.ArgumentError('sources', message="When '--' is used, " "at least 1 source name must be specified") - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) project.checkout(rev=args.rev, sources=args.sources, force=args.force) diff --git a/datumaro/cli/commands/commit.py b/datumaro/cli/commands/commit.py index 67978ce23c..477a1d283a 100644 --- a/datumaro/cli/commands/commit.py +++ b/datumaro/cli/commands/commit.py @@ -4,8 +4,7 @@ import argparse -from datumaro.util.scope import scoped -import datumaro.util.scope as scope +from datumaro.util.scope import scope_add, scoped from ..util import MultilineFormatter from ..util.project import load_project @@ -37,7 +36,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): @scoped def commit_command(args): - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) old_tree = project.head diff --git a/datumaro/cli/commands/diff.py b/datumaro/cli/commands/diff.py index 43839697fe..6f3b3443bd 100644 --- a/datumaro/cli/commands/diff.py +++ b/datumaro/cli/commands/diff.py @@ -12,8 +12,7 @@ from datumaro.components.errors import ProjectNotFoundError from datumaro.components.operations import DistanceComparator, ExactComparator from datumaro.util.os_util import rmtree -from datumaro.util.scope import on_error_do, scoped -import datumaro.util.scope as scope +from datumaro.util.scope import on_error_do, scope_add, scoped from ..contexts.project.diff import DiffVisualizer from ..util import MultilineFormatter @@ -150,7 +149,7 @@ def diff_command(args): project = None try: - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) except ProjectNotFoundError: if args.project_dir: raise @@ -161,17 +160,17 @@ def diff_command(args): second_dataset, target_project = \ parse_full_revpath(args.first_target, project) if target_project: - scope.add(target_project) + scope_add(target_project) else: first_dataset, target_project = \ parse_full_revpath(args.first_target, project) if target_project: - scope.add(target_project) + scope_add(target_project) second_dataset, target_project = \ parse_full_revpath(args.second_target, project) if target_project: - scope.add(target_project) + scope_add(target_project) except Exception as e: raise CliException(str(e)) diff --git a/datumaro/cli/commands/explain.py b/datumaro/cli/commands/explain.py index 45a925a505..b3fe2fda1c 100644 --- a/datumaro/cli/commands/explain.py +++ b/datumaro/cli/commands/explain.py @@ -8,8 +8,7 @@ import os.path as osp from datumaro.util.image import is_image, load_image, save_image -from datumaro.util.scope import scoped -import datumaro.util.scope as scope +from datumaro.util.scope import scope_add, scoped from ..util import MultilineFormatter from ..util.project import load_project, parse_full_revpath @@ -117,7 +116,7 @@ def explain_command(args): from matplotlib import cm import cv2 - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) model = project.working_tree.models.make_executable_model(args.model) @@ -174,7 +173,7 @@ def explain_command(args): dataset, target_project = \ parse_full_revpath(args.target or 'project', project) if target_project: - scope.add(target_project) + scope_add(target_project) log.info("Running inference explanation for '%s'" % args.target) diff --git a/datumaro/cli/commands/info.py b/datumaro/cli/commands/info.py index 56f2537102..1007f4f4e2 100644 --- a/datumaro/cli/commands/info.py +++ b/datumaro/cli/commands/info.py @@ -8,8 +8,7 @@ DatasetMergeError, MissingObjectError, ProjectNotFoundError, ) from datumaro.components.extractor import AnnotationType -from datumaro.util.scope import scoped -import datumaro.util.scope as scope +from datumaro.util.scope import scope_add, scoped from ..util import MultilineFormatter from ..util.project import load_project, parse_full_revpath @@ -61,7 +60,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): def info_command(args): project = None try: - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) except ProjectNotFoundError: if args.project_dir: raise @@ -70,7 +69,7 @@ def info_command(args): # TODO: avoid computing working tree hashes dataset, target_project = parse_full_revpath(args.target, project) if target_project: - scope.add(target_project) + scope_add(target_project) except DatasetMergeError as e: dataset = None dataset_problem = "Can't merge project sources automatically: %s " \ diff --git a/datumaro/cli/commands/log.py b/datumaro/cli/commands/log.py index 31afbe0ed8..75452a325b 100644 --- a/datumaro/cli/commands/log.py +++ b/datumaro/cli/commands/log.py @@ -4,8 +4,7 @@ import argparse -from datumaro.util.scope import scoped -import datumaro.util.scope as scope +from datumaro.util.scope import scope_add, scoped from ..util.project import load_project @@ -23,7 +22,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): @scoped def log_command(args): - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) revisions = project.history(args.max_count) if revisions: diff --git a/datumaro/cli/commands/merge.py b/datumaro/cli/commands/merge.py index b34670dd9b..565bbfbf9d 100644 --- a/datumaro/cli/commands/merge.py +++ b/datumaro/cli/commands/merge.py @@ -14,8 +14,7 @@ DatasetMergeError, DatasetQualityError, ProjectNotFoundError, ) from datumaro.components.operations import IntersectMerge -from datumaro.util.scope import scoped -import datumaro.util.scope as scope +from datumaro.util.scope import scope_add, scoped from ..util import MultilineFormatter from ..util.errors import CliException @@ -117,7 +116,7 @@ def merge_command(args): project = None try: - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) except ProjectNotFoundError: if args.project_dir: raise @@ -130,7 +129,7 @@ def merge_command(args): for t in args.targets: target_dataset, target_project = parse_full_revpath(t, project) if target_project: - scope.add(target_project) + scope_add(target_project) source_datasets.append(target_dataset) except Exception as e: raise CliException(str(e)) diff --git a/datumaro/cli/commands/status.py b/datumaro/cli/commands/status.py index a31e352bfe..5900b7a29a 100644 --- a/datumaro/cli/commands/status.py +++ b/datumaro/cli/commands/status.py @@ -5,8 +5,7 @@ import argparse from datumaro.cli.util import MultilineFormatter -from datumaro.util.scope import scoped -import datumaro.util.scope as scope +from datumaro.util.scope import scope_add, scoped from ..util.project import load_project @@ -28,7 +27,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): @scoped def status_command(args): - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) statuses = project.status() diff --git a/datumaro/cli/contexts/source.py b/datumaro/cli/contexts/source.py index a26b4210c8..17972b16b2 100644 --- a/datumaro/cli/contexts/source.py +++ b/datumaro/cli/contexts/source.py @@ -8,8 +8,7 @@ from datumaro.components.errors import ProjectNotFoundError from datumaro.components.project import Environment -from datumaro.util.scope import on_error_do, scoped -import datumaro.util.scope as scope +from datumaro.util.scope import on_error_do, scope_add, scoped from ..util import MultilineFormatter, add_subparser, join_cli_args from ..util.errors import CliException @@ -98,7 +97,7 @@ def add_command(args): project = None try: - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) except ProjectNotFoundError: if not show_plugin_help and args.project_dir: raise @@ -163,7 +162,7 @@ def build_remove_parser(parser_ctor=argparse.ArgumentParser): @scoped def remove_command(args): - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) if not args.names: raise CliException("Expected source name") @@ -192,7 +191,7 @@ def build_info_parser(parser_ctor=argparse.ArgumentParser): @scoped def info_command(args): - project = scope.add(load_project(args.project_dir)) + project = scope_add(load_project(args.project_dir)) if args.name: source = project.working_tree.sources[args.name] diff --git a/tests/cli/test_revpath.py b/tests/cli/test_revpath.py index 2a0c17d713..9828d0a83d 100644 --- a/tests/cli/test_revpath.py +++ b/tests/cli/test_revpath.py @@ -10,9 +10,8 @@ ) from datumaro.components.extractor import DatasetItem from datumaro.components.project import Project -from datumaro.util.scope import scoped +from datumaro.util.scope import scope_add, scoped from datumaro.util.test_utils import TestDir -import datumaro.util.scope as scope from ..requirements import Requirements, mark_requirement @@ -21,34 +20,34 @@ class TestRevpath(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped def test_can_parse(self): - test_dir = scope.add(TestDir(frame_id=5)) + test_dir = scope_add(TestDir(frame_id=5)) dataset_url = osp.join(test_dir, 'source') Dataset.from_iterable([DatasetItem(1)]).save(dataset_url) proj_dir = osp.join(test_dir, 'proj') - proj = scope.add(Project.init(proj_dir)) + proj = scope_add(Project.init(proj_dir)) proj.import_source('source-1', dataset_url, format=DEFAULT_FORMAT) ref = proj.commit("second commit", allow_empty=True) with self.subTest("project"): dataset, project = parse_full_revpath(proj_dir) if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertTrue(isinstance(project, Project)) with self.subTest("project ref"): dataset, project = parse_full_revpath(f"{proj_dir}@{ref}") if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertTrue(isinstance(project, Project)) with self.subTest("project ref source"): dataset, project = parse_full_revpath(f"{proj_dir}@{ref}:source-1") if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertTrue(isinstance(project, Project)) @@ -56,42 +55,42 @@ def test_can_parse(self): dataset, project = parse_full_revpath( f"{proj_dir}@{ref}:source-1.root") if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertTrue(isinstance(project, Project)) with self.subTest("ref"): dataset, project = parse_full_revpath(ref, proj) if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertEqual(None, project) with self.subTest("ref source"): dataset, project = parse_full_revpath(f"{ref}:source-1", proj) if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertEqual(None, project) with self.subTest("ref source stage"): dataset, project = parse_full_revpath(f"{ref}:source-1.root", proj) if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertEqual(None, project) with self.subTest("source"): dataset, project = parse_full_revpath("source-1", proj) if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertEqual(None, project) with self.subTest("source stage"): dataset, project = parse_full_revpath("source-1.root", proj) if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertEqual(None, project) @@ -107,7 +106,7 @@ def test_can_parse(self): dataset, project = parse_full_revpath( f"{dataset_url}:datumaro", proj) if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertEqual(None, project) @@ -122,7 +121,7 @@ def test_can_parse(self): with self.subTest("dataset format (no context)"): dataset, project = parse_full_revpath(f"{dataset_url}:datumaro") if project: - scope.add(project) + scope_add(project) self.assertTrue(isinstance(dataset, IDataset)) self.assertEqual(None, project) From a434eb857100f58b8e92129ce6c7bd5c7bbe169e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 6 Sep 2021 13:07:53 +0300 Subject: [PATCH 11/19] Update on_error_do call --- datumaro/cli/contexts/source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datumaro/cli/contexts/source.py b/datumaro/cli/contexts/source.py index 17972b16b2..d9a51ed1b7 100644 --- a/datumaro/cli/contexts/source.py +++ b/datumaro/cli/contexts/source.py @@ -129,8 +129,8 @@ def add_command(args): project.import_source(name, url=args.url, format=args.format, options=extra_args, no_cache=args.no_cache, rpath=args.path) - on_error_do(project.remove_source, name, force=True, keep_data=False, - ignore_errors=True) + on_error_do(project.remove_source, name, ignore_errors=True, + kwargs={'force': True, 'keep_data': False}) if not args.no_check: log.info("Checking the source...") From 559e09622f8c708c9860a083dcb66ae02d330153 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 6 Sep 2021 13:28:32 +0300 Subject: [PATCH 12/19] Update TestDir class --- datumaro/util/test_utils.py | 26 ++++++++++++++++++-------- tests/cli/test_revpath.py | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/datumaro/util/test_utils.py b/datumaro/util/test_utils.py index 869f8964b0..5933db2b70 100644 --- a/datumaro/util/test_utils.py +++ b/datumaro/util/test_utils.py @@ -4,7 +4,7 @@ from enum import Enum, auto from glob import glob -from typing import Collection, Union +from typing import Collection, Optional, Union import inspect import os import os.path as osp @@ -46,20 +46,30 @@ class TestDir(FileRemover): Usage: - with TestDir() as test_dir: - ... + with TestDir() as test_dir: + ... """ - def __init__(self, path=None, frame_id=2): + def __init__(self, path: Optional[str] = None, frame_id: int = 2): + if not path: + prefix = f'temp_{current_function_name(frame_id)}-' + else: + prefix = None + self._prefix = prefix + super().__init__(path, is_dir=True) - self._frame_id = frame_id - def __enter__(self): + def __enter__(self) -> str: + """ + Creates a test directory. + + Returns: path to the directory + """ + path = self.path if path is None: - path = f'temp_{current_function_name(self._frame_id)}-' - path = tempfile.mkdtemp(dir=os.getcwd(), prefix=path) + path = tempfile.mkdtemp(dir=os.getcwd(), prefix=self._prefix) self.path = path else: os.makedirs(path, exist_ok=False) diff --git a/tests/cli/test_revpath.py b/tests/cli/test_revpath.py index 9828d0a83d..fd5541367d 100644 --- a/tests/cli/test_revpath.py +++ b/tests/cli/test_revpath.py @@ -20,7 +20,7 @@ class TestRevpath(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped def test_can_parse(self): - test_dir = scope_add(TestDir(frame_id=5)) + test_dir = scope_add(TestDir()) dataset_url = osp.join(test_dir, 'source') Dataset.from_iterable([DatasetItem(1)]).save(dataset_url) From 8323e2e5fb76f95add3a01daa06a9c82c4da38d6 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 6 Sep 2021 17:30:46 +0300 Subject: [PATCH 13/19] Fix windows file addition in git --- datumaro/components/project.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/datumaro/components/project.py b/datumaro/components/project.py index 2185aa8839..1f1a8f1e9e 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -927,8 +927,9 @@ def add(self, paths, base=None): repo_root = osp.abspath(self._project_dir) assert is_subpath(base, base=repo_root), \ "Base path should be inside of the repo" - base = base[len(repo_root) + len(osp.sep) : ] - path_rewriter = lambda entry: osp.relpath(entry.path, base) + base = osp.relpath(base, repo_root) + path_rewriter = lambda entry: osp.relpath(entry.path, base) \ + .replace('\\', '/') if isinstance(paths, str): paths = [paths] @@ -1006,10 +1007,12 @@ def status(self, paths: Union[str, GitTree, Iterable[str]] = None, elif file_exists and not index_entry: status = 'A' elif file_exists and index_entry: - status = self.repo.git.diff('--name-status', + # '--ignore-cr-at-eol' doesn't affect '--name-status' + # so we can't really obtain 'T' + status = self.repo.git.diff('--ignore-cr-at-eol', index_entry.hexsha, file_path) if status: - status = status[0] + status = 'M' assert status in {'', 'M', 'T'}, status else: status = '' # ignore missing paths From 5ea8e1a376c6b50f21997d641f33a6aa99fbc1ff Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 6 Sep 2021 17:31:40 +0300 Subject: [PATCH 14/19] Apply scoped to project tests --- tests/test_project.py | 1285 +++++++++++++++++++++-------------------- 1 file changed, 665 insertions(+), 620 deletions(-) diff --git a/tests/test_project.py b/tests/test_project.py index d66ae5c79b..95806e35c5 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -17,6 +17,7 @@ from datumaro.components.extractor import DatasetItem, Extractor, ItemTransform from datumaro.components.launcher import Launcher from datumaro.components.project import DiffStatus, Project +from datumaro.util.scope import scoped, scope_add from datumaro.util.test_utils import TestDir, compare_datasets, compare_dirs from .requirements import Requirements, mark_requirement @@ -24,48 +25,57 @@ class ProjectTest(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_init_and_load(self): - with TestDir() as test_dir: - Project.init(test_dir) + test_dir = scope_add(TestDir()) - Project(test_dir) + scope_add(Project.init(test_dir)).close() + scope_add(Project(test_dir)) + + self.assertTrue('.datumaro' in os.listdir(test_dir)) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_find_project_in_project_dir(self): - with TestDir() as test_dir: - Project.init(test_dir) + test_dir = scope_add(TestDir()) + + scope_add(Project.init(test_dir)) - self.assertEqual(osp.join(test_dir, '.datumaro'), - Project.find_project_dir(test_dir)) + self.assertEqual(osp.join(test_dir, '.datumaro'), + Project.find_project_dir(test_dir)) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_find_project_when_no_project(self): - with TestDir() as test_dir: - self.assertEqual(None, Project.find_project_dir(test_dir)) + test_dir = scope_add(TestDir()) + + self.assertEqual(None, Project.find_project_dir(test_dir)) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_add_local_model(self): - with TestDir() as test_dir: - class TestLauncher(Launcher): - pass + class TestLauncher(Launcher): + pass - source_name = 'source' - config = Model({ - 'launcher': 'test', - 'options': { 'a': 5, 'b': 'hello' } - }) + source_name = 'source' + config = Model({ + 'launcher': 'test', + 'options': { 'a': 5, 'b': 'hello' } + }) - project = Project.init(test_dir) - project.env.launchers.register('test', TestLauncher) + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) + project.env.launchers.register('test', TestLauncher) - project.add_model(source_name, - launcher=config.launcher, options=config.options) + project.add_model(source_name, + launcher=config.launcher, options=config.options) - added = project.models[source_name] - self.assertEqual(added.launcher, config.launcher) - self.assertEqual(added.options, config.options) + added = project.models[source_name] + self.assertEqual(added.launcher, config.launcher) + self.assertEqual(added.options, config.options) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_run_inference(self): class TestLauncher(Launcher): def launch(self, inputs): @@ -80,439 +90,459 @@ def launch(self, inputs): launcher_name = 'custom_launcher' model_name = 'model' - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, image=np.ones([2, 2, 3]) * 0), - DatasetItem(1, image=np.ones([2, 2, 3]) * 1), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones([2, 2, 3]) * 0), + DatasetItem(1, image=np.ones([2, 2, 3]) * 1), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - project = Project.init(osp.join(test_dir, 'proj')) - project.env.launchers.register(launcher_name, TestLauncher) - project.add_model(model_name, launcher=launcher_name) - project.import_source('source', source_url, format=DEFAULT_FORMAT) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.env.launchers.register(launcher_name, TestLauncher) + project.add_model(model_name, launcher=launcher_name) + project.import_source('source', source_url, format=DEFAULT_FORMAT) - dataset = project.working_tree.make_dataset() - model = project.make_model(model_name) + dataset = project.working_tree.make_dataset() + model = project.make_model(model_name) - inference = dataset.run_model(model) + inference = dataset.run_model(model) - compare_datasets(self, expected, inference) + compare_datasets(self, expected, inference) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_import_local_source(self): - with TestDir() as test_dir: - source_base_url = osp.join(test_dir, 'test_repo') - source_file_path = osp.join(source_base_url, 'x', 'y.txt') - os.makedirs(osp.dirname(source_file_path), exist_ok=True) - with open(source_file_path, 'w') as f: - f.write('hello') + test_dir = scope_add(TestDir()) + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'x', 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_base_url, format='fmt') + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_base_url, format='fmt') - source = project.working_tree.sources['s1'] - self.assertEqual('fmt', source.format) - compare_dirs(self, source_base_url, project.source_data_dir('s1')) - with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertTrue('s1' in [line.strip() for line in f]) + source = project.working_tree.sources['s1'] + self.assertEqual('fmt', source.format) + compare_dirs(self, source_base_url, project.source_data_dir('s1')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertTrue('s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_import_local_source_with_relpath(self): # This form must copy all the data in URL, but read only # specified files. Required to support subtasks and subsets. - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, subset='a', image=np.ones((2, 3, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=0) ]), - DatasetItem(1, subset='b', image=np.zeros((10, 20, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, subset='a', image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='b', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - expected_dataset = Dataset.from_iterable([ - DatasetItem(1, subset='b', image=np.zeros((10, 20, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) + expected_dataset = Dataset.from_iterable([ + DatasetItem(1, subset='b', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) - project = Project.init(osp.join(test_dir, 'proj')) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT, - rpath=osp.join('annotations', 'b.json')) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT, + rpath=osp.join('annotations', 'b.json')) - source = project.working_tree.sources['s1'] - self.assertEqual(DEFAULT_FORMAT, source.format) + source = project.working_tree.sources['s1'] + self.assertEqual(DEFAULT_FORMAT, source.format) - compare_dirs(self, source_url, project.source_data_dir('s1')) - read_dataset = project.working_tree.make_dataset('s1') - compare_datasets(self, expected_dataset, read_dataset, - require_images=True) + compare_dirs(self, source_url, project.source_data_dir('s1')) + read_dataset = project.working_tree.make_dataset('s1') + compare_datasets(self, expected_dataset, read_dataset, + require_images=True) - with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertTrue('s1' in [line.strip() for line in f]) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertTrue('s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_import_local_source_with_relpath_outside(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - os.makedirs(source_url) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + os.makedirs(source_url) - project = Project.init(osp.join(test_dir, 'proj')) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) - with self.assertRaises(PathOutsideSourceError): - project.import_source('s1', url=source_url, - format=DEFAULT_FORMAT, rpath='..') + with self.assertRaises(PathOutsideSourceError): + project.import_source('s1', url=source_url, + format=DEFAULT_FORMAT, rpath='..') @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_import_local_source_with_url_inside_project(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'qq') - with open(source_url, 'w') as f: - f.write('hello') + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'qq') + with open(source_url, 'w') as f: + f.write('hello') - project = Project.init(test_dir) + project = scope_add(Project.init(test_dir)) - with self.assertRaises(SourceUrlInsideProjectError): - project.import_source('s1', url=source_url, - format=DEFAULT_FORMAT) + with self.assertRaises(SourceUrlInsideProjectError): + project.import_source('s1', url=source_url, + format=DEFAULT_FORMAT) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_report_incompatible_sources(self): - with TestDir() as test_dir: - source1_url = osp.join(test_dir, 'dataset1') - dataset1 = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - ], categories=['a', 'b']) - dataset1.save(source1_url) + test_dir = scope_add(TestDir()) + source1_url = osp.join(test_dir, 'dataset1') + dataset1 = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset1.save(source1_url) - source2_url = osp.join(test_dir, 'dataset2') - dataset2 = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - ], categories=['c', 'd']) - dataset2.save(source2_url) + source2_url = osp.join(test_dir, 'dataset2') + dataset2 = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['c', 'd']) + dataset2.save(source2_url) - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source1_url, format=DEFAULT_FORMAT) - project.import_source('s2', url=source2_url, format=DEFAULT_FORMAT) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source1_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source2_url, format=DEFAULT_FORMAT) - with self.assertRaises(DatasetMergeError) as cm: - project.working_tree.make_dataset() + with self.assertRaises(DatasetMergeError) as cm: + project.working_tree.make_dataset() - self.assertEqual({'s1.root', 's2.root'}, cm.exception.sources) + self.assertEqual({'s1.root', 's2.root'}, cm.exception.sources) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_add_sources_with_same_names(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - - with self.assertRaises(SourceExistsError): - project.import_source('s1', url=source_url, - format=DEFAULT_FORMAT) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_import_generated_source(self): - with TestDir() as test_dir: - source_name = 'source' - origin = Source({ - # no url - 'format': 'fmt', - 'options': { 'c': 5, 'd': 'hello' } - }) - project = Project.init(test_dir) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - project.import_source(source_name, url='', - format=origin.format, options=origin.options) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - added = project.working_tree.sources[source_name] - self.assertEqual(added.format, origin.format) - self.assertEqual(added.options, origin.options) - with open(osp.join(test_dir, '.gitignore')) as f: - self.assertTrue(source_name in [line.strip() for line in f]) + with self.assertRaises(SourceExistsError): + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_import_generated_source(self): + test_dir = scope_add(TestDir()) + source_name = 'source' + origin = Source({ + # no url + 'format': 'fmt', + 'options': { 'c': 5, 'd': 'hello' } + }) + project = scope_add(Project.init(test_dir)) + + project.import_source(source_name, url='', + format=origin.format, options=origin.options) + + added = project.working_tree.sources[source_name] + self.assertEqual(added.format, origin.format) + self.assertEqual(added.options, origin.options) + with open(osp.join(test_dir, '.gitignore')) as f: + self.assertTrue(source_name in [line.strip() for line in f]) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_import_source_with_wrong_name(self): - with TestDir() as test_dir: - project = Project.init(test_dir) + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) - for name in {'dataset', 'project', 'build', '.any'}: - with self.subTest(name=name), \ - self.assertRaisesRegex(ValueError, "Source name"): - project.import_source(name, url='', format='fmt') + for name in {'dataset', 'project', 'build', '.any'}: + with self.subTest(name=name), \ + self.assertRaisesRegex(ValueError, "Source name"): + project.import_source(name, url='', format='fmt') @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_remove_source_and_keep_data(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_source.txt') - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, 'w') as f: - f.write('hello') + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.remove_source('s1', keep_data=True) + project.remove_source('s1', keep_data=True) - self.assertFalse('s1' in project.working_tree.sources) - compare_dirs(self, source_url, project.source_data_dir('s1')) - with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertFalse('s1' in [line.strip() for line in f]) + self.assertFalse('s1' in project.working_tree.sources) + compare_dirs(self, source_url, project.source_data_dir('s1')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertFalse('s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_remove_source_and_wipe_data(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_source.txt') - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, 'w') as f: - f.write('hello') + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.remove_source('s1', keep_data=False) + project.remove_source('s1', keep_data=False) - self.assertFalse('s1' in project.working_tree.sources) - self.assertFalse(osp.exists(project.source_data_dir('s1'))) - with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertFalse('s1' in [line.strip() for line in f]) + self.assertFalse('s1' in project.working_tree.sources) + self.assertFalse(osp.exists(project.source_data_dir('s1'))) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertFalse('s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_redownload_source_rev_noncached(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, image=np.ones((2, 3, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=0) ]), - DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - # remove local source data - project.remove_cache_obj( - project.working_tree.build_targets['s1'].head.hash) - shutil.rmtree(project.source_data_dir('s1')) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") + + # remove local source data + project.remove_cache_obj( + project.working_tree.build_targets['s1'].head.hash) + shutil.rmtree(project.source_data_dir('s1')) - read_dataset = project.working_tree.make_dataset('s1') + read_dataset = project.working_tree.make_dataset('s1') - compare_datasets(self, source_dataset, read_dataset) - compare_dirs(self, source_url, project.cache_path( - project.working_tree.build_targets['s1'].root.hash)) + compare_datasets(self, source_dataset, read_dataset) + compare_dirs(self, source_url, project.cache_path( + project.working_tree.build_targets['s1'].root.hash)) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_redownload_source_and_check_data_hash(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, image=np.ones((2, 3, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=0) ]), - DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - # remove local source data - project.remove_cache_obj( - project.working_tree.build_targets['s1'].head.hash) - shutil.rmtree(project.source_data_dir('s1')) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") + + # remove local source data + project.remove_cache_obj( + project.working_tree.build_targets['s1'].head.hash) + shutil.rmtree(project.source_data_dir('s1')) - # modify the source repo - with open(osp.join(source_url, 'extra_file.txt'), 'w') as f: - f.write('text\n') + # modify the source repo + with open(osp.join(source_url, 'extra_file.txt'), 'w') as f: + f.write('text\n') - with self.assertRaises(MismatchingObjectError): - project.working_tree.make_dataset('s1') + with self.assertRaises(MismatchingObjectError): + project.working_tree.make_dataset('s1') @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_use_source_from_cache_with_working_copy(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, image=np.ones((2, 3, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=0) ]), - DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") - shutil.rmtree(project.source_data_dir('s1')) + shutil.rmtree(project.source_data_dir('s1')) - read_dataset = project.working_tree.make_dataset('s1') + read_dataset = project.working_tree.make_dataset('s1') - compare_datasets(self, source_dataset, read_dataset) - self.assertFalse(osp.isdir(project.source_data_dir('s1'))) + compare_datasets(self, source_dataset, read_dataset) + self.assertFalse(osp.isdir(project.source_data_dir('s1'))) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_raises_an_error_if_local_data_unknown(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, image=np.ones((2, 3, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=0) ]), - DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - # remove the cached object so that it couldn't be matched - project.remove_cache_obj( - project.working_tree.build_targets['s1'].root.hash) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") + + # remove the cached object so that it couldn't be matched + project.remove_cache_obj( + project.working_tree.build_targets['s1'].root.hash) - # modify local source data - with open(osp.join(project.source_data_dir('s1'), 'extra.txt'), - 'w') as f: - f.write('text\n') + # modify local source data + with open(osp.join(project.source_data_dir('s1'), 'extra.txt'), + 'w') as f: + f.write('text\n') - with self.assertRaises(ForeignChangesError): - project.working_tree.make_dataset('s1') + with self.assertRaises(ForeignChangesError): + project.working_tree.make_dataset('s1') @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_read_working_copy_of_source(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, image=np.ones((2, 3, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=0) ]), - DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - read_dataset = project.working_tree.make_dataset('s1') + read_dataset = project.working_tree.make_dataset('s1') - compare_datasets(self, source_dataset, read_dataset) - compare_dirs(self, source_url, project.source_data_dir('s1')) + compare_datasets(self, source_dataset, read_dataset) + compare_dirs(self, source_url, project.source_data_dir('s1')) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_read_current_revision_of_source(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'source') - source_dataset = Dataset.from_iterable([ - DatasetItem(0, image=np.ones((2, 3, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=0) ]), - DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), - annotations=[ Bbox(1, 2, 3, 4, label=1) ]), - ], categories=['a', 'b']) - source_dataset.save(source_url, save_images=True) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") - shutil.rmtree(project.source_data_dir('s1')) + shutil.rmtree(project.source_data_dir('s1')) - read_dataset = project.head.make_dataset('s1') + read_dataset = project.head.make_dataset('s1') - compare_datasets(self, source_dataset, read_dataset) - self.assertFalse(osp.isdir(project.source_data_dir('s1'))) - compare_dirs(self, source_url, project.head.source_data_dir('s1')) + compare_datasets(self, source_dataset, read_dataset) + self.assertFalse(osp.isdir(project.source_data_dir('s1'))) + compare_dirs(self, source_url, project.head.source_data_dir('s1')) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_make_dataset_from_project(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - source_dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - ], categories=['a', 'b']) - source_dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + source_dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + source_dataset.save(source_url) - read_dataset = project.working_tree.make_dataset() + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - compare_datasets(self, source_dataset, read_dataset) + read_dataset = project.working_tree.make_dataset() + + compare_datasets(self, source_dataset, read_dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_make_dataset_from_source(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset.save(source_url) - built_dataset = project.working_tree.make_dataset('s1') + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - compare_datasets(self, dataset, built_dataset) - self.assertEqual(DEFAULT_FORMAT, built_dataset.format) - self.assertEqual(project.source_data_dir('s1'), - built_dataset.data_path) + built_dataset = project.working_tree.make_dataset('s1') + + compare_datasets(self, dataset, built_dataset) + self.assertEqual(DEFAULT_FORMAT, built_dataset.format) + self.assertEqual(project.source_data_dir('s1'), + built_dataset.data_path) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_add_filter_stage(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - stage = project.working_tree.build_targets.add_filter_stage('s1', - '/item/annotation[label="b"]' - ) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - self.assertTrue(stage in project.working_tree.build_targets) - resulting_dataset = project.working_tree.make_dataset('s1') - compare_datasets(self, Dataset.from_iterable([ - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']), resulting_dataset) + stage = project.working_tree.build_targets.add_filter_stage('s1', + '/item/annotation[label="b"]' + ) + + self.assertTrue(stage in project.working_tree.build_targets) + resulting_dataset = project.working_tree.make_dataset('s1') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']), resulting_dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_add_convert_stage(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - stage = project.working_tree.build_targets.add_convert_stage('s1', - DEFAULT_FORMAT) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - self.assertTrue(stage in project.working_tree.build_targets) + stage = project.working_tree.build_targets.add_convert_stage('s1', + DEFAULT_FORMAT) + + self.assertTrue(stage in project.working_tree.build_targets) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_add_transform_stage(self): class TestTransform(ItemTransform): def __init__(self, extractor, p1=None, p2=None): @@ -524,309 +554,322 @@ def transform_item(self, item): return self.wrap_item(item, attributes={'p1': self.p1, 'p2': self.p2}) - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.working_tree.env.transforms.register('tr', TestTransform) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.working_tree.env.transforms.register('tr', TestTransform) - stage = project.working_tree.build_targets.add_transform_stage('s1', - 'tr', params={'p1': 5, 'p2': ['1', 2, 3.5]} - ) + stage = project.working_tree.build_targets.add_transform_stage('s1', + 'tr', params={'p1': 5, 'p2': ['1', 2, 3.5]} + ) - self.assertTrue(stage in project.working_tree.build_targets) - resulting_dataset = project.working_tree.make_dataset('s1') - compare_datasets(self, Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)], - attributes={'p1': 5, 'p2': ['1', 2, 3.5]}), - DatasetItem(2, annotations=[Label(1)], - attributes={'p1': 5, 'p2': ['1', 2, 3.5]}), - ], categories=['a', 'b']), resulting_dataset) + self.assertTrue(stage in project.working_tree.build_targets) + resulting_dataset = project.working_tree.make_dataset('s1') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)], + attributes={'p1': 5, 'p2': ['1', 2, 3.5]}), + DatasetItem(2, annotations=[Label(1)], + attributes={'p1': 5, 'p2': ['1', 2, 3.5]}), + ], categories=['a', 'b']), resulting_dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_make_dataset_from_stage(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - stage = project.working_tree.build_targets.add_filter_stage('s1', - '/item/annotation[label="b"]') + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - built_dataset = project.working_tree.make_dataset(stage) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + stage = project.working_tree.build_targets.add_filter_stage('s1', + '/item/annotation[label="b"]') - expected_dataset = Dataset.from_iterable([ - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - compare_datasets(self, expected_dataset, built_dataset) + built_dataset = project.working_tree.make_dataset(stage) + + expected_dataset = Dataset.from_iterable([ + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + compare_datasets(self, expected_dataset, built_dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_commit(self): - with TestDir() as test_dir: - project = Project.init(test_dir) + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) - commit_hash = project.commit("First commit", allow_empty=True) + commit_hash = project.commit("First commit", allow_empty=True) - self.assertTrue(project.is_ref(commit_hash)) - self.assertEqual(len(project.history()), 2) - self.assertEqual(project.history()[0], - (commit_hash, "First commit")) + self.assertTrue(project.is_ref(commit_hash)) + self.assertEqual(len(project.history()), 2) + self.assertEqual(project.history()[0], + (commit_hash, "First commit")) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_commit_empty(self): - with TestDir() as test_dir: - project = Project.init(test_dir) + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) - with self.assertRaises(EmptyCommitError): - project.commit("First commit") + with self.assertRaises(EmptyCommitError): + project.commit("First commit") @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_commit_patch(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_source.txt') - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, 'w') as f: - f.write('hello') - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', source_url, format=DEFAULT_FORMAT) - project.commit("First commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', source_url, format=DEFAULT_FORMAT) + project.commit("First commit") - source_path = osp.join( - project.source_data_dir('s1'), - osp.basename(source_url)) - with open(source_path, 'w') as f: - f.write('world') + source_path = osp.join( + project.source_data_dir('s1'), + osp.basename(source_url)) + with open(source_path, 'w') as f: + f.write('world') - commit_hash = project.commit("Second commit", allow_foreign=True) + commit_hash = project.commit("Second commit", allow_foreign=True) - self.assertTrue(project.is_ref(commit_hash)) - self.assertNotEqual( - project.get_rev('HEAD~1').build_targets['s1'].head.hash, - project.working_tree.build_targets['s1'].head.hash) - self.assertTrue(project.is_obj_cached( - project.working_tree.build_targets['s1'].head.hash)) + self.assertTrue(project.is_ref(commit_hash)) + self.assertNotEqual( + project.get_rev('HEAD~1').build_targets['s1'].head.hash, + project.working_tree.build_targets['s1'].head.hash) + self.assertTrue(project.is_obj_cached( + project.working_tree.build_targets['s1'].head.hash)) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_commit_foreign_changes(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_source.txt') - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, 'w') as f: - f.write('hello') - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', source_url, format=DEFAULT_FORMAT) - project.commit("First commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', source_url, format=DEFAULT_FORMAT) + project.commit("First commit") - source_path = osp.join( - project.source_data_dir('s1'), - osp.basename(source_url)) - with open(source_path, 'w') as f: - f.write('world') + source_path = osp.join( + project.source_data_dir('s1'), + osp.basename(source_url)) + with open(source_path, 'w') as f: + f.write('world') - with self.assertRaises(ForeignChangesError): - project.commit("Second commit") + with self.assertRaises(ForeignChangesError): + project.commit("Second commit") @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_checkout_revision(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_source.txt') - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, 'w') as f: - f.write('hello') - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', source_url, format=DEFAULT_FORMAT) - project.commit("First commit") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') - source_path = osp.join( - project.source_data_dir('s1'), - osp.basename(source_url)) - with open(source_path, 'w') as f: - f.write('world') - project.commit("Second commit", allow_foreign=True) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', source_url, format=DEFAULT_FORMAT) + project.commit("First commit") - project.checkout('HEAD~1') + source_path = osp.join( + project.source_data_dir('s1'), + osp.basename(source_url)) + with open(source_path, 'w') as f: + f.write('world') + project.commit("Second commit", allow_foreign=True) + + project.checkout('HEAD~1') - compare_dirs(self, source_url, project.source_data_dir('s1')) - with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertTrue('s1' in [line.strip() for line in f]) + compare_dirs(self, source_url, project.source_data_dir('s1')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertTrue('s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_checkout_sources(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) - project.commit("Commit 1") - project.remove_source('s1', keep_data=False) # remove s1 from tree - shutil.rmtree(project.source_data_dir('s2')) # modify s2 "manually" + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) + project.commit("Commit 1") + project.remove_source('s1', keep_data=False) # remove s1 from tree + shutil.rmtree(project.source_data_dir('s2')) # modify s2 "manually" - project.checkout(sources=['s1', 's2']) + project.checkout(sources=['s1', 's2']) - compare_dirs(self, source_url, project.source_data_dir('s1')) - compare_dirs(self, source_url, project.source_data_dir('s2')) - with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - lines = [line.strip() for line in f] - self.assertTrue('s1' in lines) - self.assertTrue('s2' in lines) + compare_dirs(self, source_url, project.source_data_dir('s1')) + compare_dirs(self, source_url, project.source_data_dir('s2')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + lines = [line.strip() for line in f] + self.assertTrue('s1' in lines) + self.assertTrue('s2' in lines) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_checkout_sources_from_revision(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.commit("Commit 1") - project.remove_source('s1', keep_data=False) - project.commit("Commit 2") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("Commit 1") + project.remove_source('s1', keep_data=False) + project.commit("Commit 2") - project.checkout(rev='HEAD~1', sources=['s1']) + project.checkout(rev='HEAD~1', sources=['s1']) - compare_dirs(self, source_url, project.source_data_dir('s1')) + compare_dirs(self, source_url, project.source_data_dir('s1')) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_check_status(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) - project.import_source('s3', url=source_url, format=DEFAULT_FORMAT) - project.import_source('s4', url=source_url, format=DEFAULT_FORMAT) - project.import_source('s5', url=source_url, format=DEFAULT_FORMAT) - project.commit("Commit 1") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - project.remove_source('s2') - project.import_source('s6', url=source_url, format=DEFAULT_FORMAT) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s3', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s4', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s5', url=source_url, format=DEFAULT_FORMAT) + project.commit("Commit 1") - shutil.rmtree(project.source_data_dir('s3')) + project.remove_source('s2') + project.import_source('s6', url=source_url, format=DEFAULT_FORMAT) - project.working_tree.build_targets \ - .add_transform_stage('s4', 'reindex') - project.working_tree.make_dataset('s4').save() - project.refresh_source_hash('s4') + shutil.rmtree(project.source_data_dir('s3')) - s5_dir = osp.join(project.source_data_dir('s5')) - with open(osp.join(s5_dir, 'annotations', 't.txt'), 'w') as f: - f.write("hello") + project.working_tree.build_targets \ + .add_transform_stage('s4', 'reindex') + project.working_tree.make_dataset('s4').save() + project.refresh_source_hash('s4') - status = project.status() - self.assertEqual({ - 's2': DiffStatus.removed, - 's3': DiffStatus.missing, - 's4': DiffStatus.modified, - 's5': DiffStatus.foreign_modified, - 's6': DiffStatus.added, - }, status) + s5_dir = osp.join(project.source_data_dir('s5')) + with open(osp.join(s5_dir, 'annotations', 't.txt'), 'w') as f: + f.write("hello") + + status = project.status() + self.assertEqual({ + 's2': DiffStatus.removed, + 's3': DiffStatus.missing, + 's4': DiffStatus.modified, + 's5': DiffStatus.foreign_modified, + 's6': DiffStatus.added, + }, status) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_compare_revisions(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) - rev1 = project.commit("Commit 1") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - project.remove_source('s2') - project.import_source('s3', url=source_url, format=DEFAULT_FORMAT) - rev2 = project.commit("Commit 2") + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) + rev1 = project.commit("Commit 1") - diff = project.diff(rev1, rev2) - self.assertEqual(diff, - { 's2': DiffStatus.removed, 's3': DiffStatus.added }) + project.remove_source('s2') + project.import_source('s3', url=source_url, format=DEFAULT_FORMAT) + rev2 = project.commit("Commit 2") + + diff = project.diff(rev1, rev2) + self.assertEqual(diff, + { 's2': DiffStatus.removed, 's3': DiffStatus.added }) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_restore_revision(self): - with TestDir() as test_dir: - source_url = osp.join(test_dir, 'test_repo') - dataset = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=['a', 'b']) - dataset.save(source_url) - - project = Project.init(osp.join(test_dir, 'proj')) - project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - rev1 = project.commit("Commit 1") + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + rev1 = project.commit("Commit 1") - project.remove_cache_obj(rev1) + project.remove_cache_obj(rev1) - self.assertFalse(project.is_rev_cached(rev1)) + self.assertFalse(project.is_rev_cached(rev1)) - head_dataset = project.head.make_dataset() + head_dataset = project.head.make_dataset() - self.assertTrue(project.is_rev_cached(rev1)) - compare_datasets(self, dataset, head_dataset) + self.assertTrue(project.is_rev_cached(rev1)) + compare_datasets(self, dataset, head_dataset) @mark_requirement(Requirements.DATUM_BUG_404) + @scoped def test_can_add_plugin(self): - with TestDir() as test_dir: - Project.init(test_dir) - - plugin_dir = osp.join(test_dir, '.datumaro', 'plugins') - os.makedirs(plugin_dir) - with open(osp.join(plugin_dir, '__init__.py'), 'w') as f: - f.write(textwrap.dedent(""" - from datumaro.components.extractor import (SourceExtractor, - DatasetItem) - - class MyExtractor(SourceExtractor): - def __iter__(self): - yield from [ - DatasetItem('1'), - DatasetItem('2'), - ] - """)) - - project = Project(test_dir) - project.import_source('src', url='', format='my') - - expected = Dataset.from_iterable([ - DatasetItem('1'), - DatasetItem('2') - ]) - compare_datasets(self, expected, project.working_tree.make_dataset()) + test_dir = scope_add(TestDir()) + scope_add(Project.init(test_dir)).close() + + plugin_dir = osp.join(test_dir, '.datumaro', 'plugins') + os.makedirs(plugin_dir) + with open(osp.join(plugin_dir, '__init__.py'), 'w') as f: + f.write(textwrap.dedent(""" + from datumaro.components.extractor import (SourceExtractor, + DatasetItem) + + class MyExtractor(SourceExtractor): + def __iter__(self): + yield from [ + DatasetItem('1'), + DatasetItem('2'), + ] + """)) + + project = scope_add(Project(test_dir)) + project.import_source('src', url='', format='my') + + expected = Dataset.from_iterable([ + DatasetItem('1'), + DatasetItem('2') + ]) + compare_datasets(self, expected, project.working_tree.make_dataset()) @mark_requirement(Requirements.DATUM_BUG_402) + @scoped def test_can_transform_by_name(self): class CustomExtractor(Extractor): def __iter__(self): @@ -835,32 +878,33 @@ def __iter__(self): DatasetItem('b'), ]) - with TestDir() as test_dir: - extractor_name = 'ext1' - project = Project.init(test_dir) - project.env.extractors.register(extractor_name, CustomExtractor) - project.import_source('src1', url='', format=extractor_name) - dataset = project.working_tree.make_dataset() + test_dir = scope_add(TestDir()) + extractor_name = 'ext1' + project = scope_add(Project.init(test_dir)) + project.env.extractors.register(extractor_name, CustomExtractor) + project.import_source('src1', url='', format=extractor_name) + dataset = project.working_tree.make_dataset() - dataset = dataset.transform('reindex') + dataset = dataset.transform('reindex') - expected = Dataset.from_iterable([ - DatasetItem(1), - DatasetItem(2), - ]) - compare_datasets(self, expected, dataset) + expected = Dataset.from_iterable([ + DatasetItem(1), + DatasetItem(2), + ]) + compare_datasets(self, expected, dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_cant_modify_readonly(self): - with TestDir() as test_dir: - dataset_url = osp.join(test_dir, 'dataset') - Dataset.from_iterable([ - DatasetItem('a'), - DatasetItem('b'), - ]).save(dataset_url) - - proj_dir = osp.join(test_dir, 'proj') - project = Project.init(proj_dir) + test_dir = scope_add(TestDir()) + dataset_url = osp.join(test_dir, 'dataset') + Dataset.from_iterable([ + DatasetItem('a'), + DatasetItem('b'), + ]).save(dataset_url) + + proj_dir = osp.join(test_dir, 'proj') + with Project.init(proj_dir) as project: project.import_source('source1', url=dataset_url, format=DEFAULT_FORMAT) project.commit('first commit') @@ -871,46 +915,47 @@ def test_cant_modify_readonly(self): project.remove_cache_obj( project.working_tree.sources['source1'].hash) - project = Project(proj_dir, readonly=True) + project = scope_add(Project(proj_dir, readonly=True)) - self.assertTrue(project.readonly) + self.assertTrue(project.readonly) - with self.subTest("add source"), self.assertRaises(ReadonlyProjectError): - project.import_source('src1', url='', format=DEFAULT_FORMAT) + with self.subTest("add source"), self.assertRaises(ReadonlyProjectError): + project.import_source('src1', url='', format=DEFAULT_FORMAT) - with self.subTest("remove source"), self.assertRaises(ReadonlyProjectError): - project.remove_source('src1') + with self.subTest("remove source"), self.assertRaises(ReadonlyProjectError): + project.remove_source('src1') - with self.subTest("add model"), self.assertRaises(ReadonlyProjectError): - project.add_model('m1', launcher='x') + with self.subTest("add model"), self.assertRaises(ReadonlyProjectError): + project.add_model('m1', launcher='x') - with self.subTest("remove model"), self.assertRaises(ReadonlyProjectError): - project.remove_model('m1') + with self.subTest("remove model"), self.assertRaises(ReadonlyProjectError): + project.remove_model('m1') - with self.subTest("checkout"), self.assertRaises(ReadonlyProjectError): - project.checkout('HEAD') + with self.subTest("checkout"), self.assertRaises(ReadonlyProjectError): + project.checkout('HEAD') - with self.subTest("commit"), self.assertRaises(ReadonlyProjectError): - project.commit('third commit', allow_empty=True) + with self.subTest("commit"), self.assertRaises(ReadonlyProjectError): + project.commit('third commit', allow_empty=True) - # Can't re-download the source in a readonly project - with self.subTest("make_dataset"), self.assertRaises(MissingObjectError): - project.get_rev('HEAD').make_dataset() + # Can't re-download the source in a readonly project + with self.subTest("make_dataset"), self.assertRaises(MissingObjectError): + project.get_rev('HEAD').make_dataset() class BackwardCompatibilityTests_v0_1(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped def test_can_load_old_project(self): expected_dataset = Dataset.from_iterable([ DatasetItem(0, subset='train', annotations=[Label(0)]), DatasetItem(1, subset='test', annotations=[Label(1)]), ], categories=['a', 'b']) - with TestDir() as test_dir: - shutil.copytree(osp.join(osp.dirname(__file__), - 'assets', 'compat', 'v0.1', 'project'), - osp.join(test_dir, 'proj')) + test_dir = scope_add(TestDir()) + shutil.copytree(osp.join(osp.dirname(__file__), + 'assets', 'compat', 'v0.1', 'project'), + osp.join(test_dir, 'proj')) - project = Project(osp.join(test_dir, 'proj')) - loaded_dataset = project.working_tree.make_dataset() + project = scope_add(Project(osp.join(test_dir, 'proj'))) + loaded_dataset = project.working_tree.make_dataset() - compare_datasets(self, expected_dataset, loaded_dataset) + compare_datasets(self, expected_dataset, loaded_dataset) From e773b864b23795b4fd315b14e8bc8ebcac21ccbb Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 7 Sep 2021 14:04:29 +0300 Subject: [PATCH 15/19] Fix source dvcfile placement --- datumaro/components/project.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datumaro/components/project.py b/datumaro/components/project.py index 1f1a8f1e9e..f431425adb 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -106,7 +106,10 @@ def _update_ignore_file(paths: Union[str, List[str]], repo_root: str, def _make_ignored_path(path): path = osp.join(repo_root, osp.normpath(path)) assert is_subpath(path, base=repo_root) - return osp.relpath(path, repo_root) + + # Prepend the '/' to match only direct childs. + # Otherwise the rule can be in any path part. + return '/' + osp.relpath(path, repo_root).replace('\\', '/') header = '# The file is autogenerated by Datumaro' @@ -1512,7 +1515,7 @@ def _init_vcs(self): self._git.ignore([ ProjectLayout.cache_dir, ], gitignore=osp.join(self._aux_dir, '.gitignore')) - self._git.ignore([]) + self._git.ignore([]) # create the file if not self._dvc.initialized: self._dvc.init() self._dvc.ignore([ @@ -1744,7 +1747,7 @@ def _source_dvcfile_path(self, name: str, if not root: root = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) - return osp.join(root, TreeLayout.sources_dir, name + '.dvc') + return osp.join(root, TreeLayout.sources_dir, name, 'source.dvc') def _make_tmp_dir(self, suffix: Optional[str] = None): project_tmp_dir = osp.join(self._aux_dir, ProjectLayout.tmp_dir) From 83cc4fc5b0d53effd76eefefb523e33616259f5e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 7 Sep 2021 16:24:45 +0300 Subject: [PATCH 16/19] Fix winodws issues --- datumaro/components/project.py | 30 ++++++++++++++++++++---------- datumaro/util/os_util.py | 12 ++++++++++-- tests/test_project.py | 18 +++++++++--------- 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/datumaro/components/project.py b/datumaro/components/project.py index f431425adb..e4656d65f9 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -38,7 +38,7 @@ from datumaro.util import find, parse_str_enum_value from datumaro.util.log_utils import catch_logs, logging_disabled from datumaro.util.os_util import ( - copytree, generate_next_name, is_subpath, make_file_name, rmtree, + copytree, generate_next_name, is_subpath, make_file_name, rmfile, rmtree, ) from datumaro.util.scope import on_error_do, scoped @@ -137,7 +137,8 @@ def _make_ignored_path(path): continue line_path = osp.join(repo_root, - osp.normpath(line.split('#', maxsplit=1)[0]).lstrip('/')) + osp.normpath(line.split('#', maxsplit=1)[0]) \ + .replace('\\', '/').lstrip('/')) if mode == IgnoreMode.append: if line_path in paths: @@ -1296,7 +1297,7 @@ def _copy_obj(src, dst, link=False): def remove_cache_obj(self, obj_hash: str): src = self.obj_path(obj_hash) if osp.isfile(src): - os.remove(src) + rmfile(src) return src += self.DIR_HASH_SUFFIX @@ -1308,9 +1309,9 @@ def remove_cache_obj(self, obj_hash: str): for entry in src_meta: entry_path = self.obj_path(entry['md5']) if osp.isfile(entry_path): - os.remove(entry_path) + rmfile(entry_path) - os.remove(src) + rmfile(src) class Tree: # can be: @@ -1460,7 +1461,7 @@ def _migrate_from_v1_to_v2(self) -> bool: new_tree_config.dump(osp.join(wtree_dir, TreeLayout.conf_file)) new_local_config.dump(osp.join(self._aux_dir, ProjectLayout.conf_file)) - os.remove(old_config_tmp) + rmfile(old_config_tmp) log.debug("Finished") @@ -2010,7 +2011,7 @@ def remove_source(self, name: str, force: bool = False, dvcfile = self._source_dvcfile_path(name) if osp.isfile(dvcfile): try: - os.remove(dvcfile) + rmfile(dvcfile) except Exception: if not force: raise @@ -2063,16 +2064,25 @@ def commit(self, message: str, no_cache: bool = False, wtree_dir = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) self.working_tree.save() self._git.add(wtree_dir, base=wtree_dir) - self._git.add([ + + extra_files = [ osp.join(self._root_dir, '.dvc', '.gitignore'), osp.join(self._root_dir, '.dvc', 'config'), osp.join(self._root_dir, '.dvcignore'), osp.join(self._root_dir, '.gitignore'), osp.join(self._aux_dir, '.gitignore'), - ], base=self._root_dir) + ] + self._git.add(extra_files, base=self._root_dir) + head = self._git.commit(message) - copytree(wtree_dir, self.cache_path(head)) + rev_dir = self.cache_path(head) + copytree(wtree_dir, rev_dir) + for p in extra_files: + if osp.isfile(p): + dst_path = osp.join(rev_dir, osp.relpath(p, self._root_dir)) + os.makedirs(osp.dirname(dst_path), exist_ok=True) + shutil.copyfile(p, dst_path) self._head_tree = None diff --git a/datumaro/util/os_util.py b/datumaro/util/os_util.py index f191ad1824..0ec6b8c2f9 100644 --- a/datumaro/util/os_util.py +++ b/datumaro/util/os_util.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT from contextlib import ( - ExitStack, contextmanager, redirect_stderr, redirect_stdout, + ExitStack, contextmanager, redirect_stderr, redirect_stdout, suppress, ) from io import StringIO from typing import Iterable, Optional @@ -22,7 +22,15 @@ # Use rmtree from GitPython to avoid the problem with removal of # readonly files on Windows, which Git uses extensively # It double checks if a file cannot be removed because of readonly flag - from git.util import rmfile, rmtree # pylint: disable=unused-import + import unittest + + from git.util import rmfile # pylint: disable=unused-import + from git.util import rmtree as _rmtree + + def rmtree(path): + with suppress(unittest.SkipTest): + _rmtree(path) + except ModuleNotFoundError: from os import remove as rmfile # pylint: disable=unused-import from shutil import rmtree as rmtree # pylint: disable=unused-import diff --git a/tests/test_project.py b/tests/test_project.py index 95806e35c5..c1ec07f93f 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -17,7 +17,7 @@ from datumaro.components.extractor import DatasetItem, Extractor, ItemTransform from datumaro.components.launcher import Launcher from datumaro.components.project import DiffStatus, Project -from datumaro.util.scope import scoped, scope_add +from datumaro.util.scope import scope_add, scoped from datumaro.util.test_utils import TestDir, compare_datasets, compare_dirs from .requirements import Requirements, mark_requirement @@ -127,7 +127,7 @@ def test_can_import_local_source(self): self.assertEqual('fmt', source.format) compare_dirs(self, source_base_url, project.source_data_dir('s1')) with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertTrue('s1' in [line.strip() for line in f]) + self.assertTrue('/s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped @@ -164,7 +164,7 @@ def test_can_import_local_source_with_relpath(self): require_images=True) with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertTrue('s1' in [line.strip() for line in f]) + self.assertTrue('/s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped @@ -254,7 +254,7 @@ def test_can_import_generated_source(self): self.assertEqual(added.format, origin.format) self.assertEqual(added.options, origin.options) with open(osp.join(test_dir, '.gitignore')) as f: - self.assertTrue(source_name in [line.strip() for line in f]) + self.assertTrue('/' + source_name in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped @@ -284,7 +284,7 @@ def test_can_remove_source_and_keep_data(self): self.assertFalse('s1' in project.working_tree.sources) compare_dirs(self, source_url, project.source_data_dir('s1')) with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertFalse('s1' in [line.strip() for line in f]) + self.assertFalse('/s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped @@ -303,7 +303,7 @@ def test_can_remove_source_and_wipe_data(self): self.assertFalse('s1' in project.working_tree.sources) self.assertFalse(osp.exists(project.source_data_dir('s1'))) with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertFalse('s1' in [line.strip() for line in f]) + self.assertFalse('/s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped @@ -698,7 +698,7 @@ def test_can_checkout_revision(self): compare_dirs(self, source_url, project.source_data_dir('s1')) with open(osp.join(test_dir, 'proj', '.gitignore')) as f: - self.assertTrue('s1' in [line.strip() for line in f]) + self.assertTrue('/s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped @@ -724,8 +724,8 @@ def test_can_checkout_sources(self): compare_dirs(self, source_url, project.source_data_dir('s2')) with open(osp.join(test_dir, 'proj', '.gitignore')) as f: lines = [line.strip() for line in f] - self.assertTrue('s1' in lines) - self.assertTrue('s2' in lines) + self.assertTrue('/s1' in lines) + self.assertTrue('/s2' in lines) @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped From ba5b78ed4c830bdbaecad3f1b1949a41ab681500 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 7 Sep 2021 16:51:42 +0300 Subject: [PATCH 17/19] Update cli tests --- datumaro/util/os_util.py | 4 +++- tests/cli/test_merge.py | 9 +++++---- tests/cli/test_project.py | 14 ++++++++------ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/datumaro/util/os_util.py b/datumaro/util/os_util.py index 0ec6b8c2f9..3594284f88 100644 --- a/datumaro/util/os_util.py +++ b/datumaro/util/os_util.py @@ -28,8 +28,10 @@ from git.util import rmtree as _rmtree def rmtree(path): - with suppress(unittest.SkipTest): + try: _rmtree(path) + except unittest.SkipTest as e: + raise AssertionError(f"Failed to remove '{path}'") from e except ModuleNotFoundError: from os import remove as rmfile # pylint: disable=unused-import diff --git a/tests/cli/test_merge.py b/tests/cli/test_merge.py index f4eb8dab04..5e5ca62538 100644 --- a/tests/cli/test_merge.py +++ b/tests/cli/test_merge.py @@ -3,9 +3,10 @@ import numpy as np -from datumaro.components.extractor import ( - AnnotationType, Bbox, DatasetItem, LabelCategories, MaskCategories, +from datumaro.components.annotation import ( + AnnotationType, Bbox, LabelCategories, MaskCategories, ) +from datumaro.components.extractor import DatasetItem from datumaro.components.project import Dataset, Project from datumaro.util.test_utils import TestDir, compare_datasets from datumaro.util.test_utils import run_datum as run @@ -58,8 +59,8 @@ def test_can_run_self_merge(self): dataset2.export(dataset2_url, 'voc', save_images=True) proj_dir = osp.join(test_dir, 'proj') - project = Project.init(proj_dir) - project.import_source('source', dataset2_url, 'voc') + with Project.init(proj_dir) as project: + project.import_source('source', dataset2_url, 'voc') result_dir = osp.join(test_dir, 'cmp_result') run(self, 'merge', dataset1_url + ':coco', '-o', result_dir, diff --git a/tests/cli/test_project.py b/tests/cli/test_project.py index 457a8bec03..357bedbecf 100644 --- a/tests/cli/test_project.py +++ b/tests/cli/test_project.py @@ -4,8 +4,9 @@ import numpy as np +from datumaro.components.annotation import Bbox, Label from datumaro.components.dataset import DEFAULT_FORMAT, Dataset -from datumaro.components.extractor import Bbox, DatasetItem, Label +from datumaro.components.extractor import DatasetItem from datumaro.components.project import Project from datumaro.util.test_utils import TestDir, compare_datasets from datumaro.util.test_utils import run_datum as run @@ -157,9 +158,10 @@ def test_can_chain_transforms_in_working_tree(self): run(self, 'transform', '-p', project_dir, '-t', 'remap_labels', '--', '-l', 'a:cat', '-l', 'b:dog') - built_dataset = Project(project_dir).working_tree.make_dataset() + with Project(project_dir) as project: + built_dataset = project.working_tree.make_dataset() - expected_dataset = Dataset.from_iterable([ - DatasetItem('qq', annotations=[Label(1)]), - ], categories=['cat', 'dog']) - compare_datasets(self, expected_dataset, built_dataset) + expected_dataset = Dataset.from_iterable([ + DatasetItem('qq', annotations=[Label(1)]), + ], categories=['cat', 'dog']) + compare_datasets(self, expected_dataset, built_dataset) From c4dd0a894fe7fb4d969e12e4c69e6056f14cf03a Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 7 Sep 2021 17:36:41 +0300 Subject: [PATCH 18/19] Suppress extra window --- datumaro/components/project.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/datumaro/components/project.py b/datumaro/components/project.py index e4656d65f9..d5a881c5b7 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -40,7 +40,7 @@ from datumaro.util.os_util import ( copytree, generate_next_name, is_subpath, make_file_name, rmfile, rmtree, ) -from datumaro.util.scope import on_error_do, scoped +from datumaro.util.scope import on_error_do, scope_add, scoped class ProjectSourceDataset(IDataset): @@ -1114,6 +1114,7 @@ class DvcWrapper: def module(): try: import dvc + import dvc.env import dvc.main import dvc.repo return dvc @@ -1190,19 +1191,23 @@ def add(self, paths, dvc_path=None, no_commit=False, allow_external=False): self._exec(args) def _exec(self, args, hide_output=True, answer_on_input='y'): - args = ['--cd', self._project_dir, '-q'] + args + args = ['--cd', self._project_dir] + args - with ExitStack() as contexts: - contexts.callback(os.chdir, os.getcwd()) # restore cd after DVC + # Avoid calling an extra process. Improves call performance and + # removes an extra console window on Windows. + os.environ[self.module().env.DVC_NO_ANALYTICS] = '1' + + with ExitStack() as es: + es.callback(os.chdir, os.getcwd()) # restore cd after DVC if answer_on_input is not None: def _input(*args): return answer_on_input - contexts.enter_context(unittest.mock.patch( + es.enter_context(unittest.mock.patch( 'dvc.prompt.input', new=_input)) log.debug("Calling DVC main with args: %s", args) - logs = contexts.enter_context(catch_logs('dvc')) + logs = es.enter_context(catch_logs('dvc')) retcode = self.module().main.main(args) logs = logs.getvalue() @@ -1834,16 +1839,16 @@ def _get_source_hash(dvcfile): obj_hash = obj_hash[:-len(DvcWrapper.DIR_HASH_SUFFIX)] return obj_hash + @scoped def compute_source_hash(self, data_dir: str, dvcfile: Optional[str] = None, no_cache: bool = True, allow_external: bool = True) -> ObjectId: - with ExitStack() as es: - if not dvcfile: - tmp_dir = es.enter_context(self._make_tmp_dir()) - dvcfile = osp.join(tmp_dir, 'source.dvc') + if not dvcfile: + tmp_dir = scope_add(self._make_tmp_dir()) + dvcfile = osp.join(tmp_dir, 'source.dvc') - self._dvc.add(data_dir, dvc_path=dvcfile, no_commit=no_cache, - allow_external=allow_external) - obj_hash = self._get_source_hash(dvcfile) + self._dvc.add(data_dir, dvc_path=dvcfile, no_commit=no_cache, + allow_external=allow_external) + obj_hash = self._get_source_hash(dvcfile) return obj_hash def refresh_source_hash(self, source: str, From cb4f0ab7170b147864ece6731bfc22445b6f8d7a Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 7 Sep 2021 17:53:19 +0300 Subject: [PATCH 19/19] Dont hide windows file femoval issues --- datumaro/util/os_util.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/datumaro/util/os_util.py b/datumaro/util/os_util.py index 3594284f88..2c3ea42a0f 100644 --- a/datumaro/util/os_util.py +++ b/datumaro/util/os_util.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT from contextlib import ( - ExitStack, contextmanager, redirect_stderr, redirect_stdout, suppress, + ExitStack, contextmanager, redirect_stderr, redirect_stdout, ) from io import StringIO from typing import Iterable, Optional @@ -22,16 +22,9 @@ # Use rmtree from GitPython to avoid the problem with removal of # readonly files on Windows, which Git uses extensively # It double checks if a file cannot be removed because of readonly flag - import unittest - - from git.util import rmfile # pylint: disable=unused-import - from git.util import rmtree as _rmtree - - def rmtree(path): - try: - _rmtree(path) - except unittest.SkipTest as e: - raise AssertionError(f"Failed to remove '{path}'") from e + from git.util import rmfile, rmtree # pylint: disable=unused-import + import git.util + git.util.HIDE_WINDOWS_KNOWN_ERRORS = False except ModuleNotFoundError: from os import remove as rmfile # pylint: disable=unused-import