From 0bf25f5d5e9d9aef19cc272d24319dd767c770cc Mon Sep 17 00:00:00 2001 From: zengyh1900 Date: Fri, 10 Nov 2023 19:36:53 +0800 Subject: [PATCH 1/2] support using from_pretrained for instance_crop --- .../inst-colorizatioon_full_official_cocostuff-256x256.py | 1 + mmagic/datasets/transforms/crop.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py index f754c38b44..5c4a07670a 100644 --- a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py @@ -48,6 +48,7 @@ dict( type='InstanceCrop', config_file='mmdet::mask_rcnn/mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py', # noqa + from_pretrained=None, finesize=256, box_num_upbound=5), dict( diff --git a/mmagic/datasets/transforms/crop.py b/mmagic/datasets/transforms/crop.py index 8af6ced157..5bc2b9eda9 100644 --- a/mmagic/datasets/transforms/crop.py +++ b/mmagic/datasets/transforms/crop.py @@ -958,6 +958,7 @@ class InstanceCrop(BaseTransform): def __init__(self, config_file, + from_pretrained=None, key='img', box_num_upbound=-1, finesize=256): @@ -967,6 +968,11 @@ def __init__(self, "\"mim install 'mmdet >= 3.0.0'\".") cfg = get_config(config_file, pretrained=True) + + # loading checkpoint from local path + if from_pretrained is not None: + cfg.model.backbone.init_cfg.checkpoint=from_pretrained + with DefaultScope.overwrite_default_scope('mmdet'): self.predictor = mmdet_apis.init_detector(cfg, cfg.model_path) From 9f782fcb8e66c15f456416f898f54feccc7b88e7 Mon Sep 17 00:00:00 2001 From: zengyh1900 Date: Fri, 10 Nov 2023 19:46:44 +0800 Subject: [PATCH 2/2] fix lint --- mmagic/datasets/transforms/crop.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmagic/datasets/transforms/crop.py b/mmagic/datasets/transforms/crop.py index 5bc2b9eda9..22dac70ccc 100644 --- a/mmagic/datasets/transforms/crop.py +++ b/mmagic/datasets/transforms/crop.py @@ -968,11 +968,11 @@ def __init__(self, "\"mim install 'mmdet >= 3.0.0'\".") cfg = get_config(config_file, pretrained=True) - + # loading checkpoint from local path - if from_pretrained is not None: - cfg.model.backbone.init_cfg.checkpoint=from_pretrained - + if from_pretrained is not None: + cfg.model.backbone.init_cfg.checkpoint = from_pretrained + with DefaultScope.overwrite_default_scope('mmdet'): self.predictor = mmdet_apis.init_detector(cfg, cfg.model_path)