Skip to content

Commit 447cd9c

Browse files
authored
Sync rgb order between torch and ov inference of action classification task (#3551)
* Sync rgb order between torch and ov inference of action classification task * Fix unit tests * Add error for unsupported color format * Modify unit tests * Revert unnecessasry changes
1 parent b621890 commit 447cd9c

File tree

5 files changed

+45
-4
lines changed

5 files changed

+45
-4
lines changed

src/otx/core/data/dataset/action_classification.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,55 @@
66
from __future__ import annotations
77

88
from functools import partial
9-
from typing import Callable
9+
from typing import TYPE_CHECKING, Callable
1010

1111
import torch
1212
from datumaro import Label
1313

1414
from otx.core.data.dataset.base import OTXDataset
1515
from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsDataEntity
1616
from otx.core.data.entity.base import ImageInfo
17+
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER
18+
from otx.core.types.image import ImageColorChannel
19+
20+
if TYPE_CHECKING:
21+
from datumaro import DatasetSubset
22+
23+
from otx.core.data.dataset.base import Transforms
24+
from otx.core.data.mem_cache import MemCacheHandlerBase
1725

1826

1927
class OTXActionClsDataset(OTXDataset[ActionClsDataEntity]):
2028
"""OTXDataset class for action classification task."""
2129

30+
def __init__(
31+
self,
32+
dm_subset: DatasetSubset,
33+
transforms: Transforms,
34+
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
35+
mem_cache_img_max_size: tuple[int, int] | None = None,
36+
max_refetch: int = 1000,
37+
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
38+
stack_images: bool = True,
39+
to_tv_image: bool = True,
40+
) -> None:
41+
super().__init__(
42+
dm_subset,
43+
transforms,
44+
mem_cache_handler,
45+
mem_cache_img_max_size,
46+
max_refetch,
47+
image_color_channel,
48+
stack_images,
49+
to_tv_image,
50+
)
51+
# TODO(Someone): ImageColorChannel is not used in action classification task
52+
# This task only supports BGR color format.
53+
# There should be implementation that links between ImageColorChannel and action classification task.
54+
if self.image_color_channel != ImageColorChannel.BGR:
55+
msg = "Action classification task only supports BGR color format."
56+
raise ValueError(msg)
57+
2258
def _get_item_impl(self, idx: int) -> ActionClsDataEntity | None:
2359
item = self.dm_subset[idx]
2460

src/otx/core/data/transform_libs/torchvision.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,10 @@ def _transform(self, inpt: Video, params: dict) -> tv_tensors.Video:
236236
start_index = 0
237237
frame_inds = np.concatenate(frame_inds) + start_index
238238

239-
outputs = torch.stack([torch.tensor(inpt[idx].data) for idx in frame_inds], dim=0)
239+
outputs = torch.stack(
240+
[torch.tensor(cv2.cvtColor(inpt[idx].data, cv2.COLOR_RGB2BGR)) for idx in frame_inds],
241+
dim=0,
242+
)
240243
outputs = outputs.permute(0, 3, 1, 2)
241244
outputs = tv_tensors.Video(outputs)
242245
inpt.close()

src/otx/recipe/_base_/data/mmaction_base.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ config:
55
mem_cache_img_max_size:
66
- 500
77
- 500
8-
image_color_channel: RGB
8+
image_color_channel: BGR
99
stack_images: False
1010
unannotated_items_ratio: 0.0
1111
train_subset:

tests/unit/core/data/test_factory.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from otx.core.data.transform_libs.mmpretrain import MMPretrainTransformLib
2626
from otx.core.data.transform_libs.mmseg import MMSegTransformLib
2727
from otx.core.data.transform_libs.torchvision import TorchVisionTransformLib
28+
from otx.core.types.image import ImageColorChannel
2829
from otx.core.types.task import OTXTaskType
2930
from otx.core.types.transformer_libs import TransformLibType
3031

@@ -86,6 +87,7 @@ def test_create(
8687
cfg_data_module.vpm_config = mocker.MagicMock(spec=VisualPromptingConfig)
8788
cfg_data_module.vpm_config.use_bbox = False
8889
cfg_data_module.vpm_config.use_point = False
90+
cfg_data_module.image_color_channel = ImageColorChannel.BGR
8991
mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo)
9092
assert isinstance(
9193
OTXDatasetFactory.create(

tests/unit/core/data/transform_libs/test_torchvision.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
class MockFrame:
41-
data = np.ndarray([3, 10, 10])
41+
data = np.ndarray([10, 10, 3], dtype=np.uint8)
4242

4343

4444
class MockVideo:

0 commit comments

Comments
 (0)