Skip to content

Commit

Permalink
lazily infer client until calling get function of fileio
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Jan 16, 2023
1 parent a79c9f9 commit a4cc3b0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 35 deletions.
60 changes: 27 additions & 33 deletions mmcv/transforms/loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional, Union
from typing import Optional

import mmengine.fileio as fileio
import numpy as np
Expand Down Expand Up @@ -63,6 +63,8 @@ def __init__(self,
self.color_type = color_type
self.imdecode_backend = imdecode_backend

self.file_client_args: Optional[dict] = None
self.backend_args: Optional[dict] = None
if file_client_args is not None:
warnings.warn(
'"file_client_args" will be deprecated in future. '
Expand All @@ -71,21 +73,10 @@ def __init__(self,
raise ValueError(
'"file_client_args" and "backend_args" cannot be set '
'at the same time.')
else:
file_client_args = dict(backend='disk')
if backend_args is None:
backend_args = dict(backend='disk')

self.file_client_args = file_client_args.copy()
self.file_client = fileio.FileClient(**self.file_client_args)
self.backend_args = backend_args.copy()

self.file_backend: Union[fileio.FileClient, fileio.BaseStorageBackend]
if self.file_client_args is None:
self.file_backend = fileio.get_file_backend(
backend_args=self.backend_args)
else:
self.file_backend = self.file_client

self.file_client_args = file_client_args.copy()
if backend_args is not None:
self.backend_args = backend_args.copy()

def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image.
Expand All @@ -100,7 +91,13 @@ def transform(self, results: dict) -> Optional[dict]:

filename = results['img_path']
try:
img_bytes = self.file_backend.get(filename)
if self.file_client_args is not None:
file_client = fileio.FileClient.infer_client(
self.file_client_args, filename)
img_bytes = file_client.get(filename)
else:
img_bytes = fileio.get(
filename, backend_args=self.backend_args)
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
except Exception as e:
Expand Down Expand Up @@ -240,6 +237,8 @@ def __init__(
self.with_keypoints = with_keypoints
self.imdecode_backend = imdecode_backend

self.file_client_args: Optional[dict] = None
self.backend_args: Optional[dict] = None
if file_client_args is not None:
warnings.warn(
'"file_client_args" will be deprecated in future. '
Expand All @@ -248,21 +247,10 @@ def __init__(
raise ValueError(
'"file_client_args" and "backend_args" cannot be set '
'at the same time.')
else:
file_client_args = dict(backend='disk')
if backend_args is None:
backend_args = dict(backend='disk')

self.file_client_args = file_client_args.copy()
self.file_client = fileio.FileClient(**self.file_client_args)
self.backend_args = backend_args.copy()

self.file_backend: Union[fileio.FileClient, fileio.BaseStorageBackend]
if self.file_client_args is None:
self.file_backend = fileio.get_file_backend(
backend_args=self.backend_args)
else:
self.file_backend = self.file_client

self.file_client_args = file_client_args.copy()
if backend_args is not None:
self.backend_args = backend_args.copy()

def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations.
Expand Down Expand Up @@ -306,8 +294,14 @@ def _load_seg_map(self, results: dict) -> None:
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if self.file_client_args is not None:
file_client = fileio.FileClient.infer_client(
self.file_client_args, results['seg_map_path'])
img_bytes = file_client.get(results['seg_map_path'])
else:
img_bytes = fileio.get(
results['seg_map_path'], backend_args=self.backend_args)

img_bytes = self.file_backend.get(results['seg_map_path'])
results['gt_seg_map'] = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_transforms/test_transforms_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_load_img(self):
assert results['ori_shape'] == (300, 400)
assert repr(transform) == transform.__class__.__name__ + \
"(ignore_empty=False, to_float32=False, color_type='color', " + \
"imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
"imdecode_backend='cv2', backend_args=None)"

# to_float32
transform = LoadImageFromFile(to_float32=True)
Expand Down Expand Up @@ -148,4 +148,4 @@ def test_repr(self):
'LoadAnnotations(with_bbox=True, '
'with_label=False, with_seg=False, '
"with_keypoints=False, imdecode_backend='cv2', "
"file_client_args={'backend': 'disk'})")
'backend_args=None)')

0 comments on commit a4cc3b0

Please sign in to comment.