11import numbers
22import warnings
33from functools import partial
4- from typing import Any , Callable , Dict , Iterable , Mapping , Optional , Sequence , Union , cast
4+ from typing import Any , Callable , Dict , Iterable , Mapping , Optional , Sequence , Tuple , Union , cast
55
66import torch
77import torch .nn as nn
@@ -184,9 +184,7 @@ def _setup_common_training_handlers(
184184 checkpoint_handler = Checkpoint (
185185 to_save , cast (Union [Callable , BaseSaveHandler ], save_handler ), filename_prefix = "training" , ** kwargs
186186 )
187- trainer .add_event_handler (
188- Events .ITERATION_COMPLETED (every = save_every_iters ), checkpoint_handler
189- ) # type: ignore[arg-type]
187+ trainer .add_event_handler (Events .ITERATION_COMPLETED (every = save_every_iters ), checkpoint_handler )
190188
191189 if with_gpu_stats :
192190 GpuInfo ().attach (
@@ -195,7 +193,7 @@ def _setup_common_training_handlers(
195193
196194 if output_names is not None :
197195
198- def output_transform (x , index , name ) :
196+ def output_transform (x : Any , index : int , name : str ) -> Any :
199197 if isinstance (x , Mapping ):
200198 return x [name ]
201199 elif isinstance (x , Sequence ):
@@ -216,9 +214,7 @@ def output_transform(x, index, name):
216214 if with_pbars :
217215 if with_pbar_on_iters :
218216 ProgressBar (persist = False ).attach (
219- trainer ,
220- metric_names = "all" ,
221- event_name = Events .ITERATION_COMPLETED (every = log_every_iters ), # type: ignore[arg-type]
217+ trainer , metric_names = "all" , event_name = Events .ITERATION_COMPLETED (every = log_every_iters )
222218 )
223219
224220 ProgressBar (persist = True , bar_format = "" ).attach (
@@ -266,18 +262,18 @@ def _setup_common_distrib_training_handlers(
266262 raise TypeError ("Train sampler should be torch DistributedSampler and have `set_epoch` method" )
267263
268264 @trainer .on (Events .EPOCH_STARTED )
269- def distrib_set_epoch (engine ) :
270- train_sampler .set_epoch (engine .state .epoch - 1 )
265+ def distrib_set_epoch (engine : Engine ) -> None :
266+ cast ( DistributedSampler , train_sampler ) .set_epoch (engine .state .epoch - 1 )
271267
272268
273- def empty_cuda_cache (_ ) -> None :
269+ def empty_cuda_cache (_ : Engine ) -> None :
274270 torch .cuda .empty_cache ()
275271 import gc
276272
277273 gc .collect ()
278274
279275
280- def setup_any_logging (logger , logger_module , trainer , optimizers , evaluators , log_every_iters ) -> None :
276+ def setup_any_logging (logger , logger_module , trainer , optimizers , evaluators , log_every_iters ) -> None : # type: ignore
281277 raise DeprecationWarning (
282278 "ignite.contrib.engines.common.setup_any_logging is deprecated since 0.4.0. and will be remove in 0.6.0. "
283279 "Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc."
@@ -549,7 +545,7 @@ def setup_trains_logging(
549545
550546
551547def get_default_score_fn (metric_name : str ) -> Any :
552- def wrapper (engine : Engine ):
548+ def wrapper (engine : Engine ) -> Any :
553549 score = engine .state .metrics [metric_name ]
554550 return score
555551
0 commit comments