Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync rgb order between torch and ov inference of action classification task #3551

Merged
merged 6 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/otx/core/data/dataset/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,55 @@
from __future__ import annotations

from functools import partial
from typing import Callable
from typing import TYPE_CHECKING, Callable

import torch
from datumaro import Label

from otx.core.data.dataset.base import OTXDataset
from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsDataEntity
from otx.core.data.entity.base import ImageInfo
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER
from otx.core.types.image import ImageColorChannel

if TYPE_CHECKING:
from datumaro import DatasetSubset

from otx.core.data.dataset.base import Transforms
from otx.core.data.mem_cache import MemCacheHandlerBase


class OTXActionClsDataset(OTXDataset[ActionClsDataEntity]):
"""OTXDataset class for action classification task."""

def __init__(
self,
dm_subset: DatasetSubset,
transforms: Transforms,
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
mem_cache_img_max_size: tuple[int, int] | None = None,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
stack_images: bool = True,
to_tv_image: bool = True,
) -> None:
super().__init__(
dm_subset,
transforms,
mem_cache_handler,
mem_cache_img_max_size,
max_refetch,
image_color_channel,
stack_images,
to_tv_image,
)
# TODO(Someone): ImageColorChannel is not used in action classification task
# This task only supports BGR color format.
# There should be implementation that links between ImageColorChannel and action classification task.
if self.image_color_channel != ImageColorChannel.BGR:
msg = "Action classification task only supports BGR color format."
raise ValueError(msg)

def _get_item_impl(self, idx: int) -> ActionClsDataEntity | None:
item = self.dm_subset[idx]

Expand Down
5 changes: 4 additions & 1 deletion src/otx/core/data/transform_libs/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ def _transform(self, inpt: Video, params: dict) -> tv_tensors.Video:
start_index = 0
frame_inds = np.concatenate(frame_inds) + start_index

outputs = torch.stack([torch.tensor(inpt[idx].data) for idx in frame_inds], dim=0)
outputs = torch.stack(
[torch.tensor(cv2.cvtColor(inpt[idx].data, cv2.COLOR_RGB2BGR)) for idx in frame_inds],
dim=0,
)
outputs = outputs.permute(0, 3, 1, 2)
outputs = tv_tensors.Video(outputs)
inpt.close()
Expand Down
4 changes: 2 additions & 2 deletions src/otx/recipe/_base_/data/mmaction_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ config:
mem_cache_img_max_size:
- 500
- 500
image_color_channel: RGB
image_color_channel: BGR
stack_images: False
unannotated_items_ratio: 0.0
train_subset:
Expand Down Expand Up @@ -78,7 +78,7 @@ config:
transform_lib_type: MMACTION
to_tv_image: False
batch_size: 8
num_workers: 2
num_workers: 0
transforms:
- type: LoadVideoForClassification
- type: DecordInit
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/core/data/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from otx.core.data.transform_libs.mmpretrain import MMPretrainTransformLib
from otx.core.data.transform_libs.mmseg import MMSegTransformLib
from otx.core.data.transform_libs.torchvision import TorchVisionTransformLib
from otx.core.types.image import ImageColorChannel
from otx.core.types.task import OTXTaskType
from otx.core.types.transformer_libs import TransformLibType

Expand Down Expand Up @@ -86,6 +87,7 @@ def test_create(
cfg_data_module.vpm_config = mocker.MagicMock(spec=VisualPromptingConfig)
cfg_data_module.vpm_config.use_bbox = False
cfg_data_module.vpm_config.use_point = False
cfg_data_module.image_color_channel = ImageColorChannel.BGR
mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo)
assert isinstance(
OTXDatasetFactory.create(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/core/data/transform_libs/test_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


class MockFrame:
data = np.ndarray([3, 10, 10])
data = np.ndarray([10, 10, 3], dtype=np.uint8)


class MockVideo:
Expand Down
Loading