22import functools
33import logging
44import random
5- import sys
65import warnings
76from typing import Any , Callable , Dict , Optional , TextIO , Tuple , Type , TypeVar , Union , cast
87
1211
1312
1413def convert_tensor (
15- input_ : Union [torch .Tensor , collections .Sequence , collections .Mapping , str , bytes ],
14+ x : Union [torch .Tensor , collections .Sequence , collections .Mapping , str , bytes ],
1615 device : Optional [Union [str , torch .device ]] = None ,
1716 non_blocking : bool = False ,
1817) -> Union [torch .Tensor , collections .Sequence , collections .Mapping , str , bytes ]:
19- """Move tensors to relevant device."""
18+ """Move tensors to relevant device.
19+
20+ Args:
21+ x: input tensor or mapping, or sequence of tensors.
22+ device: device type to move ``x``.
23+ non_blocking: convert a CPU Tensor with pinned memory to a CUDA Tensor
24+ asynchronously with respect to the host if possible
25+ """
2026
2127 def _func (tensor : torch .Tensor ) -> torch .Tensor :
2228 return tensor .to (device = device , non_blocking = non_blocking ) if device is not None else tensor
2329
24- return apply_to_tensor (input_ , _func )
30+ return apply_to_tensor (x , _func )
2531
2632
2733def apply_to_tensor (
28- input_ : Union [torch .Tensor , collections .Sequence , collections .Mapping , str , bytes ], func : Callable
34+ x : Union [torch .Tensor , collections .Sequence , collections .Mapping , str , bytes ], func : Callable
2935) -> Union [torch .Tensor , collections .Sequence , collections .Mapping , str , bytes ]:
3036 """Apply a function on a tensor or mapping, or sequence of tensors.
37+
38+ Args:
39+ x: input tensor or mapping, or sequence of tensors.
40+ func: the function to apply on ``x``.
3141 """
32- return apply_to_type (input_ , torch .Tensor , func )
42+ return apply_to_type (x , torch .Tensor , func )
3343
3444
3545def apply_to_type (
36- input_ : Union [Any , collections .Sequence , collections .Mapping , str , bytes ],
46+ x : Union [Any , collections .Sequence , collections .Mapping , str , bytes ],
3747 input_type : Union [Type , Tuple [Type [Any ], Any ]],
3848 func : Callable ,
3949) -> Union [Any , collections .Sequence , collections .Mapping , str , bytes ]:
40- """Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`.
50+ """Apply a function on an object of `input_type` or mapping, or sequence of objects of `input_type`.
51+
52+ Args:
53+ x: object or mapping or sequence.
54+ input_type: data type of ``x``.
55+ func: the function to apply on ``x``.
4156 """
42- if isinstance (input_ , input_type ):
43- return func (input_ )
44- if isinstance (input_ , (str , bytes )):
45- return input_
46- if isinstance (input_ , collections .Mapping ):
47- return cast (Callable , type (input_ ))(
48- {k : apply_to_type (sample , input_type , func ) for k , sample in input_ .items ()}
49- )
50- if isinstance (input_ , tuple ) and hasattr (input_ , "_fields" ): # namedtuple
51- return cast (Callable , type (input_ ))(* (apply_to_type (sample , input_type , func ) for sample in input_ ))
52- if isinstance (input_ , collections .Sequence ):
53- return cast (Callable , type (input_ ))([apply_to_type (sample , input_type , func ) for sample in input_ ])
54- raise TypeError ((f"input must contain { input_type } , dicts or lists; found { type (input_ )} " ))
57+ if isinstance (x , input_type ):
58+ return func (x )
59+ if isinstance (x , (str , bytes )):
60+ return x
61+ if isinstance (x , collections .Mapping ):
62+ return cast (Callable , type (x ))({k : apply_to_type (sample , input_type , func ) for k , sample in x .items ()})
63+ if isinstance (x , tuple ) and hasattr (x , "_fields" ): # namedtuple
64+ return cast (Callable , type (x ))(* (apply_to_type (sample , input_type , func ) for sample in x ))
65+ if isinstance (x , collections .Sequence ):
66+ return cast (Callable , type (x ))([apply_to_type (sample , input_type , func ) for sample in x ])
67+ raise TypeError ((f"x must contain { input_type } , dicts or lists; found { type (x )} " ))
5568
5669
5770def to_onehot (indices : torch .Tensor , num_classes : int ) -> torch .Tensor :
5871 """Convert a tensor of indices of any shape `(N, ...)` to a
5972 tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the
6073 input's device`.
6174
75+ Args:
76+ indices: input tensor to convert.
77+ num_classes: number of classes for one-hot tensor.
78+
6279 .. versionchanged:: 0.4.3
6380 This functions is now torchscriptable.
6481 """
@@ -78,12 +95,12 @@ def setup_logger(
7895 """Setups logger: name, level, format etc.
7996
8097 Args:
81- name (str, optional) : new name for the logger. If None, the standard logger is used.
82- level (int) : logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG.
83- stream (TextIO, optional) : logging stream. If None, the standard stream is used (sys.stderr).
84- format (str) : logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`.
85- filepath (str, optional) : Optional logging file path. If not None, logs are written to the file.
86- distributed_rank (int, optional) : Optional, rank in distributed configuration to avoid logger setup for workers.
98+ name: new name for the logger. If None, the standard logger is used.
99+ level: logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG.
100+ stream: logging stream. If None, the standard stream is used (sys.stderr).
101+ format: logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`.
102+ filepath: Optional logging file path. If not None, logs are written to the file.
103+ distributed_rank: Optional, rank in distributed configuration to avoid logger setup for workers.
87104 If None, distributed_rank is initialized to the rank of process.
88105
89106 Returns:
@@ -156,7 +173,7 @@ def manual_seed(seed: int) -> None:
156173 """Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported).
157174
158175 Args:
159- seed (int) : Random state seed
176+ seed: Random state seed
160177
161178 .. versionchanged:: 0.4.3
162179 Added ``torch.cuda.manual_seed_all(seed)``.
0 commit comments