Skip to content

Commit 2cb5d2b

Browse files
author
Songki Choi
authored
Merge load / resize / cache to optimize data loading efficiency for detection & instance segmentation (#2453)
* Implement LoadResizeDataFromOTXDataset & ResizeTo for det/iseg * Apply LoadResizeDataFromOTXDataset to ATSS * Enable non-uint8 image caching * Apply LoadResizeDataFromOTXDataset to SSD * Apply LoadResizeDataFromOTXDataset to YOLOX * Apply LoadResizeDataFromOTXDataset to MaskRCNNs * Cache "cache_key" in the pipeline --------- Signed-off-by: Songki Choi <songki.choi@intel.com>
1 parent e634abf commit 2cb5d2b

File tree

12 files changed

+282
-49
lines changed

12 files changed

+282
-49
lines changed

src/otx/algorithms/common/adapters/mmcv/pipelines/load_image_from_otx_dataset.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ def _get_unique_key(results: Dict[str, Any]) -> Tuple:
4545
# TODO: We should improve it by assigning an unique id to DatasetItemEntity.
4646
# This is because there is a case which
4747
# d_item.media.path is None, but d_item.media.data is not None
48+
if "cache_key" in results:
49+
return results["cache_key"]
4850
d_item = results["dataset_item"]
49-
return d_item.media.path, d_item.roi.id
51+
results["cache_key"] = d_item.media.path, d_item.roi.id
52+
return results["cache_key"]
5053

5154
def __call__(self, results: Dict[str, Any]):
5255
"""Callback function of LoadImageFromOTXDataset."""
@@ -177,12 +180,12 @@ def _save_cache(self, results: Dict[str, Any]):
177180
return
178181
key = self._get_unique_key(results)
179182
meta = results.copy()
180-
meta.pop("dataset_item") # remove irrlevant info
181183
img = meta.pop("img")
182184
self._mem_cache_handler.put(key, img, meta)
183185

184186
def __call__(self, results: Dict[str, Any]) -> Dict[str, Any]:
185187
"""Callback function."""
188+
results = results.copy()
186189
cached_results = self._load_cache(results)
187190
if cached_results:
188191
return cached_results

src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6-
from .load_pipelines import LoadAnnotationFromOTXDataset, LoadImageFromOTXDataset
6+
from .load_pipelines import (
7+
LoadAnnotationFromOTXDataset,
8+
LoadImageFromOTXDataset,
9+
LoadResizeDataFromOTXDataset,
10+
ResizeTo,
11+
)
712
from .torchvision2mmdet import (
813
BranchImage,
914
ColorJitter,
@@ -19,6 +24,8 @@
1924
__all__ = [
2025
"LoadImageFromOTXDataset",
2126
"LoadAnnotationFromOTXDataset",
27+
"LoadResizeDataFromOTXDataset",
28+
"ResizeTo",
2229
"ColorJitter",
2330
"RandomGrayscale",
2431
"RandomErasing",

src/otx/algorithms/detection/adapters/mmdet/datasets/pipelines/load_pipelines.py

+51-16
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
11
"""Collection Pipeline for detection task."""
2-
# Copyright (C) 2021 Intel Corporation
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License");
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing,
11-
# software distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions
14-
# and limitations under the License.
2+
# Copyright (C) 2021-2023 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
155
import copy
16-
from typing import Any, Dict
6+
from typing import Any, Dict, Optional
177

18-
from mmdet.datasets.builder import PIPELINES
8+
from mmdet.datasets.builder import PIPELINES, build_from_cfg
9+
from mmdet.datasets.pipelines import Resize
1910

2011
import otx.algorithms.common.adapters.mmcv.pipelines.load_image_from_otx_dataset as load_image_base
2112
from otx.algorithms.detection.adapters.mmdet.datasets.dataset import (
@@ -30,6 +21,50 @@ class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset):
3021
"""Pipeline element that loads an image from a OTX Dataset on the fly."""
3122

3223

24+
@PIPELINES.register_module()
25+
class LoadResizeDataFromOTXDataset(load_image_base.LoadResizeDataFromOTXDataset):
26+
"""Load and resize image & annotation with cache support."""
27+
28+
def _create_load_ann_op(self, cfg: Optional[Dict]) -> Optional[Any]:
29+
"""Creates resize operation."""
30+
if cfg is None:
31+
return None
32+
return build_from_cfg(cfg, PIPELINES)
33+
34+
def _create_resize_op(self, cfg: Optional[Dict]) -> Optional[Any]:
35+
"""Creates resize operation."""
36+
if cfg is None:
37+
return None
38+
return build_from_cfg(cfg, PIPELINES)
39+
40+
41+
@PIPELINES.register_module()
42+
class ResizeTo(Resize):
43+
"""Resize to specific size.
44+
45+
This operation works if the input is not in desired shape.
46+
If it's already in the shape, it just returns input dict for efficiency.
47+
48+
Args:
49+
img_scale (tuple): Images scales for resizing (w, h).
50+
"""
51+
52+
def __init__(self, **kwargs):
53+
super().__init__(override=True, **kwargs) # Allow multiple calls
54+
55+
def __call__(self, results: Dict[str, Any]):
56+
"""Callback function of ResizeTo.
57+
58+
Args:
59+
results: Inputs to be transformed.
60+
"""
61+
img_shape = results.get("img_shape", (0, 0))
62+
img_scale = self.img_scale[0]
63+
if img_shape[0] == img_scale[0] and img_shape[1] == img_scale[1]:
64+
return results
65+
return super().__call__(results)
66+
67+
3368
@PIPELINES.register_module()
3469
class LoadAnnotationFromOTXDataset:
3570
"""Pipeline element that loads an annotation from a OTX Dataset on the fly.
@@ -84,7 +119,7 @@ def _load_masks(results, ann_info):
84119

85120
def __call__(self, results: Dict[str, Any]):
86121
"""Callback function of LoadAnnotationFromOTXDataset."""
87-
dataset_item = results.pop("dataset_item")
122+
dataset_item = results.pop("dataset_item") # Prevent unnecessary deepcopy
88123
label_list = results.pop("ann_info")["label_list"]
89124
ann_info = get_annotation_mmdet_format(dataset_item, label_list, self.domain, self.min_size)
90125
if self.with_bbox:

src/otx/algorithms/detection/configs/base/data/atss_data_pipeline.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,24 @@
99
__img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True)
1010

1111
train_pipeline = [
12-
dict(type="LoadImageFromOTXDataset", enable_memcache=True),
13-
dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
12+
dict(
13+
type="LoadResizeDataFromOTXDataset",
14+
load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
15+
resize_cfg=dict(
16+
type="Resize",
17+
img_scale=(1088, 800), # max sizes in random image scales
18+
keep_ratio=True,
19+
downscale_only=True,
20+
), # Resize to intermediate size if org image is bigger
21+
enable_memcache=True, # Cache after resizing image & annotations
22+
),
1423
dict(type="MinIoURandomCrop", min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3),
1524
dict(
1625
type="Resize",
1726
img_scale=[(992, 736), (896, 736), (1088, 736), (992, 672), (992, 800)],
1827
multiscale_mode="value",
1928
keep_ratio=False,
29+
override=True, # Allow multiple resize
2030
),
2131
dict(type="RandomFlip", flip_ratio=0.5),
2232
dict(type="Normalize", **__img_norm_cfg),

src/otx/algorithms/detection/configs/base/data/iseg_efficientnet_data_pipeline.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,22 @@
1111
__img_norm_cfg = dict(mean=(103.53, 116.28, 123.675), std=(1.0, 1.0, 1.0), to_rgb=True)
1212

1313
train_pipeline = [
14-
dict(type="LoadImageFromOTXDataset", enable_memcache=True),
1514
dict(
16-
type="LoadAnnotationFromOTXDataset",
17-
domain="instance_segmentation",
18-
with_bbox=True,
19-
with_mask=True,
20-
poly2mask=False,
15+
type="LoadResizeDataFromOTXDataset",
16+
load_ann_cfg=dict(
17+
type="LoadAnnotationFromOTXDataset",
18+
domain="instance_segmentation",
19+
with_bbox=True,
20+
with_mask=True,
21+
poly2mask=False,
22+
),
23+
resize_cfg=dict(
24+
type="Resize",
25+
img_scale=__img_size,
26+
keep_ratio=False,
27+
),
28+
enable_memcache=True, # Cache after resizing image & annotations
2129
),
22-
dict(type="Resize", img_scale=__img_size, keep_ratio=False),
2330
dict(type="RandomFlip", flip_ratio=0.5),
2431
dict(type="Normalize", **__img_norm_cfg),
2532
dict(type="Pad", size_divisor=32),

src/otx/algorithms/detection/configs/base/data/iseg_resnet_data_pipeline.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,22 @@
1111
__img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
1212

1313
train_pipeline = [
14-
dict(type="LoadImageFromOTXDataset", enable_memcache=True),
1514
dict(
16-
type="LoadAnnotationFromOTXDataset",
17-
domain="instance_segmentation",
18-
with_bbox=True,
19-
with_mask=True,
20-
poly2mask=False,
15+
type="LoadResizeDataFromOTXDataset",
16+
load_ann_cfg=dict(
17+
type="LoadAnnotationFromOTXDataset",
18+
domain="instance_segmentation",
19+
with_bbox=True,
20+
with_mask=True,
21+
poly2mask=False,
22+
),
23+
resize_cfg=dict(
24+
type="Resize",
25+
img_scale=__img_size,
26+
keep_ratio=False,
27+
),
28+
enable_memcache=True, # Cache after resizing image & annotations
2129
),
22-
dict(type="Resize", img_scale=__img_size, keep_ratio=False),
2330
dict(type="RandomFlip", flip_ratio=0.5),
2431
dict(type="Normalize", **__img_norm_cfg),
2532
dict(type="DefaultFormatBundle"),

src/otx/algorithms/detection/configs/detection/cspdarknet_yolox/data_pipeline.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
hue_delta=18,
3535
),
3636
dict(type="RandomFlip", flip_ratio=0.5),
37-
dict(type="Resize", img_scale=__img_size, keep_ratio=True),
37+
dict(type="Resize", img_scale=__img_size, keep_ratio=True, override=True), # Allow multiple resize
3838
dict(type="Pad", pad_to_square=True, pad_val=114.0),
3939
dict(type="Normalize", **__img_norm_cfg),
4040
dict(type="DefaultFormatBundle"),
@@ -82,8 +82,18 @@
8282
dataset=dict(
8383
type=__dataset_type,
8484
pipeline=[
85-
dict(type="LoadImageFromOTXDataset", to_float32=False, enable_memcache=True),
86-
dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
85+
dict(
86+
type="LoadResizeDataFromOTXDataset",
87+
load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
88+
resize_cfg=dict(
89+
type="Resize",
90+
img_scale=__img_size,
91+
keep_ratio=True,
92+
downscale_only=True,
93+
), # Resize to intermediate size if org image is bigger
94+
to_float32=False,
95+
enable_memcache=True, # Cache after resizing image & annotations
96+
),
8797
],
8898
),
8999
pipeline=train_pipeline,

src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/data_pipeline.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,18 @@
2121
__img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True)
2222

2323
train_pipeline = [
24-
dict(type="LoadImageFromOTXDataset", to_float32=True, enable_memcache=True),
25-
dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
24+
dict(
25+
type="LoadResizeDataFromOTXDataset",
26+
load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
27+
resize_cfg=dict(
28+
type="Resize",
29+
img_scale=__img_size,
30+
keep_ratio=True,
31+
downscale_only=True,
32+
), # Resize to intermediate size if org image is bigger
33+
to_float32=True,
34+
enable_memcache=True, # Cache after resizing image & annotations
35+
),
2636
dict(
2737
type="PhotoMetricDistortion",
2838
brightness_delta=32,
@@ -31,7 +41,7 @@
3141
hue_delta=18,
3242
),
3343
dict(type="MinIoURandomCrop", min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.1),
34-
dict(type="Resize", img_scale=__img_size, keep_ratio=False),
44+
dict(type="Resize", img_scale=__img_size, keep_ratio=False, override=True), # Allow multiple resize
3545
dict(type="Normalize", **__img_norm_cfg),
3646
dict(type="RandomFlip", flip_ratio=0.5),
3747
dict(type="DefaultFormatBundle"),

src/otx/core/data/caching/mem_cache_handler.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def get(self, key: Any) -> Tuple[Optional[np.ndarray], Optional[Dict]]:
6363

6464
addr = self._cache_addr[key]
6565

66-
offset, count, shape, strides, meta = addr
66+
offset, count, dtype, shape, strides, meta = addr
6767

68-
data = np.frombuffer(self._arr, dtype=np.uint8, count=count, offset=offset)
68+
data = np.frombuffer(self._arr, dtype=dtype, count=count, offset=offset)
6969
return np.lib.stride_tricks.as_strided(data, shape, strides), meta
7070

7171
def put(self, key: Any, data: np.ndarray, meta: Optional[Dict] = None) -> Optional[int]:
@@ -82,20 +82,21 @@ def put(self, key: Any, data: np.ndarray, meta: Optional[Dict] = None) -> Option
8282
if self._freeze.value:
8383
return None
8484

85-
assert data.dtype == np.uint8
85+
data_bytes = data.size * data.itemsize
8686

8787
with self._lock:
88-
new_page = self._cur_page.value + data.size
88+
new_page = self._cur_page.value + data_bytes
8989

9090
if key in self._cache_addr or new_page > self.mem_size:
9191
return None
9292

9393
offset = ct.byref(self._arr, self._cur_page.value)
94-
ct.memmove(offset, data.ctypes.data, data.size)
94+
ct.memmove(offset, data.ctypes.data, data_bytes)
9595

9696
self._cache_addr[key] = (
9797
self._cur_page.value,
9898
data.size,
99+
data.dtype,
99100
data.shape,
100101
data.strides,
101102
meta,

tests/unit/algorithms/common/adapters/mmcv/pipelines/test_load_image_from_otx_dataset.py

+16
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,19 @@ def test_enable_memcache(self, fxt_caching_dataset_cls, fxt_data_list):
185185

186186
# The second round requires no read.
187187
assert mock.call_count == 0
188+
189+
190+
@pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"])
191+
def test_memcache_image_itemtype(mode):
192+
img = (np.random.rand(10, 10, 3) * 255).astype(np.uint8)
193+
MemCacheHandlerSingleton.create(mode, img.size * img.itemsize)
194+
cache = MemCacheHandlerSingleton.get()
195+
cache.put("img_u8", img)
196+
img_cached, _ = cache.get("img_u8")
197+
assert np.array_equal(img, img_cached)
198+
img = np.random.rand(10, 10, 3).astype(np.float)
199+
MemCacheHandlerSingleton.create(mode, img.size * img.itemsize)
200+
cache = MemCacheHandlerSingleton.get()
201+
cache.put("img_f32", img)
202+
img_cached, _ = cache.get("img_f32")
203+
assert np.array_equal(img, img_cached)

0 commit comments

Comments
 (0)