Skip to content

Commit

Permalink
Fix: Dataset infos() can be broken if a transform not redefining info…
Browse files Browse the repository at this point in the history
…s() is stacked on the top (#1101)

- Ticket no. 115725
- Fix: Dataset infos() can be broken if a transform not redefining
infos() is stacked on the top
- Enhance the StreamDatasetStorage transform tests added in #1077.
- Test `call_count` as well in the tests to validate stacked transforms.

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
Co-authored-by: Wonju Lee <wonju.lee@intel.com>
  • Loading branch information
vinnamkim and wonjuleee authored Jul 18, 2023
1 parent 8867601 commit 144489e
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Bug fixes
- Create cache dir under only writable filesystem
(<https://github.com/openvinotoolkit/datumaro/pull/1088>)
- Fix: Dataset infos() can be broken if a transform not redefining infos() is stacked on the top
(<https://github.com/openvinotoolkit/datumaro/pull/1101>)

## 07/07/2023 - Release 1.4.0rc1
### New features
Expand Down
3 changes: 3 additions & 0 deletions src/datumaro/components/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __len__(self):
def media_type(self):
return self._extractor.media_type()

def infos(self):
return self._extractor.infos()


class ItemTransform(Transform):
def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]:
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/components/test_dataset_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,35 @@ def test_item_transform(self, fxt_stream_extractor: MagicMock):
n_calls = 1

self._test_loop(fxt_stream_extractor, storage, n_calls)
assert fxt_stream_extractor.__iter__.call_count == 1

# Stack transform 1 level
storage.transform(Rename, regex="|item_|rename_|")
self._test_loop(fxt_stream_extractor, storage, n_calls, id_pattern="rename_{idx}")
assert fxt_stream_extractor.__iter__.call_count == 2

# Stack transform 2 level
storage.transform(Rename, regex="|rename_|renameagain_|")
self._test_loop(fxt_stream_extractor, storage, n_calls, id_pattern="renameagain_{idx}")
assert fxt_stream_extractor.__iter__.call_count == 3

def test_subset_transform(self, fxt_stream_extractor: MagicMock):
storage = StreamDatasetStorage(source=fxt_stream_extractor)

self._test_subsets(fxt_stream_extractor, storage)
assert fxt_stream_extractor.__iter__.call_count == 1

# Stack transform 1 level
storage.transform(RandomSplit, splits=[("train", 0.5), ("val", 0.5)], seed=3003)
self._test_subsets(fxt_stream_extractor, storage, expect={"train", "val"})
assert fxt_stream_extractor.__iter__.call_count == 2

# Stack transform 2 level
storage.transform(
MapSubsets, mapping={"train": DEFAULT_SUBSET_NAME, "val": DEFAULT_SUBSET_NAME}
)
self._test_subsets(fxt_stream_extractor, storage)
assert fxt_stream_extractor.__iter__.call_count == 3

def test_info_transform(self, fxt_stream_extractor: MagicMock, fxt_infos: DatasetInfo):
storage = StreamDatasetStorage(source=fxt_stream_extractor)
Expand All @@ -108,6 +114,7 @@ def test_info_transform(self, fxt_stream_extractor: MagicMock, fxt_infos: Datase
storage.transform(ProjectInfos, dst_infos=dst_infos)

assert storage.infos().get("new") == "info"
assert fxt_stream_extractor.__iter__.call_count == 0

def test_categories_transform(
self, fxt_stream_extractor: MagicMock, fxt_categories: CategoriesInfo
Expand All @@ -122,3 +129,44 @@ def test_categories_transform(
actual = set(cat.name for cat in storage.categories()[AnnotationType.label])
expect = set(mapping.values())
assert actual == expect

assert fxt_stream_extractor.__iter__.call_count == 0

def test_mixed_transform(
self,
fxt_stream_extractor: MagicMock,
fxt_infos: DatasetInfo,
fxt_categories: CategoriesInfo,
):
n_calls = 1
storage = StreamDatasetStorage(source=fxt_stream_extractor)

# Check extractor infos
assert storage.infos() == fxt_infos

dst_infos = {"new": "info"}
storage.transform(ProjectInfos, dst_infos=dst_infos)
assert fxt_stream_extractor.__iter__.call_count == 0

# Check extractor categories
assert storage.categories() == fxt_categories

mapping = {"car": "apple", "cat": "banana", "dog": "cinnamon"}
storage.transform(RemapLabels, mapping=mapping)
assert fxt_stream_extractor.__iter__.call_count == 0

# Stack Rename (ItemTransform) on the top
storage.transform(Rename, regex="|item_|rename_|")
assert fxt_stream_extractor.__iter__.call_count == 0

# Check ProjectInfos
assert storage.infos().get("new") == "info"

# Check RemapLabels
actual = set(cat.name for cat in storage.categories()[AnnotationType.label])
expect = set(mapping.values())
assert actual == expect

# Check Rename
self._test_loop(fxt_stream_extractor, storage, n_calls, id_pattern="rename_{idx}")
assert fxt_stream_extractor.__iter__.call_count == n_calls
20 changes: 13 additions & 7 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from datumaro.components.merge.intersect_merge import IntersectMerge
from datumaro.components.progress_reporting import NullProgressReporter, ProgressReporter
from datumaro.components.transformer import ItemTransform, Transform
from datumaro.plugins.transforms import ProjectInfos
from datumaro.plugins.transforms import ProjectInfos, RemapLabels

from ..requirements import Requirements, mark_requirement

Expand Down Expand Up @@ -2275,15 +2275,21 @@ def test_dataset_infos_intersect_merge(
fxt_test_case.assertEqual(dataset.infos(), infos)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("is_eager", [True, False])
def test_dataset_infos_transform(
self, fxt_test_case, fxt_sample_dataset_factory, fxt_sample_infos
self, fxt_test_case, fxt_sample_dataset_factory, fxt_sample_infos, is_eager
):
infos_1, infos_2, infos = fxt_sample_infos
with eager_mode(is_eager):
dataset = fxt_sample_dataset_factory(infos=infos_1)

dataset = fxt_sample_dataset_factory(infos=infos_1)
dataset.transform(ProjectInfos, dst_infos=infos_2, overwrite=False)
fxt_test_case.assertEqual(dataset.infos(), infos)

dataset.transform(ProjectInfos, dst_infos=infos_2, overwrite=False)
fxt_test_case.assertEqual(dataset.infos(), infos)
dataset.transform(ProjectInfos, dst_infos=infos_2, overwrite=True)
fxt_test_case.assertEqual(dataset.infos(), infos_2)

dataset.transform(ProjectInfos, dst_infos=infos_2, overwrite=True)
fxt_test_case.assertEqual(dataset.infos(), infos_2)
dataset.transform(
RemapLabels, mapping={"car": "apple", "cat": "banana", "dog": "cinnamon"}
)
fxt_test_case.assertEqual(dataset.infos(), infos_2)

0 comments on commit 144489e

Please sign in to comment.