Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed Apr 24, 2022
1 parent c19d7e8 commit 9d981f9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
2 changes: 2 additions & 0 deletions fairseq/models/transformer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def add_args(parser):
# args for KNN-Distill
parser.add_argument('--pseudo-vocab-ratio', type=int, default=1,
help='k|V| output embedding matrix')
parser.add_argument('--preserve-out-embed', default=False, action='store_true',
help='load output embed from the checkpoint')
parser.add_argument('--additional-linear', default=False, action='store_true',
help='add additional output embedding for last_ffn_input')
parser.add_argument('--use-l2', action='store_true',
Expand Down
2 changes: 1 addition & 1 deletion fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def load_checkpoint(
if self.args.additional_linear:
strict = False
# remove pretrained output embedding if using kv options
if self.args.pseudo_vocab_ratio > 1:
if self.args.pseudo_vocab_ratio > 1 and not self.args.preserve_out_embed:
del state["model"]['decoder.embed_out']
strict = False
self.get_model().load_state_dict(
Expand Down
78 changes: 78 additions & 0 deletions wikitext_bpe_alteval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# LM 3072, context-window 2560, none

## eval train
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset train --bpe subword_nmt --remove-bpe \
--save-tokens overfit_analysis/train_tokens.npy --save-scores overfit_analysis/train_lm_scores.npy

## eval valid
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe \
--save-tokens overfit_analysis/tokens.npy --save-scores overfit_analysis/lm_scores.npy

## eval with KNN
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore \
--indexfile checkpoints/wikitext103-bpe/knn_prune.index \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_ffn_input \
--knn-sim-func "do_not_recomp_l2" --no-load-keys \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe \
--save-knn-scores overfit_analysis/knn_scores.npy

## eval with KNN recomp
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore \
--indexfile checkpoints/wikitext103-bpe/knn_prune.index \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_ffn_input \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe \
--save-knn-scores overfit_analysis/knn_recomp_scores.npy

## eval with KNN ip
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore \
--indexfile checkpoints/wikitext103-bpe/knn_ip.index --faiss-metric-type ip \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_ffn_input \
--no-load-keys \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe \
--save-knn-scores overfit_analysis/knn_ip_scores.npy

## eval with KNN ip recomp
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe/checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore \
--indexfile checkpoints/wikitext103-bpe/knn_ip.index --faiss-metric-type ip \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size 153225485 --knn-keytype last_ffn_input \
--knn-sim-func "dot" \
--probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe \
--save-knn-scores overfit_analysis/knn_ip_recomp_scores.npy


# overfitted model
## eval train
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-overfit-new/checkpoint129.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset train --bpe subword_nmt --remove-bpe \
--save-scores overfit_analysis/train_overfit129_lm_scores.npy

## eval valid
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-overfit-new/checkpoint129.pt \
--sample-break-mode none --max-tokens 3072 --softmax-batch 1024 \
--gen-subset valid --bpe subword_nmt --remove-bpe \
--save-scores overfit_analysis/overfit129_lm_scores.npy

0 comments on commit 9d981f9

Please sign in to comment.