Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tripleMu committed May 29, 2022
1 parent 1e922b8 commit 92ba607
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 122 deletions.
21 changes: 13 additions & 8 deletions mmcv/runner/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Dict, Iterable, Optional

import torch.nn as nn

Expand All @@ -29,7 +30,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, init_cfg=None):
def __init__(self, init_cfg: Optional[Dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""

# NOTE init_cfg can be defined in different levels, but init_cfg
Expand All @@ -49,10 +50,10 @@ def __init__(self, init_cfg=None):
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

@property
def is_init(self):
def is_init(self) -> bool:
return self._is_init

def init_weights(self):
def init_weights(self) -> None:
"""Initialize the weights."""

is_top_level_module = False
Expand All @@ -67,7 +68,7 @@ def init_weights(self):
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info = defaultdict(dict)
self._params_init_info: defaultdict = defaultdict(dict)
is_top_level_module = True

# Initialize the `_params_init_info`,
Expand Down Expand Up @@ -133,7 +134,7 @@ def init_weights(self):
del sub_module._params_init_info

@master_only
def _dump_init_info(self, logger_name):
def _dump_init_info(self, logger_name: str) -> None:
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
Expand Down Expand Up @@ -176,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, *args, init_cfg=None):
def __init__(self, *args, init_cfg: Optional[Dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)

Expand All @@ -189,7 +190,9 @@ class ModuleList(BaseModule, nn.ModuleList):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, modules=None, init_cfg=None):
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[Dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)

Expand All @@ -203,6 +206,8 @@ class ModuleDict(BaseModule, nn.ModuleDict):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, modules=None, init_cfg=None):
def __init__(self,
modules: Optional[Dict] = None,
init_cfg: Optional[Dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
5 changes: 3 additions & 2 deletions mmcv/runner/builder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional

from ..utils import Registry

RUNNERS = Registry('runner')
RUNNER_BUILDERS = Registry('runner builder')


def build_runner_constructor(cfg):
def build_runner_constructor(cfg: dict):
return RUNNER_BUILDERS.build(cfg)


def build_runner(cfg, default_args=None):
def build_runner(cfg: dict, default_args: Optional[dict] = None):
runner_cfg = copy.deepcopy(cfg)
constructor_type = runner_cfg.pop('constructor',
'DefaultRunnerConstructor')
Expand Down
Loading

0 comments on commit 92ba607

Please sign in to comment.