Skip to content

Commit

Permalink
Optimizing bert_cos_score_idf (#69)
Browse files Browse the repository at this point in the history
* Optimizing bert_cos_score_idf

1) Pad BERT embeddings on GPU instead of CPU. Padding on CPU is the bottleneck in computing the greedy matching, so padding on GPU speeds up the matching by ~3x for me. Moving tensors to GPU then becomes the bottleneck, but it also takes ~2x less time to move pre-padding tensors to GPU, I think since you don't have to move a bunch of padding numbers. So overall I get a ~6x speed up on the sequences I'm evaluating
2) Using `torch.no_grad()` when computing greedy matching to save memory. I was able to increase the batch size for greedy matching by 2x after doing this. I'm not sure if increasing the batch size here will cause OOMs for others though, so it might be worth someone else checking/trying it out (or just removing the batch size increase).

* Removing batch size increase

Occasionally found OOMs with batch size increase for greedy matching only, so I removed that
  • Loading branch information
ethanjperez authored Jul 4, 2020
1 parent d645a08 commit 4c10f36
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions bert_score/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def dedup_and_sort(l):
def pad_batch_stats(sen_batch, stats_dict, device):
stats = [stats_dict[s] for s in sen_batch]
emb, idf = zip(*stats)
emb = [e.to(device) for e in emb]
idf = [i.to(device) for i in idf]
lens = [e.size(0) for e in emb]
emb_pad = pad_sequence(emb, batch_first=True, padding_value=2.0)
idf_pad = pad_sequence(idf, batch_first=True)
Expand All @@ -413,22 +415,24 @@ def length_to_mask(lens):
base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len)
return base < lens.unsqueeze(1)

pad_mask = length_to_mask(lens)
return emb_pad.to(device), pad_mask.to(device), idf_pad.to(device)
pad_mask = length_to_mask(lens).to(device)
return emb_pad, pad_mask, idf_pad

device = next(model.parameters()).device
iter_range = range(0, len(refs), batch_size)
if verbose:
print("computing greedy matching.")
iter_range = tqdm(iter_range)
for batch_start in iter_range:
batch_refs = refs[batch_start : batch_start + batch_size]
batch_hyps = hyps[batch_start : batch_start + batch_size]
ref_stats = pad_batch_stats(batch_refs, stats_dict, device)
hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device)

P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers)
preds.append(torch.stack((P, R, F1), dim=-1).cpu())
with torch.no_grad():
for batch_start in iter_range:
batch_refs = refs[batch_start : batch_start + batch_size]
batch_hyps = hyps[batch_start : batch_start + batch_size]
ref_stats = pad_batch_stats(batch_refs, stats_dict, device)
hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device)

P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers)
preds.append(torch.stack((P, R, F1), dim=-1).cpu())
preds = torch.cat(preds, dim=1 if all_layers else 0)
return preds

Expand Down

0 comments on commit 4c10f36

Please sign in to comment.