Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed Apr 17, 2022
1 parent b2f45c8 commit 5e64775
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 17 deletions.
23 changes: 15 additions & 8 deletions analysis/kv_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,21 @@
count = len(tokens) - skipped_toks

knn_helping = 0
with open('variants_interpolation.txt', 'w') as outfile:
with open('new_variants_interpolation.txt', 'w') as outfile:
for f in ['best_knn_only_scores.npy',
'seed3_scores.npy',
'kv_scores/2v_scores.npy',
'kv1_finetune_scores.npy',
'kv2_finetune_scores.npy',
'kv_scores/3v_scores.npy',
'additional_linear_scores/add_linear_scores.npy',
'additional_linear_scores/additional_softmax_scores_best.npy']:
'kv3_finetune_scores.npy',
'kv4_finetune_scores.npy',
'kv5_finetune_scores.npy',
'kv6_finetune_scores.npy',
'kv7_finetune_scores.npy',
'kv8_finetune_scores.npy',
'kv9_finetune_scores.npy',
'ip_recomp_knn_scores.npy',
'recomp_knn_scores.npy',
'ip_knn_scores.npy',
]:
overfit_scores = np.load(f)
overfit_scores = torch.from_numpy(overfit_scores).cuda()
combine_probs = torch.stack([lm_scores, overfit_scores], dim=0)
Expand All @@ -49,7 +56,7 @@

oracle_ppl = torch.exp(-oracle_scores.sum() / count)

if 'knn' in f:
if 'best_knn_only_scores' in f:
knn_helping = argmaxs

match_knn = torch.sum(argmaxs == knn_helping).item() / len(tokens)
Expand All @@ -60,7 +67,7 @@

knn_helping_ppl = torch.exp(knn_helping_scores / count)

extra_only_ppl = torch.exp(- overfit_scores.sum() / count)
extra_only_ppl = torch.exp(-overfit_scores.sum() / count)

best_ppl = 1e10
best_lmbda = 0
Expand Down
2 changes: 2 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def add_eval_lm_args(parser):
help='centroids scaling')
parser.add_argument('--save-scores', type=str, default='',
help='scores npy file path to save')
parser.add_argument('--save-knn-scores', type=str, default='',
help='knn scores npy file path to save')



Expand Down
2 changes: 1 addition & 1 deletion fairseq/sequence_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, tgt_dict, softmax_batch=None, compute_alignment=False, args=N
# one-hot coef
self.coef = AggSoftmaxCriterion.initialize_projection_matrix(tgt_dict, args.pseudo_vocab_ratio)
if torch.cuda.is_available() and not args.cpu:
self.coef = self.coef.cuda()
self.coef = self.coef.float().cuda()
if args.load_centroid_distribution:
# load prior coef
from scipy import sparse
Expand Down
4 changes: 2 additions & 2 deletions fairseq_cli/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def main(parsed_args):
if parsed_args.save_scores:
np.save(parsed_args.save_scores, np.concatenate(all_scores))

# if all_knn_scores:
# np.save('best_knn_only_scores.npy', np.concatenate(all_knn_scores))
if all_knn_scores and parsed_args.save_knn_scores:
np.save(parsed_args.save_knn_scores, np.concatenate(all_knn_scores))

if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
Expand Down
2 changes: 1 addition & 1 deletion wikitext_bpe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ python eval_lm.py data-bin/wikitext103-bpe \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_ffn_input \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe

## all knn only
## lambda 0.99 nearly knn only
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
Expand Down
41 changes: 37 additions & 4 deletions wikitext_bpe_final_linear_knn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ python eval_lm.py data-bin/wikitext103-bpe \
--softmax-batch 1024 --gen-subset train \
--context-window 1536 --tokens-per-sample 1536 \
--dstore-mmap checkpoints/wikitext103-bpe/last_linear_inp/dstore --knn-keytype 'last_linear_input' \
--dstore-size 112733184 --model-overrides "{'knn_keytype': 'last_linear_input'}" \
--dstore-size 153225485 --model-overrides "{'knn_keytype': 'last_linear_input'}" \
--save-knnlm-dstore --fp16 --dstore-fp16

# build index
python build_dstore.py \
--dstore_mmap checkpoints/wikitext103-bpe/last_linear_inp/dstore \
--dstore_size 112733184 \
--dstore_size 153225485 \
--faiss_index checkpoints/wikitext103-bpe/last_linear_inp/knn.index \
--num_keys_to_add_at_a_time 500000 \
--starting_point 0 --dstore-fp16 --dimension 1024
Expand All @@ -25,7 +25,7 @@ python eval_lm.py data-bin/wikitext103-bpe \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/last_linear_inp/dstore \
--indexfile checkpoints/wikitext103-bpe/last_linear_inp/knn.index \
--model-overrides "{'knn_keytype': 'last_linear_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 112733184 --knn-keytype last_linear_input \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_linear_input \
--knn-sim-func "do_not_recomp_l2" --no-load-keys \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe

Expand All @@ -37,5 +37,38 @@ python eval_lm.py data-bin/wikitext103-bpe \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/last_linear_inp/dstore \
--indexfile checkpoints/wikitext103-bpe/last_linear_inp/knn.index \
--model-overrides "{'knn_keytype': 'last_linear_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 112733184 --knn-keytype last_linear_input \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_linear_input \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe

## USE IP metric
# build index
python build_dstore.py \
--dstore_mmap checkpoints/wikitext103-bpe/last_linear_inp/dstore \
--dstore_size 153225485 \
--faiss_index checkpoints/wikitext103-bpe/last_linear_inp/knn_ip.index \
--num_keys_to_add_at_a_time 500000 \
--starting_point 0 --dstore-fp16 --dimension 1024 --metric ip

# no recompute
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/last_linear_inp/dstore \
--indexfile checkpoints/wikitext103-bpe/last_linear_inp/knn.index --faiss-metric-type ip \
--model-overrides "{'knn_keytype': 'last_linear_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_linear_input \
--no-load-keys \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe

# recompute
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/last_linear_inp/dstore \
--indexfile checkpoints/wikitext103-bpe/last_linear_inp/knn.index --faiss-metric-type ip \
--model-overrides "{'knn_keytype': 'last_linear_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_linear_input \
--knn-sim-func "dot" \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe
3 changes: 2 additions & 1 deletion wikitext_bpe_ip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python build_dstore.py \
--num_keys_to_add_at_a_time 500000 \
--starting_point 0 --dstore-fp16 --dimension 1024 --metric ip


# no recompute
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
Expand All @@ -20,6 +20,7 @@ python eval_lm.py data-bin/wikitext103-bpe \
--no-load-keys \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe

# recompute
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
Expand Down

0 comments on commit 5e64775

Please sign in to comment.