Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Add type ints in these files: #2020

Merged
merged 6 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions mmcv/runner/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional

import torch.nn as nn

Expand All @@ -29,7 +30,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, init_cfg=None):
def __init__(self, init_cfg: Optional[dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""

# NOTE init_cfg can be defined in different levels, but init_cfg
Expand Down Expand Up @@ -133,7 +134,7 @@ def init_weights(self):
del sub_module._params_init_info

@master_only
def _dump_init_info(self, logger_name):
def _dump_init_info(self, logger_name: str):
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.

Expand Down Expand Up @@ -176,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, *args, init_cfg=None):
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)

Expand All @@ -189,7 +190,9 @@ class ModuleList(BaseModule, nn.ModuleList):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, modules=None, init_cfg=None):
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)

Expand All @@ -203,6 +206,8 @@ class ModuleDict(BaseModule, nn.ModuleDict):
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, modules=None, init_cfg=None):
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
119 changes: 80 additions & 39 deletions mmcv/runner/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import io
import logging
import os
import os.path as osp
import pkgutil
Expand All @@ -9,6 +10,7 @@
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Sequence, Union

import torch
import torchvision
Expand Down Expand Up @@ -37,7 +39,10 @@ def _get_mmcv_home():
return mmcv_home


def load_state_dict(module, state_dict, strict=False, logger=None):
def load_state_dict(module: torch.nn.Module,
state_dict: OrderedDict,
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module.

This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Expand All @@ -53,14 +58,14 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
unexpected_keys: List = []
all_missing_keys: List = []
err_msg: List = []

metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
state_dict._metadata = metadata # type: ignore

# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
Expand All @@ -78,7 +83,8 @@ def load(module, prefix=''):
load(child, prefix + name + '.')

load(module)
load = None # break load->load reference cycle
# break load->load reference cycle
load = None # type: ignore

# ignore "num_batches_tracked" of BN layers
missing_keys = [
Expand All @@ -96,7 +102,7 @@ def load(module, prefix=''):
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
err_msg = '\n'.join(err_msg) # type: ignore
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
Expand Down Expand Up @@ -220,13 +226,16 @@ def _register_scheme(cls, prefixes, loader, force=False):
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))

@classmethod
def register_scheme(cls, prefixes, loader=None, force=False):
def register_scheme(cls,
prefixes: Union[str, Sequence[str]],
loader: Optional[Callable] = None,
force: bool = False):
"""Register a loader to CheckpointLoader.

This method can be used as a normal class method or a decorator.

Args:
prefixes (str or list[str] or tuple[str]):
prefixes (str or Sequence[str]):
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
Expand Down Expand Up @@ -264,7 +273,12 @@ def _get_checkpoint_loader(cls, path):
return cls._schemes[p]

@classmethod
def load_checkpoint(cls, filename, map_location=None, logger=None):
def load_checkpoint(
cls,
filename: str,
map_location: Optional[str] = None,
logger: Optional[logging.Logger] = None
) -> Union[dict, OrderedDict]:
"""load checkpoint through URL scheme path.

Args:
Expand All @@ -286,7 +300,9 @@ def load_checkpoint(cls, filename, map_location=None, logger=None):


@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
def load_from_local(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint by local file path.

Args:
Expand All @@ -304,15 +320,18 @@ def load_from_local(filename, map_location):


@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename, map_location=None, model_dir=None):
def load_from_http(
filename: str,
map_location: Optional[str] = None,
model_dir: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.

Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
model_dir (string, optional): directory in which to save the object,
model_dir (str, optional): directory in which to save the object,
Default: None

Returns:
Expand All @@ -331,7 +350,9 @@ def load_from_http(filename, map_location=None, model_dir=None):


@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
def load_from_pavi(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Expand Down Expand Up @@ -363,7 +384,9 @@ def load_from_pavi(filename, map_location=None):


@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
def load_from_ceph(filename, map_location=None, backend='petrel'):
def load_from_ceph(filename: str,
map_location: Optional[str] = None,
backend: str = 'petrel') -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Expand All @@ -376,7 +399,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str, optional): The storage backend type. Options are 'ceph',
backend (str): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.

.. warning::
Expand Down Expand Up @@ -410,7 +433,9 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):


@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
def load_from_torchvision(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.

Expand Down Expand Up @@ -439,7 +464,9 @@ def load_from_torchvision(filename, map_location=None):


@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
def load_from_openmmlab(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.

Expand Down Expand Up @@ -481,7 +508,9 @@ def load_from_openmmlab(filename, map_location=None):


@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
def load_from_mmcls(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with mmcls.

Args:
Expand All @@ -500,7 +529,10 @@ def load_from_mmcls(filename, map_location=None):
return checkpoint


def _load_checkpoint(filename, map_location=None, logger=None):
def _load_checkpoint(
filename: str,
map_location: Optional[str] = None,
logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]:
"""Load checkpoint from somewhere (modelzoo, file, url).

Args:
Expand All @@ -520,7 +552,10 @@ def _load_checkpoint(filename, map_location=None, logger=None):
return CheckpointLoader.load_checkpoint(filename, map_location, logger)


def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
def _load_checkpoint_with_prefix(
prefix: str,
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""Load partial pretrained model with specific prefix.

Args:
Expand Down Expand Up @@ -553,12 +588,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
return state_dict


def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None,
revise_keys=[(r'^module\.', '')]):
def load_checkpoint(
model: torch.nn.Module,
filename: str,
map_location: Optional[str] = None,
strict: bool = False,
logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
"""Load checkpoint from a file or URI.

Args:
Expand Down Expand Up @@ -603,7 +639,7 @@ def load_checkpoint(model,
return checkpoint


def weights_to_cpu(state_dict):
def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu.

Args:
Expand All @@ -616,11 +652,13 @@ def weights_to_cpu(state_dict):
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu


def _save_to_state_dict(module, destination, prefix, keep_vars):
def _save_to_state_dict(module: torch.nn.Module, destination: dict,
prefix: str, keep_vars: bool) -> None:
"""Saves module state to `destination` dictionary.

This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Expand All @@ -640,7 +678,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars):
destination[prefix + name] = buf if keep_vars else buf.detach()


def get_state_dict(module, destination=None, prefix='', keep_vars=False):
def get_state_dict(module: torch.nn.Module,
destination: Optional[OrderedDict] = None,
prefix: str = '',
keep_vars: bool = False) -> OrderedDict:
"""Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are
Expand Down Expand Up @@ -669,8 +710,8 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
Expand All @@ -681,14 +722,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
return destination # type: ignore


def save_checkpoint(model,
filename,
optimizer=None,
meta=None,
file_client_args=None):
def save_checkpoint(model: torch.nn.Module,
filename: str,
optimizer: Optional[Optimizer] = None,
meta: Optional[dict] = None,
file_client_args: Optional[dict] = None) -> None:
"""Save checkpoint to file.

The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
Expand Down
4 changes: 3 additions & 1 deletion mmcv/runner/default_constructor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

from .builder import RUNNER_BUILDERS, RUNNERS


Expand Down Expand Up @@ -34,7 +36,7 @@ class DefaultRunnerConstructor:
>>> runner = build_runner(runner_cfg)
"""

def __init__(self, runner_cfg, default_args=None):
def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
Expand Down
Loading