diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index e3751dfa3f..7e64bdfb17 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -4,6 +4,7 @@ from abc import ABCMeta from collections import defaultdict from logging import FileHandler +from typing import Iterable, Optional import torch.nn as nn @@ -29,7 +30,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): init_cfg (dict, optional): Initialization config dict. """ - def __init__(self, init_cfg=None): + def __init__(self, init_cfg: Optional[dict] = None): """Initialize BaseModule, inherited from `torch.nn.Module`""" # NOTE init_cfg can be defined in different levels, but init_cfg @@ -133,7 +134,7 @@ def init_weights(self): del sub_module._params_init_info @master_only - def _dump_init_info(self, logger_name): + def _dump_init_info(self, logger_name: str): """Dump the initialization information to a file named `initialization.log.json` in workdir. @@ -176,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential): init_cfg (dict, optional): Initialization config dict. """ - def __init__(self, *args, init_cfg=None): + def __init__(self, *args, init_cfg: Optional[dict] = None): BaseModule.__init__(self, init_cfg) nn.Sequential.__init__(self, *args) @@ -189,7 +190,9 @@ class ModuleList(BaseModule, nn.ModuleList): init_cfg (dict, optional): Initialization config dict. """ - def __init__(self, modules=None, init_cfg=None): + def __init__(self, + modules: Optional[Iterable] = None, + init_cfg: Optional[dict] = None): BaseModule.__init__(self, init_cfg) nn.ModuleList.__init__(self, modules) @@ -203,6 +206,8 @@ class ModuleDict(BaseModule, nn.ModuleDict): init_cfg (dict, optional): Initialization config dict. """ - def __init__(self, modules=None, init_cfg=None): + def __init__(self, + modules: Optional[dict] = None, + init_cfg: Optional[dict] = None): BaseModule.__init__(self, init_cfg) nn.ModuleDict.__init__(self, modules) diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 31e395edc7..bc58c5406e 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import io +import logging import os import os.path as osp import pkgutil @@ -9,6 +10,7 @@ from collections import OrderedDict from importlib import import_module from tempfile import TemporaryDirectory +from typing import Callable, List, Optional, Sequence, Union import torch import torchvision @@ -37,7 +39,10 @@ def _get_mmcv_home(): return mmcv_home -def load_state_dict(module, state_dict, strict=False, logger=None): +def load_state_dict(module: torch.nn.Module, + state_dict: OrderedDict, + strict: bool = False, + logger: Optional[logging.Logger] = None) -> None: """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. @@ -53,14 +58,14 @@ def load_state_dict(module, state_dict, strict=False, logger=None): logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ - unexpected_keys = [] - all_missing_keys = [] - err_msg = [] + unexpected_keys: List = [] + all_missing_keys: List = [] + err_msg: List = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: - state_dict._metadata = metadata + state_dict._metadata = metadata # type: ignore # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): @@ -78,7 +83,8 @@ def load(module, prefix=''): load(child, prefix + name + '.') load(module) - load = None # break load->load reference cycle + # break load->load reference cycle + load = None # type: ignore # ignore "num_batches_tracked" of BN layers missing_keys = [ @@ -96,7 +102,7 @@ def load(module, prefix=''): if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') - err_msg = '\n'.join(err_msg) + err_msg = '\n'.join(err_msg) # type: ignore if strict: raise RuntimeError(err_msg) elif logger is not None: @@ -220,13 +226,16 @@ def _register_scheme(cls, prefixes, loader, force=False): sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) @classmethod - def register_scheme(cls, prefixes, loader=None, force=False): + def register_scheme(cls, + prefixes: Union[str, Sequence[str]], + loader: Optional[Callable] = None, + force: bool = False): """Register a loader to CheckpointLoader. This method can be used as a normal class method or a decorator. Args: - prefixes (str or list[str] or tuple[str]): + prefixes (str or Sequence[str]): The prefix of the registered loader. loader (function, optional): The loader function to be registered. When this method is used as a decorator, loader is None. @@ -264,7 +273,12 @@ def _get_checkpoint_loader(cls, path): return cls._schemes[p] @classmethod - def load_checkpoint(cls, filename, map_location=None, logger=None): + def load_checkpoint( + cls, + filename: str, + map_location: Optional[str] = None, + logger: Optional[logging.Logger] = None + ) -> Union[dict, OrderedDict]: """load checkpoint through URL scheme path. Args: @@ -286,7 +300,9 @@ def load_checkpoint(cls, filename, map_location=None, logger=None): @CheckpointLoader.register_scheme(prefixes='') -def load_from_local(filename, map_location): +def load_from_local( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: """load checkpoint by local file path. Args: @@ -304,7 +320,10 @@ def load_from_local(filename, map_location): @CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) -def load_from_http(filename, map_location=None, model_dir=None): +def load_from_http( + filename: str, + map_location: Optional[str] = None, + model_dir: Optional[str] = None) -> Union[dict, OrderedDict]: """load checkpoint through HTTP or HTTPS scheme path. In distributed setting, this function only download checkpoint at local rank 0. @@ -312,7 +331,7 @@ def load_from_http(filename, map_location=None, model_dir=None): filename (str): checkpoint file path with modelzoo or torchvision prefix map_location (str, optional): Same as :func:`torch.load`. - model_dir (string, optional): directory in which to save the object, + model_dir (str, optional): directory in which to save the object, Default: None Returns: @@ -331,7 +350,9 @@ def load_from_http(filename, map_location=None, model_dir=None): @CheckpointLoader.register_scheme(prefixes='pavi://') -def load_from_pavi(filename, map_location=None): +def load_from_pavi( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: """load checkpoint through the file path prefixed with pavi. In distributed setting, this function download ckpt at all ranks to different temporary directories. @@ -363,7 +384,9 @@ def load_from_pavi(filename, map_location=None): @CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://') -def load_from_ceph(filename, map_location=None, backend='petrel'): +def load_from_ceph(filename: str, + map_location: Optional[str] = None, + backend: str = 'petrel') -> Union[dict, OrderedDict]: """load checkpoint through the file path prefixed with s3. In distributed setting, this function download ckpt at all ranks to different temporary directories. @@ -376,7 +399,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): Args: filename (str): checkpoint file path with s3 prefix map_location (str, optional): Same as :func:`torch.load`. - backend (str, optional): The storage backend type. Options are 'ceph', + backend (str): The storage backend type. Options are 'ceph', 'petrel'. Default: 'petrel'. .. warning:: @@ -410,7 +433,9 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) -def load_from_torchvision(filename, map_location=None): +def load_from_torchvision( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: """load checkpoint through the file path prefixed with modelzoo or torchvision. @@ -439,7 +464,9 @@ def load_from_torchvision(filename, map_location=None): @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) -def load_from_openmmlab(filename, map_location=None): +def load_from_openmmlab( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: """load checkpoint through the file path prefixed with open-mmlab or openmmlab. @@ -481,7 +508,9 @@ def load_from_openmmlab(filename, map_location=None): @CheckpointLoader.register_scheme(prefixes='mmcls://') -def load_from_mmcls(filename, map_location=None): +def load_from_mmcls( + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: """load checkpoint through the file path prefixed with mmcls. Args: @@ -500,7 +529,10 @@ def load_from_mmcls(filename, map_location=None): return checkpoint -def _load_checkpoint(filename, map_location=None, logger=None): +def _load_checkpoint( + filename: str, + map_location: Optional[str] = None, + logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]: """Load checkpoint from somewhere (modelzoo, file, url). Args: @@ -520,7 +552,10 @@ def _load_checkpoint(filename, map_location=None, logger=None): return CheckpointLoader.load_checkpoint(filename, map_location, logger) -def _load_checkpoint_with_prefix(prefix, filename, map_location=None): +def _load_checkpoint_with_prefix( + prefix: str, + filename: str, + map_location: Optional[str] = None) -> Union[dict, OrderedDict]: """Load partial pretrained model with specific prefix. Args: @@ -553,12 +588,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None): return state_dict -def load_checkpoint(model, - filename, - map_location=None, - strict=False, - logger=None, - revise_keys=[(r'^module\.', '')]): +def load_checkpoint( + model: torch.nn.Module, + filename: str, + map_location: Optional[str] = None, + strict: bool = False, + logger: Optional[logging.Logger] = None, + revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]: """Load checkpoint from a file or URI. Args: @@ -603,7 +639,7 @@ def load_checkpoint(model, return checkpoint -def weights_to_cpu(state_dict): +def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict: """Copy a model state_dict to cpu. Args: @@ -616,11 +652,13 @@ def weights_to_cpu(state_dict): for key, val in state_dict.items(): state_dict_cpu[key] = val.cpu() # Keep metadata in state_dict - state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict()) + state_dict_cpu._metadata = getattr( # type: ignore + state_dict, '_metadata', OrderedDict()) return state_dict_cpu -def _save_to_state_dict(module, destination, prefix, keep_vars): +def _save_to_state_dict(module: torch.nn.Module, destination: dict, + prefix: str, keep_vars: bool) -> None: """Saves module state to `destination` dictionary. This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. @@ -640,7 +678,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars): destination[prefix + name] = buf if keep_vars else buf.detach() -def get_state_dict(module, destination=None, prefix='', keep_vars=False): +def get_state_dict(module: torch.nn.Module, + destination: Optional[OrderedDict] = None, + prefix: str = '', + keep_vars: bool = False) -> OrderedDict: """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are @@ -669,8 +710,8 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): # below is the same as torch.nn.Module.state_dict() if destination is None: destination = OrderedDict() - destination._metadata = OrderedDict() - destination._metadata[prefix[:-1]] = local_metadata = dict( + destination._metadata = OrderedDict() # type: ignore + destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore version=module._version) _save_to_state_dict(module, destination, prefix, keep_vars) for name, child in module._modules.items(): @@ -681,14 +722,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result - return destination + return destination # type: ignore -def save_checkpoint(model, - filename, - optimizer=None, - meta=None, - file_client_args=None): +def save_checkpoint(model: torch.nn.Module, + filename: str, + optimizer: Optional[Optimizer] = None, + meta: Optional[dict] = None, + file_client_args: Optional[dict] = None) -> None: """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and diff --git a/mmcv/runner/default_constructor.py b/mmcv/runner/default_constructor.py index 4a4f2cc646..394b51cfd7 100644 --- a/mmcv/runner/default_constructor.py +++ b/mmcv/runner/default_constructor.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + from .builder import RUNNER_BUILDERS, RUNNERS @@ -34,7 +36,7 @@ class DefaultRunnerConstructor: >>> runner = build_runner(runner_cfg) """ - def __init__(self, runner_cfg, default_args=None): + def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None): if not isinstance(runner_cfg, dict): raise TypeError('runner_cfg should be a dict', f'but got {type(runner_cfg)}') diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index 26d77a8f95..abed57d2ca 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -5,6 +5,7 @@ import socket import subprocess from collections import OrderedDict +from typing import Callable, List, Optional, Tuple import torch import torch.multiprocessing as mp @@ -33,7 +34,7 @@ def _is_free_port(port): return all(s.connect_ex((ip, port)) != 0 for ip in ips) -def init_dist(launcher, backend='nccl', **kwargs): +def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') if launcher == 'pytorch': @@ -46,7 +47,7 @@ def init_dist(launcher, backend='nccl', **kwargs): raise ValueError(f'Invalid launcher type: {launcher}') -def _init_dist_pytorch(backend, **kwargs): +def _init_dist_pytorch(backend: str, **kwargs): # TODO: use local_rank instead of rank % num_gpus rank = int(os.environ['RANK']) if IS_MLU_AVAILABLE: @@ -63,7 +64,7 @@ def _init_dist_pytorch(backend, **kwargs): dist.init_process_group(backend=backend, **kwargs) -def _init_dist_mpi(backend, **kwargs): +def _init_dist_mpi(backend: str, **kwargs): local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) torch.cuda.set_device(local_rank) if 'MASTER_PORT' not in os.environ: @@ -76,7 +77,7 @@ def _init_dist_mpi(backend, **kwargs): dist.init_process_group(backend=backend, **kwargs) -def _init_dist_slurm(backend, port=None): +def _init_dist_slurm(backend: str, port: Optional[int] = None): """Initialize slurm distributed training environment. If argument ``port`` is not specified, then the master port will be system @@ -115,7 +116,7 @@ def _init_dist_slurm(backend, port=None): dist.init_process_group(backend=backend) -def get_dist_info(): +def get_dist_info() -> Tuple[int, int]: if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() @@ -125,7 +126,7 @@ def get_dist_info(): return rank, world_size -def master_only(func): +def master_only(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): @@ -136,12 +137,14 @@ def wrapper(*args, **kwargs): return wrapper -def allreduce_params(params, coalesce=True, bucket_size_mb=-1): +def allreduce_params(params: List[torch.nn.Parameter], + coalesce: bool = True, + bucket_size_mb: int = -1) -> None: """Allreduce parameters. Args: - params (list[torch.Parameters]): List of parameters or buffers of a - model. + params (list[torch.nn.Parameter]): List of parameters or buffers + of a model. coalesce (bool, optional): Whether allreduce parameters as a whole. Defaults to True. bucket_size_mb (int, optional): Size of bucket, the unit is MB. @@ -158,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1): dist.all_reduce(tensor.div_(world_size)) -def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): +def allreduce_grads(params: List[torch.nn.Parameter], + coalesce: bool = True, + bucket_size_mb: int = -1) -> None: """Allreduce gradients. Args: - params (list[torch.Parameters]): List of parameters of a model + params (list[torch.nn.Parameter]): List of parameters of a model. coalesce (bool, optional): Whether allreduce parameters as a whole. Defaults to True. bucket_size_mb (int, optional): Size of bucket, the unit is MB. diff --git a/mmcv/runner/fp16_utils.py b/mmcv/runner/fp16_utils.py index be3ac3a314..9deb228a5d 100644 --- a/mmcv/runner/fp16_utils.py +++ b/mmcv/runner/fp16_utils.py @@ -3,10 +3,12 @@ import warnings from collections import abc from inspect import getfullargspec +from typing import Callable, Iterable, List, Optional import numpy as np import torch import torch.nn as nn +from torch.nn.parameter import Parameter from mmcv.utils import TORCH_VERSION, digit_version from .dist_utils import allreduce_grads as _allreduce_grads @@ -21,7 +23,7 @@ pass -def cast_tensor_type(inputs, src_type, dst_type): +def cast_tensor_type(inputs, src_type: torch.dtype, dst_type: torch.dtype): """Recursively convert Tensor in inputs from src_type to dst_type. Note: @@ -52,18 +54,22 @@ def cast_tensor_type(inputs, src_type, dst_type): elif isinstance(inputs, np.ndarray): return inputs elif isinstance(inputs, abc.Mapping): - return type(inputs)({ + return type(inputs)({ # type: ignore k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items() }) elif isinstance(inputs, abc.Iterable): - return type(inputs)( + return type(inputs)( # type: ignore cast_tensor_type(item, src_type, dst_type) for item in inputs) else: return inputs -def auto_fp16(apply_to=None, out_fp32=False, supported_types=(nn.Module, )): +def auto_fp16( + apply_to: Optional[Iterable] = None, + out_fp32: bool = False, + supported_types: tuple = (nn.Module, ), +) -> Callable: """Decorator to enable fp16 training automatically. This decorator is useful when you write custom modules and want to support @@ -150,7 +156,8 @@ def new_func(*args, **kwargs): return auto_fp16_wrapper -def force_fp32(apply_to=None, out_fp16=False): +def force_fp32(apply_to: Optional[Iterable] = None, + out_fp16: bool = False) -> Callable: """Decorator to convert input arguments to fp32 in force. This decorator is useful when you write custom modules and want to support @@ -236,15 +243,17 @@ def new_func(*args, **kwargs): return force_fp32_wrapper -def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): - warnings.warning( +def allreduce_grads(params: List[Parameter], + coalesce: bool = True, + bucket_size_mb: int = -1) -> None: + warnings.warn( '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads', DeprecationWarning) _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) -def wrap_fp16_model(model): +def wrap_fp16_model(model: nn.Module) -> None: """Wrap the FP32 model to FP16. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the @@ -273,7 +282,7 @@ def wrap_fp16_model(model): m.fp16_enabled = True -def patch_norm_fp32(module): +def patch_norm_fp32(module: nn.Module) -> nn.Module: """Recursively convert normalization layers from FP16 to FP32. Args: @@ -293,7 +302,10 @@ def patch_norm_fp32(module): return module -def patch_forward_method(func, src_type, dst_type, convert_output=True): +def patch_forward_method(func: Callable, + src_type: torch.dtype, + dst_type: torch.dtype, + convert_output: bool = True) -> Callable: """Patch the forward method of a module. Args: @@ -346,10 +358,10 @@ class LossScaler: """ def __init__(self, - init_scale=2**32, - mode='dynamic', - scale_factor=2., - scale_window=1000): + init_scale: float = 2**32, + mode: str = 'dynamic', + scale_factor: float = 2., + scale_window: int = 1000): self.cur_scale = init_scale self.cur_iter = 0 assert mode in ('dynamic', @@ -359,7 +371,7 @@ def __init__(self, self.scale_factor = scale_factor self.scale_window = scale_window - def has_overflow(self, params): + def has_overflow(self, params: List[Parameter]) -> bool: """Check if params contain overflow.""" if self.mode != 'dynamic': return False @@ -382,7 +394,7 @@ def _has_inf_or_nan(x): return True return False - def update_scale(self, overflow): + def update_scale(self, overflow: bool) -> None: """update the current loss scale value when overflow happens.""" if self.mode != 'dynamic': return @@ -405,7 +417,7 @@ def state_dict(self): scale_factor=self.scale_factor, scale_window=self.scale_window) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: dict) -> None: """Loads the loss_scaler state dict. Args: diff --git a/mmcv/runner/log_buffer.py b/mmcv/runner/log_buffer.py index d949e2941c..3c9f379637 100644 --- a/mmcv/runner/log_buffer.py +++ b/mmcv/runner/log_buffer.py @@ -12,16 +12,16 @@ def __init__(self): self.output = OrderedDict() self.ready = False - def clear(self): + def clear(self) -> None: self.val_history.clear() self.n_history.clear() self.clear_output() - def clear_output(self): + def clear_output(self) -> None: self.output.clear() self.ready = False - def update(self, vars, count=1): + def update(self, vars: dict, count: int = 1) -> None: assert isinstance(vars, dict) for key, var in vars.items(): if key not in self.val_history: @@ -30,7 +30,7 @@ def update(self, vars, count=1): self.val_history[key].append(var) self.n_history[key].append(count) - def average(self, n=0): + def average(self, n: int = 0) -> None: """Average latest n values or all values.""" assert n >= 0 for key in self.val_history: diff --git a/mmcv/runner/priority.py b/mmcv/runner/priority.py index 64cc4e3a05..ff644043b8 100644 --- a/mmcv/runner/priority.py +++ b/mmcv/runner/priority.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from enum import Enum +from typing import Union class Priority(Enum): @@ -39,7 +40,7 @@ class Priority(Enum): LOWEST = 100 -def get_priority(priority): +def get_priority(priority: Union[int, str, Priority]) -> int: """Get priority value. Args: