Skip to content

Commit

Permalink
add overfitting study
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed Mar 30, 2022
1 parent 626a8d3 commit d8a9963
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 73 deletions.
4 changes: 2 additions & 2 deletions analysis/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

tokens = np.load('tokens.npy')
lm_scores = np.load('scores.npy')
knn_scores = np.load('knn_only_scores.npy')
knn_scores = np.load('overfit_lm_scores_checkpoint15.npy')

assert len(tokens) == len(lm_scores)
assert len(knn_scores) == len(tokens)
Expand All @@ -41,7 +41,7 @@

combine_probs = torch.stack([lm_scores, knn_scores], dim=0)

with open('small_interpolation_result.txt', 'w') as outfile:
with open('overfit15_interpolation_result.txt', 'w') as outfile:
for lmbda in tqdm(np.linspace(0.0, 0.99, num=50)):
coeffs = torch.ones_like(combine_probs)
coeffs[0] = np.log(1 - lmbda)
Expand Down
76 changes: 76 additions & 0 deletions analysis/overfit_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import math
from tqdm import tqdm

import numpy as np
import torch

from fairseq.data import Dictionary

dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt')
print(len(dictionary))

bpe_cont = "@@"
bpe_toks = {
i
for i in range(len(dictionary))
if dictionary[i].endswith(bpe_cont)
}

bpe_len = len(bpe_cont)

tokens = np.load('tokens.npy')
lm_scores = np.load('scores.npy')
knn_scores = np.load('best_knn_only_scores.npy')

assert len(tokens) == len(lm_scores)
assert len(knn_scores) == len(tokens)

lm_scores = torch.from_numpy(lm_scores).cuda()
tgt_len = tokens.size
skipped_toks = 0
for i in range(tgt_len - 1):
if tokens[i].item() in bpe_toks:
skipped_toks += 1

count = len(tokens) - skipped_toks

knn_helping = 0
with open('overfit_interpolation_epoch.txt', 'w') as outfile:
for epoch in tqdm(range(234)):
if epoch == 0:
overfit_scores = np.load('best_knn_only_scores.npy')
else:
overfit_scores = np.load('overfit_scores/overfit_lm_scores_checkpoint{}.npy'.format(epoch))
overfit_scores = torch.from_numpy(overfit_scores).cuda()
combine_probs = torch.stack([lm_scores, overfit_scores], dim=0)

oracle_scores, argmaxs = torch.max(combine_probs, dim=0)

oracle_ppl = torch.exp(-oracle_scores.sum() / count)

if epoch == 0:
knn_helping = argmaxs

match_knn = torch.sum(argmaxs == knn_helping).item() / len(tokens)

knn_helping_scores = -(combine_probs[0][knn_helping == 0].sum() + combine_probs[1][knn_helping == 1].sum())
knn_helping_ppl = torch.exp(knn_helping_scores / count)

best_ppl = 1e10
best_lmbda = 0
for lmbda in np.linspace(0.0, 0.999, num=200):
coeffs = torch.ones_like(combine_probs)
coeffs[0] = np.log(1 - lmbda)
coeffs[1] = np.log(lmbda)

scores = torch.logsumexp(combine_probs + coeffs, dim=0)

score_sum = scores.sum()

avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
ppl = 2 ** avg_nll_loss.item()
if ppl < best_ppl:
best_ppl = ppl
best_lmbda = lmbda

outfile.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(epoch, best_lmbda, best_ppl, oracle_ppl, match_knn, knn_helping_ppl))
29 changes: 18 additions & 11 deletions analysis/overfit_correlation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import math

from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm

import numpy as np
import torch
from scipy.stats import pearsonr, spearmanr

from fairseq.data import Dictionary
import torch

dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt')
print(len(dictionary))
Expand All @@ -22,7 +18,11 @@

tokens = np.load('tokens.npy')
lm_scores = np.load('scores.npy')
knn_scores = np.load('knn_only_scores.npy')
knn_scores = np.load('best_knn_only_scores.npy')
knnlm_scores = np.load('knnlm_scores.npy')

assert len(lm_scores) == len(tokens)
assert len(knn_scores) == len(tokens)

# calculate Pearson's correlation
corr, _ = pearsonr(lm_scores, knn_scores)
Expand All @@ -33,19 +33,26 @@
print('LM vs KNN Spearmans correlation: %.3f' % corr)


for ep in [10, 20, 30, 50, 100, 150, 200]:
print(ep)
for ep in [2, 5, 10, 15, 18]:
print('Epoch', ep)
overfit_scores = np.load('overfit_lm_scores_checkpoint' + str(ep) + '.npy')
assert len(tokens) == len(lm_scores)
assert len(knn_scores) == len(tokens)
assert len(overfit_scores) == len(tokens)

overfit_probs = np.exp(overfit_scores)
lm_probs = np.exp(lm_scores)
knn_probs = np.exp(knn_scores)

knn_lm_diff = knn_probs - lm_probs
overfit_lm_diff = overfit_probs - lm_probs

# calculate Pearson's correlation
corr, _ = pearsonr(overfit_scores, knn_scores)
corr, _ = pearsonr(overfit_lm_diff, knn_lm_diff)
print('OverfitLM vs KNN Pearsons correlation: %.3f' % corr)

# calculate Pearson's correlation
corr, _ = spearmanr(overfit_scores, knn_scores)
corr, _ = spearmanr(overfit_lm_diff, knn_lm_diff)
print('OverfitLM vs KNN Spearmans correlation: %.3f' % corr)


23 changes: 23 additions & 0 deletions analysis/overfit_loss_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import re

train_losses = [0] * 233
val_losses = [0] * 233

with open('checkpoints/wikitext103-bpe-overfit/overfit1.log') as logfile:
for line in logfile:
if line.startswith('epoch') and 'valid on' in line:
epoch = int(re.search(r'epoch (\d+)', line).group(1))
loss = float(re.search(r'loss ([-+]?\d*\.?\d+|\d+)', line).group(1))
val_losses[epoch-1] = loss
elif line.startswith('epoch'):
epoch = int(re.search(r'epoch (\d+)', line).group(1))
loss = float(re.search(r'loss ([-+]?\d*\.?\d+|\d+)', line).group(1))
train_losses[epoch-1] = loss


print(train_losses)
print(val_losses)

with open('checkpoints/wikitext103-bpe-overfit/curve.csv', 'w') as outfile:
for i in range(233):
outfile.write('{}\t{}\t{}\n'.format(i+1, train_losses[i], val_losses[i]))
6 changes: 3 additions & 3 deletions fairseq_cli/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ def main(parsed_args):
))

# np.save('knnlm_tokens.npy', np.concatenate(all_token_ids))
# np.save('overfit_lm_scores_' + parsed_args.path.split('/')[-1].split('.')[0] + '.npy', np.concatenate(all_scores))
if all_knn_scores:
np.save('recompute_knn_only_scores.npy', np.concatenate(all_knn_scores))
np.save('overfit_scores/overfit_lm_scores_' + parsed_args.path.split('/')[-1].split('.')[0] + '.npy', np.concatenate(all_scores))
# if all_knn_scores:
# np.save('best_knn_only_scores.npy', np.concatenate(all_knn_scores))

if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
Expand Down
75 changes: 18 additions & 57 deletions wikitext_bpe_overfit.sh
Original file line number Diff line number Diff line change
@@ -1,67 +1,28 @@
## eval
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_last.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe

# store overfit
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_last.pt \
--sample-break-mode none --max-tokens 3072 \
--softmax-batch 1024 --gen-subset train \
--context-window 1536 --tokens-per-sample 1536 \
--dstore-mmap checkpoints/wikitext103-bpe/dstore_last --knn-keytype 'last_ffn_input' \
--dstore-size 153225485 --model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--save-knnlm-dstore --fp16 --dstore-fp16

# build index
python build_dstore.py \
--dstore_mmap checkpoints/wikitext103-bpe/dstore_last \
--dstore_size 153225485 \
--faiss_index checkpoints/wikitext103-bpe/knn_last.index \
--num_keys_to_add_at_a_time 500000 \
--starting_point 0 --dstore-fp16 --dimension 1024

# eval with index
# no recompute
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_last.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_last \
--indexfile checkpoints/wikitext103-bpe/knn_last.index \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_ffn_input \
--knn-sim-func "do_not_recomp_l2" --no-load-keys \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe


# recompute
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_last.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_last \
--indexfile checkpoints/wikitext103-bpe/knn_last.index \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_ffn_input \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe


# continue training till overfit
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --task language_modeling \
# continue training till overfit, turns off dropout
CUDA_VISIBLE_DEVICES=0,1,2,3,8,9 python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-overfit \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_last.pt \
--dropout 0 --attention-dropout 0 --activation-dropout 0 \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--max-update 28600 --optimizer nag --lr 1e-3 --clip-norm 0.1 \
--max-update 28600 --optimizer nag --lr 1e-2 --clip-norm 100 \
--max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16 | tee overfit.log

# continue training Google Cloud 4 GPUs
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-overfit \
--arch transformer_lm_wikibpe \
--dropout 0 --attention-dropout 0 --activation-dropout 0 \
--restore-file checkpoints/wikitext103-bpe-overfit/checkpoint_last.pt \
--max-update 28600 --optimizer nag --lr 1e-2 --clip-norm 100 \
--max-tokens 3072 --update-freq 6 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16 | tee overfit.log

python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-overfit/checkpoint242.pt \
--path checkpoints/wikitext103-bpe-overfit/checkpoint2.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe
Expand Down
8 changes: 8 additions & 0 deletions wikitext_bpe_overfit_eval_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
for i in {1..233}
do
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-overfit/checkpoint${i}.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe
done

0 comments on commit d8a9963

Please sign in to comment.