-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
training_loop.py
935 lines (751 loc) · 35 KB
/
training_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
"""
The lightning training loop handles everything except the actual computations of your model.
To decide what will happen in your training loop, define the `training_step` function.
Below are all the things lightning automates for you in the training loop.
Accumulated gradients
---------------------
Accumulated gradients runs K small batches of size N before doing a backwards pass.
The effect is a large effective batch size of size KxN.
.. code-block:: python
# DEFAULT (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)
Force training for min or max epochs
------------------------------------
It can be useful to force training for a minimum number of epochs or limit to a max number
.. code-block:: python
# DEFAULT
trainer = Trainer(min_epochs=1, max_epochs=1000)
Force disable early stop
------------------------
To disable early stopping pass None to the early_stop_callback
.. code-block:: python
# DEFAULT
trainer = Trainer(early_stop_callback=None)
Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients.
Specifically, this will `clip the gradient norm computed over all model parameters
`together <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_.
.. code-block:: python
# DEFAULT (ie: don't clip)
trainer = Trainer(gradient_clip_val=0)
# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)
Inspect gradient norms
----------------------
Looking at grad norms can help you figure out where training might be going wrong.
.. code-block:: python
# DEFAULT (-1 doesn't track norms)
trainer = Trainer(track_grad_norm=-1)
# track the LP norm (P=2 here)
trainer = Trainer(track_grad_norm=2)
Set how much of the training set to check
-----------------------------------------
If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag.
train_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
.. code-block:: python
# DEFAULT
trainer = Trainer(train_percent_check=1.0)
# check 10% only
trainer = Trainer(train_percent_check=0.1)
Packed sequences as inputs
--------------------------
When using PackedSequence, do 2 things:
1. return either a padded tensor in dataset or a list of variable length tensors
in the dataloader collate_fn (example above shows the list implementation).
2. Pack the sequence in forward or training and validation steps depending on use case.
.. code-block:: python
# For use in dataloader
def collate_fn(batch):
x = [item[0] for item in batch]
y = [item[1] for item in batch]
return x, y
# In module
def training_step(self, batch, batch_idx):
x = rnn.pack_sequence(batch[0], enforce_sorted=False)
y = rnn.pack_sequence(batch[1], enforce_sorted=False)
Truncated Backpropagation Through Time
--------------------------------------
There are times when multiple backwards passes are needed for each batch.
For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs.
When this flag is enabled each batch is split into sequences of size truncated_bptt_steps
and passed to training_step(...) separately. A default splitting function is provided,
however, you can override it for more flexibility. See `tbptt_split_batch`.
.. code-block:: python
# DEFAULT (single backwards pass per batch)
trainer = Trainer(truncated_bptt_steps=None)
# (split batch into sequences of size 2)
trainer = Trainer(truncated_bptt_steps=2)
NaN detection and intervention
------------------------------
When the `terminate_on_nan` flag is enabled, after every forward pass during training, Lightning will
check that
1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
2. the model parameters have finite values.
Lightning will terminate the training loop with an error message if NaN or infinite
values are detected. If this happens, you should investigate numerically unstable operations
in your model.
.. code-block:: python
# DEFAULT (won't perform the NaN check)
trainer = Trainer(terminate_on_nan=False)
# (NaN check each batch and terminate on NaN or infinite values)
trainer = Trainer(terminate_on_nan=True)
"""
import atexit
import signal
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, List
import numpy as np
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import subprocess
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.memory import recursive_detach
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
try:
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True
# constant which signals should be catched for graceful trainer shutdown
SIGNAL_TERMINATE = ('SIGTERM', 'SIGSEGV', 'SIGINT')
class TrainerTrainLoopMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
max_epochs: int
min_epochs: int
on_gpu: bool
use_ddp: bool
use_dp: bool
use_ddp2: bool
use_horovod: bool
single_gpu: bool
use_tpu: bool
data_parallel_device_ids: ...
check_val_every_n_epoch: ...
num_training_batches: int
val_check_batch: ...
num_val_batches: int
disable_validation: bool
fast_dev_run: ...
accumulation_scheduler: ...
lr_schedulers: ...
enable_early_stop: ...
early_stop_callback: ...
callback_metrics: ...
logger: Union[LightningLoggerBase, bool]
global_step: int
testing: bool
log_save_interval: float
proc_rank: int
row_log_interval: float
truncated_bptt_steps: ...
optimizers: ...
optimizer_frequencies: ...
accumulate_grad_batches: int
track_grad_norm: ...
model: LightningModule
interrupted: bool
running_loss: ...
progress_bar_dict: ...
reduce_lr_on_plateau_scheduler: ...
profiler: ...
batch_idx: int
precision: ...
train_dataloader: DataLoader
reload_dataloaders_every_epoch: bool
max_steps: int
min_steps: int
total_batch_idx: int
checkpoint_callback: ...
terminate_on_nan: bool
tpu_id: int
# Callback system
callbacks: List[Callback]
on_train_start: Callable
on_train_end: Callable
on_batch_start: Callable
on_batch_end: Callable
on_epoch_start: Callable
on_epoch_end: Callable
on_validation_end: Callable
@abstractmethod
def get_model(self):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def is_function_implemented(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def run_evaluation(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def transfer_batch_to_gpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def transfer_batch_to_tpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def clip_gradients(self):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def detect_nan_tensors(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def add_progress_bar_metrics(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def log_metrics(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def process_output(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def reset_train_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def reset_val_dataloader(self, model):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def train(self):
# add signal handlers for process kills
# def _signal_kill_handler(*args):
# return TrainerTrainLoopMixin.run_training_teardown(self)
#
# orig_signal_handlers = {}
# for sig_name in SIGNAL_TERMINATE:
# orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name),
# _signal_kill_handler)
# get model
model = self.get_model()
# load data
# if reload_dataloaders_every_epoch, this is moved to the epoch loop
if not self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)
self.reset_val_dataloader(model)
# Train start events
with self.profiler.profile('on_train_start'):
# callbacks
self.on_train_start()
# initialize early stop callback
if self.early_stop_callback is not None:
self.early_stop_callback.on_train_start(self, self.get_model())
# model hooks
model.on_train_start()
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# reset train dataloader
if self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_horovod) \
and hasattr(self.train_dataloader, 'sampler') \
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
self.train_dataloader.sampler.set_epoch(epoch)
# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningAccum(
window_length=self.accumulate_grad_batches
)
# -----------------
# RUN TNG EPOCH
# -----------------
self.run_training_epoch()
if self.max_steps and self.max_steps == self.global_step:
self.run_training_teardown()
return
# update LR schedulers
self.update_learning_rates(interval='epoch')
# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
# TODO wrap this logic into the callback
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop:
self.run_training_teardown()
return
self.run_training_teardown()
except KeyboardInterrupt:
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
for proc in self.interactive_ddp_procs:
subprocess.Popen.kill(proc)
self.run_training_teardown()
def run_training_epoch(self):
# get model
model = self.get_model()
# Epoch start events
with self.profiler.profile('on_epoch_start'):
# callbacks
self.on_epoch_start()
# model hooks
if self.is_function_implemented('on_epoch_start'):
model.on_epoch_start()
# track local dataloader so TPU can wrap each epoch
train_dataloader = self.train_dataloader
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)
# bookkeeping
epoch_outputs = []
to_pbar_on_epoch_end = []
to_log_on_epoch_end = []
# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
break
self.batch_idx = batch_idx
model.global_step = self.global_step
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output, training_step_output_for_epoch_end = self.run_training_batch(batch, batch_idx)
# log or add these metrics to the pbar when the epoch completes
to_log_on_epoch_end.append(batch_output.to_log_on_epoch_end)
to_pbar_on_epoch_end.append(batch_output.to_pbar_on_epoch_end)
# TODO: pull out reduce fxs
#to_log_pbar_reduce_fxs.append(batch_output.training_step_outputs)
# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
if self.is_overridden('training_epoch_end', model=self.get_model()):
epoch_outputs.append(training_step_output_for_epoch_end)
# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_output.signal == -1
# TODO: consolidate all actions that need to take place only after
# self.accumulate_grad_batches steps (optimizer step, lr update, global step increment)
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
# update lr
self.update_learning_rates(interval='step')
# ---------------
# RUN VAL STEP
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val
# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()
# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
if self.proc_rank == 0 and self.logger is not None:
self.logger.save()
# when metrics should be logged
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_output.to_log_on_batch_end, batch_output.grad_norm_dic)
# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
self.total_batch_idx += 1
# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
break
# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if early_stop_epoch or self.fast_dev_run:
break
if self.use_horovod:
hvd.join(hvd.local_rank() if self.on_gpu else -1)
# process epoch outputs
model = self.get_model()
if self.is_overridden('training_epoch_end', model=model):
epoch_output = model.training_epoch_end(epoch_outputs)
# TODO: create a result object from here, process then put in correct areas
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]
self.log_metrics(log_epoch_metrics, {})
self.callback_metrics.update(callback_epoch_metrics)
self.add_progress_bar_metrics(_processed_outputs[1])
# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
self.call_checkpoint_callback()
# Epoch end events
with self.profiler.profile('on_epoch_end'):
# callbacks
self.on_epoch_end()
# model hooks
if self.is_function_implemented('on_epoch_end'):
model.on_epoch_end()
def run_training_batch(self, batch, batch_idx):
# TODO: verify new refactor
# track grad norms
grad_norm_dic = {}
# track all metrics for callbacks
batch_callback_metrics = []
# track metrics to log either on the batch end or on the epoch end
to_log_on_batch_end = []
to_pbar_on_epoch_end = []
to_log_on_epoch_end = []
if batch is None:
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
# Batch start events
with self.profiler.profile('on_batch_start'):
# callbacks
self.on_batch_start()
# hooks
if self.is_function_implemented('on_batch_start'):
response = self.get_model().on_batch_start(batch)
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
splits = [batch]
if self.truncated_bptt_steps is not None:
model_ref = self.get_model()
with self.profiler.profile('tbptt_split_batch'):
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
# apply TBPTT. normal training meanst TBPTT = 1 (ie: don't split the batch across time)
self.hiddens = None
# in TBPTT we will only pass the last step output to epoch_end
for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx
# loop over each optimizer
for opt_idx, optimizer in self._get_optimizers_iterable():
# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if len(self.optimizers) > 1:
for param in self.get_model().parameters():
param.requires_grad = False
for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
# -------------------
# calculate loss
# -------------------
opt_closure_result = self.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.hiddens
)
# ------------------------------
# POST forward bookkeeping
# ------------------------------
# TODO: track the outputs to the loggers and pbar
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)
self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end)
to_log_on_batch_end.append(opt_closure_result.training_step_output.log_on_batch_end)
# loss, training_step_output, training_step_output_for_epoch_end, hiddens
self.hiddens = opt_closure_result.hiddens
# check if loss or model weights are nan
if self.terminate_on_nan:
self.detect_nan_tensors(opt_closure_result.loss)
# track total loss for logging (avoid mem leaks)
self.batch_loss_value.append(opt_closure_result.loss)
# ------------------------------
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
# backward
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)
# calculate running loss for display
self.running_loss.append(self.batch_loss_value.mean())
# reset for next set of accumulated grads
self.batch_loss_value.reset()
# Batch end events
with self.profiler.profile('on_batch_end'):
# callbacks
self.on_batch_end()
# model hooks
if self.is_function_implemented('on_batch_end'):
self.get_model().on_batch_end()
# collapse all metrics into one dict
to_log_on_batch_end = {k: v for d in to_log_on_batch_end for k, v in d.items()}
to_pbar_on_epoch_end = {k: v for d in to_pbar_on_epoch_end for k, v in d.items()}
to_log_on_epoch_end = {k: v for d in to_log_on_epoch_end for k, v in d.items()}
# track all metrics for callbacks
# TODO: make early stopping and checkpoint use the new metrics
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})
# training_step_output are passed to training_epoch_end
# batch_log_metrics metrics to log
result = AttributeDict(
signal=0,
grad_norm_dic=grad_norm_dic,
to_log_on_batch_end=to_log_on_batch_end,
training_step_output=opt_closure_result.training_step_output,
to_pbar_on_epoch_end=to_pbar_on_epoch_end,
to_log_on_epoch_end=to_log_on_epoch_end
)
return result, opt_closure_result.training_step_output_for_epoch_end
def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
# GRAD NORMS
# ------------------
# track gradient norms when requested
grad_norm_dic = {}
if batch_idx % self.row_log_interval == 0:
if float(self.track_grad_norm) > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)
# ------------------
# CLIP GRADS
# ------------------
if self.use_amp and self.use_native_amp:
self.scaler.unscale_(optimizer)
self.clip_gradients()
# ------------------
# .STEP + ZERO_GRAD
# ------------------
model = self.get_model()
with self.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.hiddens
)[0]
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx,
lambda: lambda_closure)
return grad_norm_dic
def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
"""
# ---------------------------
# FORWARD
# ---------------------------
with self.profiler.profile('model_forward'):
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
training_step_output = self.training_forward(split_batch, batch_idx,
opt_idx, hiddens)
else:
training_step_output = self.training_forward(split_batch, batch_idx, opt_idx,
hiddens)
# ----------------------------
# PROCESS THE RESULT
# ----------------------------
# support returning a simple tensor to mimimize
if isinstance(training_step_output, torch.Tensor):
training_step_output = Result(minimize=training_step_output)
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
training_step_output_for_epoch_end = recursive_detach(training_step_output)
# format and reduce outputs accordingly
if isinstance(training_step_output, Result):
training_step_output = self.process_step_result(training_step_output, train=True)
else:
training_step_output = self.process_output(training_step_output, train=True)
# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches
# backward pass
model_ref = self.get_model()
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)
# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
training_step_output.batch_loss = training_step_output.batch_loss.detach()
if self.use_horovod:
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
optimizer.synchronize()
# insert after step hook
if self.is_function_implemented('on_after_backward'):
model_ref = self.get_model()
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()
result = AttributeDict(
loss=closure_loss,
training_step_output=training_step_output,
training_step_output_for_epoch_end=training_step_output_for_epoch_end,
hiddens=training_step_output.hiddens,
)
return result
def _get_optimizers_iterable(self):
if not self.optimizer_frequencies:
# call training_step once per optimizer
return list(enumerate(self.optimizers))
optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies)
optimizers_loop_length = optimizer_freq_cumsum[-1]
current_place_in_loop = self.total_batch_idx % optimizers_loop_length
# find optimzier index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.optimizers[opt_idx])]
# @atexit.register
def run_training_teardown(self):
if hasattr(self, '_teardown_already_run') and self._teardown_already_run:
return
# Train end events
with self.profiler.profile('on_train_end'):
# callbacks
self.on_train_end()
# model hooks
if self.is_function_implemented('on_train_end'):
self.get_model().on_train_end()
if self.logger is not None:
self.logger.finalize("success")
# summarize profile results
self.profiler.describe()
self._teardown_already_run = True
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
# TODO: add test to make sure training_step and training_step_end were called
"""
Handle forward for each training case (distributed, single gpu, etc...)
:param batch:
:param batch_idx:
:return:
"""
# ---------------
# FORWARD
# ---------------
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]
if len(self.optimizers) > 1:
if self.has_arg('training_step', 'optimizer_idx'):
args.append(opt_idx)
else:
num_opts = len(self.optimizers)
raise ValueError(
f'Your LightningModule defines {num_opts} optimizers but '
f'training_step is missing the "optimizer_idx" argument.'
)
# pass hiddens if using tbptt
if self.truncated_bptt_steps is not None:
args.append(hiddens)
# distributed forward
if self.use_ddp or self.use_ddp2 or self.use_dp:
training_step_output = self.model(*args)
# Horovod
elif self.use_horovod and self.on_gpu:
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
args[0] = batch
training_step_output = self.model.training_step(*args)
# single GPU forward
elif self.single_gpu:
gpu_id = 0
if isinstance(self.data_parallel_device_ids, list):
gpu_id = self.data_parallel_device_ids[0]
# Don't copy the batch since there is a single gpu that the batch could
# be referenced from and if there are multiple optimizers the batch will
# wind up copying it to the same device repeatedly.
batch = self.transfer_batch_to_gpu(batch, gpu_id)
args[0] = batch
training_step_output = self.model.training_step(*args)
# TPU support
elif self.use_tpu:
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch
training_step_output = self.model.training_step(*args)
# CPU forward
else:
training_step_output = self.model.training_step(*args)
train_fwd_result = training_step_output
# ------------------------------------------
# TRAINING_STEP_END
# ------------------------------------------
call_train_step_end = self.is_overridden('training_step_end')
call_train_end = self.is_overridden('training_end') # TODO: remove in 1.0.0
if call_train_step_end or call_train_end:
callback_name = 'training_step_end' if call_train_step_end else 'training_end'
model_ref = self.get_model()
with self.profiler.profile(callback_name):
# format the inputs to the callback
train_step_end_input = training_step_output
# apply the callback
callback_fx = getattr(model_ref, callback_name)
train_step_end_output = callback_fx(train_step_end_input)
train_fwd_result = train_step_end_output
if call_train_end:
rank_zero_warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use training_epoch_end instead', DeprecationWarning)
return train_fwd_result
def update_learning_rates(self, interval: str):
"""Update learning rates.
Args:
interval: either 'epoch' or 'step'.
"""
if not self.lr_schedulers:
return
for lr_scheduler in self.lr_schedulers:
current_idx = self.batch_idx if interval == 'step' else self.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
# If instance of ReduceLROnPlateau, we need to pass validation loss
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
monitor_val = self.callback_metrics.get(monitor_key)
if monitor_val is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
raise MisconfigurationException(
f'ReduceLROnPlateau conditioned on metric {monitor_key}'
f' which is not available. Available metrics are: {avail_metrics}.'
' Condition can be set using `monitor` key in lr scheduler dict'
)
lr_scheduler['scheduler'].step(monitor_val)
else:
lr_scheduler['scheduler'].step()
def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
def _with_is_last(iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True