Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Remove unnecessary calls and lazily import to speed up import performance #837

Merged
merged 6 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import Callable, Dict, Optional

import torch
import torchvision

import mmengine
from mmengine.dist import get_dist_info
Expand Down Expand Up @@ -112,6 +111,7 @@ def load(module, prefix=''):


def get_torchvision_models():
import torchvision
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
Expand Down
3 changes: 2 additions & 1 deletion mmengine/utils/dl_utils/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import subprocess
import sys
from collections import OrderedDict, defaultdict
from distutils import errors

import cv2
import numpy as np
Expand Down Expand Up @@ -47,6 +46,8 @@ def collect_env():
- OpenCV (optional): OpenCV version.
- MMENGINE: MMENGINE version.
"""
from distutils import errors

env_info = OrderedDict()
env_info['sys.platform'] = sys.platform
env_info['Python'] = sys.version.replace('\n', '')
Expand Down
1 change: 0 additions & 1 deletion mmengine/utils/dl_utils/parrots_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def _get_norm() -> tuple:

_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()

Expand Down
12 changes: 9 additions & 3 deletions mmengine/utils/package_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import os.path as osp
import subprocess

import pkg_resources
from pkg_resources import get_distribution


def is_installed(package: str) -> bool:
"""Check package whether installed.

Args:
package (str): Name of package to be checked.
"""
# When executing `import mmengine.runner`,
# pkg_resources will be imported and it takes too much time.
# Therefore, import it in function scope to save time.
import pkg_resources
from pkg_resources import get_distribution

# refresh the pkg_resources
# more datails at https://github.com/pypa/setuptools/issues/373
importlib.reload(pkg_resources)
Expand All @@ -33,6 +36,8 @@ def get_installed_path(package: str) -> str:
>>> get_installed_path('mmcls')
>>> '.../lib/python3.7/site-packages/mmcls'
"""
from pkg_resources import get_distribution

# if the package name is not the same as module name, module name should be
# inferred. For example, mmcv-full is the package name, but mmcv is module
# name. If we want to get the installed path of mmcv-full, we should concat
Expand All @@ -51,6 +56,7 @@ def package2module(package: str):
Args:
package (str): Package to infer module name.
"""
from pkg_resources import get_distribution
pkg = get_distribution(package)
if pkg.has_metadata('top_level.txt'):
module_name = pkg.get_metadata('top_level.txt').split('\n')[0]
Expand Down
14 changes: 8 additions & 6 deletions mmengine/visualization/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union

import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.backend_bases import CloseEvent
from matplotlib.backends.backend_agg import FigureCanvasAgg

if TYPE_CHECKING:
from matplotlib.backends.backend_agg import FigureCanvasAgg


def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
Expand Down Expand Up @@ -131,6 +130,7 @@ def color_str2rgb(color: str) -> tuple:
Returns:
tuple: RGB color.
"""
import matplotlib
rgb_color: tuple = matplotlib.colors.to_rgb(color)
rgb_color = tuple(int(c * 255) for c in rgb_color)
return rgb_color
Expand Down Expand Up @@ -186,6 +186,8 @@ def wait_continue(figure, timeout: int = 0, continue_key: str = ' ') -> int:
int: If zero, means time out or the user pressed ``continue_key``,
and if one, means the user closed the show figure.
""" # noqa: E501
import matplotlib.pyplot as plt
from matplotlib.backend_bases import CloseEvent
is_inline = 'inline' in plt.get_backend()
if is_inline:
# If use inline backend, interactive input and timeout is no use.
Expand Down Expand Up @@ -226,7 +228,7 @@ def handler(ev):
return 0 # Quit for continue.


def img_from_canvas(canvas: FigureCanvasAgg) -> np.ndarray:
def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray:
"""Get RGB image from ``FigureCanvasAgg``.

Args:
Expand Down
18 changes: 10 additions & 8 deletions mmengine/visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.collections import (LineCollection, PatchCollection,
PolyCollection)
from matplotlib.figure import Figure
from matplotlib.patches import Circle
from matplotlib.pyplot import new_figure_manager

from mmengine.config import Config
from mmengine.dist import master_only
Expand Down Expand Up @@ -240,6 +233,7 @@ def show(self,
continue_key (str): The key for users to continue. Defaults to
the space key.
"""
import matplotlib.pyplot as plt
is_inline = 'inline' in plt.get_backend()
img = self.get_image() if drawn_img is None else drawn_img
self._init_manager(win_name)
Expand Down Expand Up @@ -302,7 +296,8 @@ def _initialize_fig(self, fig_cfg) -> tuple:
Returns:
tuple: build canvas figure and axes.
"""

from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
fig = Figure(**fig_cfg)
ax = fig.add_subplot()
ax.axis(False)
Expand All @@ -318,6 +313,8 @@ def _init_manager(self, win_name: str) -> None:
Args:
win_name (str): The window name.
"""
from matplotlib.figure import Figure
from matplotlib.pyplot import new_figure_manager
if getattr(self, 'manager', None) is None:
self.manager = new_figure_manager(
num=1, FigureClass=Figure, **self.fig_show_cfg)
Expand Down Expand Up @@ -546,6 +543,7 @@ def draw_lines(
If ``line_widths`` is single value, all the lines will
have the same linewidth. Defaults to 2.
"""
from matplotlib.collections import LineCollection
check_type('x_datas', x_datas, (np.ndarray, torch.Tensor))
x_datas = tensor2ndarray(x_datas)
check_type('y_datas', y_datas, (np.ndarray, torch.Tensor))
Expand Down Expand Up @@ -614,6 +612,8 @@ def draw_circles(
alpha (Union[int, float]): The transparency of circles.
Defaults to 0.8.
"""
from matplotlib.collections import PatchCollection
from matplotlib.patches import Circle
check_type('center', center, (np.ndarray, torch.Tensor))
center = tensor2ndarray(center)
check_type('radius', radius, (np.ndarray, torch.Tensor))
Expand Down Expand Up @@ -760,6 +760,7 @@ def draw_polygons(
alpha (Union[int, float]): The transparency of polygons.
Defaults to 0.8.
"""
from matplotlib.collections import PolyCollection
check_type('polygons', polygons, (list, np.ndarray, torch.Tensor))
edge_colors = color_val_matplotlib(edge_colors) # type: ignore
face_colors = color_val_matplotlib(face_colors) # type: ignore
Expand Down Expand Up @@ -916,6 +917,7 @@ def draw_featmap(featmap: torch.Tensor,
Returns:
np.ndarray: RGB image.
"""
import matplotlib.pyplot as plt
assert isinstance(featmap,
torch.Tensor), (f'`featmap` should be torch.Tensor,'
f' but got {type(featmap)}')
Expand Down