Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for mmcv/fileio #1997

Merged
merged 19 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmcv/fileio/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def dump_to_fileobj(self, obj, file, **kwargs):
def dump_to_str(self, obj, **kwargs):
pass

def load_from_path(self, filepath, mode='r', **kwargs):
def load_from_path(self, filepath: str, mode: str = 'r', **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)

def dump_to_path(self, obj, filepath, mode='w', **kwargs):
def dump_to_path(self, obj, filepath: str, mode: str = 'w', **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)
32 changes: 22 additions & 10 deletions mmcv/fileio/io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO, StringIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

from ..utils import is_list_of, is_str
from ..utils import is_list_of
from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
triple-Mu marked this conversation as resolved.
Show resolved Hide resolved

FileLikeObject = Union[StringIO, BytesIO]

file_handlers = {
'json': JsonHandler(),
'yaml': YamlHandler(),
Expand All @@ -15,7 +18,10 @@
}


def load(file, file_format=None, file_client_args=None, **kwargs):
def load(file: Union[str, Path, FileLikeObject],
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Load data from json/yaml/pickle files.

This method provides a unified api for loading data from serialized files.
Expand Down Expand Up @@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
if file_format is None and isinstance(file, str):
file_format = file.split('.')[-1]
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')

handler = file_handlers[file_format]
if is_str(file):
f: FileLikeObject
if isinstance(file, str):
triple-Mu marked this conversation as resolved.
Show resolved Hide resolved
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO(file_client.get_text(file)) as f:
Expand All @@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
return obj


def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
def dump(obj: Any,
file: Optional[Union[str, Path, FileLikeObject]] = None,
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Dump data to json/yaml/pickle strings or files.

This method provides a unified api for dumping data as strings or to files,
Expand Down Expand Up @@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
if isinstance(file, str):
file_format = file.split('.')[-1]
elif file is None:
raise ValueError(
'file_format must be specified since file is None')
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')

f: FileLikeObject
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
elif isinstance(file, str):
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO() as f:
Expand All @@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
raise TypeError('"file" must be a filename str or a file-object')


def _register_handler(handler, file_formats):
def _register_handler(handler: BaseFileHandler,
file_formats: Union[str, List[str]]) -> None:
"""Register a handler for some file extensions.

Args:
Expand All @@ -142,7 +154,7 @@ def _register_handler(handler, file_formats):
file_handlers[ext] = handler


def register_handler(file_formats, **kwargs):
def register_handler(file_formats: Union[str, list], **kwargs) -> Callable:

def wrap(cls):
_register_handler(cls(**kwargs), file_formats)
Expand Down
22 changes: 12 additions & 10 deletions mmcv/fileio/parse.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.

from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Union

from .file_client import FileClient


def list_from_file(filename,
prefix='',
offset=0,
max_num=0,
encoding='utf-8',
file_client_args=None):
def list_from_file(filename: Union[str, Path],
prefix: str = '',
offset: int = 0,
max_num: int = 0,
encoding: str = 'utf-8',
file_client_args: Optional[Dict] = None) -> List:
"""Load a text file and parse the content as a list of strings.

Note:
Expand Down Expand Up @@ -52,10 +54,10 @@ def list_from_file(filename,
return item_list


def dict_from_file(filename,
key_type=str,
encoding='utf-8',
file_client_args=None):
def dict_from_file(filename: Union[str, Path],
key_type: type = str,
encoding: str = 'utf-8',
file_client_args: Optional[Dict] = None) -> Dict:
"""Load a text file and parse the content as a dict.

Each line of the text file will be two or more columns split by
Expand Down