Skip to content

Commit

Permalink
update subsample experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed Sep 26, 2022
1 parent 62e4f52 commit ea6e402
Show file tree
Hide file tree
Showing 18 changed files with 746 additions and 51 deletions.
35 changes: 29 additions & 6 deletions analysis/kv_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,35 @@
# 'epoch_scores/3v-att-init-finetune-epoch2.npy',
# 'epoch_scores/3v-att-init-finetune-epoch3.npy',
# 'epoch_scores/3v-att-init-finetune-epoch4.npy',
'interpolated_loss_scores/kv1_interpolated_scores.npy',
'interpolated_loss_scores/kv1_att_interpolated_scores.npy',
'interpolated_loss_scores/kv3_interpolated_scores.npy',
'interpolated_loss_scores/kv3_att_interpolated_scores.npy',
'weighted_loss_scores/kv1_weighted_scores.npy',
'weighted_loss_scores/kv1_att_weighted_scores.npy',
# 'interpolated_loss_scores/kv1_interpolated_scores.npy',
# 'interpolated_loss_scores/kv1_att_interpolated_scores.npy',
# 'interpolated_loss_scores/kv3_interpolated_scores.npy',
# 'interpolated_loss_scores/kv3_att_interpolated_scores.npy',
# 'weighted_loss_scores/kv1_weighted_scores.npy',
# 'weighted_loss_scores/kv1_att_weighted_scores.npy',
'0.05_knn_scores.npy',
'0.1_knn_scores.npy',
'0.2_knn_scores.npy',
'0.3_knn_scores.npy',
'0.4_knn_scores.npy',
'0.5_knn_scores.npy',
'0.6_knn_scores.npy',
'0.7_knn_scores.npy',
'0.8_knn_scores.npy',
'0.9_knn_scores.npy',
'0.05_knn_recomp_scores.npy',
'0.1_knn_recomp_scores.npy',
'0.2_knn_recomp_scores.npy',
'0.3_knn_recomp_scores.npy',
'0.4_knn_recomp_scores.npy',
'0.5_knn_recomp_scores.npy',
'0.6_knn_recomp_scores.npy',
'0.7_knn_recomp_scores.npy',
'0.8_knn_recomp_scores.npy',
'0.9_knn_recomp_scores.npy',
'0.05_knn_scores_2.npy',
'0.1_knn_scores_2.npy',
'0.2_knn_scores_2.npy',
]

# extra_score_files = [f'epoch_scores/3v-att-init-finetune-epoch{e}.npy' for e in range(1, 40)]
Expand Down
26 changes: 26 additions & 0 deletions analysis/subsample_datastore_0.05.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

from fairseq.data import Dictionary
np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 7661274

dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt')
print(len(dictionary))

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.05_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.05_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
26 changes: 26 additions & 0 deletions analysis/subsample_datastore_0.1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

from fairseq.data import Dictionary
np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 15322548

dictionary = Dictionary.load('data-bin/wikitext103-bpe/dict.txt')
print(len(dictionary))

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.1_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.1_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
21 changes: 21 additions & 0 deletions analysis/subsample_datastore_0.2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 30645096

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.2_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.2_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
21 changes: 21 additions & 0 deletions analysis/subsample_datastore_0.3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 45967645

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.3_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.3_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
21 changes: 21 additions & 0 deletions analysis/subsample_datastore_0.4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 61290194

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.4_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.4_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
21 changes: 21 additions & 0 deletions analysis/subsample_datastore_0.5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 76612742

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.5_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.5_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
22 changes: 22 additions & 0 deletions analysis/subsample_datastore_0.6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np

np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 91935291

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.6_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.6_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
22 changes: 22 additions & 0 deletions analysis/subsample_datastore_0.7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np

np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 107257840

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.7_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.7_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
22 changes: 22 additions & 0 deletions analysis/subsample_datastore_0.8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np

np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 122580388

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.8_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.8_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
22 changes: 22 additions & 0 deletions analysis/subsample_datastore_0.9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np

np.random.seed(0)
dstore_size = 153225485
vec_dim = 1024
subsample_size = 137902936

keys_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_keys.npy',
dtype=np.float16, mode='r', shape=(dstore_size, vec_dim))
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))

subsampled_idxs = np.random.choice(dstore_size, subsample_size, replace=False)

print('sampled idx created')
subsampled_keys_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.9_keys.npy',
dtype=np.float16, mode='w+', shape=(subsample_size, vec_dim))
subsampled_vals_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_subsampled_0.9_vals.npy',
dtype=np.int64, mode='w+', shape=(subsample_size, 1))

subsampled_keys_memmap[:] = keys_from_memmap[subsampled_idxs]
subsampled_vals_memmap[:] = vals_from_memmap[subsampled_idxs]
56 changes: 29 additions & 27 deletions build_dstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import faiss
import time


parser = argparse.ArgumentParser()
parser.add_argument('--dstore_mmap', type=str, help='memmap where keys and vals are stored')
parser.add_argument('--dstore_size', type=int, help='number of items saved in the datastore memmap')
parser.add_argument('--dimension', type=int, default=1024, help='Size of each key')
parser.add_argument('--dstore_fp16', '--dstore-fp16', default=False, action='store_true')
parser.add_argument('--seed', type=int, default=1, help='random seed for sampling the subset of vectors to train the cache')
parser.add_argument('--seed', type=int, default=1,
help='random seed for sampling the subset of vectors to train the cache')
parser.add_argument('--ncentroids', type=int, default=4096, help='number of centroids faiss should learn')
parser.add_argument('--code_size', type=int, default=64, help='size of quantized vectors')
parser.add_argument('--probe', type=int, default=8, help='number of clusters to query')
Expand All @@ -25,40 +25,42 @@
print(args)

if args.dstore_fp16:
keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float16, mode='r', shape=(args.dstore_size, args.dimension))
vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1))
keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float16, mode='r',
shape=(args.dstore_size, args.dimension))
vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1))
else:
keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='r', shape=(args.dstore_size, args.dimension))
vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1))
keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float32, mode='r',
shape=(args.dstore_size, args.dimension))
vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int64, mode='r', shape=(args.dstore_size, 1))

if not os.path.exists(args.faiss_index+".trained"):
# Initialize faiss index
metric = faiss.METRIC_L2 if args.metric == 'l2' else faiss.METRIC_INNER_PRODUCT
quantizer = faiss.IndexFlatL2(args.dimension) if args.metric == 'l2' else faiss.IndexFlatIP(args.dimension)
index = faiss.IndexIVFPQ(quantizer, args.dimension,
args.ncentroids, args.code_size, 8, metric)
index.nprobe = args.probe
# if not os.path.exists(args.faiss_index+".trained"):
# Initialize faiss index
metric = faiss.METRIC_L2 if args.metric == 'l2' else faiss.METRIC_INNER_PRODUCT
quantizer = faiss.IndexFlatL2(args.dimension) if args.metric == 'l2' else faiss.IndexFlatIP(args.dimension)
index = faiss.IndexIVFPQ(quantizer, args.dimension,
args.ncentroids, args.code_size, 8, metric)
index.nprobe = args.probe

print('Training Index')
np.random.seed(args.seed)
random_sample = np.random.choice(np.arange(vals.shape[0]), size=[min(1000000, vals.shape[0])], replace=False)
start = time.time()
# Faiss does not handle adding keys in fp16 as of writing this.
index.train(keys[random_sample].astype(np.float32))
print('Training took {} s'.format(time.time() - start))
print('Training Index')
np.random.seed(args.seed)
random_sample = np.random.choice(np.arange(vals.shape[0]), size=[min(1000000, vals.shape[0])], replace=False)
start = time.time()
# Faiss does not handle adding keys in fp16 as of writing this.
index.train(keys[random_sample].astype(np.float32))
print('Training took {} s'.format(time.time() - start))

print('Writing index after training')
start = time.time()
faiss.write_index(index, args.faiss_index+".trained")
print('Writing index took {} s'.format(time.time()-start))
print('Writing index after training')
start = time.time()
faiss.write_index(index, args.faiss_index + ".trained")
print('Writing index took {} s'.format(time.time() - start))

print('Adding Keys')
index = faiss.read_index(args.faiss_index+".trained")
index = faiss.read_index(args.faiss_index + ".trained")
start = args.starting_point

start_time = time.time()
while start < args.dstore_size:
end = min(args.dstore_size, start+args.num_keys_to_add_at_a_time)
end = min(args.dstore_size, start + args.num_keys_to_add_at_a_time)
to_add = keys[start:end].copy()
index.add_with_ids(to_add.astype(np.float32), np.arange(start, end))
start += args.num_keys_to_add_at_a_time
Expand All @@ -73,4 +75,4 @@
print('Writing Index')
start_time = time.time()
faiss.write_index(index, args.faiss_index)
print('Writing index took {} s'.format(time.time()-start_time))
print('Writing index took {} s'.format(time.time() - start_time))
5 changes: 2 additions & 3 deletions cluster/group_vecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
vals_from_memmap = np.memmap('checkpoints/wikitext103-bpe/dstore_vals.npy',
dtype=np.int64, mode='r', shape=(dstore_size, 1))


keys = np.zeros((dstore_size, vec_dim), dtype=np.float16)
vals = np.zeros((dstore_size, 1), dtype=np.int64)

Expand All @@ -24,6 +23,6 @@

vals = vals.squeeze()

#exclude 0
# exclude 0
for word_id in tqdm.tqdm(range(1, len(dictionary))):
np.save('dstore/ids/' + str(word_id) + '.npy', keys[vals==word_id])
np.save('dstore/ids/' + str(word_id) + '.npy', keys[vals == word_id])
2 changes: 2 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ def add_eval_lm_args(parser):
help='tokens npy file path to save')
parser.add_argument('--save-knn-scores', type=str, default='',
help='knn scores npy file path to save')
parser.add_argument('--save-queries', type=str, default='',
help='query vectors npy file path to save')

# fmt: on

Expand Down
Loading

0 comments on commit ea6e402

Please sign in to comment.