Skip to content

Commit

Permalink
add conditional models
Browse files Browse the repository at this point in the history
  • Loading branch information
liuwenran committed Feb 21, 2023
1 parent a17e934 commit a0d47c2
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mmedit/apis/inferencers/base_mmedit_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchvision import utils

from mmedit.registry import MODELS
from mmedit.utils import ConfigType, SampleList
from mmedit.utils import ConfigType, SampleList, register_all_modules
from .inference_functions import set_random_seed

InputType = Union[str, int, np.ndarray]
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self,
device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
register_all_modules()
super().__init__(config, ckpt, device)

self._init_extra_parameters(extra_parameters)
Expand Down
2 changes: 1 addition & 1 deletion mmedit/apis/inferencers/conditional_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ConditionalInferencer(BaseMMEditInferencer):
visualize=['result_out_dir'],
postprocess=[])

extra_parameters = dict(num_batches=4, sample_model='ema')
extra_parameters = dict(num_batches=4, sample_model='orig')

def preprocess(self, label: InputsType) -> Dict:
"""Process the inputs into a model-feedable format.
Expand Down
2 changes: 2 additions & 0 deletions mmedit/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class MMEdit:

# conditional models
'biggan',
'sngan_proj',
'sagan',

# unconditional models
'styleganv1',
Expand Down

0 comments on commit a0d47c2

Please sign in to comment.