Skip to content

Commit

Permalink
[Refactor] Delete redundant set_random_seed function (open-mmlab#104)
Browse files Browse the repository at this point in the history
* refactor set_random_seed

* add unittests

* fix unittests error

* fix lint

* avoid bc breaking
  • Loading branch information
wutongshenqiu authored and pppppM committed Mar 27, 2022
1 parent 077c870 commit 0cbe900
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 46 deletions.
1 change: 1 addition & 0 deletions mmrazor/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .mmcls import * # noqa: F401,F403
from .mmdet import * # noqa: F401,F403
from .mmseg import * # noqa: F401,F403
from .utils import set_random_seed # noqa: F401
4 changes: 2 additions & 2 deletions mmrazor/apis/mmcls/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import init_mmcls_model
from .train import set_random_seed, train_model
from .train import set_random_seed, train_mmcls_model

__all__ = ['set_random_seed', 'train_model', 'init_mmcls_model']
__all__ = ['train_mmcls_model', 'init_mmcls_model', 'set_random_seed']
24 changes: 15 additions & 9 deletions mmrazor/apis/mmcls/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@


def set_random_seed(seed, deterministic=False):
"""Set random seed.
"""Import `set_random_seed` function here was deprecated in v0.3 and will
be removed in v0.5.
Args:
seed (int): Seed to be used.
Expand All @@ -29,6 +30,11 @@ def set_random_seed(seed, deterministic=False):
to True and ``torch.backends.cudnn.benchmark`` to False.
Default: False.
"""
warnings.warn(
'Deprecated in v0.3 and will be removed in v0.5, '
'please import `set_random_seed` directly from `mmrazor.apis`',
category=DeprecationWarning)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -38,14 +44,14 @@ def set_random_seed(seed, deterministic=False):
torch.backends.cudnn.benchmark = False


def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
device='cuda',
meta=None):
def train_mmcls_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
device='cuda',
meta=None):
"""Copy from mmclassification and modify some codes.
This is an ugly implementation, and will be deprecated in the future. In
Expand Down
4 changes: 2 additions & 2 deletions mmrazor/apis/mmdet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

if mmdet is not None:
from .inference import init_mmdet_model
from .train import set_random_seed, train_detector
from .train import set_random_seed, train_mmdet_model

__all__ = ['set_random_seed', 'train_detector', 'init_mmdet_model']
__all__ = ['train_mmdet_model', 'init_mmdet_model', 'set_random_seed']
22 changes: 14 additions & 8 deletions mmrazor/apis/mmdet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


def set_random_seed(seed, deterministic=False):
"""Set random seed.
"""Import `set_random_seed` function here was deprecated in v0.3 and will
be removed in v0.5.
Args:
seed (int): Seed to be used.
Expand All @@ -28,6 +29,11 @@ def set_random_seed(seed, deterministic=False):
to True and ``torch.backends.cudnn.benchmark`` to False.
Default: False.
"""
warnings.warn(
'Deprecated in v0.3 and will be removed in v0.5, '
'please import `set_random_seed` directly from `mmrazor.apis`',
category=DeprecationWarning)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -37,13 +43,13 @@ def set_random_seed(seed, deterministic=False):
torch.backends.cudnn.benchmark = False


def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
def train_mmdet_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Copy from mmdetection and modify some codes.
This is an ugly implementation, and will be deprecated in the future. In
Expand Down
4 changes: 2 additions & 2 deletions mmrazor/apis/mmseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

if mmseg:
from .inference import init_mmseg_model
from .train import set_random_seed, train_segmentor
from .train import set_random_seed, train_mmseg_model

__all__ = ['set_random_seed', 'train_segmentor', 'init_mmseg_model']
__all__ = ['train_mmseg_model', 'init_mmseg_model', 'set_random_seed']
22 changes: 14 additions & 8 deletions mmrazor/apis/mmseg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


def set_random_seed(seed, deterministic=False):
"""Set random seed.
"""Import `set_random_seed` function here was deprecated in v0.3 and will
be removed in v0.5.
Args:
seed (int): Seed to be used.
Expand All @@ -25,6 +26,11 @@ def set_random_seed(seed, deterministic=False):
to True and ``torch.backends.cudnn.benchmark`` to False.
Default: False.
"""
warnings.warn(
'Deprecated in v0.3 and will be removed in v0.5, '
'please import `set_random_seed` directly from `mmrazor.apis`',
category=DeprecationWarning)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -34,13 +40,13 @@ def set_random_seed(seed, deterministic=False):
torch.backends.cudnn.benchmark = False


def train_segmentor(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
def train_mmseg_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Copy from mmsegmentation and modify some codes.
This is an ugly implementation, and will be deprecated in the future. In
Expand Down
24 changes: 24 additions & 0 deletions mmrazor/apis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random

import numpy as np
import torch


def set_random_seed(seed: int, deterministic: bool = False) -> None:
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set ``torch.backends.cudnn.deterministic``
to True and ``torch.backends.cudnn.benchmark`` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
18 changes: 18 additions & 0 deletions tests/deprecated_api/test_remove_0-5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest


def test_v0_5_deprecated_set_random_seed() -> None:
warn_msg = 'Deprecated in v0.3 and will be removed in v0.5, ' \
'please import `set_random_seed` directly from `mmrazor.apis`'
from mmrazor.apis.mmcls import set_random_seed
with pytest.deprecated_call(match=warn_msg):
set_random_seed(123)

from mmrazor.apis.mmdet import set_random_seed
with pytest.deprecated_call(match=warn_msg):
set_random_seed(123)

from mmrazor.apis.mmseg import set_random_seed
with pytest.deprecated_call(match=warn_msg):
set_random_seed(123)
30 changes: 21 additions & 9 deletions tests/test_apis/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,26 @@
from mmrazor.apis import init_mmcls_model, init_mmdet_model, init_mmseg_model


def test_init_mmcls_model():
def _sync_bn2bn(config: mmcv.Config) -> None:

def dfs(cfg_dict) -> None:
if isinstance(cfg_dict, dict):
for k, v in cfg_dict.items():
if k == 'norm_cfg':
if v['type'] == 'SyncBN':
v['type'] = 'BN'
dfs(v)

dfs(config._cfg_dict)


def test_init_mmcls_model() -> None:
from mmcls.datasets import ImageNet

config_file = 'configs/nas/spos/spos_subnet_shufflenetv2_8xb128_in1k.py'
config = mmcv.Config.fromfile(config_file)
# Replace SyncBN with BN to inference on CPU
_sync_bn2bn(config)

mutable_file = 'configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml'
model = init_mmcls_model(
Expand All @@ -34,10 +49,12 @@ def test_init_mmcls_model():
assert result.get('pred_class') is not None


def test_init_mmdet_model():
def test_init_mmdet_model() -> None:
config_file = \
'configs/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py'
config = mmcv.Config.fromfile(config_file)
# Replace SyncBN with BN to inference on CPU
_sync_bn2bn(config)

mutable_file = \
'configs/nas/detnas/DETNAS_FRCNN_SHUFFLENETV2_340M_COCO_MMRAZOR.yaml'
Expand All @@ -52,17 +69,12 @@ def test_init_mmdet_model():
assert isinstance(result, list)


def test_init_mmseg_model():
def test_init_mmseg_model() -> None:
config_file = 'configs/distill/cwd/' \
'cwd_cls_head_pspnet_r101_d8_pspnet_r18_d8_512x1024_cityscapes_80k.py'
config = mmcv.Config.fromfile(config_file)

# Replace SyncBN with BN to inference on CPU
norm_cfg = dict(type='BN', requires_grad=True)
model_config = config.algorithm.architecture
model_config.model.backbone.norm_cfg = norm_cfg
model_config.model.decode_head.norm_cfg = norm_cfg
model_config.model.auxiliary_head.norm_cfg = norm_cfg
_sync_bn2bn(config)

# Enable test time augmentation
config.data.test.pipeline[1].flip = True
Expand Down
20 changes: 20 additions & 0 deletions tests/test_apis/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch

from mmrazor.apis import set_random_seed


def test_set_random_seed() -> None:
set_random_seed(123, False)
x1 = torch.rand(3, 3)
x2 = np.random.rand(3, 3)

set_random_seed(123, True)
assert torch.backends.cudnn.deterministic
assert not torch.backends.cudnn.benchmark
y1 = torch.rand(3, 3)
y2 = np.random.rand(3, 3)

assert torch.allclose(x1, y1, 1e-6)
assert np.allclose(x2, y2, 1e-6)
4 changes: 2 additions & 2 deletions tools/mmcls/train_mmcls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mmcv.runner import get_dist_info, init_dist

# Differences from mmclassification
from mmrazor.apis.mmcls.train import set_random_seed, train_model
from mmrazor.apis import set_random_seed, train_mmcls_model
from mmrazor.models import build_algorithm
from mmrazor.utils import setup_multi_processes

Expand Down Expand Up @@ -179,7 +179,7 @@ def main():
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
train_model(
train_mmcls_model(
# Difference from mmclassification
# replace `model` to `algorithm`
algorithm,
Expand Down
4 changes: 2 additions & 2 deletions tools/mmdet/train_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from mmdet.datasets import build_dataset
from mmdet.utils import collect_env, get_root_logger

from mmrazor.apis import set_random_seed, train_detector
from mmrazor.apis import set_random_seed, train_mmdet_model
from mmrazor.models import build_algorithm
from mmrazor.utils import setup_multi_processes

Expand Down Expand Up @@ -190,7 +190,7 @@ def main():
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
algorithm.CLASSES = datasets[0].CLASSES
train_detector(
train_mmdet_model(
algorithm,
datasets,
cfg,
Expand Down
4 changes: 2 additions & 2 deletions tools/mmseg/train_mmseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mmseg.utils import collect_env, get_root_logger

# Differences from mmdetection
from mmrazor.apis.mmseg import set_random_seed, train_segmentor
from mmrazor.apis import set_random_seed, train_mmseg_model
from mmrazor.models.builder import build_algorithm
from mmrazor.utils import setup_multi_processes

Expand Down Expand Up @@ -210,7 +210,7 @@ def main():
algorithm.CLASSES = datasets[0].CLASSES
# passing checkpoint meta for saving best checkpoint
meta.update(cfg.checkpoint_config.meta)
train_segmentor(
train_mmseg_model(
# Difference from mmsegmentation
# replace `model` to `algorithm`
algorithm,
Expand Down

0 comments on commit 0cbe900

Please sign in to comment.