From b219a4338d74a888fb459430943ba5a94bb759f7 Mon Sep 17 00:00:00 2001 From: Yanhong Zeng Date: Fri, 13 Jan 2023 16:47:27 +0800 Subject: [PATCH] [Refactor] Refactor PackEditInputs and EditDataSample (#1573) * refactor packeditinput and editdatasample * move util from formatting to img_utils * fix bugs in ut * use permute instead of transpose in all_to_tensor * remove undesired output results * fix the path of results in ut * use np.ascontiguousarray(img) in image_to_tensor * remove entry function from unit test files * refine sinGAN's config Co-authored-by: LeoXing1996 --- .gitignore | 1 + configs/singan/singan_balloons.py | 7 +- configs/singan/singan_fish.py | 6 +- .../apis/inferencers/inference_functions.py | 2 +- mmedit/datasets/transforms/__init__.py | 2 +- mmedit/datasets/transforms/formatting.py | 355 +++------- .../datasets/transforms/generate_assistant.py | 5 +- mmedit/models/base_models/base_mattor.py | 2 +- mmedit/structures/edit_data_sample.py | 626 +++--------------- mmedit/utils/__init__.py | 6 +- mmedit/utils/img_utils.py | 71 ++ .../test_base_mmedit_inferencer.py | 4 - .../test_conditional_inferencer.py | 7 +- .../test_inpainting_inferencer.py | 6 +- .../test_matting_inferencer.py | 6 +- .../test_mmedit_inferencer.py | 4 - .../test_restoration_inferencer.py | 7 +- .../test_text2image_inferencers.py | 2 +- .../test_translation_inferencer.py | 7 +- .../test_unconditional_inferencer.py | 7 +- .../test_video_interpolation_inferencer.py | 4 - .../test_video_restoration_inferencer.py | 12 +- tests/test_datasets/test_singan_dataset.py | 15 +- .../test_transforms/test_formatting.py | 100 +-- .../test_base_models/test_base_mattor.py | 66 +- .../test_editors/test_dim/test_dim.py | 35 +- .../test_editors/test_gca/test_gca.py | 21 +- .../test_indexnet/test_indexnet.py | 67 +- .../test_structures/test_edit_data_sample.py | 4 - tests/test_utils/test_img_utils.py | 29 +- 30 files changed, 424 insertions(+), 1062 deletions(-) diff --git a/.gitignore b/.gitignore index 2a5bba964c..81448c8213 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ coverage.xml *.cover .hypothesis/ .pytest_cache/ +tests/data/out # Translations *.mo diff --git a/configs/singan/singan_balloons.py b/configs/singan/singan_balloons.py index 0f1dbfb93c..56d08cb093 100644 --- a/configs/singan/singan_balloons.py +++ b/configs/singan/singan_balloons.py @@ -17,8 +17,13 @@ test_pkl_data=test_pkl_data) # DATA +pipeline = [ + dict( + type='PackEditInputs', + keys=[f'real_scale{i}' for i in range(num_scales)] + ['input_sample']) +] data_root = './data/singan/balloons.png' -train_dataloader = dict(dataset=dict(data_root=data_root)) +train_dataloader = dict(dataset=dict(data_root=data_root, pipeline=pipeline)) # HOOKS custom_hooks = [ diff --git a/configs/singan/singan_fish.py b/configs/singan/singan_fish.py index 1ffbf734eb..8b24a9e0cb 100644 --- a/configs/singan/singan_fish.py +++ b/configs/singan/singan_fish.py @@ -41,7 +41,11 @@ dataset_type = 'SinGANDataset' data_root = './data/singan/fish-crop.jpg' -pipeline = [dict(type='PackEditInputs', pack_all=True)] +pipeline = [ + dict( + type='PackEditInputs', + keys=[f'real_scale{i}' for i in range(num_scales)] + ['input_sample']) +] dataset = dict( type=dataset_type, data_root=data_root, diff --git a/mmedit/apis/inferencers/inference_functions.py b/mmedit/apis/inferencers/inference_functions.py index f6bc78656f..a0cf6e77d5 100644 --- a/mmedit/apis/inferencers/inference_functions.py +++ b/mmedit/apis/inferencers/inference_functions.py @@ -303,7 +303,7 @@ def matting_inference(model, img, trimap): # prepare data data = dict(merged_path=img, trimap_path=trimap) _data = test_pipeline(data) - trimap = _data['data_samples'].trimap.data + trimap = _data['data_samples'].trimap data = dict() data['inputs'] = torch.cat([_data['inputs'], trimap], dim=0).float() data = collate([data]) diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py index 39543085f1..cb0da835f2 100644 --- a/mmedit/datasets/transforms/__init__.py +++ b/mmedit/datasets/transforms/__init__.py @@ -11,7 +11,7 @@ RandomResizedCrop) from .fgbg import (CompositeFg, MergeFgAndBg, PerturbBg, RandomJitter, RandomLoadResizeBg) -from .formatting import PackEditInputs, ToTensor +from .formatting import PackEditInputs from .generate_assistant import (GenerateCoordinateAndCell, GenerateFacialHeatmap) from .generate_frame_indices import (GenerateFrameIndices, diff --git a/mmedit/datasets/transforms/formatting.py b/mmedit/datasets/transforms/formatting.py index 5df741884d..b525257d4b 100644 --- a/mmedit/datasets/transforms/formatting.py +++ b/mmedit/datasets/transforms/formatting.py @@ -1,128 +1,52 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List, Tuple +from typing import List, Tuple -import numpy as np -import torch -from mmcv.transforms import to_tensor from mmcv.transforms.base import BaseTransform from mmedit.registry import TRANSFORMS -from mmedit.structures import EditDataSample, PixelData - - -def check_if_image(value: Any) -> bool: - """Check if the input value is image or images. - - If value is a list or Tuple, - recursively check if each element in ``value`` is image. - - Args: - value (Any): The value to be checked. - - Returns: - bool: If the value is image or sequence of images. - """ - - if isinstance(value, (List, Tuple)): - is_image = (len(value) > 0) - for v in value: - is_image = is_image and check_if_image(v) - - else: - is_image = isinstance(value, np.ndarray) and len(value.shape) > 1 - - return is_image - - -def image_to_tensor(img): - """Trans image to tensor. - - Args: - img (np.ndarray): The original image. - - Returns: - Tensor: The output tensor. - """ - - if len(img.shape) < 3: - img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)) - tensor = to_tensor(img) - - return tensor - - -def images_to_tensor(value): - """Trans image and sequence of frames to tensor. - - Args: - value (np.ndarray | list[np.ndarray] | Tuple[np.ndarray]): - The original image or list of frames. - - Returns: - Tensor: The output tensor. - """ - - if isinstance(value, (List, Tuple)): - # sequence of frames - frames = [image_to_tensor(v) for v in value] - tensor = torch.stack(frames, dim=0) - elif isinstance(value, np.ndarray): - tensor = image_to_tensor(value) - else: - # Maybe the data has been converted to Tensor. - tensor = to_tensor(value) - - return tensor - - -def can_convert_to_image(value): - """Judge whether the input value can be converted to image tensor via - :func:`images_to_tensor` function. - - Args: - value (any): The input value. - - Returns: - bool: If true, the input value can convert to image with - :func:`images_to_tensor`, and vice versa. - """ - if isinstance(value, (List, Tuple)): - return all([can_convert_to_image(v) for v in value]) - elif isinstance(value, np.ndarray): - return True - elif isinstance(value, torch.Tensor): - return True - else: - return False +from mmedit.structures import EditDataSample +from mmedit.utils import all_to_tensor @TRANSFORMS.register_module() class PackEditInputs(BaseTransform): - """Pack the inputs data for SR, VFI, matting and inpainting. + """Pack data into EditDataSample for training, evaluation and testing. - Keys for images include ``img``, ``gt``, ``ref``, ``mask``, ``gt_heatmap``, - ``trimap``, ``gt_alpha``, ``gt_fg``, ``gt_bg``. All of them will be - packed into data field of EditDataSample. - pack_all (bool): Whether pack all variables in `results` to `inputs` dict. - This is useful when keys of the input dict is not fixed. - Please be careful when using this function, because we do not - Defaults to False. + MMediting follows the design of data structure from MMEngine. + Data from the loader will be packed into data field of EditDataSample. + More details of DataSample refer to the documentation of MMEngine: + https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html - Others will be packed into metainfo field of EditDataSample. + Args: + keys Tuple[List[str], str, None]: The keys to saved in returned + inputs, which are used as the input of models, default to + ['img', 'noise', 'merged']. + data_keys Tuple[List[str], str, None]: The keys to saved in + `data_field` of the `data_samples`. + meta_keys Tuple[List[str], str, None]: The meta keys to saved + in `metainfo` of the `data_samples`. All the other data will + be packed into the data of the `data_samples` """ - def __init__(self, - keys: Tuple[List[str], str, None] = None, - pack_all: bool = False): - if keys is not None: - if isinstance(keys, list): - self.keys = keys - else: - self.keys = [keys] - else: - self.keys = None - self.pack_all = pack_all + def __init__( + self, + keys: Tuple[List[str], str] = ['merged', 'img'], + meta_keys: Tuple[List[str], str] = [], + data_keys: Tuple[List[str], str] = [], + ) -> None: + + assert keys is not None, \ + 'keys in PackEditInputs can not be None.' + assert data_keys is not None, \ + 'data_keys in PackEditInputs can not be None.' + assert meta_keys is not None, \ + 'meta_keys in PackEditInputs can not be None.' + + self.keys = keys if isinstance(keys, List) else [keys] + self.data_keys = data_keys if isinstance(data_keys, + List) else [data_keys] + self.meta_keys = meta_keys if isinstance(meta_keys, + List) else [meta_keys] def transform(self, results: dict) -> dict: """Method to pack the input data. @@ -131,185 +55,54 @@ def transform(self, results: dict) -> dict: results (dict): Result dict from the data pipeline. Returns: - dict: + dict: A dict contains + + - 'inputs' (obj:`dict`): The forward data of models. + According to different tasks, the `inputs` may contain images, + videos, labels, text, etc. - - 'inputs' (obj:`torch.Tensor`): The forward data of models. - 'data_samples' (obj:`EditDataSample`): The annotation info of the sample. """ - packed_results = dict() - data_sample = EditDataSample() - - pack_keys = [k for k in results.keys()] if self.pack_all else self.keys - if pack_keys is not None: - packed_results['inputs'] = dict() - for key in pack_keys: - val = results[key] - if can_convert_to_image(val): - packed_results['inputs'][key] = images_to_tensor(val) - results.pop(key) - - elif 'img' in results: - img = results.pop('img') - img_tensor = images_to_tensor(img) - packed_results['inputs'] = img_tensor - data_sample.input = PixelData(data=img_tensor.clone()) - - if 'gt' in results: - gt = results.pop('gt') - gt_tensor = images_to_tensor(gt) - if len(gt_tensor.shape) > 3 and gt_tensor.size(0) == 1: - gt_tensor.squeeze_(0) - data_sample.gt_img = PixelData(data=gt_tensor) - - if 'gt_label' in results: - gt_label = results.pop('gt_label') - data_sample.set_gt_label(gt_label) + # prepare inputs + inputs = dict() + for k in self.keys: + value = results.get(k, None) + if value is not None: + inputs[k] = all_to_tensor(value) - if 'img_lq' in results: - img_lq = results.pop('img_lq') - img_lq_tensor = images_to_tensor(img_lq) - data_sample.img_lq = PixelData(data=img_lq_tensor) + # return the inputs as tensor, if it has only one item + if len(inputs.values()) == 1: + inputs = list(inputs.values())[0] - if 'ref' in results: - ref = results.pop('ref') - ref_tensor = images_to_tensor(ref) - data_sample.ref_img = PixelData(data=ref_tensor) - - if 'ref_lq' in results: - ref_lq = results.pop('ref_lq') - ref_lq_tensor = images_to_tensor(ref_lq) - data_sample.ref_lq = PixelData(data=ref_lq_tensor) - - if 'mask' in results: - mask = results.pop('mask') - mask_tensor = images_to_tensor(mask) - data_sample.mask = PixelData(data=mask_tensor) - - if 'gt_heatmap' in results: - gt_heatmap = results.pop('gt_heatmap') - gt_heatmap_tensor = images_to_tensor(gt_heatmap) - data_sample.gt_heatmap = PixelData(data=gt_heatmap_tensor) - - if 'gt_unsharp' in results: - gt_unsharp = results.pop('gt_unsharp') - gt_unsharp_tensor = images_to_tensor(gt_unsharp) - data_sample.gt_unsharp = PixelData(data=gt_unsharp_tensor) - - if 'merged' in results: - # image in matting annotation is named merged - img = results.pop('merged') - img_tensor = images_to_tensor(img) - # used for model inputs - packed_results['inputs'] = img_tensor - # used as ground truth for composition losses - data_sample.gt_merged = PixelData(data=img_tensor.clone()) - - if 'trimap' in results: - trimap = results.pop('trimap') - trimap_tensor = images_to_tensor(trimap) - data_sample.trimap = PixelData(data=trimap_tensor) - - if 'alpha' in results: - # gt_alpha in matting annotation is named alpha - gt_alpha = results.pop('alpha') - gt_alpha_tensor = images_to_tensor(gt_alpha) - data_sample.gt_alpha = PixelData(data=gt_alpha_tensor) - - if 'fg' in results: - # gt_fg in matting annotation is named fg - gt_fg = results.pop('fg') - gt_fg_tensor = images_to_tensor(gt_fg) - data_sample.gt_fg = PixelData(data=gt_fg_tensor) - - if 'bg' in results: - # gt_bg in matting annotation is named bg - gt_bg = results.pop('bg') - gt_bg_tensor = images_to_tensor(gt_bg) - data_sample.gt_bg = PixelData(data=gt_bg_tensor) - - if 'rgb_img' in results: - gt_rgb = results.pop('rgb_img') - gt_rgb_tensor = images_to_tensor(gt_rgb) - data_sample.gt_rgb = PixelData(data=gt_rgb_tensor) - - if 'gray_img' in results: - gray = results.pop('gray_img') - gray_tensor = images_to_tensor(gray) - data_sample.gray = PixelData(data=gray_tensor) - - if 'cropped_img' in results: - cropped_img = results.pop('cropped_img') - cropped_img = images_to_tensor(cropped_img) - data_sample.cropped_img = PixelData(data=cropped_img) - - metainfo = dict() - for key in results: - metainfo[key] = results[key] - - data_sample.set_metainfo(metainfo=metainfo) - - packed_results['data_samples'] = data_sample - - return packed_results + data_sample = EditDataSample() + # prepare metainfo and data in DataSample according to predefined keys + predefined_data = { + k: v + for (k, v) in results.items() + if k not in (self.data_keys + self.meta_keys) + } + data_sample.set_predefined_data(predefined_data) + + # prepare metainfo in DataSample according to user-provided meta_keys + required_metainfo = { + k: v + for (k, v) in results.items() if k in self.meta_keys + } + data_sample.set_metainfo(required_metainfo) + + # prepare metainfo in DataSample according to user-provided data_keys + required_data = { + k: v + for (k, v) in results.items() if k in self.data_keys + } + data_sample.set_tensor_data(required_data) + + return {'inputs': inputs, 'data_samples': data_sample} def __repr__(self) -> str: repr_str = self.__class__.__name__ return repr_str - - -@TRANSFORMS.register_module() -class ToTensor(BaseTransform): - """Convert some values in results dict to `torch.Tensor` type in data - loader pipeline. - - Args: - keys (Sequence[str]): Required keys to be converted. - to_float32 (bool): Whether convert tensors of images to float32. - Default: True. - """ - - def __init__(self, keys, to_float32=True): - - self.keys = keys - self.to_float32 = to_float32 - - def _data_to_tensor(self, value): - """Convert the value to tensor.""" - is_image = check_if_image(value) - - if is_image: - tensor = images_to_tensor(value) - if self.to_float32: - tensor = tensor.float() - if len(tensor.shape) > 3 and tensor.size(0) == 1: - tensor.squeeze_(0) - - else: - tensor = to_tensor(value) - - return tensor - - def transform(self, results): - """transform function. - - Args: - results (dict): A dict containing the necessary information and - data for augmentation. - - Returns: - dict: A dict containing the processed data and information. - """ - - for key in self.keys: - results[key] = self._data_to_tensor(results[key]) - - return results - - def __repr__(self): - - return self.__class__.__name__ + ( - f'(keys={self.keys}, to_float32={self.to_float32})') diff --git a/mmedit/datasets/transforms/generate_assistant.py b/mmedit/datasets/transforms/generate_assistant.py index 202f95fcce..f305b34d0f 100644 --- a/mmedit/datasets/transforms/generate_assistant.py +++ b/mmedit/datasets/transforms/generate_assistant.py @@ -4,8 +4,7 @@ from mmcv.transforms.base import BaseTransform from mmedit.registry import TRANSFORMS -from mmedit.utils import make_coord -from .formatting import images_to_tensor +from mmedit.utils import all_to_tensor, make_coord try: import face_alignment @@ -85,7 +84,7 @@ def transform(self, results): # generate hr_coord (and hr_rgb) if 'gt' in results: crop_hr = results['gt'] - crop_hr = images_to_tensor(crop_hr) + crop_hr = all_to_tensor(crop_hr) self.target_size = crop_hr.shape if self.reshape_gt: hr_rgb = crop_hr.contiguous().view(3, -1).permute(1, 0) diff --git a/mmedit/models/base_models/base_mattor.py b/mmedit/models/base_models/base_mattor.py index 65b331cc4a..720a5baee8 100644 --- a/mmedit/models/base_models/base_mattor.py +++ b/mmedit/models/base_models/base_mattor.py @@ -192,7 +192,7 @@ def postprocess( pa = pa[0] # H, W pa.clamp_(min=0, max=1) - ori_trimap = ds.ori_trimap[:, :, 0] # H, W + ori_trimap = ds.ori_trimap[0, :, :] # H, W pa[ori_trimap == 255] = 1 pa[ori_trimap == 0] = 0 diff --git a/mmedit/structures/edit_data_sample.py b/mmedit/structures/edit_data_sample.py index 22208a2131..963c611d01 100644 --- a/mmedit/structures/edit_data_sample.py +++ b/mmedit/structures/edit_data_sample.py @@ -6,9 +6,8 @@ import numpy as np import torch from mmengine.structures import BaseDataElement, LabelData -from torch import Tensor -from .pixel_data import PixelData +from mmedit.utils import all_to_tensor def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int], @@ -52,7 +51,16 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int], class EditDataSample(BaseDataElement): """A data structure interface of MMEditing. They are used as interfaces - between different components. + between different components, e.g., model, visualizer, evaluator, etc. + Typically, EditDataSample contains all the information and data from + ground-truth and predictions. + + `EditDataSample` inherits from `BaseDataElement`. See more details in: + https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html + Specifically, an instance of BaseDataElement consists of two components, + - ``metainfo``, which contains some meta information, + e.g., `img_shape`, `img_id`, etc. + - ``data``, which contains the data used in the loop. The attributes in ``EditDataSample`` are divided into several parts: @@ -112,542 +120,77 @@ class EditDataSample(BaseDataElement): ) at 0x1f6a5a99a00> """ - @property - def gt_img(self) -> PixelData: - """This is the function to fetch gt_img in PixelData. - - Returns: - PixelData: data element - """ - return self._gt_img - - @gt_img.setter - def gt_img(self, value: PixelData): - """This is the function used to set gt_img in PixelData. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_gt_img', dtype=PixelData) - - @gt_img.deleter - def gt_img(self): - """This is the function to fetch gt_img.""" - del self._gt_img - - @property - def gt_samples(self) -> 'EditDataSample': - """This is the function to fetch gt_samples. - - Returns: - EditDataSample: gt samples. - """ - return self._gt_samples - - @gt_samples.setter - def gt_samples(self, value: 'EditDataSample'): - """This is the function to set gt_samples. - - Args: - value (EditDataSample): gt samples. - """ - self.set_field(value, '_gt_samples', dtype=EditDataSample) - - @gt_samples.deleter - def gt_samples(self): - """This is the function to delete gt_samples.""" - del self._gt_samples - - @property - def noise(self) -> torch.Tensor: - """This is the function to fetch noise. - - Returns: - torch.Tensor: noise. - """ - return self._noise - - @noise.setter - def noise(self, value: PixelData): - """This is the function to set noise. - - Args: - value (PixelData): noise. - """ - self.set_field(value, '_noise', dtype=torch.Tensor) - - @noise.deleter - def noise(self): - """This is the functionto delete noise.""" - del self._noise - - @property - def pred_img(self) -> PixelData: - """This is the function to fetch pred_img in PixelData. - - Returns: - PixelData: data element - """ - return self._pred_img - - @pred_img.setter - def pred_img(self, value: PixelData): - """This is the function to set the value of pred_img in PixelData. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_pred_img', dtype=PixelData) - - @pred_img.deleter - def pred_img(self): - """This is the function to fetch pred_img.""" - del self._pred_img - - @property - def fake_img(self) -> Union[PixelData, Tensor]: - """This is the function to fetch fake_img. - - Returns: - Union[PixelData, Tensor]: The fake img. - """ - return self._fake_img - - @fake_img.setter - def fake_img(self, value: Union[PixelData, Tensor]): - """This is the function to set fake_img. - - Args: - value (Union[PixelData, Tensor]): The value of fake img. - """ - assert isinstance(value, (PixelData, Tensor)) - if isinstance(value, PixelData): - self.set_field(value, '_fake_img', dtype=PixelData) - else: - self.set_field(value, '_fake_img', dtype=Tensor) - - @fake_img.deleter - def fake_img(self): - """This is the function to delete fake_img.""" - del self._fake_img - - @property - def img_lq(self) -> PixelData: - """This is the function to fetch img_lq in PixelData. - - Returns: - PixelData: data element - """ - return self._img_lq - - @img_lq.setter - def img_lq(self, value: PixelData): - """This is the function to set img_lq in PixelData. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_img_lq', dtype=PixelData) - - @img_lq.deleter - def img_lq(self): - """This is the function to delete img_lq.""" - del self._img_lq - - @property - def ref_img(self) -> PixelData: - """This is the function to fetch ref_img. - - Returns: - PixelData: data element - """ - return self._ref_img - - @ref_img.setter - def ref_img(self, value: PixelData): - """This is the function to set the value of ref_img. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_ref_img', dtype=PixelData) - - @ref_img.deleter - def ref_img(self): - """This is the function to fetch ref_img.""" - del self._ref_img - - @property - def ref_lq(self) -> PixelData: - """This is the function to fetch ref_lq. - - Returns: - PixelData: data element - """ - return self._ref_lq - - @ref_lq.setter - def ref_lq(self, value: PixelData): - """This is the function to set the value of ref_lq. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_ref_lq', dtype=PixelData) - - @ref_lq.deleter - def ref_lq(self): - """This is the function to delete ref_lq.""" - del self._ref_lq - - @property - def gt_unsharp(self) -> PixelData: - """This is the function to fetch gt_unsharp in PixelData. - - Returns: - PixelData: data element - """ - return self._gt_unsharp - - @gt_unsharp.setter - def gt_unsharp(self, value: PixelData): - """This is the function to set the value of gt_unsharp. - - Args: - value (PixelData): base element - """ - self.set_field(value, '_gt_unsharp', dtype=PixelData) - - @gt_unsharp.deleter - def gt_unsharp(self): - """This is the function to delete gt_unsharp.""" - del self._gt_unsharp - - @property - def mask(self) -> PixelData: - """This is the function to fetch mask. - - Returns: - PixelData: data element - """ - return self._mask - - @mask.setter - def mask(self, value: Union[PixelData, Tensor]): - """This is the function to set the value of mask. - - Args: - value (Union[PixelData, Tensor]): data element - """ - assert isinstance(value, (PixelData, Tensor)) - if isinstance(value, PixelData): - self.set_field(value, '_mask', dtype=PixelData) - else: - self.set_field(value, '_mask', dtype=Tensor) - - @mask.deleter - def mask(self): - """This is the function to delete mask.""" - del self._mask - - @property - def gt_heatmap(self) -> PixelData: - """This is the function to fetch gt_heatmap. - - Returns: - PixelData: data element - """ - return self._gt_heatmap - - @gt_heatmap.setter - def gt_heatmap(self, value: PixelData): - """This is the function to set the value of gt_heatmap. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_gt_heatmap', dtype=PixelData) - - @gt_heatmap.deleter - def gt_heatmap(self): - """This is the function to delete gt_heatmap.""" - del self._gt_heatmap - - @property - def pred_heatmap(self) -> PixelData: - """This is the function to fetch pred_heatmap. - - Returns: - PixelData: data element - """ - return self._pred_heatmap - - @pred_heatmap.setter - def pred_heatmap(self, value: PixelData): - """This is the function to set the value of pred_heatmap. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_pred_heatmap', dtype=PixelData) - - @pred_heatmap.deleter - def pred_heatmap(self): - """This is the function to fetch pred_heatmap.""" - del self._pred_heatmap - - @property - def trimap(self) -> PixelData: - """This is the function to fetch trimap. - - Returns: - PixelData: data element - """ - return self._trimap - - @trimap.setter - def trimap(self, value: PixelData): - """This is the function to set the value of trimap. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_trimap', dtype=PixelData) - - @trimap.deleter - def trimap(self): - """This is the function to delete trimap.""" - del self._trimap - - @property - def gt_alpha(self) -> PixelData: - """This is the function to fetch gt_alpha. - - Returns: - PixelData: data element - """ - return self._gt_alpha - - @gt_alpha.setter - def gt_alpha(self, value: PixelData): - """This is the function to set the value of gt_alpha. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_gt_alpha', dtype=PixelData) - - @gt_alpha.deleter - def gt_alpha(self): - """This is the function to delete gt_alpha.""" - del self._gt_alpha - - @property - def pred_alpha(self) -> PixelData: - """This is the function to fetch pred_alpha. - - Returns: - PixelData: data element - """ - return self._pred_alpha - - @pred_alpha.setter - def pred_alpha(self, value: PixelData): - """This is the function to set the value of pred_alpha. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_pred_alpha', dtype=PixelData) - - @pred_alpha.deleter - def pred_alpha(self): - """This is the function to delete pred_alpha.""" - del self._pred_alpha - - @property - def gt_fg(self) -> PixelData: - """This is the function to fetch gt_fg. - - Returns: - PixelData: data element - """ - return self._gt_fg - - @gt_fg.setter - def gt_fg(self, value: PixelData): - """This is the function to set the value of gt_fg. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_gt_fg', dtype=PixelData) - - @gt_fg.deleter - def gt_fg(self): - """This is the function to delete gt_fg.""" - del self._gt_fg - - @property - def pred_fg(self) -> PixelData: - """This is the function to fetch pred_fg. - - Returns: - PixelData: _description_ - """ - return self._pred_fg - - @pred_fg.setter - def pred_fg(self, value: PixelData): - """This is the function to set the value of pred_fg in PixelData. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_pred_fg', dtype=PixelData) - - @pred_fg.deleter - def pred_fg(self): - """This is the function to delete pred_fg.""" - del self._pred_fg - - @property - def gt_bg(self) -> PixelData: - """This is the function to fetch gt_bg. - - Returns: - PixelData: data element - """ - return self._gt_bg - - @gt_bg.setter - def gt_bg(self, value: PixelData): - """This is the function to set the value of gt_bg in PixelData. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_gt_bg', dtype=PixelData) - - @gt_bg.deleter - def gt_bg(self): - """This is the function to delete gt_bg.""" - del self._gt_bg - - @property - def pred_bg(self) -> PixelData: - """This is the function to fetch pred_bg in PixelData. - - Returns: - PixelData: data element - """ - return self._pred_bg - - @pred_bg.setter - def pred_bg(self, value: PixelData): - """This is the function to set the value of pred_bg in PixelData. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_pred_bg', dtype=PixelData) - - @pred_bg.deleter - def pred_bg(self): - """This is the function to fetch pred_bg.""" - del self._pred_bg - - @property - def gt_merged(self) -> PixelData: - """This is the function to fetch gt_merged in PixelData. - - Returns: - PixelData: _description_ - """ - return self._gt_merged - - @gt_merged.setter - def gt_merged(self, value: PixelData): - """This is the function to set gt_merged in PixelDate. - - Args: - value (PixelData): data element - """ - self.set_field(value, '_gt_merged', dtype=PixelData) - - @gt_merged.deleter - def gt_merged(self): - """This is the function to fetch gt_merged.""" - del self._gt_merged - - @property - def sample_model(self) -> str: - """This is the function to fetch sample model. - - Returns: - str: Mode of Sample model. - """ - return self._sample_model - - @sample_model.setter - def sample_model(self, value: str): - """This is the function to set sample model. - - Args: - value (str): The mode of sample model. - """ - self.set_field(value, '_sample_model', dtype=str) - - @sample_model.deleter - def sample_model(self): - """This is the function to delete sample model.""" - del self._sample_model - - @property - def ema(self) -> 'EditDataSample': - """This is the function to fetch ema results. - - Returns: - EditDataSample: Results of the ema model. - """ - return self._ema - - @ema.setter - def ema(self, value: 'EditDataSample'): - """This is the function to set ema results. - - Args: - value (EditDataSample): Results of the ema model. - """ - self.set_field(value, '_ema', dtype=EditDataSample) - - @ema.deleter - def ema(self): - """This is the function to delete ema results.""" - del self._ema - - @property - def orig(self) -> 'EditDataSample': - """This is the function to fetch original results. - - Returns: - EditDataSample: Results of the ema model. - """ - return self._orig - - @orig.setter - def orig(self, value: 'EditDataSample'): - """This is the function to set ema results. - - Args: - value (EditDataSample): Results of the ema model. - """ - self.set_field(value, '_orig', dtype=EditDataSample) - - @orig.deleter - def orig(self): - """This is the function to delete ema results.""" - del self._orig + # source_key_in_results: target_key_in_metainfo + META_KEYS = { + 'img_path': 'img_path', + 'merged_path': 'merged_path', + 'trimap_path': 'trimap_path', + 'ori_shape': 'ori_shape', + 'img_shape': 'img_shape', + 'ori_merged_shape': 'ori_merged_shape', + 'ori_trimap_shape': 'ori_trimap_shape', + 'trimap_channel_order': 'trimap_channel_order', + 'empty_box': 'empty_box' + } + + # source_key_in_results: target_key_in_datafield + DATA_KEYS = { + 'gt': 'gt_img', + 'gt_label': 'gt_label', + 'gt_heatmap': 'gt_heatmap', + 'gt_unsharp': 'gt_unsharp', + 'merged': 'gt_merged', + 'fg': 'gt_fg', + 'bg': 'gt_bg', + 'gt_rgb': 'gt_rgb', + 'alpha': 'gt_alpha', + 'img_lq': 'img_lq', + 'ref': 'ref_img', + 'ref_lq': 'ref_lq', + 'mask': 'mask', + 'trimap': 'trimap', + 'gray': 'gray', + 'cropped_img': 'cropped_img', + 'pred_img': 'pred_img', + 'ori_trimap': 'ori_trimap' + } + + def set_predefined_data(self, data: dict) -> None: + """set or change pre-defined key-value pairs in ``data_field`` by + parameter ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. + """ + + metainfo = { + self.META_KEYS[k]: v + for (k, v) in data.items() if k in self.META_KEYS + } + self.set_metainfo(metainfo) + + data = { + self.DATA_KEYS[k]: v + for (k, v) in data.items() if k in self.DATA_KEYS + } + self.set_tensor_data(data) + + def set_tensor_data(self, data: dict) -> None: + """convert input data to tensor, and then set or change key-value pairs + in ``data_field`` by parameter ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. + """ + assert isinstance(data, + dict), f'data should be a `dict` but got {data}' + for k, v in data.items(): + if k == 'gt_label': + self.set_gt_label(v) + else: + setattr(self, k, all_to_tensor(v)) def set_gt_label( self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] @@ -677,8 +220,3 @@ def gt_label(self, value: LabelData): value (LabelData): gt label. """ self.set_field(value, '_gt_label', dtype=LabelData) - - @gt_label.deleter - def gt_label(self): - """This is the function to delete gt label.""" - del self._gt_label diff --git a/mmedit/utils/__init__.py b/mmedit/utils/__init__.py index ce14853774..58ed14cc58 100644 --- a/mmedit/utils/__init__.py +++ b/mmedit/utils/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cli import modify_args -from .img_utils import get_box_info, reorder_image, tensor2img, to_numpy +from .img_utils import (all_to_tensor, can_convert_to_image, get_box_info, + reorder_image, tensor2img, to_numpy) from .io_utils import MMEDIT_CACHE_DIR, download_from_url # TODO replace with engine's API from .logger import print_colored_log @@ -17,5 +18,6 @@ 'MMEDIT_CACHE_DIR', 'download_from_url', 'get_sampler', 'tensor2img', 'random_choose_unknown', 'add_gaussian_noise', 'adjust_gamma', 'make_coord', 'bbox2mask', 'brush_stroke_mask', 'get_irregular_mask', - 'random_bbox', 'reorder_image', 'to_numpy', 'get_box_info' + 'random_bbox', 'reorder_image', 'to_numpy', 'get_box_info', + 'can_convert_to_image', 'all_to_tensor' ] diff --git a/mmedit/utils/img_utils.py b/mmedit/utils/img_utils.py index ff5b2c1f03..70584cc591 100644 --- a/mmedit/utils/img_utils.py +++ b/mmedit/utils/img_utils.py @@ -1,11 +1,82 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +from typing import List, Tuple import numpy as np import torch +from mmcv.transforms import to_tensor from torchvision.utils import make_grid +def can_convert_to_image(value): + """Judge whether the input value can be converted to image tensor via + :func:`images_to_tensor` function. + + Args: + value (any): The input value. + + Returns: + bool: If true, the input value can convert to image with + :func:`images_to_tensor`, and vice versa. + """ + if isinstance(value, (List, Tuple)): + return all([can_convert_to_image(v) for v in value]) + elif isinstance(value, np.ndarray) and len(value.shape) > 1: + return True + elif isinstance(value, torch.Tensor): + return True + else: + return False + + +def image_to_tensor(img): + """Trans image to tensor. + + Args: + img (np.ndarray): The original image. + + Returns: + Tensor: The output tensor. + """ + + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img) + tensor = to_tensor(img).permute(2, 0, 1).contiguous() + + return tensor + + +def all_to_tensor(value): + """Trans image and sequence of frames to tensor. + + Args: + value (np.ndarray | list[np.ndarray] | Tuple[np.ndarray]): + The original image or list of frames. + + Returns: + Tensor: The output tensor. + """ + + if not can_convert_to_image(value): + return value + + if isinstance(value, (List, Tuple)): + # sequence of frames + if len(value) == 1: + tensor = image_to_tensor(value[0]) + else: + frames = [image_to_tensor(v) for v in value] + tensor = torch.stack(frames, dim=0) + elif isinstance(value, np.ndarray): + tensor = image_to_tensor(value) + else: + # Maybe the data has been converted to Tensor. + tensor = to_tensor(value) + + return tensor + + def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): """Convert torch Tensors into image numpy arrays. diff --git a/tests/test_apis/test_inferencers/test_base_mmedit_inferencer.py b/tests/test_apis/test_inferencers/test_base_mmedit_inferencer.py index 2c721868c6..54edff049e 100644 --- a/tests/test_apis/test_inferencers/test_base_mmedit_inferencer.py +++ b/tests/test_apis/test_inferencers/test_base_mmedit_inferencer.py @@ -23,7 +23,3 @@ def test_base_mmedit_inferencer(): inferencer_instance = BaseMMEditInferencer(cfg, None) extra_parameters = inferencer_instance.get_extra_parameters() assert len(extra_parameters) == 0 - - -if __name__ == '__main__': - test_base_mmedit_inferencer() diff --git a/tests/test_apis/test_inferencers/test_conditional_inferencer.py b/tests/test_apis/test_inferencers/test_conditional_inferencer.py index 84f146e319..6bf8d5e148 100644 --- a/tests/test_apis/test_inferencers/test_conditional_inferencer.py +++ b/tests/test_apis/test_inferencers/test_conditional_inferencer.py @@ -13,7 +13,8 @@ def test_conditional_inferencer(): osp.dirname(__file__), '..', '..', '..', 'configs', 'sngan_proj', 'sngan-proj_woReLUinplace_lr2e-4-ndisc5-1xb64_cifar10-32x32.py') result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', 'conditional_result.png') + osp.dirname(__file__), '..', '..', 'data/out', + 'conditional_result.png') inferencer_instance = \ ConditionalInferencer(cfg, None, @@ -23,7 +24,3 @@ def test_conditional_inferencer(): label=1, result_out_dir=result_out_dir) result_img = inference_result[1] assert result_img.shape == (4, 3, 32, 32) - - -if __name__ == '__main__': - test_conditional_inferencer() diff --git a/tests/test_apis/test_inferencers/test_inpainting_inferencer.py b/tests/test_apis/test_inferencers/test_inpainting_inferencer.py index a8a655f662..df184caa02 100644 --- a/tests/test_apis/test_inferencers/test_inpainting_inferencer.py +++ b/tests/test_apis/test_inferencers/test_inpainting_inferencer.py @@ -21,7 +21,7 @@ def test_inpainting_inferencer(): 'aot-gan_smpgan_4xb4_places-512x512.py', ) result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', 'inpainting_result.png') + osp.dirname(__file__), '..', '..', 'data/out', 'inpainting_result.png') inferencer_instance = \ InpaintingInferencer(cfg, None) @@ -30,7 +30,3 @@ def test_inpainting_inferencer(): img=masked_img_path, mask=mask_path, result_out_dir=result_out_dir) result_img = inference_result[1] assert result_img.shape == (256, 256, 3) - - -if __name__ == '__main__': - test_inpainting_inferencer() diff --git a/tests/test_apis/test_inferencers/test_matting_inferencer.py b/tests/test_apis/test_inferencers/test_matting_inferencer.py index 5068c3e59a..b78d4400c3 100644 --- a/tests/test_apis/test_inferencers/test_matting_inferencer.py +++ b/tests/test_apis/test_inferencers/test_matting_inferencer.py @@ -13,7 +13,7 @@ def test_matting_inferencer(): img_path = data_root + 'tests/data/matting_dataset/merged/GT05.jpg' trimap_path = data_root + 'tests/data/matting_dataset/trimap/GT05.png' result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', 'matting_result.png') + osp.dirname(__file__), '..', '..', 'data/out', 'matting_result.png') inferencer_instance = \ MattingInferencer(config, None) @@ -22,7 +22,3 @@ def test_matting_inferencer(): img=img_path, trimap=trimap_path, result_out_dir=result_out_dir) result_img = inference_result[1] assert result_img.numpy().shape == (552, 800) - - -if __name__ == '__main__': - test_matting_inferencer() diff --git a/tests/test_apis/test_inferencers/test_mmedit_inferencer.py b/tests/test_apis/test_inferencers/test_mmedit_inferencer.py index adadb0478c..da5e1c1551 100644 --- a/tests/test_apis/test_inferencers/test_mmedit_inferencer.py +++ b/tests/test_apis/test_inferencers/test_mmedit_inferencer.py @@ -60,7 +60,3 @@ def test_mmedit_inferencer(): extra_parameters = inferencer_instance.get_extra_parameters() assert len(extra_parameters) == 2 - - -if __name__ == '__main__': - test_mmedit_inferencer() diff --git a/tests/test_apis/test_inferencers/test_restoration_inferencer.py b/tests/test_apis/test_inferencers/test_restoration_inferencer.py index 111131ef54..eb457775f9 100644 --- a/tests/test_apis/test_inferencers/test_restoration_inferencer.py +++ b/tests/test_apis/test_inferencers/test_restoration_inferencer.py @@ -20,7 +20,8 @@ def test_restoration_inferencer(): config = data_root + 'configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py' # noqa img_path = data_root + 'tests/data/image/lq/baboon_x4.png' result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', 'restoration_result.png') + osp.dirname(__file__), '..', '..', 'data/out', + 'restoration_result.png') inferencer_instance = \ RestorationInferencer(config, None) @@ -29,7 +30,3 @@ def test_restoration_inferencer(): img=img_path, result_out_dir=result_out_dir) result_img = inference_result[1] assert result_img.shape == (480, 500, 3) - - -if __name__ == '__main__': - test_restoration_inferencer() diff --git a/tests/test_apis/test_inferencers/test_text2image_inferencers.py b/tests/test_apis/test_inferencers/test_text2image_inferencers.py index 265238ed97..df0b7ae6be 100644 --- a/tests/test_apis/test_inferencers/test_text2image_inferencers.py +++ b/tests/test_apis/test_inferencers/test_text2image_inferencers.py @@ -93,7 +93,7 @@ def test_translation(self): 'disco-diffusion_adm-u-finetuned_imagenet-512x512.py') text = {0: ['sad']} result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', 'disco_result.png') + osp.dirname(__file__), '..', '..', 'data/out', 'disco_result.png') with patch.object(Text2ImageInferencer, '_init_model'): inferencer_instance = Text2ImageInferencer( diff --git a/tests/test_apis/test_inferencers/test_translation_inferencer.py b/tests/test_apis/test_inferencers/test_translation_inferencer.py index 6a4b5f5ce9..5d1daaa932 100644 --- a/tests/test_apis/test_inferencers/test_translation_inferencer.py +++ b/tests/test_apis/test_inferencers/test_translation_inferencer.py @@ -16,7 +16,8 @@ def test_translation_inferencer(): osp.dirname(__file__), '..', '..', 'data', 'unpaired', 'trainA', '1.jpg') result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', 'translation_result.png') + osp.dirname(__file__), '..', '..', 'data/out', + 'translation_result.png') inferencer_instance = \ TranslationInferencer(cfg, None) @@ -25,7 +26,3 @@ def test_translation_inferencer(): img=data_path, result_out_dir=result_out_dir) result_img = inference_result[1] assert result_img[0].numpy().shape == (3, 256, 256) - - -if __name__ == '__main__': - test_translation_inferencer() diff --git a/tests/test_apis/test_inferencers/test_unconditional_inferencer.py b/tests/test_apis/test_inferencers/test_unconditional_inferencer.py index 8ce59bcd08..693890747d 100644 --- a/tests/test_apis/test_inferencers/test_unconditional_inferencer.py +++ b/tests/test_apis/test_inferencers/test_unconditional_inferencer.py @@ -13,7 +13,8 @@ def test_unconditional_inferencer(): osp.dirname(__file__), '..', '..', '..', 'configs', 'styleganv1', 'styleganv1_ffhq-256x256_8xb4-25Mimgs.py') result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', 'unconditional_result.png') + osp.dirname(__file__), '..', '..', 'data/out', + 'unconditional_result.png') inferencer_instance = \ UnconditionalInferencer(cfg, @@ -26,7 +27,3 @@ def test_unconditional_inferencer(): inference_result = inferencer_instance(result_out_dir=result_out_dir) result_img = inference_result[1] assert result_img.detach().numpy().shape == (1, 3, 256, 256) - - -if __name__ == '__main__': - test_unconditional_inferencer() diff --git a/tests/test_apis/test_inferencers/test_video_interpolation_inferencer.py b/tests/test_apis/test_inferencers/test_video_interpolation_inferencer.py index 8e9a586db9..1d3352092d 100644 --- a/tests/test_apis/test_inferencers/test_video_interpolation_inferencer.py +++ b/tests/test_apis/test_inferencers/test_video_interpolation_inferencer.py @@ -56,7 +56,3 @@ def test_video_interpolation_inferencer_fps_multiplier(): inference_result = inferencer_instance( video=video_path, result_out_dir=result_out_dir) assert inference_result is None - - -if __name__ == '__main__': - test_video_interpolation_inferencer_fps_multiplier() diff --git a/tests/test_apis/test_inferencers/test_video_restoration_inferencer.py b/tests/test_apis/test_inferencers/test_video_restoration_inferencer.py index 09be036fb4..ba0b8ecad9 100644 --- a/tests/test_apis/test_inferencers/test_video_restoration_inferencer.py +++ b/tests/test_apis/test_inferencers/test_video_restoration_inferencer.py @@ -13,7 +13,7 @@ def test_video_restoration_inferencer(): osp.dirname(__file__), '..', '..', '..', 'configs', 'basicvsr', 'basicvsr_2xb4_reds4.py') result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', + osp.dirname(__file__), '..', '..', 'data/out', 'video_restoration_result.mp4') data_root = osp.join(osp.dirname(__file__), '../../../') video_path = data_root + 'tests/data/frames/test_inference.mp4' @@ -32,7 +32,7 @@ def test_video_restoration_inferencer_input_dir(): osp.dirname(__file__), '..', '..', '..', 'configs', 'basicvsr', 'basicvsr_2xb4_reds4.py') result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', + osp.dirname(__file__), '..', '..', 'data/out', 'video_restoration_result.mp4') data_root = osp.join(osp.dirname(__file__), '../../../') input_dir = osp.join(data_root, 'tests/data/frames/sequence/gt/sequence_1') @@ -52,7 +52,7 @@ def test_video_restoration_inferencer_window_size(): osp.dirname(__file__), '..', '..', '..', 'configs', 'basicvsr', 'basicvsr_2xb4_reds4.py') result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', + osp.dirname(__file__), '..', '..', 'data/out', 'video_restoration_result.mp4') data_root = osp.join(osp.dirname(__file__), '../../../') video_path = data_root + 'tests/data/frames/test_inference.mp4' @@ -74,7 +74,7 @@ def test_video_restoration_inferencer_max_seq_len(): osp.dirname(__file__), '..', '..', '..', 'configs', 'basicvsr', 'basicvsr_2xb4_reds4.py') result_out_dir = osp.join( - osp.dirname(__file__), '..', '..', 'data', + osp.dirname(__file__), '..', '..', 'data/out', 'video_restoration_result.mp4') data_root = osp.join(osp.dirname(__file__), '../../../') video_path = data_root + 'tests/data/frames/test_inference.mp4' @@ -89,7 +89,3 @@ def test_video_restoration_inferencer_max_seq_len(): inference_result = inferencer_instance( video=video_path, result_out_dir=result_out_dir) assert inference_result is None - - -if __name__ == '__main__': - test_video_restoration_inferencer_input_dir() diff --git a/tests/test_datasets/test_singan_dataset.py b/tests/test_datasets/test_singan_dataset.py index a872d4242f..96b78b8ae8 100644 --- a/tests/test_datasets/test_singan_dataset.py +++ b/tests/test_datasets/test_singan_dataset.py @@ -15,8 +15,13 @@ def setup_class(cls): osp.dirname(osp.dirname(__file__)), 'data/image/gt/baboon.png') cls.min_size = 25 cls.max_size = 250 + cls.num_scales = 10 cls.scale_factor_init = 0.75 - cls.pipeline = [dict(type='PackEditInputs', pack_all=True)] + cls.pipeline = [ + dict( + type='PackEditInputs', + keys=[f'real_scale{i}' for i in range(cls.num_scales)]) + ] def test_singan_dataset(self): dataset = SinGANDataset( @@ -28,4 +33,10 @@ def test_singan_dataset(self): assert len(dataset) == 1000000 data_dict = dataset[0]['inputs'] - assert all([f'real_scale{i}' in data_dict for i in range(10)]) + assert all( + [f'real_scale{i}' in data_dict for i in range(self.num_scales)]) + + +a = TestSinGANDataset() +a.setup_class() +a.test_singan_dataset() diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index c9ac37ddf1..b7121b5280 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -3,9 +3,7 @@ import torch from mmcv.transforms import to_tensor -from mmedit.datasets.transforms import PackEditInputs, ToTensor -from mmedit.datasets.transforms.formatting import (can_convert_to_image, - images_to_tensor) +from mmedit.datasets.transforms import PackEditInputs from mmedit.structures.edit_data_sample import EditDataSample @@ -18,24 +16,9 @@ def assert_tensor_equal(img, ref_img, ratio_thr=0.999): assert torch.sum(diff <= 1) / float(area) > ratio_thr -def test_images_to_tensor(): - - data = [np.random.rand(64, 64, 3), np.random.rand(64, 64, 3)] - tensor = images_to_tensor(data) - assert tensor.shape == torch.Size([2, 3, 64, 64]) - - data = np.random.rand(64, 64, 3) - tensor = images_to_tensor(data) - assert tensor.shape == torch.Size([3, 64, 64]) - - data = 1 - tensor = images_to_tensor(data) - assert tensor == torch.tensor(1) - - def test_pack_edit_inputs(): - pack_edit_inputs = PackEditInputs() + pack_edit_inputs = PackEditInputs(meta_keys='a', data_keys='numpy') assert repr(pack_edit_inputs) == 'PackEditInputs' ori_results = dict( @@ -54,7 +37,8 @@ def test_pack_edit_inputs(): fg=np.random.rand(64, 68, 3), bg=np.random.rand(64, 69, 3), img_shape=(64, 64), - a='b') + a='b', + numpy=np.random.rand(48, 48, 3)) results = ori_results.copy() @@ -66,96 +50,58 @@ def test_pack_edit_inputs(): data_sample = packed_results['data_samples'] assert isinstance(data_sample, EditDataSample) + assert data_sample.img_shape == (64, 64) + assert data_sample.a == 'b' + + numpy_tensor = to_tensor(ori_results['numpy']) + numpy_tensor = numpy_tensor.permute(2, 0, 1) + assert_tensor_equal(data_sample.numpy, numpy_tensor) + gt_tensors = [to_tensor(v) for v in ori_results['gt']] gt_tensors = [v.permute(2, 0, 1) for v in gt_tensors] gt_tensor = torch.stack(gt_tensors, dim=0) - assert_tensor_equal(data_sample.gt_img.data, gt_tensor) + assert_tensor_equal(data_sample.gt_img, gt_tensor) img_lq_tensor = to_tensor(ori_results['ref']) img_lq_tensor = img_lq_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.ref_img.data, img_lq_tensor) + assert_tensor_equal(data_sample.ref_img, img_lq_tensor) ref_lq_tensor = to_tensor(ori_results['ref']) ref_lq_tensor = ref_lq_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.ref_img.data, ref_lq_tensor) + assert_tensor_equal(data_sample.ref_img, ref_lq_tensor) ref_tensor = to_tensor(ori_results['ref']) ref_tensor = ref_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.ref_img.data, ref_tensor) + assert_tensor_equal(data_sample.ref_img, ref_tensor) mask_tensor = to_tensor(ori_results['mask']) mask_tensor = mask_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.mask.data, mask_tensor) + assert_tensor_equal(data_sample.mask, mask_tensor) gt_heatmap_tensor = to_tensor(ori_results['gt_heatmap']) gt_heatmap_tensor = gt_heatmap_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.gt_heatmap.data, gt_heatmap_tensor) + assert_tensor_equal(data_sample.gt_heatmap, gt_heatmap_tensor) gt_unsharp_tensor = to_tensor(ori_results['gt_heatmap']) gt_unsharp_tensor = gt_unsharp_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.gt_heatmap.data, gt_unsharp_tensor) + assert_tensor_equal(data_sample.gt_heatmap, gt_unsharp_tensor) gt_merged_tensor = to_tensor(ori_results['merged']) gt_merged_tensor = gt_merged_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.gt_merged.data, gt_merged_tensor) + assert_tensor_equal(data_sample.gt_merged, gt_merged_tensor) trimap_tensor = to_tensor(ori_results['trimap']) trimap_tensor = trimap_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.trimap.data, trimap_tensor) + assert_tensor_equal(data_sample.trimap, trimap_tensor) gt_alpha_tensor = to_tensor(ori_results['alpha']) gt_alpha_tensor = gt_alpha_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.gt_alpha.data, gt_alpha_tensor) + assert_tensor_equal(data_sample.gt_alpha, gt_alpha_tensor) gt_fg_tensor = to_tensor(ori_results['fg']) gt_fg_tensor = gt_fg_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.gt_fg.data, gt_fg_tensor) + assert_tensor_equal(data_sample.gt_fg, gt_fg_tensor) gt_bg_tensor = to_tensor(ori_results['bg']) gt_bg_tensor = gt_bg_tensor.permute(2, 0, 1) - assert_tensor_equal(data_sample.gt_bg.data, gt_bg_tensor) - - assert data_sample.metainfo['img_shape'] == (64, 64) - assert data_sample.metainfo['a'] == 'b' - - # test pack_all - pack_edit_inputs = PackEditInputs(pack_all=True) - results = ori_results.copy() - packed_results = pack_edit_inputs(results) - print(packed_results['inputs'].keys()) - - target_keys = [ - 'img', 'gt', 'img_lq', 'ref', 'ref_lq', 'mask', 'gt_heatmap', - 'gt_unsharp', 'merged', 'trimap', 'alpha', 'fg', 'bg' - ] - assert all([k in target_keys for k in packed_results['inputs']]) - - -def test_to_tensor(): - - ori_results = dict( - img=np.random.rand(64, 64, 3), - gt=[np.random.rand(64, 64, 3), - np.random.rand(64, 64, 3)], - gt1=[np.random.rand(64, 64, 3)], - a=1) - - keys = ['img', 'gt', 'gt1', 'a'] - to_tensor = ToTensor(keys=keys, to_float32=True) - assert repr(to_tensor) == f'ToTensor(keys={keys}, to_float32=True)' - - results = to_tensor(ori_results) - assert set(keys).issubset(results.keys()) - for _, v in results.items(): - assert isinstance(v, torch.Tensor) - - -def test_can_convert_to_image(): - values = [ - np.random.rand(64, 64, 3), - [np.random.rand(64, 61, 3), - np.random.rand(64, 61, 3)], (64, 64), 'b' - ] - targets = [True, True, False, False] - for val, tar in zip(values, targets): - assert can_convert_to_image(val) == tar + assert_tensor_equal(data_sample.gt_bg, gt_bg_tensor) diff --git a/tests/test_models/test_base_models/test_base_mattor.py b/tests/test_models/test_base_models/test_base_mattor.py index d9c8962bd9..6f91a56c00 100644 --- a/tests/test_models/test_base_models/test_base_mattor.py +++ b/tests/test_models/test_base_models/test_base_mattor.py @@ -7,10 +7,11 @@ import torch from mmengine.config import ConfigDict +from mmedit.datasets.transforms import PackEditInputs from mmedit.models.base_models import BaseMattor from mmedit.models.editors import DIM from mmedit.registry import MODELS -from mmedit.structures import EditDataSample, PixelData +from mmedit.structures import EditDataSample from mmedit.utils import register_all_modules register_all_modules() @@ -43,31 +44,26 @@ def _demo_input_train(img_shape, batch_size=1, cuda=False, meta={}): merged = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) trimap = torch.from_numpy( np.random.randint(255, size=gray_shape).astype(np.float32)) - alpha = torch.from_numpy(np.random.random(gray_shape).astype(np.float32)) - ori_merged = torch.from_numpy( - np.random.random(color_shape).astype(np.float32)) - fg = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) - bg = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) + inputs = torch.cat((merged, trimap), dim=1) if cuda: - merged = merged.cuda() - trimap = trimap.cuda() - alpha = alpha.cuda() - ori_merged = ori_merged.cuda() - fg = fg.cuda() - bg = bg.cuda() + inputs = inputs.cuda() + + results = dict( + alpha=np.random.random( + (img_shape[0], img_shape[1], 1)).astype(np.float32), + merged=np.random.random( + (img_shape[0], img_shape[1], 3)).astype(np.float32), + fg=np.random.random( + (img_shape[0], img_shape[1], 3)).astype(np.float32), + bg=np.random.random( + (img_shape[0], img_shape[1], 3)).astype(np.float32)) - inputs = torch.cat((merged, trimap), dim=1) data_samples = [] - for a, m, f, b in zip(alpha, ori_merged, fg, bg): - ds = EditDataSample() - - ds.gt_alpha = PixelData(data=a) - ds.gt_merged = PixelData(data=m) - ds.gt_fg = PixelData(data=f) - ds.gt_bg = PixelData(data=b) - for k, v in meta.items(): - ds.set_field(name=k, value=v, field_type='metainfo', dtype=None) - + packinputs = PackEditInputs() + for _ in range(batch_size): + ds = packinputs(results)['data_samples'] + if cuda: + ds = ds.cuda() data_samples.append(ds) return inputs, data_samples @@ -85,25 +81,25 @@ def _demo_input_test(img_shape, batch_size=1, cuda=False, meta={}): color_shape = (batch_size, 3, img_shape[0], img_shape[1]) gray_shape = (batch_size, 1, img_shape[0], img_shape[1]) ori_shape = (img_shape[0], img_shape[1], 1) + merged = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) trimap = torch.from_numpy( np.random.randint(255, size=gray_shape).astype(np.float32)) - ori_alpha = np.random.random(ori_shape).astype(np.float32) - ori_trimap = np.random.randint(256, size=ori_shape).astype(np.float32) + inputs = torch.cat((merged, trimap), dim=1) if cuda: - merged = merged.cuda() - trimap = trimap.cuda() - meta = dict( - ori_alpha=ori_alpha, - ori_trimap=ori_trimap, - ori_merged_shape=img_shape, - # ori_merged_shape=ori_shape, - **meta) + inputs = inputs.cuda() + + results = dict( + ori_alpha=np.random.random(ori_shape).astype(np.float32), + ori_trimap=np.random.randint(256, size=ori_shape).astype(np.float32), + ori_merged_shape=img_shape) - inputs = torch.cat((merged, trimap), dim=1) data_samples = [] + packinputs = PackEditInputs() for _ in range(batch_size): - ds = EditDataSample(metainfo=meta) + ds = packinputs(results)['data_samples'] + if cuda: + ds = ds.cuda() data_samples.append(ds) return inputs, data_samples diff --git a/tests/test_models/test_editors/test_dim/test_dim.py b/tests/test_models/test_editors/test_dim/test_dim.py index 38b6021a86..32b5608a7f 100644 --- a/tests/test_models/test_editors/test_dim/test_dim.py +++ b/tests/test_models/test_editors/test_dim/test_dim.py @@ -6,6 +6,7 @@ import torch from mmengine.config import ConfigDict +from mmedit.datasets.transforms import PackEditInputs from mmedit.models.editors import DIM from mmedit.registry import MODELS from mmedit.structures import EditDataSample, PixelData @@ -75,26 +76,29 @@ def _demo_input_test(img_shape, batch_size=1, cuda=False, meta={}): color_shape = (batch_size, 3, img_shape[0], img_shape[1]) gray_shape = (batch_size, 1, img_shape[0], img_shape[1]) ori_shape = (img_shape[0], img_shape[1], 1) + merged = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) trimap = torch.from_numpy( np.random.randint(255, size=gray_shape).astype(np.float32)) - ori_alpha = np.random.random(ori_shape).astype(np.float32) - ori_trimap = np.random.randint(256, size=ori_shape).astype(np.float32) - if cuda: - merged = merged.cuda() - trimap = trimap.cuda() - meta = dict( - ori_alpha=ori_alpha, - ori_trimap=ori_trimap, - ori_merged_shape=img_shape, - **meta) - inputs = torch.cat((merged, trimap), dim=1) + + results = { + 'ori_alpha': np.random.random(ori_shape).astype(np.float32), + 'ori_trimap': np.random.randint(256, + size=ori_shape).astype(np.float32), + 'ori_merged_shape': img_shape, + } + packinputs = PackEditInputs() + data_samples = [] for _ in range(batch_size): - ds = EditDataSample(metainfo=meta) + ds = packinputs(results)['data_samples'] + if cuda: + ds = ds.cuda() data_samples.append(ds) + if cuda: + inputs = inputs.cuda() return inputs, data_samples @@ -228,8 +232,8 @@ def test_dim(): # test model forward in test mode with torch.no_grad(): model = MODELS.build(model_cfg) - input_test = _demo_input_test((48, 48)) - output_test = model(*input_test, mode='predict') + inputs, data_samples = _demo_input_test((48, 48)) + output_test = model(inputs, data_samples, mode='predict') assert isinstance(output_test, list) assert isinstance(output_test[0], EditDataSample) pred_alpha = output_test[0].output.pred_alpha.data @@ -281,3 +285,6 @@ def test_dim(): model.cpu().eval() inputs = torch.ones((1, 4, 32, 32)) model.forward(inputs) + + +test_dim() diff --git a/tests/test_models/test_editors/test_gca/test_gca.py b/tests/test_models/test_editors/test_gca/test_gca.py index a441d100fb..688510b54a 100644 --- a/tests/test_models/test_editors/test_gca/test_gca.py +++ b/tests/test_models/test_editors/test_gca/test_gca.py @@ -4,6 +4,7 @@ import torch from mmengine.config import ConfigDict +from mmedit.datasets.transforms import PackEditInputs from mmedit.registry import MODELS from mmedit.structures import EditDataSample, PixelData from mmedit.utils import register_all_modules @@ -66,24 +67,26 @@ def _demo_input_test(img_shape, batch_size=1, cuda=False, meta={}): color_shape = (batch_size, 3, img_shape[0], img_shape[1]) gray_shape = (batch_size, 1, img_shape[0], img_shape[1]) ori_shape = (img_shape[0], img_shape[1], 1) + merged = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) trimap = torch.from_numpy( np.random.randint(255, size=gray_shape).astype(np.float32)) ori_alpha = np.random.random(ori_shape).astype(np.float32) ori_trimap = np.random.randint(256, size=ori_shape).astype(np.float32) - if cuda: - merged = merged.cuda() - trimap = trimap.cuda() - meta = dict( - ori_alpha=ori_alpha, - ori_trimap=ori_trimap, - ori_merged_shape=img_shape, - **meta) inputs = torch.cat((merged, trimap), dim=1) + if cuda: + inputs = inputs.cuda() + + results = dict( + ori_alpha=ori_alpha, ori_trimap=ori_trimap, ori_merged_shape=img_shape) + data_samples = [] + packinputs = PackEditInputs() for _ in range(batch_size): - ds = EditDataSample(metainfo=meta) + ds = packinputs(results)['data_samples'] + if cuda: + ds = ds.cuda() data_samples.append(ds) return inputs, data_samples diff --git a/tests/test_models/test_editors/test_indexnet/test_indexnet.py b/tests/test_models/test_editors/test_indexnet/test_indexnet.py index 39f397af81..2aed23a387 100644 --- a/tests/test_models/test_editors/test_indexnet/test_indexnet.py +++ b/tests/test_models/test_editors/test_indexnet/test_indexnet.py @@ -3,9 +3,10 @@ import torch from mmengine.config import ConfigDict +from mmedit.datasets.transforms import PackEditInputs from mmedit.models.editors import IndexedUpsample from mmedit.registry import MODELS -from mmedit.structures import EditDataSample, PixelData +from mmedit.structures import EditDataSample from mmedit.utils import register_all_modules @@ -19,34 +20,31 @@ def _demo_input_train(img_shape, batch_size=1, cuda=False, meta={}): """ color_shape = (batch_size, 3, img_shape[0], img_shape[1]) gray_shape = (batch_size, 1, img_shape[0], img_shape[1]) + ori_shape = (img_shape[0], img_shape[1], 1) + merged = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) trimap = torch.from_numpy( np.random.randint(255, size=gray_shape).astype(np.float32)) - alpha = torch.from_numpy(np.random.random(gray_shape).astype(np.float32)) - ori_merged = torch.from_numpy( - np.random.random(color_shape).astype(np.float32)) - fg = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) - bg = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) + inputs = torch.cat((merged, trimap), dim=1) if cuda: - merged = merged.cuda() - trimap = trimap.cuda() - alpha = alpha.cuda() - ori_merged = ori_merged.cuda() - fg = fg.cuda() - bg = bg.cuda() + inputs = inputs.cuda() + + results = dict( + alpha=np.random.random(ori_shape).astype(np.float32), + merged=np.random.random( + (img_shape[0], img_shape[1], 3)).astype(np.float32), + fg=np.random.random( + (img_shape[0], img_shape[1], 3)).astype(np.float32), + bg=np.random.random( + (img_shape[0], img_shape[1], 3)).astype(np.float32), + ) - inputs = torch.cat((merged, trimap), dim=1) data_samples = [] - for a, m, f, b in zip(alpha, ori_merged, fg, bg): - ds = EditDataSample() - - ds.gt_alpha = PixelData(data=a) - ds.gt_merged = PixelData(data=m) - ds.gt_fg = PixelData(data=f) - ds.gt_bg = PixelData(data=b) - for k, v in meta.items(): - ds.set_field(name=k, value=v, field_type='metainfo', dtype=None) - + packinputs = PackEditInputs() + for _ in range(batch_size): + ds = packinputs(results)['data_samples'] + if cuda: + ds = ds.cuda() data_samples.append(ds) return inputs, data_samples @@ -64,24 +62,25 @@ def _demo_input_test(img_shape, batch_size=1, cuda=False, meta={}): color_shape = (batch_size, 3, img_shape[0], img_shape[1]) gray_shape = (batch_size, 1, img_shape[0], img_shape[1]) ori_shape = (img_shape[0], img_shape[1], 1) + merged = torch.from_numpy(np.random.random(color_shape).astype(np.float32)) trimap = torch.from_numpy( np.random.randint(255, size=gray_shape).astype(np.float32)) - ori_alpha = np.random.random(ori_shape).astype(np.float32) - ori_trimap = np.random.randint(256, size=ori_shape).astype(np.float32) + inputs = torch.cat((merged, trimap), dim=1) if cuda: - merged = merged.cuda() - trimap = trimap.cuda() - meta = dict( - ori_alpha=ori_alpha, - ori_trimap=ori_trimap, - ori_merged_shape=img_shape, - **meta) + inputs = inputs.cuda() + + results = dict( + ori_alpha=np.random.random(ori_shape).astype(np.float32), + ori_trimap=np.random.randint(256, size=ori_shape).astype(np.float32), + ori_merged_shape=img_shape) - inputs = torch.cat((merged, trimap), dim=1) data_samples = [] + packinputs = PackEditInputs() for _ in range(batch_size): - ds = EditDataSample(metainfo=meta) + ds = packinputs(results)['data_samples'] + if cuda: + ds = ds.cuda() data_samples.append(ds) return inputs, data_samples diff --git a/tests/test_structures/test_edit_data_sample.py b/tests/test_structures/test_edit_data_sample.py index 18cd8a97fe..3bd9d6dfbf 100644 --- a/tests/test_structures/test_edit_data_sample.py +++ b/tests/test_structures/test_edit_data_sample.py @@ -169,10 +169,6 @@ def test_setter(self): assert _equal(edit_data_sample.ignored_data.labels, ignored_data_data['labels']) - # test type error - with pytest.raises(AssertionError): - edit_data_sample.pred_img = torch.rand(2, 4) - # test shape error with pytest.raises(AssertionError): gt_img_data = dict( diff --git a/tests/test_utils/test_img_utils.py b/tests/test_utils/test_img_utils.py index b8382f571d..5ad011daac 100644 --- a/tests/test_utils/test_img_utils.py +++ b/tests/test_utils/test_img_utils.py @@ -3,7 +3,34 @@ import pytest import torch -from mmedit.utils import tensor2img, to_numpy +from mmedit.utils import (all_to_tensor, can_convert_to_image, tensor2img, + to_numpy) + + +def test_all_to_tensor(): + + data = [np.random.rand(64, 64, 3), np.random.rand(64, 64, 3)] + tensor = all_to_tensor(data) + assert tensor.shape == torch.Size([2, 3, 64, 64]) + + data = np.random.rand(64, 64, 3) + tensor = all_to_tensor(data) + assert tensor.shape == torch.Size([3, 64, 64]) + + data = 1 + tensor = all_to_tensor(data) + assert tensor == torch.tensor(1) + + +def test_can_convert_to_image(): + values = [ + np.random.rand(64, 64, 3), + [np.random.rand(64, 61, 3), + np.random.rand(64, 61, 3)], (64, 64), 'b' + ] + targets = [True, True, False, False] + for val, tar in zip(values, targets): + assert can_convert_to_image(val) == tar def test_tensor2img():