Skip to content

Commit

Permalink
[Refactor] Refactor PackEditInputs and EditDataSample (#1573)
Browse files Browse the repository at this point in the history
* refactor packeditinput and editdatasample

* move util from formatting to img_utils

* fix bugs in ut

* use permute instead of transpose in all_to_tensor

* remove undesired output results

* fix the path of results in ut

* use np.ascontiguousarray(img) in image_to_tensor

* remove entry function from unit test files

* refine sinGAN's config

Co-authored-by: LeoXing1996 <xingzn1996@hotmail.com>
  • Loading branch information
zengyh1900 and LeoXing1996 committed Feb 7, 2023
1 parent be2e9f8 commit b219a43
Show file tree
Hide file tree
Showing 30 changed files with 424 additions and 1,062 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ coverage.xml
*.cover
.hypothesis/
.pytest_cache/
tests/data/out

# Translations
*.mo
Expand Down
7 changes: 6 additions & 1 deletion configs/singan/singan_balloons.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
test_pkl_data=test_pkl_data)

# DATA
pipeline = [
dict(
type='PackEditInputs',
keys=[f'real_scale{i}' for i in range(num_scales)] + ['input_sample'])
]
data_root = './data/singan/balloons.png'
train_dataloader = dict(dataset=dict(data_root=data_root))
train_dataloader = dict(dataset=dict(data_root=data_root, pipeline=pipeline))

# HOOKS
custom_hooks = [
Expand Down
6 changes: 5 additions & 1 deletion configs/singan/singan_fish.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@
dataset_type = 'SinGANDataset'
data_root = './data/singan/fish-crop.jpg'

pipeline = [dict(type='PackEditInputs', pack_all=True)]
pipeline = [
dict(
type='PackEditInputs',
keys=[f'real_scale{i}' for i in range(num_scales)] + ['input_sample'])
]
dataset = dict(
type=dataset_type,
data_root=data_root,
Expand Down
2 changes: 1 addition & 1 deletion mmedit/apis/inferencers/inference_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def matting_inference(model, img, trimap):
# prepare data
data = dict(merged_path=img, trimap_path=trimap)
_data = test_pipeline(data)
trimap = _data['data_samples'].trimap.data
trimap = _data['data_samples'].trimap
data = dict()
data['inputs'] = torch.cat([_data['inputs'], trimap], dim=0).float()
data = collate([data])
Expand Down
2 changes: 1 addition & 1 deletion mmedit/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RandomResizedCrop)
from .fgbg import (CompositeFg, MergeFgAndBg, PerturbBg, RandomJitter,
RandomLoadResizeBg)
from .formatting import PackEditInputs, ToTensor
from .formatting import PackEditInputs
from .generate_assistant import (GenerateCoordinateAndCell,
GenerateFacialHeatmap)
from .generate_frame_indices import (GenerateFrameIndices,
Expand Down
Loading

0 comments on commit b219a43

Please sign in to comment.