Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed May 4, 2022
1 parent 9d981f9 commit 0522639
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 7 deletions.
9 changes: 8 additions & 1 deletion analysis/kv_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@
'last_linear_ip_scores.npy',
'last_linear_ip_recomp_scores.npy',
'last_linear_scores.npy',
'last_linear_recomp_scores.npy'
'last_linear_recomp_scores.npy',
'kv1_att_finetune_scores.npy',
'kv2_att_finetune_scores.npy',
'kv3_att_finetune_scores.npy',
'kv4_att_finetune_scores.npy',
'kv5_att_finetune_scores.npy',
'kv6_att_finetune_scores.npy',
'kv9_att_finetune_scores.npy',
]:
extra_scores = np.load(f)
extra_scores = torch.from_numpy(extra_scores).cuda()
Expand Down
82 changes: 82 additions & 0 deletions analysis/ov_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import math
from tqdm import tqdm

import numpy as np
import torch

from fairseq.data import Dictionary

dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt')
print(len(dictionary))

bpe_cont = "@@"
bpe_toks = {
i
for i in range(len(dictionary))
if dictionary[i].endswith(bpe_cont)
}

bpe_len = len(bpe_cont)

tokens = np.load('overfit_analysis/tokens.npy')
lm_scores = np.load('overfit_analysis/lm_scores.npy')

assert len(tokens) == len(lm_scores)

lm_scores = torch.from_numpy(lm_scores).cuda()
tgt_len = tokens.size
skipped_toks = 0
for i in range(tgt_len - 1):
if tokens[i].item() in bpe_toks:
skipped_toks += 1

count = len(tokens) - skipped_toks

knn_helping = 0
with open('ov_interpolation.txt', 'w') as outfile:
for f in ['overfit_analysis/knn_scores.npy',
'overfit_analysis/knn_recomp_scores.npy',
'overfit_analysis/knn_ip_scores.npy',
'overfit_analysis/knn_ip_recomp_scores.npy',
'overfit_analysis/overfit129_lm_scores.npy',
]:
extra_scores = np.load(f)
extra_scores = torch.from_numpy(extra_scores).cuda()
combine_probs = torch.stack([lm_scores, extra_scores], dim=0)

oracle_scores, argmaxs = torch.max(combine_probs, dim=0)

oracle_ppl = torch.exp(-oracle_scores.sum() / count)

if 'knn_scores.npy' in f:
knn_helping = argmaxs

match_knn = torch.sum(argmaxs == knn_helping).item() / len(tokens)
extra_helping_percentage = torch.sum(argmaxs).item() / len(tokens)

knn_helping_scores = -(combine_probs[0][knn_helping == 0].sum() +
combine_probs[1][knn_helping == 1].sum())

knn_helping_ppl = torch.exp(knn_helping_scores / count)

extra_only_ppl = torch.exp(-extra_scores.sum() / count)

best_ppl = 1e10
best_lmbda = 0
for lmbda in np.linspace(0.0, 0.999, num=200):
coeffs = torch.ones_like(combine_probs)
coeffs[0] = np.log(1 - lmbda)
coeffs[1] = np.log(lmbda)

scores = torch.logsumexp(combine_probs + coeffs, dim=0)

score_sum = scores.sum()

avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
ppl = 2 ** avg_nll_loss.item()
if ppl < best_ppl:
best_ppl = ppl
best_lmbda = lmbda

outfile.write(f'{f}\t{extra_only_ppl}\t{best_lmbda}\t{best_ppl}\t{oracle_ppl}\t'
f'{match_knn}\t{extra_helping_percentage}\t{knn_helping_ppl}\n')
129 changes: 123 additions & 6 deletions wikitext_bpe_additional_linear.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,23 @@ python eval_lm.py data-bin/wikitext103-bpe --path checkpoints/wikitext103-bpe-ad
--gen-subset valid --bpe subword_nmt --remove-bpe \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}"


## after softmax, reinit
CUDA_VISIBLE_DEVICES=1,2,4,5,6,7 python train.py --task language_modeling \
## ATT K=1
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-additional-linear-after-softmax-reinit \
--save-dir checkpoints/wikitext103-bpe-kv1-att-fix \
--arch transformer_lm_wikibpe --restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed --init-out-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--max-update 286000 --optimizer nag --lr 1e-2 --clip-norm 100 \
--max-tokens 12288 --update-freq 1 --tokens-per-sample 3072 --seed 1 \
--max-update 100000 --max-lr 1.0 --t-mult 2 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 5000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.01 --clip-norm 0.1 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

python eval_lm.py data-bin/wikitext103-bpe --path checkpoints/wikitext103-bpe-kv1-att-fix/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}" \
--save-scores kv1_att_finetune_scores.npy

## ATT K=9
python train.py --task language_modeling \
Expand All @@ -69,3 +74,115 @@ python train.py --task language_modeling \
--max-update 286000 --optimizer nag --lr 5e-2 --clip-norm 100 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

python eval_lm.py data-bin/wikitext103-bpe --path checkpoints/wikitext103-bpe-kv9-att-fix/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe --pseudo-vocab-ratio 9 \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}" \
--save-scores kv9_att_finetune_scores.npy

## ATT K=5
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-kv5-att-fix \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed \
--pseudo-vocab-ratio 5 --criterion agg_softmax \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

## ATT K=3
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-kv3-att-fix \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed \
--pseudo-vocab-ratio 3 --criterion agg_softmax \
--max-update 100000 --max-lr 0.5 --t-mult 2 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 5000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

## ATT K=6
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-kv6-att-fix \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed \
--pseudo-vocab-ratio 6 --criterion agg_softmax \
--max-update 100000 --max-lr 1.0 --t-mult 2 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 5000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.01 --clip-norm 0.1 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16


## ATT K=2
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-kv2-att-fix \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed \
--pseudo-vocab-ratio 2 --criterion agg_softmax \
--max-update 100000 --max-lr 1.0 --t-mult 2 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 5000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.01 --clip-norm 0.1 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

## ATT K=4
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-kv4-att-fix \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed \
--pseudo-vocab-ratio 4 --criterion agg_softmax \
--max-update 100000 --max-lr 1.0 --t-mult 2 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 5000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.01 --clip-norm 0.1 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

## ATT K=7
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-kv7-att-fix \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed \
--pseudo-vocab-ratio 7 --criterion agg_softmax \
--max-update 100000 --max-lr 1.0 --t-mult 2 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 10000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.01 --clip-norm 0.1 \
--max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

## ATT K=3 fixed lr
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-kv3-att-fix \
--arch transformer_lm_wikibpe \
--restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--knn-keytype last_ffn_input --use-last-ffn-input --finetune-out-embed \
--pseudo-vocab-ratio 3 --criterion agg_softmax \
--max-update 286000 --optimizer nag --lr 5e-2 --clip-norm 100 \
--max-tokens 9216 --update-freq 1 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

## evaluate all
CUDA_VISIBLE_DEVICES=6 python eval_lm.py data-bin/wikitext103-bpe --path checkpoints/wikitext103-bpe-kv6-att-fix/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe --pseudo-vocab-ratio 6 \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}" \
--save-scores kv6_att_finetune_scores.npy

0 comments on commit 0522639

Please sign in to comment.