-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add Metric.from_mask helper method (#3411) #4894
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it fixes that old issue! Actually @mojtaba-komeili looks like he may be able to use this.
Leaving one comment to see if we can go even one step cleaner
parlai/core/metrics.py
Outdated
""" | ||
tokens_per_ex = mask.long().sum(dim=-1) | ||
metric_per_ex = (metric_per_token * mask).sum(dim=-1) | ||
metrics = MyMetric.many(metric_per_ex, tokens_per_ex) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use cls.many
instead and get away without passing MyMetric as an extra parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooh, yes! Good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Sweet! @stephenroller and @klshuster should I wait for / help dig into the lint and CircleCI failures, or go ahead and land the PR? |
Landed! (Discussed with @klshuster offline and it seems that the lint and CircleCI failures are not related to this PR) |
Patch description
This change introduces a new
from_mask
helper function in theMetric
class. It also refactors thecompute_loss
function intorch_generator_agent.py
to call thefrom_mask
helper when computing the loss, ppl, and token_acc.Testing steps
Unit tests
pytest -v tests/test_metrics.py
Note that I added two new unit tests,
test_average_metric_from_mask
andtest_ppl_metric_from_mask
.Manual logging
I manually verified loss, ppl, token_acc, and token_em in
torch_generator_agent.py
by logging both the original and new values like this and running the below command:We can see that
old_loss
equalsloss
,old_ppl
equalsppl
, etc.Other information
autoformat.sh
but it was running quite slowly... Tips on how to use this script would be helpful :)