diff --git a/.dev_scripts/download_models.py b/.dev_scripts/download_models.py index ead0796c62..f3349aaf15 100644 --- a/.dev_scripts/download_models.py +++ b/.dev_scripts/download_models.py @@ -76,6 +76,7 @@ def download(args): http_prefix_long = 'https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/' # noqa http_prefix_short = 'https://download.openmmlab.com/mmediting/' + http_prefix_gen = 'https://download.openmmlab.com/mmgen/' # load model list if args.model_list: @@ -112,6 +113,11 @@ def download(args): model_name = model_weight_url[len(http_prefix_long):] elif model_weight_url.startswith(http_prefix_short): model_name = model_weight_url[len(http_prefix_short):] + elif model_weight_url.startswith(http_prefix_gen): + model_name = model_weight_url[len(http_prefix_gen):] + elif model_weight_url == '': + print(f'{model_info.Name} weight is missing') + return None else: raise ValueError(f'Unknown url prefix. \'{model_weight_url}\'') diff --git a/.dev_scripts/test_benchmark.py b/.dev_scripts/test_benchmark.py index 50bc913390..ae36afb13b 100644 --- a/.dev_scripts/test_benchmark.py +++ b/.dev_scripts/test_benchmark.py @@ -100,12 +100,18 @@ def create_test_job_batch(commands, model_info, args, port, script_name): http_prefix_short = 'https://download.openmmlab.com/mmediting/' http_prefix_long = 'https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/' # noqa + http_prefix_gen = 'https://download.openmmlab.com/mmgen/' model_weight_url = model_info.weights if model_weight_url.startswith(http_prefix_long): model_name = model_weight_url[len(http_prefix_long):] elif model_weight_url.startswith(http_prefix_short): model_name = model_weight_url[len(http_prefix_short):] + elif model_weight_url.startswith(http_prefix_gen): + model_name = model_weight_url[len(http_prefix_gen):] + elif model_weight_url == '': + print(f'{fname} weight is missing') + return None else: raise ValueError(f'Unknown url prefix. \'{model_weight_url}\'') diff --git a/.dev_scripts/train_benchmark.py b/.dev_scripts/train_benchmark.py index d8f5eca71a..a7acc7e759 100644 --- a/.dev_scripts/train_benchmark.py +++ b/.dev_scripts/train_benchmark.py @@ -171,14 +171,19 @@ def create_train_job_batch(commands, model_info, args, port, script_name): config = Path(config) assert config.exists(), f'{fname}: {config} not found.' - # get n gpus try: n_gpus = int(model_info.metadata.data['GPUs'].split()[0]) except Exception: if 'official' in model_info.config: return None else: - n_gpus = 1 + pattern = r'\d+xb\d+' + parse_res = re.search(pattern, config.name) + if not parse_res: + # defaults to use 1 gpu + n_gpus = 1 + else: + n_gpus = int(parse_res.group().split('x')[0]) if args.gpus_per_job is not None: n_gpus = min(args.gpus_per_job, n_gpus) diff --git a/tools/test.py b/tools/test.py index 58dcdb9d8c..f353544daa 100644 --- a/tools/test.py +++ b/tools/test.py @@ -5,6 +5,7 @@ import mmengine from mmengine.config import Config, DictAction +from mmengine.hooks import Hook from mmengine.runner import Runner from mmedit.utils import print_colored_log, register_all_modules @@ -73,7 +74,7 @@ def main(): if args.out: - class SaveMetricHook(mmengine.Hook): + class SaveMetricHook(Hook): def after_test_epoch(self, _, metrics=None): if metrics is not None: