Skip to content

Commit

Permalink
Fix incorrect size of input for 1st strided window length in `Perplex…
Browse files Browse the repository at this point in the history
…ity of fixed-length models` (#18906)

* update the PPL for stride 512

* fix 1st strided window size

* linting

* fix typo

* styling
  • Loading branch information
ekagra-ranjan authored Sep 6, 2022
1 parent 7d5fde9 commit 0a632f0
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions docs/source/en/perplexity.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,32 @@ from tqdm import tqdm

max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
begin_loc = max(i + stride - max_length, 0)
end_loc = min(i + stride, encodings.input_ids.size(1))
trg_len = end_loc - i # may be different from stride on last loop
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100

with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs[0] * trg_len

# loss is calculated using CrossEntropyLoss which averages over input tokens.
# Multiply it with trg_len to get the summation instead of average.
# We will take average over all the tokens to get the true average
# in the last step of this example.
neg_log_likelihood = outputs.loss * trg_len

nlls.append(neg_log_likelihood)

prev_end_loc = end_loc
if end_loc == seq_len:
break

ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
```

Expand All @@ -126,5 +136,5 @@ and the better the reported perplexity will typically be.

When we run the above with `stride = 1024`, i.e. no overlap, the resulting PPL is `19.64`, which is about the same
as the `19.93` reported in the GPT-2 paper. By using `stride = 512` and thereby employing our striding window
strategy, this jumps down to `16.53`. This is not only a more favorable score, but is calculated in a way that is
strategy, this jumps down to `16.44`. This is not only a more favorable score, but is calculated in a way that is
closer to the true autoregressive decomposition of a sequence likelihood.

0 comments on commit 0a632f0

Please sign in to comment.