-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix gradient norm tracking for row_log_interval > 1 #3489
Conversation
@rohitgr7 @Tim-Chard The test I wrote replicates the bug but it is not very elegant. I'll try to make it better but if you have any suggestions let me know. EDIT: I now improved the test |
Codecov Report
@@ Coverage Diff @@
## master #3489 +/- ##
=======================================
+ Coverage 86% 91% +5%
=======================================
Files 107 107
Lines 8025 8025
=======================================
+ Hits 6903 7291 +388
+ Misses 1122 734 -388 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You beat me to it! I was going to change the test above to include the row_log_interval
parameter to something like the code snippet below but it doesn't work because of the corner cases (the logging of the first and last batch). You could handle them but then it seems like that would make the test just as complex as the code being tested so I think yours is the way to go.
In case you can think of some simple way to do it here is the partially modified test
@pytest.mark.parametrize("norm_type", [1., 1.25, 2, 3, 5, 10, 'inf'])
@pytest.mark.parametrize("log_interval", [1, 2, 20])
def test_grad_tracking(tmpdir, norm_type, log_interval, rtol=5e-3,):
os.environ['PL_DEV_DEBUG'] = '1'
# rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above
reset_seed()
# use a custom grad tracking module and a list logger
model = ModelWithManualGradTracker(norm_type)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
track_grad_norm=norm_type,
row_log_interval=log_interval,
)
result = trainer.fit(model)
assert result == 1, "Training failed"
logged_metrics = trainer.dev_debugger.logged_metrics
expected_grad_norms = model.stored_grad_norms[::log_interval]
# Remove first batch and include last batch?
#
assert len(logged_metrics) == len(expected_grad_norms)
# compare the logged metrics against tracked norms on `.backward`
for mod, log in zip(expected_grad_norms, logged_metrics):
common = mod.keys() & log.keys()
log, mod = [log[k] for k in common], [mod[k] for k in common]
assert np.allclose(log, mod, rtol=rtol)
Co-authored-by: Tim Chard <timchard@hotmail.com>
@Tim-Chard Thanks, yes this test you mention is a bit complex and partially re-implements the functionality it is supposed to check. It's better not to add the row_log_interval there.
sorry, it's my "on call" shift so I try to deal with these bug reports as fast as possible because of upcoming release. |
Don't be, it is good to see the project can squash bugs this quickly. |
This pull request is now in conflict... :( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🐰
What does this PR do?
Fixes #3487
Test fails on master, replicating the bug report
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃