1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import Any , Callable , Iterable , Optional , Union
14+ from typing import Any , Callable , Dict , Iterable , List , Optional , Sequence , TYPE_CHECKING , Union
1515
1616import torch
1717from torch .optim import Optimizer
2424from pytorch_lightning .utilities .distributed import all_gather_ddp_if_available
2525from pytorch_lightning .utilities .enums import AMPType , LightningEnum
2626
27+ if TYPE_CHECKING :
28+ from torch .cuda .amp import GradScaler
29+
30+ from pytorch_lightning .trainer .trainer import Trainer
31+
32+
33+ _STEP_OUTPUT_TYPE = Union [torch .Tensor , Dict [str , torch .Tensor ], None ]
34+
2735
2836class Accelerator (object ):
2937 """
@@ -54,11 +62,11 @@ def __init__(
5462 self .precision_plugin = precision_plugin
5563 self .training_type_plugin = training_type_plugin
5664
57- self .optimizers = None
58- self .lr_schedulers = None
59- self .optimizer_frequencies = None
65+ self .optimizers : Sequence = []
66+ self .lr_schedulers : Sequence = []
67+ self .optimizer_frequencies : Sequence = []
6068
61- def setup (self , trainer , model : LightningModule ) -> None :
69+ def setup (self , trainer : 'Trainer' , model : LightningModule ) -> None :
6270 """
6371 Connects the plugins to the training process, creates optimizers
6472
@@ -70,13 +78,13 @@ def setup(self, trainer, model: LightningModule) -> None:
7078 self .setup_optimizers (trainer )
7179 self .connect_precision_plugin (self .precision_plugin )
7280
73- def start_training (self , trainer ) :
81+ def start_training (self , trainer : 'Trainer' ) -> None :
7482 self .training_type_plugin .start_training (trainer )
7583
76- def start_testing (self , trainer ) :
84+ def start_testing (self , trainer : 'Trainer' ) -> None :
7785 self .training_type_plugin .start_testing (trainer )
7886
79- def start_predicting (self , trainer ) :
87+ def start_predicting (self , trainer : 'Trainer' ) -> None :
8088 self .training_type_plugin .start_predicting (trainer )
8189
8290 def pre_dispatch (self ) -> None :
@@ -113,7 +121,7 @@ def lightning_module(self) -> LightningModule:
113121 def root_device (self ) -> torch .device :
114122 return self .training_type_plugin .root_device
115123
116- def teardown (self ):
124+ def teardown (self ) -> None :
117125 """This method is called to teardown the training process.
118126 It is the right place to release memory and free other ressources.
119127 """
@@ -134,11 +142,14 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) ->
134142
135143 return move_data_to_device (batch , device )
136144
137- def on_train_start (self ):
145+ def on_train_start (self ) -> None :
138146 """Hook to do something upon the training start"""
139147 pass
140148
141- def training_step (self , args ):
149+ def training_step (
150+ self ,
151+ args : List [Union [Any , int ]],
152+ ) -> _STEP_OUTPUT_TYPE :
142153 """The actual training step.
143154
144155 Args:
@@ -156,10 +167,10 @@ def training_step(self, args):
156167 with self .precision_plugin .train_step_context (), self .training_type_plugin .train_step_context ():
157168 return self .training_type_plugin .training_step (* args )
158169
159- def post_training_step (self ):
170+ def post_training_step (self ) -> None :
160171 self .training_type_plugin .post_training_step ()
161172
162- def validation_step (self , args ) :
173+ def validation_step (self , args : List [ Union [ Any , int ]]) -> _STEP_OUTPUT_TYPE :
163174 """The actual validation step.
164175
165176 Args:
@@ -177,7 +188,7 @@ def validation_step(self, args):
177188 with self .precision_plugin .val_step_context (), self .training_type_plugin .val_step_context ():
178189 return self .training_type_plugin .validation_step (* args )
179190
180- def test_step (self , args ) :
191+ def test_step (self , args : List [ Union [ Any , int ]]) -> _STEP_OUTPUT_TYPE :
181192 """The actual test step.
182193
183194 Args:
@@ -195,7 +206,7 @@ def test_step(self, args):
195206 with self .precision_plugin .test_step_context (), self .training_type_plugin .test_step_context ():
196207 return self .training_type_plugin .test_step (* args )
197208
198- def predict (self , args ) :
209+ def predict (self , args : List [ Union [ Any , int ]]) -> _STEP_OUTPUT_TYPE :
199210 """The actual predict step.
200211
201212 Args:
@@ -213,23 +224,29 @@ def predict(self, args):
213224 with self .precision_plugin .predict_context (), self .training_type_plugin .predict_context ():
214225 return self .training_type_plugin .predict (* args )
215226
216- def training_step_end (self , output ):
227+ def training_step_end (
228+ self , output : _STEP_OUTPUT_TYPE
229+ ) -> _STEP_OUTPUT_TYPE :
217230 """A hook to do something at the end of the training step
218231
219232 Args:
220233 output: the output of the training step
221234 """
222235 return self .training_type_plugin .training_step_end (output )
223236
224- def test_step_end (self , output ):
237+ def test_step_end (
238+ self , output : _STEP_OUTPUT_TYPE
239+ ) -> _STEP_OUTPUT_TYPE :
225240 """A hook to do something at the end of the test step
226241
227242 Args:
228243 output: the output of the test step
229244 """
230245 return self .training_type_plugin .test_step_end (output )
231246
232- def validation_step_end (self , output ):
247+ def validation_step_end (
248+ self , output : _STEP_OUTPUT_TYPE
249+ ) -> _STEP_OUTPUT_TYPE :
233250 """A hook to do something at the end of the validation step
234251
235252 Args:
@@ -243,8 +260,8 @@ def backward(
243260 optimizer : Optimizer ,
244261 optimizer_idx : int ,
245262 should_accumulate : bool ,
246- * args ,
247- ** kwargs ,
263+ * args : Any ,
264+ ** kwargs : Any ,
248265 ) -> torch .Tensor :
249266 """Forwards backward-calls to the precision plugin.
250267
@@ -262,7 +279,7 @@ def backward(
262279
263280 return output
264281
265- def optimizer_step (self , optimizer : Optimizer , opt_idx : int , lambda_closure : Callable , ** kwargs ) :
282+ def optimizer_step (self , optimizer : Optimizer , opt_idx : int , lambda_closure : Callable , ** kwargs : Any ) -> None :
266283 """performs the actual optimizer step.
267284
268285 Args:
@@ -279,7 +296,9 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal
279296 self .precision_plugin .post_optimizer_step (optimizer , opt_idx )
280297 self .training_type_plugin .post_optimizer_step (optimizer , opt_idx , ** kwargs )
281298
282- def run_optimizer_step (self , optimizer : Optimizer , optimizer_idx : int , lambda_closure : Callable , ** kwargs ):
299+ def run_optimizer_step (
300+ self , optimizer : Optimizer , optimizer_idx : int , lambda_closure : Callable , ** kwargs : Any
301+ ) -> None :
283302 self .training_type_plugin .optimizer_step (optimizer , lambda_closure = lambda_closure , ** kwargs )
284303
285304 def optimizer_zero_grad (self , current_epoch : int , batch_idx : int , optimizer : Optimizer , opt_idx : int ) -> None :
@@ -292,7 +311,7 @@ def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> N
292311
293312 self .precision_plugin .clip_gradients (optimizer , clip_val )
294313
295- def on_train_epoch_end (self , outputs ) -> None :
314+ def on_train_epoch_end (self , outputs : Sequence [ _STEP_OUTPUT_TYPE ] ) -> None :
296315 """Hook to do something on the end of an training epoch
297316
298317 Args:
@@ -304,7 +323,7 @@ def on_train_end(self) -> None:
304323 """Hook to do something at the end of the training"""
305324 pass
306325
307- def setup_optimizers (self , trainer ) :
326+ def setup_optimizers (self , trainer : 'Trainer' ) -> None :
308327 """creates optimizers and schedulers
309328
310329 Args:
@@ -327,7 +346,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn
327346 """
328347 plugin .connect (model )
329348
330- def connect_precision_plugin (self , plugin : PrecisionPlugin ):
349+ def connect_precision_plugin (self , plugin : PrecisionPlugin ) -> None :
331350 """Attaches the precision plugin to the accelerator"""
332351 model , optimizers , schedulers = plugin .connect (self .model , self .optimizers , self .lr_schedulers )
333352 self .model = model
@@ -351,26 +370,22 @@ def precision(self) -> int:
351370 return self .precision_plugin .precision
352371
353372 @property
354- def scaler (self ):
355- if hasattr (self .precision_plugin , "scaler" ):
356- return self .precision_plugin .scaler
373+ def scaler (self ) -> Optional ['GradScaler' ]:
357374
358- return None
375+ return getattr ( self . precision_plugin , 'scaler' , None )
359376
360377 @property
361378 def rpc_enabled (self ) -> bool :
362379 return self .training_type_plugin .rpc_enabled
363380
364- def optimizer_state (self , optimizer : Optimizer ) -> dict :
381+ def optimizer_state (self , optimizer : Optimizer ) -> Dict [ str , torch . Tensor ] :
365382 """
366383 Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
367384 plugins.
368385 """
369- if self .training_type_plugin and hasattr (self .training_type_plugin , "optimizer_state" ):
370- return self .training_type_plugin .optimizer_state (optimizer )
371- return optimizer .state_dict ()
386+ return getattr (self .training_type_plugin , 'optimizer_state' , lambda x : x .state_dict ())(optimizer )
372387
373- def on_save (self , checkpoint ) :
388+ def on_save (self , checkpoint : Dict [ str , Union [ Any , torch . Tensor ]]) -> Dict [ str , Union [ Any , torch . Tensor ]] :
374389 return checkpoint
375390
376391 def barrier (self , name : Optional [str ] = None ) -> None :
@@ -385,7 +400,9 @@ def broadcast(self, obj: object, src: int = 0) -> object:
385400 """
386401 return self .training_type_plugin .broadcast (obj , src )
387402
388- def all_gather (self , tensor : Union [torch .Tensor ], group : Optional [Any ] = None , sync_grads : bool = False ):
403+ def all_gather (
404+ self , tensor : torch .Tensor , group : Optional [Any ] = None , sync_grads : bool = False
405+ ) -> torch .Tensor :
389406 """
390407 Function to gather a tensor from several distributed processes.
391408
0 commit comments