Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/lighteval/metrics/imports/bert_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@
logger = logging.getLogger(__name__)


def validate_tokenizer_length(tokenizer: AutoTokenizer, override_length: int | None) -> int:
if override_length:
return override_length
if tokenizer.model_max_length == int(1e30):
logger.warning("Could not read max_model_length attribute for BERTScorer's tokenizer - defaulting to 512.")
return 512
else:
return tokenizer.model_max_length


def padding(arr, pad_token, dtype=torch.long):
lens = torch.LongTensor([len(a) for a in arr])
max_len = lens.max().item()
Expand Down Expand Up @@ -321,6 +331,7 @@ def __init__(
lang=None,
rescale_with_baseline=False,
baseline_path=None,
tokenizer_max_len: int | None = None,
):
"""Initialize BERTScorer.

Expand All @@ -343,6 +354,7 @@ def __init__(
return_hash (bool): Return hash code of the setting.
rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline.
baseline_path (str): Customized baseline file.
tokenizer_max_len (int, optional): will override the tokenizer's max model length if set.
"""
assert lang is not None or model_type is not None, "Either lang or model_type should be specified"

Expand All @@ -366,6 +378,7 @@ def __init__(

# Model and tokenizer are lazily loaded in `score()`.
self._tokenizer = None
self._tokenizer_len = tokenizer_max_len
self._model = None

self._idf_dict = None
Expand Down Expand Up @@ -430,6 +443,9 @@ def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
if self._model is None:
logger.info(f"Loading BERTScorer model `{self._model_type}`")
self._tokenizer = AutoTokenizer.from_pretrained(self._model_type)
self._tokenizer.model_max_length = validate_tokenizer_length(
tokenizer=self._tokenizer, override_length=self._tokenizer_len
)
self._model = AutoModel.from_pretrained(self._model_type)
self._model.eval()
self._model.to(self.device)
Expand Down
6 changes: 5 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,11 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str
logger.warning("The first metric computation step might be a bit longer as we need to download the model.")
# We only initialize on first compute
self.bert_scorer = BERTScorer(
model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9
model_type="microsoft/deberta-large-mnli",
lang="en",
rescale_with_baseline=True,
num_layers=9,
tokenizer_max_len=512,
)
golds = as_list(golds)
predictions = as_list(predictions)
Expand Down