diff --git a/bert_score/utils.py b/bert_score/utils.py index 2925ab2..d25d049 100644 --- a/bert_score/utils.py +++ b/bert_score/utils.py @@ -403,6 +403,8 @@ def dedup_and_sort(l): def pad_batch_stats(sen_batch, stats_dict, device): stats = [stats_dict[s] for s in sen_batch] emb, idf = zip(*stats) + emb = [e.to(device) for e in emb] + idf = [i.to(device) for i in idf] lens = [e.size(0) for e in emb] emb_pad = pad_sequence(emb, batch_first=True, padding_value=2.0) idf_pad = pad_sequence(idf, batch_first=True) @@ -413,22 +415,24 @@ def length_to_mask(lens): base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len) return base < lens.unsqueeze(1) - pad_mask = length_to_mask(lens) - return emb_pad.to(device), pad_mask.to(device), idf_pad.to(device) + pad_mask = length_to_mask(lens).to(device) + return emb_pad, pad_mask, idf_pad device = next(model.parameters()).device iter_range = range(0, len(refs), batch_size) if verbose: print("computing greedy matching.") iter_range = tqdm(iter_range) - for batch_start in iter_range: - batch_refs = refs[batch_start : batch_start + batch_size] - batch_hyps = hyps[batch_start : batch_start + batch_size] - ref_stats = pad_batch_stats(batch_refs, stats_dict, device) - hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device) - P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers) - preds.append(torch.stack((P, R, F1), dim=-1).cpu()) + with torch.no_grad(): + for batch_start in iter_range: + batch_refs = refs[batch_start : batch_start + batch_size] + batch_hyps = hyps[batch_start : batch_start + batch_size] + ref_stats = pad_batch_stats(batch_refs, stats_dict, device) + hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device) + + P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers) + preds.append(torch.stack((P, R, F1), dim=-1).cpu()) preds = torch.cat(preds, dim=1 if all_layers else 0) return preds