Skip to content

Commit 5e8de9a

Browse files
committed
make-private
make attributes private to the class
1 parent dd16af3 commit 5e8de9a

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None:
205205
if self._should_skip_saving_checkpoint(trainer):
206206
return
207207
step = trainer.global_step
208-
skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0)
208+
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
209209
if skip_batch:
210210
return
211211
self.save_checkpoint(trainer, pl_module)
@@ -215,8 +215,8 @@ def on_validation_end(self, trainer, pl_module):
215215
checkpoints can be saved at the end of the val loop
216216
"""
217217
skip = (
218-
self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None
219-
or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0
218+
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
219+
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
220220
)
221221
if skip:
222222
return
@@ -279,18 +279,18 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool:
279279
def __validate_init_configuration(self):
280280
if self.save_top_k is not None and self.save_top_k < -1:
281281
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
282-
if self.every_n_train_steps < 0:
282+
if self._every_n_train_steps < 0:
283283
raise MisconfigurationException(
284-
f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0'
284+
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
285285
)
286-
if self.every_n_val_epochs < 0:
286+
if self._every_n_val_epochs < 0:
287287
raise MisconfigurationException(
288-
f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0'
288+
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
289289
)
290-
if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0:
290+
if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:
291291
raise MisconfigurationException(
292-
f'Invalid values for every_n_train_steps={self.every_n_train_steps}'
293-
' and every_n_val_epochs={self.every_n_val_epochs}.'
292+
f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
293+
' and every_n_val_epochs={self._every_n_val_epochs}.'
294294
'Both cannot be enabled at the same time.'
295295
)
296296
if self.monitor is None:
@@ -346,22 +346,22 @@ def __init_triggers(
346346
# Default to running once after each validation epoch if neither
347347
# every_n_train_steps nor every_n_val_epochs is set
348348
if every_n_train_steps is None and every_n_val_epochs is None:
349-
self.every_n_val_epochs = 1
350-
self.every_n_train_steps = 0
349+
self._every_n_val_epochs = 1
350+
self._every_n_train_steps = 0
351351
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
352352
else:
353-
self.every_n_val_epochs = every_n_val_epochs or 0
354-
self.every_n_train_steps = every_n_train_steps or 0
353+
self._every_n_val_epochs = every_n_val_epochs or 0
354+
self._every_n_train_steps = every_n_train_steps or 0
355355

356356
# period takes precedence over every_n_val_epochs for backwards compatibility
357357
if period is not None:
358358
rank_zero_warn(
359359
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
360360
' Please use `every_n_val_epochs` instead.', DeprecationWarning
361361
)
362-
self.every_n_val_epochs = period
362+
self._every_n_val_epochs = period
363363

364-
self._period = self.every_n_val_epochs
364+
self._period = self._every_n_val_epochs
365365

366366
@property
367367
def period(self) -> Optional[int]:

tests/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,8 @@ def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir):
550550
def test_none_every_n_train_steps_val_epochs(tmpdir):
551551
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
552552
assert checkpoint_callback.period == 1
553-
assert checkpoint_callback.every_n_val_epochs == 1
554-
assert checkpoint_callback.every_n_train_steps == 0
553+
assert checkpoint_callback._every_n_val_epochs == 1
554+
assert checkpoint_callback._every_n_train_steps == 0
555555

556556

557557
def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):

0 commit comments

Comments
 (0)