Skip to content

Commit

Permalink
add mixture of softmax update
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed May 10, 2022
1 parent 5a00008 commit 2b699c2
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 9 deletions.
10 changes: 9 additions & 1 deletion analysis/kv_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
from tqdm import tqdm

import numpy as np
import torch
Expand Down Expand Up @@ -61,6 +60,15 @@
'kv5_att_finetune_scores.npy',
'kv6_att_finetune_scores.npy',
'kv9_att_finetune_scores.npy',
'kv3_att_finetune_new_scores.npy',
'mos_scores/mos2_att_embed_finetune.npy',
'mos_scores/mos3_att_embed_finetune.npy',
'mos_scores/mos2_att_finetune.npy',
'mos_scores/mos3_att_finetune.npy',
'mos_scores/mos2_finetune.npy',
'mos_scores/mos3_finetune.npy',
'mos_scores/mos4_finetune.npy',
'mos_scores/mos5_finetune.npy',
]:
extra_scores = np.load(f)
extra_scores = torch.from_numpy(extra_scores).cuda()
Expand Down
2 changes: 2 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ def add_eval_lm_args(parser):
help='Use GPU faiss')
parser.add_argument('--pseudo-vocab-ratio', type=int, default=1,
help='k|V| output embedding matrix')
parser.add_argument('--k-mos', type=int, default=1,
help='k mixture of softmax')
parser.add_argument('--use-last-ffn-input', action='store_true',
help='if set, use last ffn input to multiply weight matrix')
parser.add_argument('--load-centroids', type=str, default='',
Expand Down
46 changes: 41 additions & 5 deletions wikitext_bpe_att_mos_finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,56 @@ python train.py --task language_modeling \


# ATT MOS k=3 + finetune output embedding
CUDA_VISIBLE_DEVICES=0,1,2,5,6,7 python train.py --task language_modeling \
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-mos3-att-embed-finetune \
--arch transformer_lm_wikibpe --restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--finetune-mos --finetune-out-embed --k-mos 3 --knn-keytype last_ffn_input --use-last-ffn-input \
--max-update 286000 --optimizer nag --lr 5e-2 --clip-norm 100 \
--max-tokens 12288 --update-freq 1 --tokens-per-sample 3072 --seed 1 \
--max-tokens 9216 --update-freq 1 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

# ATT MOS k=2 + finetune output embedding
python train.py --task language_modeling \
data-bin/wikitext103-bpe \
--save-dir checkpoints/wikitext103-bpe-mos2-att-embed-finetune \
--arch transformer_lm_wikibpe --restore-file checkpoints/wikitext103-bpe/checkpoint_best.pt \
--reset-optimizer --reset-dataloader --reset-meters \
--finetune-mos --finetune-out-embed --k-mos 2 --knn-keytype last_ffn_input --use-last-ffn-input \
--max-update 286000 --optimizer nag --lr 1e-3 --clip-norm 100 \
--max-tokens 9216 --update-freq 1 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --fp16

## eval
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-kv2-fix/checkpoint_best.pt \
--path checkpoints/wikitext103-bpe-mos3-att-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --k-mos 3 \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}" \
--gen-subset valid --bpe subword_nmt --remove-bpe \
--save-scores mos_scores/mos3_att_finetune.npy

python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-mos2-att-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --k-mos 2 \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}" \
--gen-subset valid --bpe subword_nmt --remove-bpe \
--save-scores mos_scores/mos2_att_finetune.npy

python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-mos3-att-embed-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --k-mos 3 \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}" \
--gen-subset valid --bpe subword_nmt --remove-bpe \
--save-scores mos_scores/mos3_att_embed_finetune.npy

python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-mos2-att-embed-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --pseudo-vocab-ratio 2 \
--gen-subset valid --bpe subword_nmt --remove-bpe
--context-window 2560 --softmax-batch 1024 --k-mos 2 \
--model-overrides "{'knn_keytype': 'last_ffn_input', 'use_last_ffn_input': True}" \
--gen-subset valid --bpe subword_nmt --remove-bpe \
--save-scores mos_scores/mos2_att_embed_finetune.npy
25 changes: 22 additions & 3 deletions wikitext_bpe_mos_finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,26 @@ python train.py --task language_modeling \

## eval
python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-kv2-fix/checkpoint_best.pt \
--path checkpoints/wikitext103-bpe-mos3-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --pseudo-vocab-ratio 2 \
--gen-subset valid --bpe subword_nmt --remove-bpe
--context-window 2560 --softmax-batch 1024 --k-mos 3 \
--gen-subset valid --bpe subword_nmt --remove-bpe --save-scores mos_scores/mos3_finetune.npy


python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-mos2-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --k-mos 2 \
--gen-subset valid --bpe subword_nmt --remove-bpe --save-scores mos_scores/mos2_finetune.npy

python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-mos4-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --k-mos 4 \
--gen-subset valid --bpe subword_nmt --remove-bpe --save-scores mos_scores/mos4_finetune.npy

python eval_lm.py data-bin/wikitext103-bpe \
--path checkpoints/wikitext103-bpe-mos5-finetune/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --k-mos 5 \
--gen-subset valid --bpe subword_nmt --remove-bpe --save-scores mos_scores/mos5_finetune.npy

0 comments on commit 2b699c2

Please sign in to comment.