This repository has been archived by the owner on Jun 2, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcallbacks.py
1810 lines (1520 loc) · 69 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Callbacks: utilities called at certain points during model training.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import csv
import six
import numpy as np
import time
import json
import warnings
import io
import datetime
from collections import deque
from collections import OrderedDict
from collections import Iterable
from collections import defaultdict
from keras.utils.generic_utils import Progbar
from keras import backend as K
from keras.engine.training_utils import standardize_input_data
try:
import requests
except ImportError:
requests = None
_TRAIN = 'train'
_TEST = 'test'
_PREDICT = 'predict'
class CallbackList(object):
"""Container abstracting a list of callbacks.
# Arguments
callbacks: List of `Callback` instances.
queue_length: Queue length for keeping
running statistics over callback execution time.
"""
def __init__(self, callbacks=None, queue_length=10):
callbacks = callbacks or []
self.callbacks = [c for c in callbacks]
self.queue_length = queue_length
self.params = {}
self.model = None
self._reset_batch_timing()
def _reset_batch_timing(self):
self._delta_t_batch = 0.
self._delta_ts = defaultdict(lambda: deque([], maxlen=self.queue_length))
def append(self, callback):
self.callbacks.append(callback)
def set_params(self, params):
self.params = params
for callback in self.callbacks:
callback.set_params(params)
def set_model(self, model):
self.model = model
for callback in self.callbacks:
callback.set_model(model)
def _call_batch_hook(self, mode, hook, batch, logs=None):
"""Helper function for all batch_{begin | end} methods."""
if not self.callbacks:
return
hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook)
if hook == 'end':
if not hasattr(self, '_t_enter_batch'):
self._t_enter_batch = time.time()
# Batch is ending, calculate batch time
self._delta_t_batch = time.time() - self._t_enter_batch
logs = logs or {}
t_before_callbacks = time.time()
for callback in self.callbacks:
batch_hook = getattr(callback, hook_name)
batch_hook(batch, logs)
self._delta_ts[hook_name].append(time.time() - t_before_callbacks)
delta_t_median = np.median(self._delta_ts[hook_name])
if (self._delta_t_batch > 0. and
delta_t_median > 0.95 * self._delta_t_batch and
delta_t_median > 0.1):
warnings.warn(
'Method (%s) is slow compared '
'to the batch update (%f). Check your callbacks.'
% (hook_name, delta_t_median), RuntimeWarning)
if hook == 'begin':
self._t_enter_batch = time.time()
def _call_begin_hook(self, mode):
"""Helper function for on_{train|test|predict}_begin methods."""
if mode == _TRAIN:
self.on_train_begin()
elif mode == _TEST:
self.on_test_begin()
else:
self.on_predict_begin()
def _call_end_hook(self, mode):
"""Helper function for on_{train|test|predict}_end methods."""
if mode == _TRAIN:
self.on_train_end()
elif mode == _TEST:
self.on_test_end()
else:
self.on_predict_end()
def on_batch_begin(self, batch, logs=None):
self._call_batch_hook(_TRAIN, 'begin', batch, logs=logs)
def on_batch_end(self, batch, logs=None):
self._call_batch_hook(_TRAIN, 'end', batch, logs=logs)
def on_epoch_begin(self, epoch, logs=None):
"""Calls the `on_epoch_begin` methods of its callbacks.
This function should only be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
self._reset_batch_timing()
def on_epoch_end(self, epoch, logs=None):
"""Calls the `on_epoch_end` methods of its callbacks.
This function should only be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
def on_train_batch_begin(self, batch, logs=None):
"""Calls the `on_train_batch_begin` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
self._call_batch_hook(_TRAIN, 'begin', batch, logs=logs)
def on_train_batch_end(self, batch, logs=None):
"""Calls the `on_train_batch_end` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
self._call_batch_hook(_TRAIN, 'end', batch, logs=logs)
def on_test_batch_begin(self, batch, logs=None):
"""Calls the `on_test_batch_begin` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
self._call_batch_hook(_TEST, 'begin', batch, logs=logs)
def on_test_batch_end(self, batch, logs=None):
"""Calls the `on_test_batch_end` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
self._call_batch_hook(_TEST, 'end', batch, logs=logs)
def on_predict_batch_begin(self, batch, logs=None):
"""Calls the `on_predict_batch_begin` methods of its callbacks.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
self._call_batch_hook(_PREDICT, 'begin', batch, logs=logs)
def on_predict_batch_end(self, batch, logs=None):
"""Calls the `on_predict_batch_end` methods of its callbacks.
# Argument
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
self._call_batch_hook(_PREDICT, 'end', batch, logs=logs)
def on_train_begin(self, logs=None):
"""Calls the `on_train_begin` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_train_begin(logs)
def on_train_end(self, logs=None):
"""Calls the `on_train_end` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_train_end(logs)
def on_test_begin(self, logs=None):
"""Calls the `on_test_begin` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_test_begin(logs)
def on_test_end(self, logs=None):
"""Calls the `on_test_end` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_test_end(logs)
def on_predict_begin(self, logs=None):
"""Calls the `on_predict_begin` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_predict_begin(logs)
def on_predict_end(self, logs=None):
"""Calls the `on_predict_end` methods of its callbacks.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
for callback in self.callbacks:
callback.on_predict_end(logs)
def __iter__(self):
return iter(self.callbacks)
class Callback(object):
"""Abstract base class used to build new callbacks.
# Properties
params: dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch.
Currently, the `.fit()` method of the `Sequential` model class
will include the following quantities in the `logs` that
it passes to its callbacks:
on_epoch_end: logs include `acc` and `loss`, and
optionally include `val_loss`
(if validation is enabled in `fit`), and `val_acc`
(if validation and accuracy monitoring are enabled).
on_batch_begin: logs include `size`,
the number of samples in the current batch.
on_batch_end: logs include `loss`, and optionally `acc`
(if accuracy monitoring is enabled).
"""
def __init__(self):
self.validation_data = None
self.model = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""
def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""
def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only
be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_epoch_end(self, epoch, logs=None):
"""Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only
be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
def on_train_batch_begin(self, batch, logs=None):
"""Called at the beginning of a training batch in `fit` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
# For backwards compatibility
self.on_batch_begin(batch, logs=logs)
def on_train_batch_end(self, batch, logs=None):
"""Called at the end of a training batch in `fit` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
# For backwards compatibility
self.on_batch_end(batch, logs=logs)
def on_test_batch_begin(self, batch, logs=None):
"""Called at the beginning of a batch in `evaluate` methods.
Also called at the beginning of a validation batch in the `fit` methods,
if validation data is provided.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
def on_test_batch_end(self, batch, logs=None):
"""Called at the end of a batch in `evaluate` methods.
Also called at the end of a validation batch in the `fit` methods,
if validation data is provided.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
def on_predict_batch_begin(self, batch, logs=None):
"""Called at the beginning of a batch in `predict` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, has keys `batch` and `size` representing the current
batch number and the size of the batch.
"""
def on_predict_batch_end(self, batch, logs=None):
"""Called at the end of a batch in `predict` methods.
Subclasses should override for any actions to run.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dict, metric results for this batch.
"""
def on_train_begin(self, logs=None):
"""Called at the beginning of training.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_train_end(self, logs=None):
"""Called at the end of training.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_test_begin(self, logs=None):
"""Called at the beginning of evaluation or validation.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_test_end(self, logs=None):
"""Called at the end of evaluation or validation.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_predict_begin(self, logs=None):
"""Called at the beginning of prediction.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
def on_predict_end(self, logs=None):
"""Called at the end of prediction.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
class BaseLogger(Callback):
"""Callback that accumulates epoch averages of metrics.
This callback is automatically applied to every Keras model.
# Arguments
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is in `on_epoch_end`.
All others will be averaged in `on_epoch_end`.
"""
def __init__(self, stateful_metrics=None):
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
def on_epoch_begin(self, epoch, logs=None):
self.seen = 0
self.totals = {}
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
self.seen += batch_size
for k, v in logs.items():
if k in self.stateful_metrics:
self.totals[k] = v
else:
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size
def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
if k in self.stateful_metrics:
logs[k] = self.totals[k]
else:
logs[k] = self.totals[k] / self.seen
class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered.
"""
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
print('Batch %d: Invalid loss, terminating training' % (batch))
self.model.stop_training = True
class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
# Arguments
count_mode: One of "steps" or "samples".
Whether the progress bar should
count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is.
All others will be averaged over time (e.g. loss, etc).
# Raises
ValueError: In case of invalid `count_mode`.
"""
def __init__(self, count_mode='samples',
stateful_metrics=None):
super(ProgbarLogger, self).__init__()
if count_mode == 'samples':
self.use_steps = False
elif count_mode == 'steps':
self.use_steps = True
else:
raise ValueError('Unknown `count_mode`: ' + str(count_mode))
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
def on_train_begin(self, logs=None):
self.verbose = self.params['verbose']
self.epochs = self.params['epochs']
def on_epoch_begin(self, epoch, logs=None):
if self.verbose:
print('Epoch %d/%d' % (epoch + 1, self.epochs))
if self.use_steps:
target = self.params['steps']
else:
target = self.params['samples']
self.target = target
self.progbar = Progbar(target=self.target,
verbose=self.verbose,
stateful_metrics=self.stateful_metrics)
self.seen = 0
def on_batch_begin(self, batch, logs=None):
if self.seen < self.target:
self.log_values = []
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
if self.use_steps:
self.seen += 1
else:
self.seen += batch_size
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
# Skip progbar update for the last batch;
# will be handled by on_epoch_end.
if self.verbose and self.seen < self.target:
self.progbar.update(self.seen, self.log_values)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
if self.verbose:
self.progbar.update(self.seen, self.log_values)
class History(Callback):
"""Callback that records events into a `History` object.
This callback is automatically applied to
every Keras model. The `History` object
gets returned by the `fit` method of models.
"""
def on_train_begin(self, logs=None):
self.epoch = []
self.history = {}
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
class ModelCheckpoint(Callback):
"""Save the model after every epoch.
`filepath` can contain named formatting options,
which will be filled with the values of `epoch` and
keys in `logs` (passed in `on_epoch_end`).
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
then the model checkpoints will be saved with the epoch number and
the validation loss in the filename.
# Arguments
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`,
the latest best model according to
the quantity monitored will not be overwritten.
save_weights_only: if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
mode: one of {auto, min, max}.
If `save_best_only=True`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
period: Interval (number of epochs) between checkpoints.
"""
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
'fallback to auto mode.' % (mode),
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
class EarlyStopping(Callback):
"""Stop training when a monitored quantity has stopped improving.
# Arguments
monitor: quantity to be monitored.
min_delta: minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity to reach.
Training will stop if the model doesn't show improvement
over the baseline.
restore_best_weights: whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
"""
def __init__(self,
monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.baseline = baseline
self.patience = patience
self.verbose = verbose
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
self.restore_best_weights = restore_best_weights
self.best_weights = None
if mode not in ['auto', 'min', 'max']:
warnings.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode,
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of '
'the best epoch')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
if monitor_value is None:
warnings.warn(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return monitor_value
class RemoteMonitor(Callback):
"""Callback used to stream events to a server.
Requires the `requests` library.
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
HTTP POST, with a `data` argument which is a
JSON-encoded dictionary of event data.
If send_as_json is set to True, the content type of the request will be
application/json. Otherwise the serialized JSON will be send within a form
# Arguments
root: String; root url of the target server.
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
The field is used only if the payload is sent within a form
(i.e. send_as_json is set to False).
headers: Dictionary; optional custom HTTP headers.
send_as_json: Boolean; whether the request should be send as
application/json.
"""
def __init__(self,
root='http://localhost:9000',
path='/publish/epoch/end/',
field='data',
headers=None,
send_as_json=False):
super(RemoteMonitor, self).__init__()
self.root = root
self.path = path
self.field = field
self.headers = headers
self.send_as_json = send_as_json
def on_epoch_end(self, epoch, logs=None):
if requests is None:
raise ImportError('RemoteMonitor requires '
'the `requests` library.')
logs = logs or {}
send = {}
send['epoch'] = epoch
for k, v in logs.items():
if isinstance(v, (np.ndarray, np.generic)):
send[k] = v.item()
else:
send[k] = v
try:
if self.send_as_json:
requests.post(self.root + self.path, json=send, headers=self.headers)
else:
requests.post(self.root + self.path,
{self.field: json.dumps(send)},
headers=self.headers)
except requests.exceptions.RequestException:
warnings.warn('Warning: could not reach RemoteMonitor '
'root server at ' + str(self.root))
class LearningRateScheduler(Callback):
"""Learning rate scheduler.
# Arguments
schedule: a function that takes an epoch index as input
(integer, indexed from 0) and current learning rate
and returns a new learning rate as output (float).
verbose: int. 0: quiet, 1: update messages.
"""
def __init__(self, schedule, verbose=0):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = float(K.get_value(self.model.optimizer.lr))
try: # new API
lr = self.schedule(epoch, lr)
except TypeError: # old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler setting learning '
'rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)
class TensorBoard(Callback):
"""TensorBoard basic visualizations.
[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard)
is a visualization tool provided with TensorFlow.
This callback writes a log for TensorBoard, which allows
you to visualize dynamic graphs of your training and test
metrics, as well as activation histograms for the different
layers in your model.
If you have installed TensorFlow with pip, you should be able
to launch TensorBoard from the command line:
```sh
tensorboard --logdir=/full_path_to_your_logs
```
When using a backend other than TensorFlow, TensorBoard will still work
(if you have TensorFlow installed), but the only feature available will
be the display of the losses and metrics plots.
# Arguments
log_dir: the path of the directory where to save the log
files to be parsed by TensorBoard.
histogram_freq: frequency (in epochs) at which to compute activation
and weight histograms for the layers of the model. If set to 0,
histograms won't be computed. Validation data (or split) must be
specified for histogram visualizations.
batch_size: size of batch of inputs to feed to the network
for histograms computation.
write_graph: whether to visualize the graph in TensorBoard.
The log file can become quite large when
write_graph is set to True.
write_grads: whether to visualize gradient histograms in TensorBoard.
`histogram_freq` must be greater than 0.
write_images: whether to write model weights to visualize as
image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
layers will be saved. If set to 0, embeddings won't be computed.
Data to be visualized in TensorBoard's Embedding tab must be passed
as `embeddings_data`.
embeddings_layer_names: a list of names of layers to keep eye on. If
None or empty list all the embedding layer will be watched.
embeddings_metadata: a dictionary which maps layer name to a file name
in which metadata for this embedding layer is saved. See the
[details](https://www.tensorflow.org/guide/embedding#metadata)
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
embeddings_data: data to be embedded at layers specified in
`embeddings_layer_names`. Numpy array (if the model has a single
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings](
https://www.tensorflow.org/guide/embedding).
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, writes
the losses and metrics to TensorBoard after each batch. The same
applies for `'epoch'`. If using an integer, let's say `10000`,
the callback will write the metrics and losses to TensorBoard every