Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Custom metrics from compute_loss in TGA #4905

Closed
wants to merge 12 commits into from
Closed

Conversation

mojtaba-komeili
Copy link
Contributor

Patch description
The compute_loss function in TGA generates the loss_per_token values but they are not accessible from outside this method. This PR adds a handle for calculating custom metrics that require loss_per_token. The added method (custom_loss_metrics) can be overridden by its children to work with the loss_per_token and batch for generating custom metrics.

@mojtaba-komeili mojtaba-komeili changed the title Custom metrics for loss values in TGA Custom metrics from compute_loss in TGA Dec 2, 2022
@@ -34,7 +34,7 @@
from parlai.utils.misc import warn_once
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for the lint error.

mojtaba-komeili and others added 2 commits December 5, 2022 13:55
* cat or concat

* back to cat

* Only add the metric if it is not None

* lint
* zero3 init commit

* minor cleanup:

* handle mpeval

* remove fairscale dependence

* fsdp avail

* update reqs

* better reqs

* autoformat

* autofromat
@@ -715,6 +715,12 @@ def _encoder_input(self, batch):
"""
return self._model_input(batch)

def custom_loss_metrics(self, batch, loss_per_token):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: can we call this something more related to what it's doing?

e.g., compute_per_token_metrics?

dependabot bot and others added 5 commits December 5, 2022 14:47
Bumps [decode-uri-component](https://github.com/SamVerschueren/decode-uri-component) from 0.2.0 to 0.2.2.
- [Release notes](https://github.com/SamVerschueren/decode-uri-component/releases)
- [Commits](SamVerschueren/decode-uri-component@v0.2.0...v0.2.2)

---
updated-dependencies:
- dependency-name: decode-uri-component
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…o compute-loss

* 'compute-loss' of github.com:facebookresearch/ParlAI:
  lint
  added the custom_loss_metrics
@klshuster klshuster self-requested a review December 13, 2022 16:24
klshuster and others added 3 commits December 13, 2022 16:16
@mojtaba-komeili
Copy link
Contributor Author

Rebase added some unwanted changes to this PR. Closing it and opening a new clean one.

@mojtaba-komeili mojtaba-komeili deleted the compute-loss branch December 15, 2022 01:38
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants