Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark Trainer.terminate_on_nan protected and deprecate public property #9849

Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))


- Deprecated `Trainer.terminate_on_nan` public attribute access ([#9849](https://github.com/PyTorchLightning/pytorch-lightning/pull/9849))


- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def backward_fn(loss: Tensor) -> None:
self._backward(loss, optimizer, opt_idx)

# check if model weights are nan
if self.trainer.terminate_on_nan:
if self.trainer._terminate_on_nan:
detect_nan_parameters(self.trainer.lightning_module)

return backward_fn
Expand Down Expand Up @@ -460,7 +460,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)

if self.trainer.terminate_on_nan:
if self.trainer._terminate_on_nan:
check_finite_loss(result.closure_loss)

if self.trainer.move_metrics_to_cpu:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def on_trainer_init(
f"`track_grad_norm` should be an int, a float or 'inf' (infinity norm). Got {track_grad_norm}."
)

self.trainer.terminate_on_nan = terminate_on_nan
self.trainer._terminate_on_nan = terminate_on_nan
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
self.trainer.track_grad_norm = float(track_grad_norm)
27 changes: 20 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,13 +1983,6 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def train_loop(self) -> FitLoop:
rank_zero_deprecation(
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
)
return self.fit_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
Expand Down Expand Up @@ -2039,3 +2032,23 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state

@property
def train_loop(self) -> FitLoop:
rank_zero_deprecation(
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
)
return self.fit_loop

@property
def terminate_on_nan(self) -> bool:
rank_zero_deprecation("`Trainer.terminate_on_nan` is deprecated in v1.5 and will be removed in 1.7.")
return self._terminate_on_nan

@terminate_on_nan.setter
def terminate_on_nan(self, val: bool) -> None:
rank_zero_deprecation(
f"Setting `Trainer.terminate_on_nan = {val}` is deprecated in v1.5 and will be removed in 1.7."
f" Please set `Trainer(detect_anomaly={val})` instead."
)
self._terminate_on_nan = val # : 212
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
assert trainer.terminate_on_nan is terminate_on_nan
assert trainer._detect_anomaly is False

trainer = Trainer()
with pytest.deprecated_call(match=r"`Trainer.terminate_on_nan` is deprecated in v1.5"):
_ = trainer.terminate_on_nan

with pytest.deprecated_call(match=r"Setting `Trainer.terminate_on_nan = True` is deprecated in v1.5"):
trainer.terminate_on_nan = True


def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
Expand Down