Skip to content

Commit

Permalink
fixed bug where tuner would not tune lr if also tuning batch_size (#4688
Browse files Browse the repository at this point in the history
)

* fixed bug where tuner would not tune lr if also tuning batch_size

* added a '+1' to computing the smoothed loss. This maintains the behavior for the smoothed loss as before the bug fix

* pep8 fix

* add changelog

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored Mar 9, 2021
1 parent 9eded7f commit 523c59b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))


- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))


## [1.2.2] - 2021-03-02

### Added
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
self.progress_bar.update()

current_loss = trainer.train_loop.running_loss.last().item()
current_step = trainer.global_step + 1 # remove the +1 in 1.0
current_step = trainer.global_step

# Avg loss (loss with momentum) + smoothing
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)
smoothed_loss = self.avg_loss / (1 - self.beta**(current_step + 1))

# Check if we diverging
if self.early_stop_threshold is not None:
Expand Down

0 comments on commit 523c59b

Please sign in to comment.