Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of BARTScore and BLEURT #11

Open
ranjan-p-mishra opened this issue Jan 17, 2025 · 1 comment
Open

Implementation of BARTScore and BLEURT #11

ranjan-p-mishra opened this issue Jan 17, 2025 · 1 comment

Comments

@ranjan-p-mishra
Copy link

Hi,

I am trying to reproduce the results in the paper, but it seems that the implementation for BARTScore and BLEURT is not provided. I am using this code for the implementation of BARTScore. Can you confirm if this has been implemented correctly?

def BART_score(predictions, references, device='cuda', checkpoint='facebook/bart-large-cnn', bart_model_path=None):
bart_scorer = BARTScorer(device=device, checkpoint=checkpoint)

if bart_model_path:
    bart_scorer.load(path=bart_model_path)

scores = bart_scorer.score(predictions, references, batch_size=4)

mean_score = np.mean(scores)
std_score = np.std(scores)

return mean_score, std_score

Also, could you kindly provide some details on how the BLEURT score was calculated?

@Martin-qyma
Copy link
Collaborator

Thank you for your interest in XRec! Below is my implementation, hope you find it helpful.

import numpy as np
from bleurt import score
from BARTScore_main.bart_score import BARTScorer

'''
you can find the source code of BLEURT and BARTScore in the following links:
https://github.com/neulab/BARTScore (remember to download the bart_score.pth file as well)
https://github.com/google-research/bleurt
'''


def BLEURT_score(predictions, references):
    checkpoint = "bleurt/bleurt/test_checkpoint"
    scorer = score.BleurtScorer(checkpoint)
    scores = scorer.score(references=references, candidates=predictions)
    # get mean and mean squared error
    scores = np.mean(scores), np.std(scores)
    return scores

def BART_score_para(predictions, references):
    references = [[element] for element in references]
    bart_scorer = BARTScorer(device="cpu", checkpoint="facebook/bart-large-cnn")
    bart_scorer.load(path="bart_score.pth")
    score = bart_scorer.multi_ref_score(
        predictions, references, agg="max", batch_size=4
    )
    return np.mean(score), np.std(score)

a = "This is a test."
b = "This is another test."
print(BLEURT_score([a], [b]))
print(BART_score_para([a], [b]))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants