diff --git a/analysis/kv_analysis.py b/analysis/kv_analysis.py index de8f8bb..880dba5 100644 --- a/analysis/kv_analysis.py +++ b/analysis/kv_analysis.py @@ -79,12 +79,35 @@ # 'epoch_scores/3v-att-init-finetune-epoch2.npy', # 'epoch_scores/3v-att-init-finetune-epoch3.npy', # 'epoch_scores/3v-att-init-finetune-epoch4.npy', - 'interpolated_loss_scores/kv1_interpolated_scores.npy', - 'interpolated_loss_scores/kv1_att_interpolated_scores.npy', - 'interpolated_loss_scores/kv3_interpolated_scores.npy', - 'interpolated_loss_scores/kv3_att_interpolated_scores.npy', - 'weighted_loss_scores/kv1_weighted_scores.npy', - 'weighted_loss_scores/kv1_att_weighted_scores.npy', + # 'interpolated_loss_scores/kv1_interpolated_scores.npy', + # 'interpolated_loss_scores/kv1_att_interpolated_scores.npy', + # 'interpolated_loss_scores/kv3_interpolated_scores.npy', + # 'interpolated_loss_scores/kv3_att_interpolated_scores.npy', + # 'weighted_loss_scores/kv1_weighted_scores.npy', + # 'weighted_loss_scores/kv1_att_weighted_scores.npy', + '0.05_knn_scores.npy', + '0.1_knn_scores.npy', + '0.2_knn_scores.npy', + '0.3_knn_scores.npy', + '0.4_knn_scores.npy', + '0.5_knn_scores.npy', + '0.6_knn_scores.npy', + '0.7_knn_scores.npy', + '0.8_knn_scores.npy', + '0.9_knn_scores.npy', + '0.05_knn_recomp_scores.npy', + '0.1_knn_recomp_scores.npy', + '0.2_knn_recomp_scores.npy', + '0.3_knn_recomp_scores.npy', + '0.4_knn_recomp_scores.npy', + '0.5_knn_recomp_scores.npy', + '0.6_knn_recomp_scores.npy', + '0.7_knn_recomp_scores.npy', + '0.8_knn_recomp_scores.npy', + '0.9_knn_recomp_scores.npy', + '0.05_knn_scores_2.npy', + '0.1_knn_scores_2.npy', + '0.2_knn_scores_2.npy', ] # extra_score_files = [f'epoch_scores/3v-att-init-finetune-epoch{e}.npy' for e in range(1, 40)] diff --git a/analysis/subsample_datastore_0.05.py b/analysis/subsample_datastore_0.05.py new file mode 100644 index 0000000..9ceaf1e --- /dev/null +++ b/analysis/subsample_datastore_0.05.py @@ -0,0 +1,26 @@ +import numpy as np + +from fairseq.data import Dictionary +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 7661274 + +dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt') +print(len(dictionary)) + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.05_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.05_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.1.py b/analysis/subsample_datastore_0.1.py new file mode 100644 index 0000000..2cee783 --- /dev/null +++ b/analysis/subsample_datastore_0.1.py @@ -0,0 +1,26 @@ +import numpy as np + +from fairseq.data import Dictionary +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 15322548 + +dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt') +print(len(dictionary)) + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.1_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.1_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.2.py b/analysis/subsample_datastore_0.2.py new file mode 100644 index 0000000..1476ef8 --- /dev/null +++ b/analysis/subsample_datastore_0.2.py @@ -0,0 +1,21 @@ +import numpy as np +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 30645096 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.2_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.2_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.3.py b/analysis/subsample_datastore_0.3.py new file mode 100644 index 0000000..785f94b --- /dev/null +++ b/analysis/subsample_datastore_0.3.py @@ -0,0 +1,21 @@ +import numpy as np +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 45967645 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.3_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.3_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.4.py b/analysis/subsample_datastore_0.4.py new file mode 100644 index 0000000..cddc283 --- /dev/null +++ b/analysis/subsample_datastore_0.4.py @@ -0,0 +1,21 @@ +import numpy as np +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 61290194 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.4_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.4_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.5.py b/analysis/subsample_datastore_0.5.py new file mode 100644 index 0000000..d3b8383 --- /dev/null +++ b/analysis/subsample_datastore_0.5.py @@ -0,0 +1,21 @@ +import numpy as np +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 76612742 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.5_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.5_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.6.py b/analysis/subsample_datastore_0.6.py new file mode 100644 index 0000000..c16b87e --- /dev/null +++ b/analysis/subsample_datastore_0.6.py @@ -0,0 +1,22 @@ +import numpy as np + +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 91935291 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.6_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.6_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.7.py b/analysis/subsample_datastore_0.7.py new file mode 100644 index 0000000..aa15378 --- /dev/null +++ b/analysis/subsample_datastore_0.7.py @@ -0,0 +1,22 @@ +import numpy as np + +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 107257840 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.7_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.7_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.8.py b/analysis/subsample_datastore_0.8.py new file mode 100644 index 0000000..431cb52 --- /dev/null +++ b/analysis/subsample_datastore_0.8.py @@ -0,0 +1,22 @@ +import numpy as np + +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 122580388 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.8_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.8_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/analysis/subsample_datastore_0.9.py b/analysis/subsample_datastore_0.9.py new file mode 100644 index 0000000..19e9b57 --- /dev/null +++ b/analysis/subsample_datastore_0.9.py @@ -0,0 +1,22 @@ +import numpy as np + +np.random.seed(0) +dstore_size = 153225485 +vec_dim = 1024 +subsample_size = 137902936 + +keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy', + dtype=np.float16, mode='r', shape=(dstore_size, vec_dim)) +vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', + dtype=np.int64, mode='r', shape=(dstore_size, 1)) + +subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False) + +print('sampled idx created') +subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.9_keys.npy', + dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim)) +subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.9_vals.npy', + dtype=np.int64, mode='w+', shape=(subsample_size, 1)) + +subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs] +subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs] diff --git a/build_dstore.py b/build_dstore.py index 2737eda..af49801 100644 --- a/build_dstore.py +++ b/build_dstore.py @@ -4,13 +4,13 @@ import faiss import time - parser = argparse.ArgumentParser() parser.add_argument('--dstore_mmap', type=str, help='memmap where keys and vals are stored') parser.add_argument('--dstore_size', type=int, help='number of items saved in the datastore memmap') parser.add_argument('--dimension', type=int, default=1024, help='Size of each key') parser.add_argument('--dstore_fp16', '--dstore-fp16', default=False, action='store_true') -parser.add_argument('--seed', type=int, default=1, help='random seed for sampling the subset of vectors to train the cache') +parser.add_argument('--seed', type=int, default=1, + help='random seed for sampling the subset of vectors to train the cache') parser.add_argument('--ncentroids', type=int, default=4096, help='number of centroids faiss should learn') parser.add_argument('--code_size', type=int, default=64, help='size of quantized vectors') parser.add_argument('--probe', type=int, default=8, help='number of clusters to query') @@ -25,40 +25,42 @@ print(args) if args.dstore_fp16: - keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float16, mode='r', shape=(args.dstore_size, args.dimension)) - vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1)) + keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float16, mode='r', + shape=(args.dstore_size, args.dimension)) + vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1)) else: - keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='r', shape=(args.dstore_size, args.dimension)) - vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1)) + keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float32, mode='r', + shape=(args.dstore_size, args.dimension)) + vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1)) -if not os.path.exists(args.faiss_index+".trained"): - # Initialize faiss index - metric = faiss.METRIC_L2 if args.metric == 'l2' else faiss.METRIC_INNER_PRODUCT - quantizer = faiss.IndexFlatL2(args.dimension) if args.metric == 'l2' else faiss.IndexFlatIP(args.dimension) - index = faiss.IndexIVFPQ(quantizer, args.dimension, - args.ncentroids, args.code_size, 8, metric) - index.nprobe = args.probe +# if not os.path.exists(args.faiss_index+".trained"): +# Initialize faiss index +metric = faiss.METRIC_L2 if args.metric == 'l2' else faiss.METRIC_INNER_PRODUCT +quantizer = faiss.IndexFlatL2(args.dimension) if args.metric == 'l2' else faiss.IndexFlatIP(args.dimension) +index = faiss.IndexIVFPQ(quantizer, args.dimension, + args.ncentroids, args.code_size, 8, metric) +index.nprobe = args.probe - print('Training Index') - np.random.seed(args.seed) - random_sample = np.random.choice(np.arange(vals.shape[0]), size=[min(1000000, vals.shape[0])], replace=False) - start = time.time() - # Faiss does not handle adding keys in fp16 as of writing this. - index.train(keys[random_sample].astype(np.float32)) - print('Training took {} s'.format(time.time() - start)) +print('Training Index') +np.random.seed(args.seed) +random_sample = np.random.choice(np.arange(vals.shape[0]), size=[min(1000000, vals.shape[0])], replace=False) +start = time.time() +# Faiss does not handle adding keys in fp16 as of writing this. +index.train(keys[random_sample].astype(np.float32)) +print('Training took {} s'.format(time.time() - start)) - print('Writing index after training') - start = time.time() - faiss.write_index(index, args.faiss_index+".trained") - print('Writing index took {} s'.format(time.time()-start)) +print('Writing index after training') +start = time.time() +faiss.write_index(index, args.faiss_index + ".trained") +print('Writing index took {} s'.format(time.time() - start)) print('Adding Keys') -index = faiss.read_index(args.faiss_index+".trained") +index = faiss.read_index(args.faiss_index + ".trained") start = args.starting_point start_time = time.time() while start < args.dstore_size: - end = min(args.dstore_size, start+args.num_keys_to_add_at_a_time) + end = min(args.dstore_size, start + args.num_keys_to_add_at_a_time) to_add = keys[start:end].copy() index.add_with_ids(to_add.astype(np.float32), np.arange(start, end)) start += args.num_keys_to_add_at_a_time @@ -73,4 +75,4 @@ print('Writing Index') start_time = time.time() faiss.write_index(index, args.faiss_index) -print('Writing index took {} s'.format(time.time()-start_time)) +print('Writing index took {} s'.format(time.time() - start_time)) diff --git a/cluster/group_vecs.py b/cluster/group_vecs.py index c18f825..c0c3d8f 100644 --- a/cluster/group_vecs.py +++ b/cluster/group_vecs.py @@ -14,7 +14,6 @@ vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy', dtype=np.int64, mode='r', shape=(dstore_size, 1)) - keys = np.zeros((dstore_size, vec_dim), dtype=np.float16) vals = np.zeros((dstore_size, 1), dtype=np.int64) @@ -24,6 +23,6 @@ vals = vals.squeeze() -#exclude 0 +# exclude 0 for word_id in tqdm.tqdm(range(1, len(dictionary))): - np.save('dstore/ids/' + str(word_id) + '.npy', keys[vals==word_id]) + np.save('dstore/ids/' + str(word_id) + '.npy', keys[vals == word_id]) diff --git a/fairseq/options.py b/fairseq/options.py index fce5898..fcbb59f 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -536,6 +536,8 @@ def add_eval_lm_args(parser): help='tokens npy file path to save') parser.add_argument('--save-knn-scores', type=str, default='', help='knn scores npy file path to save') + parser.add_argument('--save-queries', type=str, default='', + help='query vectors npy file path to save') # fmt: on diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index 50697ba..d481178 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -142,6 +142,8 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): orig_target.permute(1, 0), pad_idx=self.pad) yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1) + queries = queries.permute(1, 0, 2) + if self.args.fp16: yhat_knn_prob = yhat_knn_prob.half() probs = probs.half() @@ -159,6 +161,7 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): avg_attn = attn else: avg_attn.add_(attn) + if len(models) > 1: avg_probs.div_(len(models)) avg_probs.log_() @@ -176,6 +179,7 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] if 'knn_dstore' in kwargs: knn_probs_i = yhat_knn_prob[i][start_idxs[i]:start_idxs[i] + tgt_len] + queries_i = queries[i][start_idxs[i]:start_idxs[i] + tgt_len] score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: avg_attn_i = avg_attn[i] @@ -200,5 +204,6 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): 'dstore_keys': decoder_out[1][self.args.knn_keytype][start_idxs[i]:, i, :] if self.args.save_knnlm_dstore else None, 'knn_probs': knn_probs_i if 'knn_dstore' in kwargs else None, + 'queries': queries_i if 'knn_dstore' in kwargs else None, }]) return hypos diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 8fdb32f..79becf1 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -21,7 +21,6 @@ from fairseq.sequence_scorer import SequenceScorer from fairseq.knnlm import KNN_Dstore - logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', @@ -166,18 +165,24 @@ def main(parsed_args): print('keytype being saved:', args.knn_keytype) if args.dstore_fp16: print('Saving fp16') - dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float16, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) - dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int64, mode='w+', shape=(args.dstore_size, 1)) + dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float16, mode='w+', + shape=(args.dstore_size, args.decoder_embed_dim)) + dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int64, mode='w+', + shape=(args.dstore_size, 1)) else: print('Saving fp32') - dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) - dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int64, mode='w+', shape=(args.dstore_size, 1)) + dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float32, mode='w+', + shape=(args.dstore_size, args.decoder_embed_dim)) + dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int64, mode='w+', + shape=(args.dstore_size, 1)) dstore_idx = 0 all_token_ids = [] all_scores = [] all_knn_scores = [] + all_queries = [] + for ex_i, sample in enumerate(t): if 'net_input' not in sample: continue @@ -191,8 +196,7 @@ def main(parsed_args): hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) - - for i, hypos_i in enumerate(hypos): + for idx, hypos_i in enumerate(hypos): hypo = hypos_i[0] if args.save_knnlm_dstore: shape = hypo['dstore_keys'].shape @@ -203,20 +207,20 @@ def main(parsed_args): hypo['dstore_keys'] = hypo['dstore_keys'][:shape[0]] hypo['tokens'] = hypo['tokens'][:shape[0]] if args.dstore_fp16: - dstore_keys[dstore_idx:shape[0]+dstore_idx] = hypo['dstore_keys'].view( + dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu().numpy().astype(np.float16) - dstore_vals[dstore_idx:shape[0]+dstore_idx] = hypo['tokens'].view( + dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int64) else: - dstore_keys[dstore_idx:shape[0]+dstore_idx] = hypo['dstore_keys'].view( + dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu().numpy().astype(np.float32) - dstore_vals[dstore_idx:shape[0]+dstore_idx] = hypo['tokens'].view( + dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int64) dstore_idx += shape[0] else: print('Skipping this one with shape', shape) - sample_id = sample['id'][i] + sample_id = sample['id'][idx] tokens = hypo['tokens'] tgt_len = tokens.numel() @@ -225,6 +229,8 @@ def main(parsed_args): if args.knnlm: knn_scores = hypo['knn_probs'].float() all_knn_scores.append(knn_scores.cpu().numpy()) + queries = hypo['queries'].float() + all_queries.append(queries.cpu().numpy()) if args.add_bos_token: assert hypo['tokens'][0].item() == task.target_dictionary.bos() @@ -245,8 +251,8 @@ def main(parsed_args): pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 - #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) - #if inf_scores.any(): + # inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) + # if inf_scores.any(): # logger.info( # 'skipping tokens with inf scores:', # task.target_dictionary.string(tokens[inf_scores.nonzero()]) @@ -298,7 +304,7 @@ def main(parsed_args): gen_timer.n, gen_timer.sum, 1. / gen_timer.avg )) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( - avg_nll_loss, 2**avg_nll_loss + avg_nll_loss, 2 ** avg_nll_loss )) # saving @@ -308,6 +314,8 @@ def main(parsed_args): np.save(parsed_args.save_scores, np.concatenate(all_scores)) if all_knn_scores and parsed_args.save_knn_scores: np.save(parsed_args.save_knn_scores, np.concatenate(all_knn_scores)) + if all_queries and parsed_args.save_queries: + np.save(parsed_args.save_queries, np.concatenate(all_queries)) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): diff --git a/wikitext_bpe_memknn.sh b/wikitext_bpe_memknn.sh new file mode 100644 index 0000000..abd4203 --- /dev/null +++ b/wikitext_bpe_memknn.sh @@ -0,0 +1,78 @@ +# 0.05 +python analysis/subsample_datastore_0.05.py + +# build index +python build_dstore.py \ + --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.05 \ + --dstore_size 7661274 \ + --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.05.index \ + --num_keys_to_add_at_a_time 500000 \ + --starting_point 0 --dstore-fp16 --dimension 1024 + +# eval with index +# no recompute +python eval_lm.py data-bin/wikitext103-bpe \ + --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ + --sample-break-mode complete --max-tokens 3072 \ + --context-window 2560 --softmax-batch 1024 \ + --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.05 \ + --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.05.index \ + --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ + --k 1024 --lmbda 0.25 --dstore-size 7661274 --knn-keytype last_ffn_input \ + --knn-sim-func "do_not_recomp_l2" --no-load-keys \ + --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.05_knn_scores_2.npy \ + --save-queries all_queries.npy + + +# 0.1 +python analysis/subsample_datastore_0.1.py + +# build index +python build_dstore.py \ + --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.1 \ + --dstore_size 15322548 \ + --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.1.index \ + --num_keys_to_add_at_a_time 500000 \ + --starting_point 0 --dstore-fp16 --dimension 1024 + +# eval with index +# no recompute +python eval_lm.py data-bin/wikitext103-bpe \ + --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ + --sample-break-mode complete --max-tokens 3072 \ + --context-window 2560 --softmax-batch 1024 \ + --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.1 \ + --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.1.index \ + --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ + --k 1024 --lmbda 0.25 --dstore-size 15322548 --knn-keytype last_ffn_input \ + --knn-sim-func "do_not_recomp_l2" --no-load-keys \ + --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.1_knn_scores_2.npy \ + --save-queries all_queries_1.npy + + +# 0.2 +python analysis/subsample_datastore_0.2.py + +# build index +python build_dstore.py \ + --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.2 \ + --dstore_size 30645096 \ + --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.2.index \ + --num_keys_to_add_at_a_time 500000 \ + --starting_point 0 --dstore-fp16 --dimension 1024 + +# eval with index +# no recompute +python eval_lm.py data-bin/wikitext103-bpe \ + --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ + --sample-break-mode complete --max-tokens 3072 \ + --context-window 2560 --softmax-batch 1024 \ + --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.2 \ + --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.2.index \ + --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ + --k 1024 --lmbda 0.25 --dstore-size 30645096 --knn-keytype last_ffn_input \ + --knn-sim-func "do_not_recomp_l2" --no-load-keys \ + --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.2_knn_scores_2.npy \ + --save-queries all_queries_2.npy + + diff --git a/wikitext_bpe_subsample.sh b/wikitext_bpe_subsample.sh new file mode 100644 index 0000000..6b95ba3 --- /dev/null +++ b/wikitext_bpe_subsample.sh @@ -0,0 +1,354 @@ +## 0.05 +#python analysis/subsample_datastore_0.05.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.05 \ +# --dstore_size 7661274 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.05.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.05 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.05.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 7661274 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.05_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.05 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.05.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 7661274 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.05_knn_recomp_scores.npy +# +## 0.1 +#python analysis/subsample_datastore_0.1.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.1 \ +# --dstore_size 15322548 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.1.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.1 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.1.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 15322548 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.1_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.1 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.1.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 15322548 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.1_knn_recomp_scores.npy +# +## 0.2 +#python analysis/subsample_datastore_0.2.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.2 \ +# --dstore_size 30645096 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.2.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.2 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.2.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 30645096 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.2_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.2 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.2.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 30645096 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.2_knn_recomp_scores.npy +# +# +## 0.3 +#python analysis/subsample_datastore_0.3.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.3 \ +# --dstore_size 45967645 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.3.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.3 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.3.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 45967645 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.3_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.3 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.3.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 45967645 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.3_knn_recomp_scores.npy +# +# +## 0.4 +#python analysis/subsample_datastore_0.4.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.4 \ +# --dstore_size 61290194 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.4.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.4 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.4.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 61290194 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.4_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.4 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.4.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 61290194 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.4_knn_recomp_scores.npy +# +## 0.5 +#python analysis/subsample_datastore_0.5.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.5 \ +# --dstore_size 76612742 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.5.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.5 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.5.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 76612742 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.5_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.5 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.5.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 76612742 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.5_knn_recomp_scores.npy +# +# +## 0.6 +#python analysis/subsample_datastore_0.6.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.6 \ +# --dstore_size 91935291 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.6.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.6 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.6.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 91935291 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.6_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.6 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.6.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 91935291 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.6_knn_recomp_scores.npy + + +## 0.7 +#python analysis/subsample_datastore_0.7.py +# +## build index +#python build_dstore.py \ +# --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.7 \ +# --dstore_size 107257840 \ +# --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.7.index \ +# --num_keys_to_add_at_a_time 500000 \ +# --starting_point 0 --dstore-fp16 --dimension 1024 +# +## eval with index +## no recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.7 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.7.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 107257840 --knn-keytype last_ffn_input \ +# --knn-sim-func "do_not_recomp_l2" --no-load-keys \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.7_knn_scores.npy +# +## recompute +#python eval_lm.py data-bin/wikitext103-bpe \ +# --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ +# --sample-break-mode complete --max-tokens 3072 \ +# --context-window 2560 --softmax-batch 1024 \ +# --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.7 \ +# --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.7.index \ +# --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ +# --k 1024 --lmbda 0.25 --dstore-size 107257840 --knn-keytype last_ffn_input \ +# --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.7_knn_recomp_scores.npy + + +# 0.8 +python analysis/subsample_datastore_0.8.py + +# build index +python build_dstore.py \ + --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.8 \ + --dstore_size 122580388 \ + --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.8.index \ + --num_keys_to_add_at_a_time 500000 \ + --starting_point 0 --dstore-fp16 --dimension 1024 + +# eval with index +# no recompute +python eval_lm.py data-bin/wikitext103-bpe \ + --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ + --sample-break-mode complete --max-tokens 3072 \ + --context-window 2560 --softmax-batch 1024 \ + --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.8 \ + --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.8.index \ + --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ + --k 1024 --lmbda 0.25 --dstore-size 122580388 --knn-keytype last_ffn_input \ + --knn-sim-func "do_not_recomp_l2" --no-load-keys \ + --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.8_knn_scores.npy + +# recompute +python eval_lm.py data-bin/wikitext103-bpe \ + --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ + --sample-break-mode complete --max-tokens 3072 \ + --context-window 2560 --softmax-batch 1024 \ + --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.8 \ + --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.8.index \ + --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ + --k 1024 --lmbda 0.25 --dstore-size 122580388 --knn-keytype last_ffn_input \ + --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.8_knn_recomp_scores.npy + +# 0.9 +python analysis/subsample_datastore_0.9.py + +# build index +python build_dstore.py \ + --dstore_mmap checkpoints/wikitext103-bpe/dstore_subsampled_0.9 \ + --dstore_size 137902936 \ + --faiss_index checkpoints/wikitext103-bpe/knn_subsampled_0.9.index \ + --num_keys_to_add_at_a_time 500000 \ + --starting_point 0 --dstore-fp16 --dimension 1024 + +# eval with index +# no recompute +python eval_lm.py data-bin/wikitext103-bpe \ + --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ + --sample-break-mode complete --max-tokens 3072 \ + --context-window 2560 --softmax-batch 1024 \ + --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.9 \ + --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.9.index \ + --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ + --k 1024 --lmbda 0.25 --dstore-size 137902936 --knn-keytype last_ffn_input \ + --knn-sim-func "do_not_recomp_l2" --no-load-keys \ + --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.9_knn_scores.npy + +# recompute +python eval_lm.py data-bin/wikitext103-bpe \ + --path checkpoints/wikitext103-bpe/checkpoint_best.pt \ + --sample-break-mode complete --max-tokens 3072 \ + --context-window 2560 --softmax-batch 1024 \ + --gen-subset valid --dstore-filename checkpoints/wikitext103-bpe/dstore_subsampled_0.9 \ + --indexfile checkpoints/wikitext103-bpe/knn_subsampled_0.9.index \ + --model-overrides "{'knn_keytype': 'last_ffn_input'}" \ + --k 1024 --lmbda 0.25 --dstore-size 137902936 --knn-keytype last_ffn_input \ + --probe 32 --knnlm --fp16 --dstore-fp16 --bpe subword_nmt --remove-bpe --save-knn-scores 0.9_knn_recomp_scores.npy