You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to reproduce the results in the paper, but it seems that the implementation for BARTScore and BLEURT is not provided. I am using this code for the implementation of BARTScore. Can you confirm if this has been implemented correctly?
Thank you for your interest in XRec! Below is my implementation, hope you find it helpful.
import numpy as np
from bleurt import score
from BARTScore_main.bart_score import BARTScorer
'''
you can find the source code of BLEURT and BARTScore in the following links:
https://github.com/neulab/BARTScore (remember to download the bart_score.pth file as well)
https://github.com/google-research/bleurt
'''
def BLEURT_score(predictions, references):
checkpoint = "bleurt/bleurt/test_checkpoint"
scorer = score.BleurtScorer(checkpoint)
scores = scorer.score(references=references, candidates=predictions)
# get mean and mean squared error
scores = np.mean(scores), np.std(scores)
return scores
def BART_score_para(predictions, references):
references = [[element] for element in references]
bart_scorer = BARTScorer(device="cpu", checkpoint="facebook/bart-large-cnn")
bart_scorer.load(path="bart_score.pth")
score = bart_scorer.multi_ref_score(
predictions, references, agg="max", batch_size=4
)
return np.mean(score), np.std(score)
a = "This is a test."
b = "This is another test."
print(BLEURT_score([a], [b]))
print(BART_score_para([a], [b]))
Hi,
I am trying to reproduce the results in the paper, but it seems that the implementation for BARTScore and BLEURT is not provided. I am using this code for the implementation of BARTScore. Can you confirm if this has been implemented correctly?
def BART_score(predictions, references, device='cuda', checkpoint='facebook/bart-large-cnn', bart_model_path=None):
bart_scorer = BARTScorer(device=device, checkpoint=checkpoint)
Also, could you kindly provide some details on how the BLEURT score was calculated?
The text was updated successfully, but these errors were encountered: