Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Bug in BertScore calculation: pred target misalignment #2347

Merged
merged 17 commits into from
Jul 19, 2024

Conversation

gxy-gxy
Copy link
Contributor

@gxy-gxy gxy-gxy commented Feb 3, 2024

fix Bug in BertScore calculation: pred target misalignment

Fixes bug in BertScore cal.
This pull request addresses a bug identified in the BertScore calculation within the TextDataset class in src/torchmetrics/functional/text/helper_embedding_metric.py.
The class is designed with a preprocess function automatically sorts input text by length to optimize batch encoding efficiency. However, this behavior introduces an issue during the BertScore calculation process, as predictions (preds) and targets (targets) are initialized in separate datasets. This results in a mismatched ordering of text pairs, which is problematic given the pairwise nature of BertScore's calculation. To ensure accurate scoring, it is critical to re-align the datasets to their original order before computing the scores. The proposed fix involves ensuring that the datasets for predictions and targets are processed in a way that maintains their original pairing throughout the calculation process.

Here is the fixed code:

preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices]
target_embeddings = target_embeddings[target_loader.dataset.sorting_indices]

This change is essential for preserving the integrity of the BertScore evaluation, ensuring that each prediction is accurately compared against its corresponding target.


📚 Documentation preview 📚: https://torchmetrics--2347.org.readthedocs.build/en/2347/

@gxy-gxy
Copy link
Contributor Author

gxy-gxy commented Feb 3, 2024

Here is the test code:

from torchmetrics.text.bert import BERTScore

score_model = BERTScore(model_name_or_path='roberta-large', batch_size=2)

text1 = [ "Claim A from machine", "Claim A from machine"]
text2 = ["Claim A from machine", "Claim B"]
similarities = score_model(text1, text2)
print(similarities)

@Borda Borda added the bug / fix Something isn't working label Feb 14, 2024
@Borda Borda added this to the v1.3.x milestone Feb 19, 2024
@Borda
Copy link
Member

Borda commented Mar 15, 2024

@stancld could you help here, pls?

Copy link
Collaborator

@baskrahmer baskrahmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I can reproduce the bug using your test code snippet. Although the actual metric values seem to be correct, the ordering is not always valid.

It might be nice to somehow integrate this case with the current test suite. E.g. an assertation that reversing the targets/preds also reverses the scores.

src/torchmetrics/functional/text/bert.py Show resolved Hide resolved
@Borda
Copy link
Member

Borda commented Jul 16, 2024

Here is the test code:

from torchmetrics.text.bert import BERTScore

score_model = BERTScore(model_name_or_path='roberta-large', batch_size=2)

text1 = [ "Claim A from machine", "Claim A from machine"]
text2 = ["Claim A from machine", "Claim B"]
similarities = score_model(text1, text2)
print(similarities)

@gxy-gxy can you pls add it as a test?

Copy link

codecov bot commented Jul 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 39%. Comparing base (30b3fd5) to head (dc3b758).

❗ There is a different number of reports uploaded between BASE (30b3fd5) and HEAD (dc3b758). Click for more details.

HEAD has 7 uploads less than BASE
Flag BASE (30b3fd5) HEAD (dc3b758)
Windows 4 2
cpu 22 21
torch2.4.0+cpu 2 1
python3.11 7 5
torch2.3.0+cpu 3 2
Additional details and impacted files
@@           Coverage Diff            @@
##           master   #2347     +/-   ##
========================================
- Coverage      69%     39%    -30%     
========================================
  Files         316     316             
  Lines       17878   17874      -4     
========================================
- Hits        12329    7030   -5299     
- Misses       5549   10844   +5295     

@Borda Borda requested a review from baskrahmer July 16, 2024 11:20
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding test - #2347 (comment)

Copy link
Collaborator

@baskrahmer baskrahmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the addition. I can quickly add the test

@Borda Borda enabled auto-merge (squash) July 19, 2024 16:02
@Borda
Copy link
Member

Borda commented Jul 19, 2024

@baskrahmer, would mind also adding an entry to the changelog?

@mergify mergify bot added the ready label Jul 19, 2024
@Borda Borda merged commit 75c33ea into Lightning-AI:master Jul 19, 2024
65 checks passed
Borda pushed a commit that referenced this pull request Aug 2, 2024
* fix pred target misalignment
* add test

---------

Co-authored-by: Xinyan Guan <xinyan@xinyan.local>
Co-authored-by: Bas Krahmer <baskrahmer@gmail.com>
(cherry picked from commit 75c33ea)
Borda pushed a commit that referenced this pull request Aug 2, 2024
* fix pred target misalignment
* add test

---------

Co-authored-by: Xinyan Guan <xinyan@xinyan.local>
Co-authored-by: Bas Krahmer <baskrahmer@gmail.com>
(cherry picked from commit 75c33ea)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working ready topic: Text
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants