Skip to content

Commit c6777e7

Browse files
Jeff Yangvfdev-5
andcommitted
docs: rm type hints in ignite.contrib.handlers (2) (#1676)
* docs: only show type hints in docstring * Apply suggestions from code review Co-authored-by: vfdev <vfdev.5@gmail.com> * fix(docs): correctly link to missing links * docs: rm type hints in docstring of ignite.contrib.handlers * docs: rm type hints in docstring of ignite.contrib.handlers * Apply suggestions from code review Co-authored-by: vfdev <vfdev.5@gmail.com> * review: apply suggestions * fix: no return in __init__ * fix: return None in no argument __init__ * remove kwargs in ConcatScheduler.simulate_values * fix: remove unused mypy ignore Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent e751042 commit c6777e7

File tree

13 files changed

+290
-277
lines changed

13 files changed

+290
-277
lines changed

ignite/contrib/handlers/base_logger.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class BaseWeightsScalarHandler(BaseHandler):
107107
Helper handler to log model's weights as scalars.
108108
"""
109109

110-
def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None:
110+
def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None):
111111
if not isinstance(model, torch.nn.Module):
112112
raise TypeError(f"Argument model should be of type torch.nn.Module, but given {type(model)}")
113113

@@ -152,14 +152,14 @@ def attach(
152152
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.
153153
154154
Args:
155-
engine (Engine): engine object.
156-
log_handler (callable): a logging handler to execute
155+
engine: engine object.
156+
log_handler: a logging handler to execute
157157
event_name: event to attach the logging handler to. Valid events are from
158-
:class:`~ignite.engine.events.Events` or class:`~ignite.engine.events.EventsList` or any `event_name`
158+
:class:`~ignite.engine.events.Events` or :class:`~ignite.engine.events.EventsList` or any `event_name`
159159
added by :meth:`~ignite.engine.engine.Engine.register_events`.
160160
161161
Returns:
162-
:class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.
162+
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
163163
"""
164164
if isinstance(event_name, EventsList):
165165
for name in event_name:
@@ -180,15 +180,15 @@ def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **k
180180
"""Shortcut method to attach `OutputHandler` to the logger.
181181
182182
Args:
183-
engine (Engine): engine object.
183+
engine: engine object.
184184
event_name: event to attach the logging handler to. Valid events are from
185185
:class:`~ignite.engine.events.Events` or any `event_name` added by
186186
:meth:`~ignite.engine.engine.Engine.register_events`.
187-
*args: args to initialize `OutputHandler`
188-
**kwargs: kwargs to initialize `OutputHandler`
187+
args: args to initialize `OutputHandler`
188+
kwargs: kwargs to initialize `OutputHandler`
189189
190190
Returns:
191-
:class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.
191+
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
192192
"""
193193
return self.attach(engine, self._create_output_handler(*args, **kwargs), event_name=event_name)
194194

@@ -198,15 +198,15 @@ def attach_opt_params_handler(
198198
"""Shortcut method to attach `OptimizerParamsHandler` to the logger.
199199
200200
Args:
201-
engine (Engine): engine object.
201+
engine: engine object.
202202
event_name: event to attach the logging handler to. Valid events are from
203203
:class:`~ignite.engine.events.Events` or any `event_name` added by
204204
:meth:`~ignite.engine.engine.Engine.register_events`.
205-
*args: args to initialize `OptimizerParamsHandler`
206-
**kwargs: kwargs to initialize `OptimizerParamsHandler`
205+
args: args to initialize `OptimizerParamsHandler`
206+
kwargs: kwargs to initialize `OptimizerParamsHandler`
207207
208208
Returns:
209-
:class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.
209+
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
210210
211211
.. versionchanged:: 0.4.3
212212
Added missing return statement.

ignite/contrib/handlers/clearml_logger.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ class ClearMLLogger(BaseLogger):
4949
clearml-init
5050
5151
Args:
52-
project_name (str): The name of the project in which the experiment will be created. If the project
52+
project_name: The name of the project in which the experiment will be created. If the project
5353
does not exist, it is created. If ``project_name`` is ``None``, the repository name is used. (Optional)
54-
task_name (str): The name of Task (experiment). If ``task_name`` is ``None``, the Python experiment
54+
task_name: The name of Task (experiment). If ``task_name`` is ``None``, the Python experiment
5555
script's file name is used. (Optional)
56-
task_type (str): Optional. The task type. Valid values are:
56+
task_type: Optional. The task type. Valid values are:
5757
- ``TaskTypes.training`` (Default)
5858
- ``TaskTypes.train``
5959
- ``TaskTypes.testing``
@@ -119,7 +119,7 @@ class ClearMLLogger(BaseLogger):
119119
120120
"""
121121

122-
def __init__(self, *_: Any, **kwargs: Any) -> None:
122+
def __init__(self, *_: Any, **kwargs: Any):
123123
try:
124124
from clearml import Task
125125
from clearml.binding.frameworks.tensorflow_bind import WeightsGradientHistHelper
@@ -270,14 +270,14 @@ def global_step_transform(*args, **kwargs):
270270
)
271271
272272
Args:
273-
tag (str): common title for all produced plots. For example, "training"
274-
metric_names (list of str, optional): list of metric names to plot or a string "all" to plot all available
273+
tag: common title for all produced plots. For example, "training"
274+
metric_names: list of metric names to plot or a string "all" to plot all available
275275
metrics.
276-
output_transform (callable, optional): output transform function to prepare `engine.state.output` as a number.
276+
output_transform: output transform function to prepare `engine.state.output` as a number.
277277
For example, `output_transform = lambda output: output`
278278
This function can also return a dictionary, e.g `{"loss": loss1, "another_loss": loss2}` to label the plot
279279
with corresponding keys.
280-
global_step_transform (callable, optional): global step transform function to output a desired global step.
280+
global_step_transform: global step transform function to output a desired global step.
281281
Input of the function is `(engine, event_name)`. Output of function should be an integer.
282282
Default is None, global_step based on attached engine. If provided,
283283
uses function output as global_step. To setup global step from another engine, please use
@@ -299,7 +299,7 @@ def __init__(
299299
metric_names: Optional[List[str]] = None,
300300
output_transform: Optional[Callable] = None,
301301
global_step_transform: Optional[Callable] = None,
302-
) -> None:
302+
):
303303
super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform)
304304

305305
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
@@ -359,13 +359,13 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
359359
)
360360
361361
Args:
362-
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
362+
optimizer: torch optimizer or any object with attribute ``param_groups``
363363
as a sequence.
364-
param_name (str): parameter name
365-
tag (str, optional): common title for all produced plots. For example, "generator"
364+
param_name: parameter name
365+
tag: common title for all produced plots. For example, "generator"
366366
"""
367367

368-
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None:
368+
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
369369
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
370370

371371
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
@@ -410,13 +410,13 @@ class WeightsScalarHandler(BaseWeightsScalarHandler):
410410
)
411411
412412
Args:
413-
model (torch.nn.Module): model to log weights
414-
reduction (callable): function to reduce parameters into scalar
415-
tag (str, optional): common title for all produced plots. For example, "generator"
413+
model: model to log weights
414+
reduction: function to reduce parameters into scalar
415+
tag: common title for all produced plots. For example, "generator"
416416
417417
"""
418418

419-
def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None:
419+
def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None):
420420
super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag)
421421

422422
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
@@ -463,12 +463,12 @@ class WeightsHistHandler(BaseWeightsHistHandler):
463463
)
464464
465465
Args:
466-
model (torch.nn.Module): model to log weights
467-
tag (str, optional): common title for all produced plots. For example, 'generator'
466+
model: model to log weights
467+
tag: common title for all produced plots. For example, 'generator'
468468
469469
"""
470470

471-
def __init__(self, model: Module, tag: Optional[str] = None) -> None:
471+
def __init__(self, model: Module, tag: Optional[str] = None):
472472
super(WeightsHistHandler, self).__init__(model, tag=tag)
473473

474474
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
@@ -517,13 +517,13 @@ class GradsScalarHandler(BaseWeightsScalarHandler):
517517
)
518518
519519
Args:
520-
model (torch.nn.Module): model to log weights
521-
reduction (callable): function to reduce parameters into scalar
522-
tag (str, optional): common title for all produced plots. For example, "generator"
520+
model: model to log weights
521+
reduction: function to reduce parameters into scalar
522+
tag: common title for all produced plots. For example, "generator"
523523
524524
"""
525525

526-
def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None:
526+
def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None):
527527
super(GradsScalarHandler, self).__init__(model, reduction, tag=tag)
528528

529529
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
@@ -569,12 +569,12 @@ class GradsHistHandler(BaseWeightsHistHandler):
569569
)
570570
571571
Args:
572-
model (torch.nn.Module): model to log weights
573-
tag (str, optional): common title for all produced plots. For example, 'generator'
572+
model: model to log weights
573+
tag: common title for all produced plots. For example, 'generator'
574574
575575
"""
576576

577-
def __init__(self, model: Module, tag: Optional[str] = None) -> None:
577+
def __init__(self, model: Module, tag: Optional[str] = None):
578578
super(GradsHistHandler, self).__init__(model, tag=tag)
579579

580580
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
@@ -602,12 +602,12 @@ class ClearMLSaver(DiskSaver):
602602
Handler that saves input checkpoint as ClearML artifacts
603603
604604
Args:
605-
logger (ClearMLLogger, optional): An instance of :class:`~ignite.contrib.handlers.clearml_logger.ClearMLLogger`,
605+
logger: An instance of :class:`~ignite.contrib.handlers.clearml_logger.ClearMLLogger`,
606606
ensuring a valid ClearML ``Task`` has been initialized. If not provided, and a ClearML Task
607607
has not been manually initialized, a runtime error will be raised.
608-
output_uri (str, optional): The default location for output models and other artifacts uploaded by ClearML. For
608+
output_uri: The default location for output models and other artifacts uploaded by ClearML. For
609609
more information, see ``clearml.Task.init``.
610-
dirname (str, optional): Directory path where the checkpoint will be saved. If not provided, a temporary
610+
dirname: Directory path where the checkpoint will be saved. If not provided, a temporary
611611
directory will be created.
612612
613613
Examples:
@@ -645,7 +645,7 @@ def __init__(
645645
dirname: Optional[str] = None,
646646
*args: Any,
647647
**kwargs: Any,
648-
) -> None:
648+
):
649649

650650
self._setup_check_clearml(logger, output_uri)
651651

@@ -793,7 +793,7 @@ def get_local_copy(self, filename: str) -> Optional[str]:
793793
In distributed configuration this method should be called on rank 0 process.
794794
795795
Args:
796-
filename (str): artifact name.
796+
filename: artifact name.
797797
798798
Returns:
799799
a local path to a downloaded copy of the artifact

ignite/contrib/handlers/lr_finder.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,11 @@ def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True) ->
196196
pip install matplotlib
197197
198198
Args:
199-
skip_start (int, optional): number of batches to trim from the start.
199+
skip_start: number of batches to trim from the start.
200200
Default: 10.
201-
skip_end (int, optional): number of batches to trim from the start.
201+
skip_end: number of batches to trim from the start.
202202
Default: 5.
203-
log_lr (bool, optional): True to plot the learning rate in a logarithmic
203+
log_lr: True to plot the learning rate in a logarithmic
204204
scale; otherwise, plotted in a linear scale. Default: True.
205205
"""
206206
try:
@@ -273,27 +273,27 @@ def attach(
273273
trainer_with_lr_finder.run(dataloader)`
274274
275275
Args:
276-
trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers
276+
trainer: lr_finder is attached to this trainer. Please, keep in mind that all attached handlers
277277
will be executed.
278-
to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running
278+
to_save: dictionary with optimizer and other objects that needs to be restored after running
279279
the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should
280280
implement `state_dict` and `load_state_dict` methods.
281-
output_transform (callable, optional): function that transforms the trainer's `state.output` after each
281+
output_transform: function that transforms the trainer's `state.output` after each
282282
iteration. It must return the loss of that iteration.
283-
num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will
283+
num_iter: number of iterations for lr schedule between base lr and end_lr. Default, it will
284284
run for `trainer.state.epoch_length * trainer.state.max_epochs`.
285-
end_lr (float, optional): upper bound for lr search. Default, 10.0.
286-
step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial
285+
end_lr: upper bound for lr search. Default, 10.0.
286+
step_mode: "exp" or "linear", which way should the lr be increased from optimizer's initial
287287
lr to `end_lr`. Default, "exp".
288-
smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05
289-
diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`.
288+
smooth_f: loss smoothing factor in range `[0, 1)`. Default, 0.05
289+
diverge_th: Used for stopping the search when `current loss > diverge_th * best_loss`.
290290
Default, 5.0.
291291
292+
Returns:
293+
trainer_with_lr_finder (trainer used for finding the lr)
294+
292295
Note:
293296
lr_finder cannot be attached to more than one trainer at a time.
294-
295-
Returns:
296-
trainer_with_lr_finder: trainer used for finding the lr
297297
"""
298298
if not isinstance(to_save, Mapping):
299299
raise TypeError(f"Argument to_save should be a mapping, but given {type(to_save)}")
@@ -363,16 +363,16 @@ class _ExponentialLR(_LRScheduler):
363363
iterations.
364364
365365
Args:
366-
optimizer (torch.optim.Optimizer): wrapped optimizer.
367-
end_lr (float, optional): the initial learning rate which is the lower
366+
optimizer: wrapped optimizer.
367+
end_lr: the initial learning rate which is the lower
368368
boundary of the test. Default: 10.
369-
num_iter (int, optional): the number of iterations over which the test
369+
num_iter: the number of iterations over which the test
370370
occurs. Default: 100.
371-
last_epoch (int): the index of last epoch. Default: -1.
371+
last_epoch: the index of last epoch. Default: -1.
372372
373373
"""
374374

375-
def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None:
375+
def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
376376
self.end_lr = end_lr
377377
self.num_iter = num_iter
378378
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

ignite/contrib/handlers/mlflow_logger.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class MLflowLogger(BaseLogger):
2424
pip install mlflow
2525
2626
Args:
27-
tracking_uri (str): MLflow tracking uri. See MLflow docs for more details
27+
tracking_uri: MLflow tracking uri. See MLflow docs for more details
2828
2929
Examples:
3030
@@ -86,7 +86,7 @@ class MLflowLogger(BaseLogger):
8686
)
8787
"""
8888

89-
def __init__(self, tracking_uri: Optional[str] = None) -> None:
89+
def __init__(self, tracking_uri: Optional[str] = None):
9090
try:
9191
import mlflow
9292
except ImportError:
@@ -182,14 +182,14 @@ def global_step_transform(*args, **kwargs):
182182
)
183183
184184
Args:
185-
tag (str): common title for all produced plots. For example, 'training'
186-
metric_names (list of str, optional): list of metric names to plot or a string "all" to plot all available
185+
tag: common title for all produced plots. For example, 'training'
186+
metric_names: list of metric names to plot or a string "all" to plot all available
187187
metrics.
188-
output_transform (callable, optional): output transform function to prepare `engine.state.output` as a number.
188+
output_transform: output transform function to prepare `engine.state.output` as a number.
189189
For example, `output_transform = lambda output: output`
190190
This function can also return a dictionary, e.g `{'loss': loss1, 'another_loss': loss2}` to label the plot
191191
with corresponding keys.
192-
global_step_transform (callable, optional): global step transform function to output a desired global step.
192+
global_step_transform: global step transform function to output a desired global step.
193193
Input of the function is `(engine, event_name)`. Output of function should be an integer.
194194
Default is None, global_step based on attached engine. If provided,
195195
uses function output as global_step. To setup global step from another engine, please use
@@ -284,13 +284,13 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
284284
)
285285
286286
Args:
287-
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
287+
optimizer: torch optimizer or any object with attribute ``param_groups``
288288
as a sequence.
289-
param_name (str): parameter name
290-
tag (str, optional): common title for all produced plots. For example, 'generator'
289+
param_name: parameter name
290+
tag: common title for all produced plots. For example, 'generator'
291291
"""
292292

293-
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None:
293+
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
294294
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
295295

296296
def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None:

0 commit comments

Comments
 (0)