Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Custom metrics from compute_loss in TGA (2nd try) #4913

Merged
merged 2 commits into from
Jan 17, 2023
Merged

Conversation

mojtaba-komeili
Copy link
Contributor

Patch description
The compute_loss function in TGA generates the loss_per_token values but they are not accessible from outside this method. This PR adds a handle for calculating custom metrics that require loss_per_token. The added method (custom_loss_metrics) can be overridden by its children to work with the loss_per_token and batch for generating custom metrics.

NOTE: this is cleaned up version of 4905 which became corrupted with other commits after some bad rebase and merge.

@mojtaba-komeili
Copy link
Contributor Author

@klshuster reminding you of this old PR.

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the changes!

@mojtaba-komeili mojtaba-komeili merged commit f249627 into main Jan 17, 2023
@mojtaba-komeili mojtaba-komeili deleted the compute-loss-2 branch January 17, 2023 18:32
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants