diff --git a/src/otx/algorithms/common/adapters/mmcv/pipelines/load_image_from_otx_dataset.py b/src/otx/algorithms/common/adapters/mmcv/pipelines/load_image_from_otx_dataset.py index 7e71d95dd00..a642b09a4d8 100644 --- a/src/otx/algorithms/common/adapters/mmcv/pipelines/load_image_from_otx_dataset.py +++ b/src/otx/algorithms/common/adapters/mmcv/pipelines/load_image_from_otx_dataset.py @@ -45,8 +45,11 @@ def _get_unique_key(results: Dict[str, Any]) -> Tuple: # TODO: We should improve it by assigning an unique id to DatasetItemEntity. # This is because there is a case which # d_item.media.path is None, but d_item.media.data is not None + if "cache_key" in results: + return results["cache_key"] d_item = results["dataset_item"] - return d_item.media.path, d_item.roi.id + results["cache_key"] = d_item.media.path, d_item.roi.id + return results["cache_key"] def __call__(self, results: Dict[str, Any]): """Callback function of LoadImageFromOTXDataset.""" @@ -177,12 +180,12 @@ def _save_cache(self, results: Dict[str, Any]): return key = self._get_unique_key(results) meta = results.copy() - meta.pop("dataset_item") # remove irrlevant info img = meta.pop("img") self._mem_cache_handler.put(key, img, meta) def __call__(self, results: Dict[str, Any]) -> Dict[str, Any]: """Callback function.""" + results = results.copy() cached_results = self._load_cache(results) if cached_results: return cached_results diff --git a/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/__init__.py b/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/__init__.py index a25510ab91e..55efa8bd7e8 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/__init__.py +++ b/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/__init__.py @@ -3,7 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 # -from .load_pipelines import LoadAnnotationFromOTXDataset, LoadImageFromOTXDataset +from .load_pipelines import ( + LoadAnnotationFromOTXDataset, + LoadImageFromOTXDataset, + LoadResizeDataFromOTXDataset, + ResizeTo, +) from .torchvision2mmdet import ( BranchImage, ColorJitter, @@ -19,6 +24,8 @@ __all__ = [ "LoadImageFromOTXDataset", "LoadAnnotationFromOTXDataset", + "LoadResizeDataFromOTXDataset", + "ResizeTo", "ColorJitter", "RandomGrayscale", "RandomErasing", diff --git a/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py b/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py index c102ea92ba3..3eda94767e3 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py +++ b/src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py @@ -1,21 +1,12 @@ """Collection Pipeline for detection task.""" -# Copyright (C) 2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# Copyright (C) 2021-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import copy -from typing import Any, Dict +from typing import Any, Dict, Optional -from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.builder import PIPELINES, build_from_cfg +from mmdet.datasets.pipelines import Resize import otx.algorithms.common.adapters.mmcv.pipelines.load_image_from_otx_dataset as load_image_base from otx.algorithms.detection.adapters.mmdet.datasets.dataset import ( @@ -30,6 +21,50 @@ class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset): """Pipeline element that loads an image from a OTX Dataset on the fly.""" +@PIPELINES.register_module() +class LoadResizeDataFromOTXDataset(load_image_base.LoadResizeDataFromOTXDataset): + """Load and resize image & annotation with cache support.""" + + def _create_load_ann_op(self, cfg: Optional[Dict]) -> Optional[Any]: + """Creates resize operation.""" + if cfg is None: + return None + return build_from_cfg(cfg, PIPELINES) + + def _create_resize_op(self, cfg: Optional[Dict]) -> Optional[Any]: + """Creates resize operation.""" + if cfg is None: + return None + return build_from_cfg(cfg, PIPELINES) + + +@PIPELINES.register_module() +class ResizeTo(Resize): + """Resize to specific size. + + This operation works if the input is not in desired shape. + If it's already in the shape, it just returns input dict for efficiency. + + Args: + img_scale (tuple): Images scales for resizing (w, h). + """ + + def __init__(self, **kwargs): + super().__init__(override=True, **kwargs) # Allow multiple calls + + def __call__(self, results: Dict[str, Any]): + """Callback function of ResizeTo. + + Args: + results: Inputs to be transformed. + """ + img_shape = results.get("img_shape", (0, 0)) + img_scale = self.img_scale[0] + if img_shape[0] == img_scale[0] and img_shape[1] == img_scale[1]: + return results + return super().__call__(results) + + @PIPELINES.register_module() class LoadAnnotationFromOTXDataset: """Pipeline element that loads an annotation from a OTX Dataset on the fly. @@ -84,7 +119,7 @@ def _load_masks(results, ann_info): def __call__(self, results: Dict[str, Any]): """Callback function of LoadAnnotationFromOTXDataset.""" - dataset_item = results.pop("dataset_item") + dataset_item = results.pop("dataset_item") # Prevent unnecessary deepcopy label_list = results.pop("ann_info")["label_list"] ann_info = get_annotation_mmdet_format(dataset_item, label_list, self.domain, self.min_size) if self.with_bbox: diff --git a/src/otx/algorithms/detection/configs/base/data/atss_data_pipeline.py b/src/otx/algorithms/detection/configs/base/data/atss_data_pipeline.py index 917412cabf8..61d0580684a 100644 --- a/src/otx/algorithms/detection/configs/base/data/atss_data_pipeline.py +++ b/src/otx/algorithms/detection/configs/base/data/atss_data_pipeline.py @@ -9,14 +9,24 @@ __img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True) train_pipeline = [ - dict(type="LoadImageFromOTXDataset", enable_memcache=True), - dict(type="LoadAnnotationFromOTXDataset", with_bbox=True), + dict( + type="LoadResizeDataFromOTXDataset", + load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True), + resize_cfg=dict( + type="Resize", + img_scale=(1088, 800), # max sizes in random image scales + keep_ratio=True, + downscale_only=True, + ), # Resize to intermediate size if org image is bigger + enable_memcache=True, # Cache after resizing image & annotations + ), dict(type="MinIoURandomCrop", min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3), dict( type="Resize", img_scale=[(992, 736), (896, 736), (1088, 736), (992, 672), (992, 800)], multiscale_mode="value", keep_ratio=False, + override=True, # Allow multiple resize ), dict(type="RandomFlip", flip_ratio=0.5), dict(type="Normalize", **__img_norm_cfg), diff --git a/src/otx/algorithms/detection/configs/base/data/iseg_efficientnet_data_pipeline.py b/src/otx/algorithms/detection/configs/base/data/iseg_efficientnet_data_pipeline.py index c125412df3b..63005cce261 100644 --- a/src/otx/algorithms/detection/configs/base/data/iseg_efficientnet_data_pipeline.py +++ b/src/otx/algorithms/detection/configs/base/data/iseg_efficientnet_data_pipeline.py @@ -11,15 +11,22 @@ __img_norm_cfg = dict(mean=(103.53, 116.28, 123.675), std=(1.0, 1.0, 1.0), to_rgb=True) train_pipeline = [ - dict(type="LoadImageFromOTXDataset", enable_memcache=True), dict( - type="LoadAnnotationFromOTXDataset", - domain="instance_segmentation", - with_bbox=True, - with_mask=True, - poly2mask=False, + type="LoadResizeDataFromOTXDataset", + load_ann_cfg=dict( + type="LoadAnnotationFromOTXDataset", + domain="instance_segmentation", + with_bbox=True, + with_mask=True, + poly2mask=False, + ), + resize_cfg=dict( + type="Resize", + img_scale=__img_size, + keep_ratio=False, + ), + enable_memcache=True, # Cache after resizing image & annotations ), - dict(type="Resize", img_scale=__img_size, keep_ratio=False), dict(type="RandomFlip", flip_ratio=0.5), dict(type="Normalize", **__img_norm_cfg), dict(type="Pad", size_divisor=32), diff --git a/src/otx/algorithms/detection/configs/base/data/iseg_resnet_data_pipeline.py b/src/otx/algorithms/detection/configs/base/data/iseg_resnet_data_pipeline.py index 3b1d9a8ceff..d5010e6fad7 100644 --- a/src/otx/algorithms/detection/configs/base/data/iseg_resnet_data_pipeline.py +++ b/src/otx/algorithms/detection/configs/base/data/iseg_resnet_data_pipeline.py @@ -11,15 +11,22 @@ __img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ - dict(type="LoadImageFromOTXDataset", enable_memcache=True), dict( - type="LoadAnnotationFromOTXDataset", - domain="instance_segmentation", - with_bbox=True, - with_mask=True, - poly2mask=False, + type="LoadResizeDataFromOTXDataset", + load_ann_cfg=dict( + type="LoadAnnotationFromOTXDataset", + domain="instance_segmentation", + with_bbox=True, + with_mask=True, + poly2mask=False, + ), + resize_cfg=dict( + type="Resize", + img_scale=__img_size, + keep_ratio=False, + ), + enable_memcache=True, # Cache after resizing image & annotations ), - dict(type="Resize", img_scale=__img_size, keep_ratio=False), dict(type="RandomFlip", flip_ratio=0.5), dict(type="Normalize", **__img_norm_cfg), dict(type="DefaultFormatBundle"), diff --git a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox/data_pipeline.py b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox/data_pipeline.py index 3f6a007615b..61d26d49ff3 100644 --- a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox/data_pipeline.py +++ b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox/data_pipeline.py @@ -34,7 +34,7 @@ hue_delta=18, ), dict(type="RandomFlip", flip_ratio=0.5), - dict(type="Resize", img_scale=__img_size, keep_ratio=True), + dict(type="Resize", img_scale=__img_size, keep_ratio=True, override=True), # Allow multiple resize dict(type="Pad", pad_to_square=True, pad_val=114.0), dict(type="Normalize", **__img_norm_cfg), dict(type="DefaultFormatBundle"), @@ -82,8 +82,18 @@ dataset=dict( type=__dataset_type, pipeline=[ - dict(type="LoadImageFromOTXDataset", to_float32=False, enable_memcache=True), - dict(type="LoadAnnotationFromOTXDataset", with_bbox=True), + dict( + type="LoadResizeDataFromOTXDataset", + load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True), + resize_cfg=dict( + type="Resize", + img_scale=__img_size, + keep_ratio=True, + downscale_only=True, + ), # Resize to intermediate size if org image is bigger + to_float32=False, + enable_memcache=True, # Cache after resizing image & annotations + ), ], ), pipeline=train_pipeline, diff --git a/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/data_pipeline.py b/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/data_pipeline.py index fa22099bf30..1494121e514 100644 --- a/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/data_pipeline.py +++ b/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/data_pipeline.py @@ -21,8 +21,18 @@ __img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True) train_pipeline = [ - dict(type="LoadImageFromOTXDataset", to_float32=True, enable_memcache=True), - dict(type="LoadAnnotationFromOTXDataset", with_bbox=True), + dict( + type="LoadResizeDataFromOTXDataset", + load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True), + resize_cfg=dict( + type="Resize", + img_scale=__img_size, + keep_ratio=True, + downscale_only=True, + ), # Resize to intermediate size if org image is bigger + to_float32=True, + enable_memcache=True, # Cache after resizing image & annotations + ), dict( type="PhotoMetricDistortion", brightness_delta=32, @@ -31,7 +41,7 @@ hue_delta=18, ), dict(type="MinIoURandomCrop", min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.1), - dict(type="Resize", img_scale=__img_size, keep_ratio=False), + dict(type="Resize", img_scale=__img_size, keep_ratio=False, override=True), # Allow multiple resize dict(type="Normalize", **__img_norm_cfg), dict(type="RandomFlip", flip_ratio=0.5), dict(type="DefaultFormatBundle"), diff --git a/src/otx/core/data/caching/mem_cache_handler.py b/src/otx/core/data/caching/mem_cache_handler.py index 260af378f82..436fd23fe24 100644 --- a/src/otx/core/data/caching/mem_cache_handler.py +++ b/src/otx/core/data/caching/mem_cache_handler.py @@ -63,9 +63,9 @@ def get(self, key: Any) -> Tuple[Optional[np.ndarray], Optional[Dict]]: addr = self._cache_addr[key] - offset, count, shape, strides, meta = addr + offset, count, dtype, shape, strides, meta = addr - data = np.frombuffer(self._arr, dtype=np.uint8, count=count, offset=offset) + data = np.frombuffer(self._arr, dtype=dtype, count=count, offset=offset) return np.lib.stride_tricks.as_strided(data, shape, strides), meta def put(self, key: Any, data: np.ndarray, meta: Optional[Dict] = None) -> Optional[int]: @@ -82,20 +82,21 @@ def put(self, key: Any, data: np.ndarray, meta: Optional[Dict] = None) -> Option if self._freeze.value: return None - assert data.dtype == np.uint8 + data_bytes = data.size * data.itemsize with self._lock: - new_page = self._cur_page.value + data.size + new_page = self._cur_page.value + data_bytes if key in self._cache_addr or new_page > self.mem_size: return None offset = ct.byref(self._arr, self._cur_page.value) - ct.memmove(offset, data.ctypes.data, data.size) + ct.memmove(offset, data.ctypes.data, data_bytes) self._cache_addr[key] = ( self._cur_page.value, data.size, + data.dtype, data.shape, data.strides, meta, diff --git a/tests/unit/algorithms/common/adapters/mmcv/pipelines/test_load_image_from_otx_dataset.py b/tests/unit/algorithms/common/adapters/mmcv/pipelines/test_load_image_from_otx_dataset.py index c064d077240..67a706c9685 100644 --- a/tests/unit/algorithms/common/adapters/mmcv/pipelines/test_load_image_from_otx_dataset.py +++ b/tests/unit/algorithms/common/adapters/mmcv/pipelines/test_load_image_from_otx_dataset.py @@ -185,3 +185,19 @@ def test_enable_memcache(self, fxt_caching_dataset_cls, fxt_data_list): # The second round requires no read. assert mock.call_count == 0 + + +@pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) +def test_memcache_image_itemtype(mode): + img = (np.random.rand(10, 10, 3) * 255).astype(np.uint8) + MemCacheHandlerSingleton.create(mode, img.size * img.itemsize) + cache = MemCacheHandlerSingleton.get() + cache.put("img_u8", img) + img_cached, _ = cache.get("img_u8") + assert np.array_equal(img, img_cached) + img = np.random.rand(10, 10, 3).astype(np.float) + MemCacheHandlerSingleton.create(mode, img.size * img.itemsize) + cache = MemCacheHandlerSingleton.get() + cache.put("img_f32", img) + img_cached, _ = cache.get("img_f32") + assert np.array_equal(img, img_cached) diff --git a/tests/unit/algorithms/detection/adapters/mmdet/datasets/pipelines/test_load_pipelines.py b/tests/unit/algorithms/detection/adapters/mmdet/datasets/pipelines/test_load_pipelines.py new file mode 100644 index 00000000000..f03c54085fd --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/datasets/pipelines/test_load_pipelines.py @@ -0,0 +1,127 @@ +import numpy as np +import pytest +from PIL import Image +from typing import Iterator, List, Optional, Sequence, Tuple + +from otx.algorithms.detection.adapters.mmdet.datasets.pipelines import ( + LoadResizeDataFromOTXDataset, + ResizeTo, +) +from otx.api.entities.model_template import TaskType +from otx.core.data.caching import MemCacheHandlerSingleton +from tests.test_suite.e2e_test_system import e2e_pytest_unit +from tests.unit.algorithms.detection.test_helpers import generate_det_dataset + + +@e2e_pytest_unit +def test_load_resize_data_from_otx_dataset_call(mocker): + """Test LoadResizeDataFromOTXDataset.""" + otx_dataset, labels = generate_det_dataset( + TaskType.INSTANCE_SEGMENTATION, # covers det & iseg format both + image_width=320, + image_height=320, + ) + MemCacheHandlerSingleton.create("singleprocessing", otx_dataset[0].numpy.size) + op = LoadResizeDataFromOTXDataset( + load_ann_cfg=dict( + type="LoadAnnotationFromOTXDataset", + domain="instance_segmentation", + with_bbox=True, + with_mask=True, + poly2mask=False, + ), + resize_cfg=dict(type="ResizeTo", img_scale=(32, 16), keep_ratio=False), # 320x320 -> 16x32 + ) + src_dict = dict( + dataset_item=otx_dataset[0], + width=otx_dataset[0].width, + height=otx_dataset[0].height, + index=0, + ann_info=dict(label_list=labels), + bbox_fields=[], + mask_fields=[], + ) + dst_dict = op(src_dict) + assert dst_dict["ori_shape"][0] == 320 + assert dst_dict["img_shape"][0] == 16 # height + assert dst_dict["img"].shape == dst_dict["img_shape"] + assert dst_dict["gt_masks"].width == 32 + assert dst_dict["gt_masks"].height == 16 + op._load_img = mocker.MagicMock() + dst_dict_from_cache = op(src_dict) + assert op._load_img.call_count == 0 # _load_img() should not be called + assert np.array_equal(dst_dict["img"], dst_dict_from_cache["img"]) + assert (dst_dict["gt_labels"] == dst_dict_from_cache["gt_labels"]).all() + assert (dst_dict["gt_bboxes"] == dst_dict_from_cache["gt_bboxes"]).all() + assert dst_dict["gt_masks"] == dst_dict_from_cache["gt_masks"] + + +@e2e_pytest_unit +def test_load_resize_data_from_otx_dataset_downscale_only(mocker): + """Test LoadResizeDataFromOTXDataset.""" + otx_dataset, labels = generate_det_dataset( + TaskType.INSTANCE_SEGMENTATION, # covers det & iseg format both + image_width=320, + image_height=320, + ) + MemCacheHandlerSingleton.create("singleprocessing", otx_dataset[0].numpy.size) + op = LoadResizeDataFromOTXDataset( + load_ann_cfg=dict( + type="LoadAnnotationFromOTXDataset", + domain="instance_segmentation", + with_bbox=True, + with_mask=True, + poly2mask=False, + ), + resize_cfg=dict(type="ResizeTo", img_scale=(640, 640), downscale_only=True), # 320x320 -> 16x32 + ) + src_dict = dict( + dataset_item=otx_dataset[0], + width=otx_dataset[0].width, + height=otx_dataset[0].height, + index=0, + ann_info=dict(label_list=labels), + bbox_fields=[], + mask_fields=[], + ) + dst_dict = op(src_dict) + assert dst_dict["ori_shape"][0] == 320 + assert dst_dict["img_shape"][0] == 320 # Skipped upscale + assert dst_dict["img"].shape == dst_dict["img_shape"] + op._load_img_op = mocker.MagicMock() + dst_dict_from_cache = op(src_dict) + assert op._load_img_op.call_count == 0 # _load_img() should not be called + assert np.array_equal(dst_dict["img"], dst_dict_from_cache["img"]) + assert (dst_dict["gt_labels"] == dst_dict_from_cache["gt_labels"]).all() + assert (dst_dict["gt_bboxes"] == dst_dict_from_cache["gt_bboxes"]).all() + assert dst_dict["gt_masks"] == dst_dict_from_cache["gt_masks"] + + +@e2e_pytest_unit +def test_resize_to(mocker): + """Test ResizeTo.""" + src_dict = dict( + img=np.random.randint(0, 10, (16, 16, 3), dtype=np.uint8), + img_fields=["img"], + ori_shape=(16, 16), + img_shape=(16, 16), + ) + # Test downscale + op = ResizeTo(img_scale=(4, 4)) + dst_dict = op(src_dict) + assert dst_dict["ori_shape"][0] == 16 + assert dst_dict["img_shape"][0] == 4 + assert dst_dict["img"].shape == dst_dict["img_shape"] + # Test upscale from output + op = ResizeTo(img_scale=(8, 8)) + dst_dict = op(dst_dict) + assert dst_dict["ori_shape"][0] == 16 + assert dst_dict["img_shape"][0] == 8 + assert dst_dict["img"].shape == dst_dict["img_shape"] + # Test same size from output + op = ResizeTo(img_scale=(8, 8)) + op._resize_img = mocker.MagicMock() + dst_dict = op(dst_dict) + assert dst_dict["ori_shape"][0] == 16 + assert dst_dict["img_shape"][0] == 8 + assert op._resize_img.call_count == 0 # _resize_img() should not be called diff --git a/tests/unit/algorithms/detection/test_helpers.py b/tests/unit/algorithms/detection/test_helpers.py index 990fa38cbdb..32c8a8d67a7 100644 --- a/tests/unit/algorithms/detection/test_helpers.py +++ b/tests/unit/algorithms/detection/test_helpers.py @@ -68,7 +68,7 @@ def init_environment(params, model_template, task_type=TaskType.DETECTION): return environment -def generate_det_dataset(task_type, number_of_images=1): +def generate_det_dataset(task_type, number_of_images=1, image_width=640, image_height=480): classes = ("rectangle", "ellipse", "triangle") label_schema = generate_label_schema(classes, task_type_to_label_domain(task_type)) @@ -79,8 +79,8 @@ def generate_det_dataset(task_type, number_of_images=1): else: subset = Subset.TRAINING image_numpy, annos = generate_random_annotated_image( - image_width=640, - image_height=480, + image_width=image_width, + image_height=image_height, labels=label_schema.get_labels(False), ) # Convert shapes according to task