-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
properties.py
668 lines (551 loc) · 22 KB
/
properties.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import cast, List, Optional, Type, TypeVar, Union
import torch
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_helpers import is_overridden
class TrainerProperties(ABC):
_default_root_dir: str
_fit_loop: FitLoop
_lightning_optimizers = None
_predict_loop: PredictionLoop
_progress_bar_callback: ProgressBarBase
_test_loop: EvaluationLoop
_validate_loop: EvaluationLoop
_weights_save_path: str
accelerator_connector: AcceleratorConnector
callbacks: List[Callback]
checkpoint_connector: CheckpointConnector
reload_dataloaders_every_n_epochs: int
limit_val_batches: int
logger: LightningLoggerBase
logger_connector: LoggerConnector
state: TrainerState
# .validate() and .test() set this when they load a checkpoint
validated_ckpt_path: Optional[str] = None
tested_ckpt_path: Optional[str] = None
predicted_ckpt_path: Optional[str] = None
"""
Accelerator properties
"""
@property
def accelerator(self) -> Accelerator:
return self.accelerator_connector.accelerator
@property
def distributed_backend(self) -> Optional[str]:
# for backward compatibility
return self.accelerator_connector.distributed_backend
@property
def training_type_plugin(self) -> TrainingTypePlugin:
return self.accelerator.training_type_plugin
@property
def precision_plugin(self) -> PrecisionPlugin:
return self.accelerator.precision_plugin
@property
def global_rank(self) -> int:
return self.accelerator.training_type_plugin.global_rank
@property
def local_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)
@property
def node_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)
@property
def world_size(self) -> int:
# some training types define a world size
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
@property
def should_rank_save_checkpoint(self) -> bool:
return self.accelerator.training_type_plugin.should_rank_save_checkpoint
@property
def _distrib_type(self) -> DistributedType:
return self.accelerator_connector._distrib_type
@property
def _device_type(self) -> DeviceType:
return self.accelerator_connector._device_type
@property
def num_nodes(self) -> int:
return self.accelerator_connector.num_nodes
@property
def num_processes(self) -> int:
return self.accelerator_connector.num_processes
@property
def root_gpu(self) -> Optional[int]:
return self.accelerator_connector.root_gpu
@property
def tpu_cores(self) -> int:
return self.accelerator_connector.tpu_cores
@property
def ipus(self) -> int:
return self.accelerator_connector.num_ipus
@property
def num_gpus(self) -> int:
return self.accelerator_connector.num_gpus
@property
def devices(self) -> Optional[Union[List[int], str, int]]:
return self.accelerator_connector.devices
@property
def data_parallel_device_ids(self) -> Optional[List[int]]:
return self.accelerator_connector.parallel_device_ids
@property
def lightning_module(self) -> "pl.LightningModule":
return self.accelerator.lightning_module
@property
def optimizers(self) -> Optional[List[Optimizer]]:
return self.accelerator.optimizers
@optimizers.setter
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
# Necessary to rewrap optimizers to lightning
# They will be re-created when accessing
# the `lightning_optimizers` trainer property
self._lightning_optimizers = None
self.accelerator.optimizers = new_optims
@property
def lr_schedulers(self) -> Optional[list]:
return self.accelerator.lr_schedulers
@lr_schedulers.setter
def lr_schedulers(self, new_schedulers: Optional[list]) -> None:
self.accelerator.lr_schedulers = new_schedulers
@property
def optimizer_frequencies(self) -> list:
return self.accelerator.optimizer_frequencies
@optimizer_frequencies.setter
def optimizer_frequencies(self, new_freqs: list) -> None:
self.accelerator.optimizer_frequencies = new_freqs
@property
def amp_backend(self) -> Optional[str]:
return self.accelerator.amp_backend
@property
def precision(self) -> Union[str, int]:
return self.accelerator.precision
@property
def scaler(self):
return self.accelerator.scaler
@property
def gpus(self) -> Optional[Union[List[int], str, int]]:
return self.accelerator_connector.gpus
@property
def model(self) -> torch.nn.Module:
"""
The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
To access the pure LightningModule, use
:meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead.
"""
return self.accelerator.model
@model.setter
def model(self, model: torch.nn.Module) -> None:
"""
Setter for the model, pass-through to accelerator and plugin where the model reference is stored.
Used by the Tuner to reset the state of Trainer and Accelerator.
Args:
model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending
on the backend.
"""
self.accelerator.model = model
"""
General properties
"""
@property
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
elif isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
elif isinstance(self.logger, LoggerCollection):
dirpath = self.default_root_dir
else:
dirpath = self.logger.save_dir
dirpath = self.accelerator.broadcast(dirpath)
return dirpath
@property
def use_amp(self) -> bool:
return self.precision == 16
@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
@property
def slurm_job_id(self) -> Optional[int]:
job_id = os.environ.get("SLURM_JOB_ID")
if job_id:
try:
job_id = int(job_id)
except ValueError:
job_id = None
# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash"
if in_slurm_interactive_mode:
job_id = None
return job_id
@property
def lightning_optimizers(self) -> List[LightningOptimizer]:
if self._lightning_optimizers is None:
self.convert_to_lightning_optimizers()
return self._lightning_optimizers
@property
def distributed_sampler_kwargs(self) -> Optional[dict]:
if isinstance(self.training_type_plugin, ParallelPlugin):
return self.training_type_plugin.distributed_sampler_kwargs
@property
def data_parallel(self) -> bool:
return self._distrib_type in (
DistributedType.DP,
DistributedType.DDP,
DistributedType.DDP_SPAWN,
DistributedType.DDP2,
)
@property
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
return self._progress_bar_callback
@property
def progress_bar_dict(self) -> dict:
"""Read-only for progress bar metrics."""
ref_model = self.lightning_module
ref_model = cast(pl.LightningModule, ref_model)
standard_metrics = ref_model.get_progress_bar_dict()
pbar_metrics = self.progress_bar_metrics
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
" If this is undesired, change the name or override `get_progress_bar_dict()`"
" in `LightingModule`.",
UserWarning,
)
return {**standard_metrics, **pbar_metrics}
@property
def _should_reload_dl_epoch(self) -> bool:
"""Check if dataloader should be reloaded in the current epoch."""
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (not self.current_epoch % n_epochs)
@property
def disable_validation(self) -> bool:
"""Check if validation is disabled during training."""
rank_zero_deprecation(
"`trainer.disable_validation` is deprecated in v1.4 and will be removed in v1.6."
" Use `not trainer.enable_validation` instead."
)
return not self.enable_validation
@property
def enable_validation(self) -> bool:
"""Check if we should run validation during training."""
model_ref = self.lightning_module
val_loop_enabled = is_overridden("validation_step", model_ref) and self.limit_val_batches > 0
return val_loop_enabled
@property
def default_root_dir(self) -> str:
"""
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir
@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path
@property
def early_stopping_callback(self) -> Optional[EarlyStopping]:
"""
The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
"""
callbacks = self.early_stopping_callbacks
return callbacks[0] if len(callbacks) > 0 else None
@property
def early_stopping_callbacks(self) -> List[EarlyStopping]:
"""
A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
found in the Trainer.callbacks list.
"""
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]
@property
def prediction_writer_callbacks(self) -> List[BasePredictionWriter]:
"""
A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter`
found in the Trainer.callbacks list.
"""
return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)]
@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
"""
The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
"""
callbacks = self.checkpoint_callbacks
return callbacks[0] if len(callbacks) > 0 else None
@property
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
"""
A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
found in the Trainer.callbacks list.
"""
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
@property
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
return self.checkpoint_connector.resume_checkpoint_path
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
"""
Parsing properties
"""
@classmethod
def default_attributes(cls) -> dict:
init_signature = inspect.signature(cls)
return {k: v.default for k, v in init_signature.parameters.items()}
@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith("DEPRECATED") and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names
@classmethod
def from_argparse_args(cls: Type["_T"], args: Union[Namespace, ArgumentParser], **kwargs) -> "_T":
return from_argparse_args(cls, args, **kwargs)
@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return parse_argparser(cls, arg_parser)
@classmethod
def match_env_arguments(cls) -> Namespace:
return parse_env_variables(cls)
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
return add_argparse_args(cls, parent_parser, **kwargs)
"""
State properties
"""
@property
def interrupted(self) -> bool:
return self.state.status == TrainerStatus.INTERRUPTED
@property
def training(self) -> bool:
return self.state.stage == RunningStage.TRAINING
@training.setter
def training(self, val: bool) -> None:
if val:
self.state.stage = RunningStage.TRAINING
elif self.training:
self.state.stage = None
@property
def testing(self) -> bool:
return self.state.stage == RunningStage.TESTING
@testing.setter
def testing(self, val: bool) -> None:
if val:
self.state.stage = RunningStage.TESTING
elif self.testing:
self.state.stage = None
@property
def predicting(self) -> bool:
return self.state.stage == RunningStage.PREDICTING
@predicting.setter
def predicting(self, val: bool) -> None:
if val:
self.state.stage = RunningStage.PREDICTING
elif self.predicting:
self.state.stage = None
@property
def tuning(self) -> bool:
return self.state.stage == RunningStage.TUNING
@tuning.setter
def tuning(self, val: bool) -> None:
if val:
self.state.stage = RunningStage.TUNING
elif self.tuning:
self.state.stage = None
@property
def validating(self) -> bool:
return self.state.stage == RunningStage.VALIDATING
@validating.setter
def validating(self, val: bool) -> None:
if val:
self.state.stage = RunningStage.VALIDATING
elif self.validating:
self.state.stage = None
@property
def evaluating(self) -> bool:
return self.state.stage and self.state.stage.evaluating
@property
def sanity_checking(self) -> bool:
return self.state.stage == RunningStage.SANITY_CHECKING
@sanity_checking.setter
def sanity_checking(self, val: bool) -> None:
if val:
self.state.stage = RunningStage.SANITY_CHECKING
elif self.sanity_checking:
self.state.stage = None
"""
Loop properties
"""
@property
def global_step(self) -> int:
return self.fit_loop.global_step
@property
def current_epoch(self) -> int:
return self.fit_loop.current_epoch
@property
def max_epochs(self) -> Optional[int]:
return self.fit_loop.max_epochs
@property
def min_epochs(self) -> Optional[int]:
return self.fit_loop.min_epochs
@property
def max_steps(self) -> Optional[int]:
return self.fit_loop.max_steps
@property
def min_steps(self) -> Optional[int]:
return self.fit_loop.min_steps
@property
def is_last_batch(self) -> bool:
return self.fit_loop.epoch_loop.is_last_batch
@property
def fit_loop(self) -> FitLoop:
return self._fit_loop
@fit_loop.setter
def fit_loop(self, loop: FitLoop):
"""
Attach a custom fit loop to this Trainer. It will run with
:meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`.
"""
loop.trainer = self
self._fit_loop = loop
@property
def validate_loop(self) -> EvaluationLoop:
return self._validate_loop
@validate_loop.setter
def validate_loop(self, loop: EvaluationLoop):
"""
Attach a custom validation loop to this Trainer. It will run with
:meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
"""
loop.trainer = self
self._validate_loop = loop
@property
def test_loop(self) -> EvaluationLoop:
return self._test_loop
@test_loop.setter
def test_loop(self, loop: EvaluationLoop):
"""
Attach a custom test loop to this Trainer. It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
"""
loop.trainer = self
self._test_loop = loop
@property
def predict_loop(self) -> PredictionLoop:
return self._predict_loop
@predict_loop.setter
def predict_loop(self, loop: PredictionLoop):
"""
Attach a custom prediction loop to this Trainer. It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
"""
loop.trainer = self
self._predict_loop = loop
@property
def _evaluation_loop(self) -> EvaluationLoop:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
return self.fit_loop.epoch_loop.val_loop
if self.state.fn == TrainerFn.VALIDATING:
return self.validate_loop
if self.state.fn == TrainerFn.TESTING:
return self.test_loop
raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope")
@property
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop]]:
if self.training:
return self.fit_loop
if self.sanity_checking or self.evaluating:
return self._evaluation_loop
if self.predicting:
return self.predict_loop
@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path
"""
Logging properties
"""
@property
def callback_metrics(self) -> dict:
return self.logger_connector.callback_metrics
@property
def logged_metrics(self) -> dict:
return self.logger_connector.logged_metrics
@property
def progress_bar_metrics(self) -> dict:
return self.logger_connector.progress_bar_metrics
@property
def _results(self) -> Optional[ResultCollection]:
active_loop = self._active_loop
if active_loop is not None:
return active_loop._results
"""
Other
"""
# TODO: refactor this so that it can be done in LightningOptimizer
def __getstate__(self):
# remove lightning_optimizers
self._lightning_optimizers = None
return self.__dict__
def __setstate__(self, state):
self.__dict__ = state
# Used to represent the concrete type TrainerProperties class methods are called on.
_T = TypeVar("_T", bound=TrainerProperties)