Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve FitLoop setter TODOs #16803

Merged
merged 3 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `ProgressBarBase.{train_batch_idx,val_batch_idx,test_batch_idx,predict_batch_idx}` properties ([#16760](https://github.com/Lightning-AI/lightning/pull/16760))


- Removed the `fit_loop.{min,max}_steps` setters ([#16803](https://github.com/Lightning-AI/lightning/pull/16803))


- Removed the `Trainer(track_grad_norm=...)` argument ([#16745](https://github.com/Lightning-AI/lightning/pull/16745))

Expand Down
16 changes: 0 additions & 16 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,27 +102,11 @@ def min_steps(self) -> Optional[int]:
"""Returns the minimum number of steps to run."""
return self.epoch_loop.min_steps

@min_steps.setter
def min_steps(self, value: Optional[int]) -> None:
"""Sets the minimum number of steps (forwards to epoch_loop)"""
# TODO: This setter is required by debugging connector (fast dev run), should be avoided
self.epoch_loop.min_steps = value

@property
def max_steps(self) -> int:
"""Returns the maximum number of steps to run."""
return self.epoch_loop.max_steps

@max_steps.setter
def max_steps(self, value: int) -> None:
"""Sets the maximum number of steps (forwards to epoch_loop)"""
# TODO: This setter is required by debugging connector (fast dev run), should be avoided
if value < -1:
raise MisconfigurationException(
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
)
self.epoch_loop.max_steps = value

@_Loop.restarting.setter
def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _init_debugging_flags(

trainer.limit_test_batches = num_batches
trainer.limit_predict_batches = num_batches
trainer.fit_loop.max_steps = num_batches
trainer.fit_loop.epoch_loop.max_steps = num_batches
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer.num_sanity_val_steps = 0
trainer.fit_loop.max_epochs = 1
trainer.val_check_interval = 1.0
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N
if isinstance(loop, pl.loops._FitLoop):
trainer.limit_train_batches = 1.0
trainer.limit_val_batches = steps_per_trial
trainer.fit_loop.max_steps = steps_per_trial
trainer.fit_loop.epoch_loop.max_steps = steps_per_trial
elif isinstance(loop, pl.loops._EvaluationLoop):
stage = trainer.state.stage
assert stage is not None
Expand All @@ -145,7 +145,7 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any])
loop = trainer._active_loop
assert loop is not None
if isinstance(loop, pl.loops._FitLoop):
loop.max_steps = params["max_steps"]
loop.epoch_loop.max_steps = params["max_steps"]
trainer.limit_train_batches = params["limit_train_batches"]
trainer.limit_val_batches = params["limit_val_batches"]
elif isinstance(loop, pl.loops._EvaluationLoop):
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None
# Max step set to number of iterations starting at current number of iterations
trainer.fit_loop.max_steps = num_training + trainer.global_step
trainer.fit_loop.epoch_loop.max_steps = num_training + trainer.global_step
trainer.limit_val_batches = num_training


Expand All @@ -329,10 +329,10 @@ def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) ->
trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"]
trainer.callbacks = params["callbacks"]
trainer.loggers = params["loggers"]
trainer.fit_loop.max_steps = params["max_steps"]
loop = trainer.fit_loop
loop.epoch_loop.max_steps = params["max_steps"]
trainer.limit_val_batches = params["limit_val_batches"]

loop = trainer.fit_loop
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
loop.restarting = False
trainer.should_stop = False
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def test_fit_loop_done_log_messages(caplog):
epoch_loop = Mock()
epoch_loop.global_step = 10
fit_loop.epoch_loop = epoch_loop
fit_loop.max_steps = 10
epoch_loop.max_steps = 10
assert fit_loop.done
assert "max_steps=10` reached" in caplog.text
caplog.clear()
fit_loop.max_steps = 20
epoch_loop.max_steps = 20

fit_loop.epoch_progress.current.processed = 3
fit_loop.max_epochs = 3
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ def test_dataloaders_reset_and_attach(tmpdir):
assert trainer.train_dataloader.dataset is dataloader_0.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
# 2nd fit
trainer.fit_loop.max_steps += 1
trainer.fit_loop.epoch_loop.max_steps += 1
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
assert trainer.train_dataloader.dataset is dataloader_2.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset
Expand Down