Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Logging only on not should_accumulate() during training #5417

Merged
merged 6 commits into from
Jan 9, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Jan 8, 2021

What does this PR do?

This PR fixes visual logging with accumulated_grad_batches > 1.

Solution: We can't assume the metrics can be averaged, so we will log only when optimizer_step will be called.

Previously

Screenshot 2021-01-08 at 13 16 19

Now.

Screenshot 2021-01-08 at 12 55 57

Code used to generate the visualization

def test_logging_with_accumulate_grad_batches(tmpdir):
    class LitClassifier(pl.LightningModule):
        def __init__(self, hidden_dim=128, learning_rate=1e-3):
            super().__init__()
            self.save_hyperparameters()

            self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
            self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

            self.train_acc = pl.metrics.Accuracy()

        def forward(self, x):
            x = x.view(x.size(0), -1)
            x = torch.relu(self.l1(x))
            return self.l2(x)

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.cross_entropy(y_hat, y)
            self.log('train_acc', self.train_acc(y_hat, y), on_step=True, on_epoch=True, prog_bar=True)
            self.log("train_loss",loss)
            return loss

        def validation_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.cross_entropy(y_hat, y)

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    def run_test(model, max_epochs, accumulate_grad_batches, batch_size, num_workers=4):
        dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
        mnist_train, mnist_val = random_split(dataset, [55000, 5000])
        train_loader = DataLoader(mnist_train,batch_size)
        val_loader = DataLoader(mnist_val,batch_size)

        trainer = pl.Trainer(
            logger=WandbLogger(name="bug", project='.....', save_dir=".", log_model=False),
            accumulate_grad_batches=accumulate_grad_batches,
            limit_train_batches=100,
            log_every_n_steps=5,
            max_epochs=max_epochs
            )
        trainer.fit(model, train_loader, val_loader)

    model = LitClassifier()

    run_test(model, 3, 1, 32)
    run_test(model, 3, 8, 32)

Fixes #5405 <- this links related issue to this PR

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified
  • Check that target branch and milestone match!

Did you have fun?

Make sure you had fun coding 🙃

@tchaton tchaton self-assigned this Jan 8, 2021
@tchaton tchaton added logger Related to the Loggers priority: 0 High priority task labels Jan 8, 2021
@tchaton tchaton added this to the 1.1.x milestone Jan 8, 2021
@codecov
Copy link

codecov bot commented Jan 8, 2021

Codecov Report

Merging #5417 (20b1abf) into master (f2e99d6) will not change coverage.
The diff coverage is 100%.

@@          Coverage Diff           @@
##           master   #5417   +/-   ##
======================================
  Coverage      93%     93%           
======================================
  Files         134     134           
  Lines        9996    9996           
======================================
  Hits         9313    9313           
  Misses        683     683           

@tchaton tchaton requested review from awaelchli, Borda and SeanNaren and removed request for awaelchli and Borda January 8, 2021 13:55
@carmocca
Copy link
Contributor

carmocca commented Jan 8, 2021

Not familiar with W&B. Can you see the raw data for the epoch graph? Doesn't look right, does it?

@tchaton tchaton marked this pull request as ready for review January 8, 2021 16:05
@tchaton
Copy link
Contributor Author

tchaton commented Jan 8, 2021

Not familiar with W&B. Can you see the raw data for the epoch graph? Doesn't look right, does it?

It is contracted to log on number of optimizer_step.

Here are the other options which I didn't prefer.

Screenshot 2021-01-08 at 12 58 35

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Copy link
Contributor

@teddykoker teddykoker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great fix :)

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, just check back-compatibity and chlog

@@ -158,7 +158,7 @@ def cache_training_step_metrics(self, opt_closure_result):
self.logged_metrics.update(logged_metrics_tmp)
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)

def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=False):
def log_metrics(self, metrics, grad_norm_dic, step=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it used by a user, right? then let's hold back compatibility with API and add a warning...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is internal. LoggerConnector class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not it is not.

@tchaton tchaton enabled auto-merge (squash) January 8, 2021 21:54
@tchaton tchaton merged commit a053d75 into master Jan 9, 2021
@tchaton tchaton deleted the bugfix/5405_logging_with_accumulated_gradient branch January 9, 2021 00:35
SeanNaren pushed a commit that referenced this pull request Jan 12, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
SeanNaren pushed a commit that referenced this pull request Jan 12, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
SeanNaren pushed a commit that referenced this pull request Jan 13, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
SeanNaren pushed a commit that referenced this pull request Jan 13, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
SeanNaren pushed a commit that referenced this pull request Jan 19, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
Borda pushed a commit that referenced this pull request Jan 23, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
Borda pushed a commit that referenced this pull request Jan 25, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
Borda pushed a commit that referenced this pull request Jan 26, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
Borda pushed a commit that referenced this pull request Jan 26, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
Borda pushed a commit that referenced this pull request Jan 26, 2021
…5417)

* resolve bug

* resolve tests

* update

* Update tests/loggers/test_tensorboard.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

(cherry picked from commit a053d75)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
logger Related to the Loggers priority: 0 High priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

W&B logger not working as expected with accumulate_grad_batches>1
6 participants