Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gradient norm tracking for row_log_interval > 1 #3489

Merged
merged 5 commits into from
Sep 15, 2020

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Sep 14, 2020

What does this PR do?

Fixes #3487

Test fails on master, replicating the bug report

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@awaelchli awaelchli added the bug Something isn't working label Sep 14, 2020
@awaelchli
Copy link
Contributor Author

awaelchli commented Sep 14, 2020

@rohitgr7 @Tim-Chard The test I wrote replicates the bug but it is not very elegant. I'll try to make it better but if you have any suggestions let me know.

EDIT: I now improved the test

@codecov
Copy link

codecov bot commented Sep 14, 2020

Codecov Report

Merging #3489 into master will increase coverage by 5%.
The diff coverage is 100%.

@@           Coverage Diff           @@
##           master   #3489    +/-   ##
=======================================
+ Coverage      86%     91%    +5%     
=======================================
  Files         107     107            
  Lines        8025    8025            
=======================================
+ Hits         6903    7291   +388     
+ Misses       1122     734   -388     

Copy link
Contributor

@Tim-Chard Tim-Chard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You beat me to it! I was going to change the test above to include the row_log_interval parameter to something like the code snippet below but it doesn't work because of the corner cases (the logging of the first and last batch). You could handle them but then it seems like that would make the test just as complex as the code being tested so I think yours is the way to go.

In case you can think of some simple way to do it here is the partially modified test

@pytest.mark.parametrize("norm_type", [1., 1.25, 2, 3, 5, 10, 'inf'])
@pytest.mark.parametrize("log_interval", [1, 2, 20])
def test_grad_tracking(tmpdir, norm_type, log_interval, rtol=5e-3,):
    os.environ['PL_DEV_DEBUG'] = '1'

    # rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above

    reset_seed()

    # use a custom grad tracking module and a list logger
    model = ModelWithManualGradTracker(norm_type)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=3,
        track_grad_norm=norm_type,
        row_log_interval=log_interval,
    )
    result = trainer.fit(model)

    assert result == 1, "Training failed"
    logged_metrics = trainer.dev_debugger.logged_metrics
    expected_grad_norms = model.stored_grad_norms[::log_interval]
    # Remove first batch and include last batch?
    # 
    
    assert len(logged_metrics) == len(expected_grad_norms)

    # compare the logged metrics against tracked norms on `.backward`
    for mod, log in zip(expected_grad_norms, logged_metrics):
        common = mod.keys() & log.keys()

        log, mod = [log[k] for k in common], [mod[k] for k in common]

        assert np.allclose(log, mod, rtol=rtol)

tests/models/test_grad_norm.py Outdated Show resolved Hide resolved
tests/models/test_grad_norm.py Outdated Show resolved Hide resolved
Co-authored-by: Tim Chard <timchard@hotmail.com>
@awaelchli
Copy link
Contributor Author

@Tim-Chard Thanks, yes this test you mention is a bit complex and partially re-implements the functionality it is supposed to check. It's better not to add the row_log_interval there.

You beat me to it!

sorry, it's my "on call" shift so I try to deal with these bug reports as fast as possible because of upcoming release.

@awaelchli awaelchli marked this pull request as ready for review September 14, 2020 03:59
@mergify mergify bot requested a review from a team September 14, 2020 03:59
@Tim-Chard
Copy link
Contributor

sorry, it's my "on call" shift so I try to deal with these bug reports as fast as possible because of upcoming release.

Don't be, it is good to see the project can squash bugs this quickly.

@awaelchli awaelchli requested review from Borda and rohitgr7 September 14, 2020 17:21
@mergify mergify bot requested a review from a team September 14, 2020 17:37
@mergify mergify bot requested a review from a team September 14, 2020 20:26
@mergify
Copy link
Contributor

mergify bot commented Sep 15, 2020

This pull request is now in conflict... :(

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🐰

@Borda Borda added the ready PRs ready to be merged label Sep 15, 2020
@Borda Borda merged commit 4ed96b2 into master Sep 15, 2020
@Borda Borda deleted the bugfix/track_grad_norm_interval branch September 15, 2020 16:41
@Borda Borda added this to the 0.9.x milestone Sep 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Gradient norms are not logged unless row_log_interval==1
5 participants