Skip to content

Commit

Permalink
[Enhancement] support all inpainting models inferencer (#1833)
Browse files Browse the repository at this point in the history
* [Enhancement] support all inpainting models inferencer

* [Enhancement] support all inpainting models inferencer
  • Loading branch information
Z-Fran authored May 4, 2023
1 parent 6a28cc6 commit 50e2b9b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
22 changes: 4 additions & 18 deletions mmagic/apis/inferencers/inpainting_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import torch
from mmengine import mkdir_or_exist
from mmengine.dataset import Compose
from mmengine.dataset.utils import default_collate as collate
from torch.nn.parallel import scatter

from mmagic.structures import DataSample
from mmagic.utils import tensor2img
from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType

Expand Down Expand Up @@ -53,25 +50,15 @@ def preprocess(self, img: InputsType, mask: InputsType) -> Dict:
# prepare data
_data = infer_pipeline(dict(gt_path=img, mask_path=mask))
data = dict()
data['inputs'] = _data['inputs'] / 255.0
data = collate([data])
data['inputs'] = [_data['inputs']]
data['data_samples'] = [_data['data_samples']]
if 'cuda' in str(self.device):
data = scatter(data, [self.device])[0]
data['data_samples'][0].mask.data = scatter(
data['data_samples'][0].mask.data, [self.device])[0] / 255.0

# save masks and masked_imgs to visualize
self.masks = data['data_samples'][0].mask.data * 255
self.masked_imgs = data['inputs'][0]

data['data_samples'] = DataSample.stack(data['data_samples'])
return data

def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
inputs = self.model.data_preprocessor(inputs)
with torch.no_grad():
result, x = self.model(mode='tensor', **inputs)
result = self.model(mode='predict', **inputs)
return result

def visualize(self,
Expand All @@ -89,8 +76,7 @@ def visualize(self,
Returns:
List[np.ndarray]: Result of visualize
"""
result = preds[0]
result = result * self.masks + self.masked_imgs * (1. - self.masks)
result = preds[0].output.pred_img / 255.

result = tensor2img(result)[..., ::-1]
if result_out_dir:
Expand Down
5 changes: 4 additions & 1 deletion mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ class MMagicInferencer:
'gca',

# inpainting models
'global_local',
'aot_gan',
'deepfillv1',
'deepfillv2',
'global_local',
'partial_conv',

# translation models
'pix2pix',
Expand Down

0 comments on commit 50e2b9b

Please sign in to comment.