Skip to content

Commit ab4012d

Browse files
committed
make-private
make attributes private to the class
1 parent 1e7f640 commit ab4012d

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None:
217217
if self._should_skip_saving_checkpoint(trainer):
218218
return
219219
step = trainer.global_step
220-
skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0)
220+
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
221221
if skip_batch:
222222
return
223223
self.save_checkpoint(trainer, pl_module)
@@ -227,8 +227,8 @@ def on_validation_end(self, trainer, pl_module):
227227
checkpoints can be saved at the end of the val loop
228228
"""
229229
skip = (
230-
self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None
231-
or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0
230+
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
231+
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
232232
)
233233
if skip:
234234
return
@@ -291,18 +291,18 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool:
291291
def __validate_init_configuration(self):
292292
if self.save_top_k is not None and self.save_top_k < -1:
293293
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
294-
if self.every_n_train_steps < 0:
294+
if self._every_n_train_steps < 0:
295295
raise MisconfigurationException(
296-
f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0'
296+
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
297297
)
298-
if self.every_n_val_epochs < 0:
298+
if self._every_n_val_epochs < 0:
299299
raise MisconfigurationException(
300-
f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0'
300+
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
301301
)
302-
if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0:
302+
if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:
303303
raise MisconfigurationException(
304-
f'Invalid values for every_n_train_steps={self.every_n_train_steps}'
305-
' and every_n_val_epochs={self.every_n_val_epochs}.'
304+
f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
305+
' and every_n_val_epochs={self._every_n_val_epochs}.'
306306
'Both cannot be enabled at the same time.'
307307
)
308308
if self.monitor is None:
@@ -358,22 +358,22 @@ def __init_triggers(
358358
# Default to running once after each validation epoch if neither
359359
# every_n_train_steps nor every_n_val_epochs is set
360360
if every_n_train_steps is None and every_n_val_epochs is None:
361-
self.every_n_val_epochs = 1
362-
self.every_n_train_steps = 0
361+
self._every_n_val_epochs = 1
362+
self._every_n_train_steps = 0
363363
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
364364
else:
365-
self.every_n_val_epochs = every_n_val_epochs or 0
366-
self.every_n_train_steps = every_n_train_steps or 0
365+
self._every_n_val_epochs = every_n_val_epochs or 0
366+
self._every_n_train_steps = every_n_train_steps or 0
367367

368368
# period takes precedence over every_n_val_epochs for backwards compatibility
369369
if period is not None:
370370
rank_zero_warn(
371371
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
372372
' Please use `every_n_val_epochs` instead.', DeprecationWarning
373373
)
374-
self.every_n_val_epochs = period
374+
self._every_n_val_epochs = period
375375

376-
self._period = self.every_n_val_epochs
376+
self._period = self._every_n_val_epochs
377377

378378
@property
379379
def period(self) -> Optional[int]:

tests/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,8 @@ def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir):
559559
def test_none_every_n_train_steps_val_epochs(tmpdir):
560560
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
561561
assert checkpoint_callback.period == 1
562-
assert checkpoint_callback.every_n_val_epochs == 1
563-
assert checkpoint_callback.every_n_train_steps == 0
562+
assert checkpoint_callback._every_n_val_epochs == 1
563+
assert checkpoint_callback._every_n_train_steps == 0
564564

565565

566566
def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):

0 commit comments

Comments
 (0)