From a17e934a5b3b8d317865ea9b0019d089c8a79d5c Mon Sep 17 00:00:00 2001 From: rangoliu Date: Tue, 21 Feb 2023 10:43:07 +0800 Subject: [PATCH] [Refact] support all models in translation inferencer (#1650) * fix translation load bug. * fix lint. * support cyclegan inference. * fix lint. --- mmedit/apis/inferencers/mmedit_inferencer.py | 2 +- mmedit/apis/inferencers/translation_inferencer.py | 9 ++++----- mmedit/edit.py | 1 + 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mmedit/apis/inferencers/mmedit_inferencer.py b/mmedit/apis/inferencers/mmedit_inferencer.py index 3d29237800..0da98016aa 100644 --- a/mmedit/apis/inferencers/mmedit_inferencer.py +++ b/mmedit/apis/inferencers/mmedit_inferencer.py @@ -52,7 +52,7 @@ def __init__(self, elif self.task in ['inpainting', 'Inpainting']: self.inferencer = InpaintingInferencer( config, ckpt, device, extra_parameters, seed=seed) - elif self.task in ['translation', 'Image2Image Translation']: + elif self.task in ['translation', 'Image2Image']: self.inferencer = TranslationInferencer( config, ckpt, device, extra_parameters, seed=seed) elif self.task in ['restoration', 'Image Super-Resolution']: diff --git a/mmedit/apis/inferencers/translation_inferencer.py b/mmedit/apis/inferencers/translation_inferencer.py index a48e445359..63976fba40 100644 --- a/mmedit/apis/inferencers/translation_inferencer.py +++ b/mmedit/apis/inferencers/translation_inferencer.py @@ -42,16 +42,15 @@ def preprocess(self, img: InputsType) -> Dict: test_pipeline = Compose(cfg.test_pipeline) # prepare data - data = dict() # dirty code to deal with test data pipeline + data = dict() data['pair_path'] = img - data[f'img_{source_domain}_path'] = img - data[f'img_{self.target_domain}_path'] = img - + data['img_A_path'] = img + data['img_B_path'] = img data = collate([test_pipeline(data)]) data = self.model.data_preprocessor(data, False) - inputs_dict = data['inputs'] + inputs_dict = data['inputs'] results = inputs_dict[f'img_{source_domain}'] return results diff --git a/mmedit/edit.py b/mmedit/edit.py index 2207f9610d..44ccf173ec 100644 --- a/mmedit/edit.py +++ b/mmedit/edit.py @@ -57,6 +57,7 @@ class MMEdit: # translation models 'pix2pix', + 'cyclegan', # restoration models 'esrgan',