Skip to content

Commit e751042

Browse files
Jeff Yangvfdev-5
authored andcommitted
docs: rm type hints in ignite.utils (1) (#1684)
* docs: rm type hints in ignite.utils * fix: change input_ to input in ignite.utils * fix: change input to x
1 parent 1ff17ae commit e751042

File tree

1 file changed

+45
-28
lines changed

1 file changed

+45
-28
lines changed

ignite/utils.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import functools
33
import logging
44
import random
5-
import sys
65
import warnings
76
from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Type, TypeVar, Union, cast
87

@@ -12,53 +11,71 @@
1211

1312

1413
def convert_tensor(
15-
input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes],
14+
x: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes],
1615
device: Optional[Union[str, torch.device]] = None,
1716
non_blocking: bool = False,
1817
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
19-
"""Move tensors to relevant device."""
18+
"""Move tensors to relevant device.
19+
20+
Args:
21+
x: input tensor or mapping, or sequence of tensors.
22+
device: device type to move ``x``.
23+
non_blocking: convert a CPU Tensor with pinned memory to a CUDA Tensor
24+
asynchronously with respect to the host if possible
25+
"""
2026

2127
def _func(tensor: torch.Tensor) -> torch.Tensor:
2228
return tensor.to(device=device, non_blocking=non_blocking) if device is not None else tensor
2329

24-
return apply_to_tensor(input_, _func)
30+
return apply_to_tensor(x, _func)
2531

2632

2733
def apply_to_tensor(
28-
input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], func: Callable
34+
x: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], func: Callable
2935
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
3036
"""Apply a function on a tensor or mapping, or sequence of tensors.
37+
38+
Args:
39+
x: input tensor or mapping, or sequence of tensors.
40+
func: the function to apply on ``x``.
3141
"""
32-
return apply_to_type(input_, torch.Tensor, func)
42+
return apply_to_type(x, torch.Tensor, func)
3343

3444

3545
def apply_to_type(
36-
input_: Union[Any, collections.Sequence, collections.Mapping, str, bytes],
46+
x: Union[Any, collections.Sequence, collections.Mapping, str, bytes],
3747
input_type: Union[Type, Tuple[Type[Any], Any]],
3848
func: Callable,
3949
) -> Union[Any, collections.Sequence, collections.Mapping, str, bytes]:
40-
"""Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`.
50+
"""Apply a function on an object of `input_type` or mapping, or sequence of objects of `input_type`.
51+
52+
Args:
53+
x: object or mapping or sequence.
54+
input_type: data type of ``x``.
55+
func: the function to apply on ``x``.
4156
"""
42-
if isinstance(input_, input_type):
43-
return func(input_)
44-
if isinstance(input_, (str, bytes)):
45-
return input_
46-
if isinstance(input_, collections.Mapping):
47-
return cast(Callable, type(input_))(
48-
{k: apply_to_type(sample, input_type, func) for k, sample in input_.items()}
49-
)
50-
if isinstance(input_, tuple) and hasattr(input_, "_fields"): # namedtuple
51-
return cast(Callable, type(input_))(*(apply_to_type(sample, input_type, func) for sample in input_))
52-
if isinstance(input_, collections.Sequence):
53-
return cast(Callable, type(input_))([apply_to_type(sample, input_type, func) for sample in input_])
54-
raise TypeError((f"input must contain {input_type}, dicts or lists; found {type(input_)}"))
57+
if isinstance(x, input_type):
58+
return func(x)
59+
if isinstance(x, (str, bytes)):
60+
return x
61+
if isinstance(x, collections.Mapping):
62+
return cast(Callable, type(x))({k: apply_to_type(sample, input_type, func) for k, sample in x.items()})
63+
if isinstance(x, tuple) and hasattr(x, "_fields"): # namedtuple
64+
return cast(Callable, type(x))(*(apply_to_type(sample, input_type, func) for sample in x))
65+
if isinstance(x, collections.Sequence):
66+
return cast(Callable, type(x))([apply_to_type(sample, input_type, func) for sample in x])
67+
raise TypeError((f"x must contain {input_type}, dicts or lists; found {type(x)}"))
5568

5669

5770
def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor:
5871
"""Convert a tensor of indices of any shape `(N, ...)` to a
5972
tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the
6073
input's device`.
6174
75+
Args:
76+
indices: input tensor to convert.
77+
num_classes: number of classes for one-hot tensor.
78+
6279
.. versionchanged:: 0.4.3
6380
This functions is now torchscriptable.
6481
"""
@@ -78,12 +95,12 @@ def setup_logger(
7895
"""Setups logger: name, level, format etc.
7996
8097
Args:
81-
name (str, optional): new name for the logger. If None, the standard logger is used.
82-
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG.
83-
stream (TextIO, optional): logging stream. If None, the standard stream is used (sys.stderr).
84-
format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`.
85-
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
86-
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
98+
name: new name for the logger. If None, the standard logger is used.
99+
level: logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG.
100+
stream: logging stream. If None, the standard stream is used (sys.stderr).
101+
format: logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`.
102+
filepath: Optional logging file path. If not None, logs are written to the file.
103+
distributed_rank: Optional, rank in distributed configuration to avoid logger setup for workers.
87104
If None, distributed_rank is initialized to the rank of process.
88105
89106
Returns:
@@ -156,7 +173,7 @@ def manual_seed(seed: int) -> None:
156173
"""Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported).
157174
158175
Args:
159-
seed (int): Random state seed
176+
seed: Random state seed
160177
161178
.. versionchanged:: 0.4.3
162179
Added ``torch.cuda.manual_seed_all(seed)``.

0 commit comments

Comments
 (0)