Skip to content

Commit

Permalink
update subsample experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed Sep 29, 2022
1 parent ea6e402 commit bff8d61
Show file tree
Hide file tree
Showing 10 changed files with 612 additions and 149 deletions.
17 changes: 17 additions & 0 deletions analysis/compare_faiss_real.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np

tokens = np.load('tokens.npy')

faiss_knns = np.load('dstore_subsampled_0.1_faiss_mask_knns.npy')
real_knns = np.load('dstore_subsampled_0.1_real_mask_knns.npy')

print(faiss_knns.shape)
print(real_knns.shape)
print(tokens.shape)

matched = 0
for idx in range(len(tokens)):
matched += len(set(faiss_knns[idx]).intersection(set(real_knns[idx])))

print(matched)
print(matched/faiss_knns.size)
70 changes: 70 additions & 0 deletions analysis/faiss_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import time
import torch
import torch.nn.functional as F
import tqdm
import faiss


@torch.jit.script
def my_cdist(x1, x2, x2_norm):
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
return res


topk = 1024
dimension = 1024
dstore_filename = "checkpoints/wikitext103-bpe/dstore_subsampled_0.05"
dstore_size = 7661274
indexfile = "checkpoints/wikitext103-bpe/knn_subsampled_0.05.index"


index = faiss.read_index(indexfile, faiss.IO_FLAG_ONDISK_SAME_DIR)
index.nprobe = 32

print('Loading tokens...')
tokens = np.load('tokens.npy')
tokens = torch.from_numpy(tokens)

print('Loading queries...')
queries = np.load('all_queries.npy').astype(np.float16)
queries = torch.from_numpy(queries)

assert len(queries) == len(tokens)
print(queries.dtype)

keys_from_memmap = np.memmap(dstore_filename + '_keys.npy', dtype=np.float16, mode='r',
shape=(dstore_size, dimension))
vals_from_memmap = np.memmap(dstore_filename + '_vals.npy', dtype=np.int64, mode='r',
shape=(dstore_size, 1))
print('Loading to memory...')
start = time.time()

vals = vals_from_memmap[:]
vals = vals.astype(np.int64)

print('Loading to memory took {} s'.format(time.time() - start))

batch_size = 20000
num_batches = len(queries) // batch_size + 1

all_probs = []
all_knns = []

for batch_idx in tqdm.tqdm(range(num_batches)):
batch_queries = queries[batch_idx * batch_size:(batch_idx + 1) * batch_size]
tgt = tokens[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
dists, knns = index.search(batch_queries.float().numpy(), topk)
all_knns.append(knns)
dists = torch.from_numpy(-1*dists).cuda()
probs = F.log_softmax(dists, dim=-1)
index_mask = torch.eq(torch.from_numpy(vals[knns]).long().cuda().squeeze(-1), tgt.unsqueeze(-1)).float()
index_mask[index_mask == 0] = -10000 # for stability
index_mask[index_mask == 1] = 0
yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1)
all_probs.append(yhat_knn_prob.cpu().numpy())


np.save(dstore_filename.split('/')[-1] + '_faiss_mask.npy', np.concatenate(all_probs))
np.save(dstore_filename.split('/')[-1] + '_faiss_mask_knns.npy', np.concatenate(all_knns, axis=0))
70 changes: 70 additions & 0 deletions analysis/faiss_mask_0.1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import time
import torch
import torch.nn.functional as F
import tqdm
import faiss


@torch.jit.script
def my_cdist(x1, x2, x2_norm):
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
return res


topk = 1024
dimension = 1024
dstore_filename = "checkpoints/wikitext103-bpe/dstore_subsampled_0.1"
dstore_size = 15322548
indexfile = "checkpoints/wikitext103-bpe/knn_subsampled_0.1.index"


index = faiss.read_index(indexfile, faiss.IO_FLAG_ONDISK_SAME_DIR)
index.nprobe = 32

print('Loading tokens...')
tokens = np.load('tokens.npy')
tokens = torch.from_numpy(tokens)

print('Loading queries...')
queries = np.load('all_queries.npy').astype(np.float16)
queries = torch.from_numpy(queries)

assert len(queries) == len(tokens)
print(queries.dtype)

keys_from_memmap = np.memmap(dstore_filename + '_keys.npy', dtype=np.float16, mode='r',
shape=(dstore_size, dimension))
vals_from_memmap = np.memmap(dstore_filename + '_vals.npy', dtype=np.int64, mode='r',
shape=(dstore_size, 1))
print('Loading to memory...')
start = time.time()

vals = vals_from_memmap[:]
vals = vals.astype(np.int64)

print('Loading to memory took {} s'.format(time.time() - start))

batch_size = 20000
num_batches = len(queries) // batch_size + 1

all_probs = []
all_knns = []

for batch_idx in tqdm.tqdm(range(num_batches)):
batch_queries = queries[batch_idx * batch_size:(batch_idx + 1) * batch_size]
tgt = tokens[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
dists, knns = index.search(batch_queries.float().numpy(), topk)
all_knns.append(knns)
dists = torch.from_numpy(-1*dists).cuda()
probs = F.log_softmax(dists, dim=-1)
index_mask = torch.eq(torch.from_numpy(vals[knns]).long().cuda().squeeze(-1), tgt.unsqueeze(-1)).float()
index_mask[index_mask == 0] = -10000 # for stability
index_mask[index_mask == 1] = 0
yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1)
all_probs.append(yhat_knn_prob.cpu().numpy())


np.save(dstore_filename.split('/')[-1] + '_faiss_mask.npy', np.concatenate(all_probs))
np.save(dstore_filename.split('/')[-1] + '_faiss_mask_knns.npy', np.concatenate(all_knns, axis=0))
8 changes: 5 additions & 3 deletions analysis/kv_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@
'0.7_knn_recomp_scores.npy',
'0.8_knn_recomp_scores.npy',
'0.9_knn_recomp_scores.npy',
'0.05_knn_scores_2.npy',
'0.1_knn_scores_2.npy',
'0.2_knn_scores_2.npy',
'dstore_subsampled_0.05_real_mask.npy',
'dstore_subsampled_0.05_real_mask_fp32.npy',
'dstore_subsampled_0.05_faiss_mask.npy',
'dstore_subsampled_0.1_real_mask.npy',
'dstore_subsampled_0.1_faiss_mask.npy',
]

# extra_score_files = [f'epoch_scores/3v-att-init-finetune-epoch{e}.npy' for e in range(1, 40)]
Expand Down
78 changes: 78 additions & 0 deletions analysis/real_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import time
import torch
import torch.nn.functional as F
import tqdm


@torch.jit.script
def my_cdist(x1, x2, x2_norm):
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
res = res.clamp_min_(1e-20)
return res


topk = 1024
dimension = 1024
dstore_filename = "checkpoints/wikitext103-bpe/dstore_subsampled_0.05"
dstore_size = 7661274

print('Loading tokens...')
tokens = np.load('tokens.npy')
tokens = torch.from_numpy(tokens)

print('Loading queries...')
queries = np.load('all_queries.npy').astype(np.float16)
queries = torch.from_numpy(queries)

assert len(queries) == len(tokens)
print(queries.dtype)

keys_from_memmap = np.memmap(dstore_filename + '_keys.npy', dtype=np.float16, mode='r',
shape=(dstore_size, dimension))
vals_from_memmap = np.memmap(dstore_filename + '_vals.npy', dtype=np.int64, mode='r',
shape=(dstore_size, 1))
print('Loading to memory...')
start = time.time()

keys = keys_from_memmap[:]
keys = keys.astype(np.float16)

vals = vals_from_memmap[:]
vals = vals.astype(np.int64)

keys = torch.from_numpy(keys)
vals = torch.from_numpy(vals)

print('Loading to memory took {} s'.format(time.time() - start))

keys_norm = keys.cuda().pow_(2).sum(dim=-1, keepdim=True)
keys = keys.cuda()
vals = vals.cuda()
print(vals.dtype)
batch_size = 200
num_batches = len(queries) // batch_size + 1

all_probs = []
all_knns = []

for batch_idx in tqdm.tqdm(range(num_batches)):
batch_queries = queries[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
tgt = tokens[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
distances = -my_cdist(batch_queries, keys, keys_norm)
res = distances.topk(topk, dim=1)
knns = res.indices
all_knns.append(knns.cpu().numpy())
dists = res.values
probs = F.log_softmax(dists, dim=-1)
index_mask = torch.eq(vals[knns].squeeze(-1), tgt.unsqueeze(-1)).float()
index_mask[index_mask == 0] = -10000 # for stability
index_mask[index_mask == 1] = 0
yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1)
all_probs.append(yhat_knn_prob.cpu().numpy())


np.save(dstore_filename.split('/')[-1] + '_real_mask.npy', np.concatenate(all_probs))
np.save(dstore_filename.split('/')[-1] + '_real_mask_knns.npy', np.concatenate(all_knns, axis=0))

77 changes: 77 additions & 0 deletions analysis/real_mask_0.1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import time
import torch
import torch.nn.functional as F
import tqdm


@torch.jit.script
def my_cdist(x1, x2, x2_norm):
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
return res


topk = 1024
dimension = 1024
dstore_filename = "checkpoints/wikitext103-bpe/dstore_subsampled_0.1"
dstore_size = 15322548

print('Loading tokens...')
tokens = np.load('tokens.npy')
tokens = torch.from_numpy(tokens)

print('Loading queries...')
queries = np.load('all_queries.npy').astype(np.float16)
queries = torch.from_numpy(queries)

assert len(queries) == len(tokens)
print(queries.dtype)

keys_from_memmap = np.memmap(dstore_filename + '_keys.npy', dtype=np.float16, mode='r',
shape=(dstore_size, dimension))
vals_from_memmap = np.memmap(dstore_filename + '_vals.npy', dtype=np.int64, mode='r',
shape=(dstore_size, 1))
print('Loading to memory...')
start = time.time()

keys = keys_from_memmap[:]
keys = keys.astype(np.float16)

vals = vals_from_memmap[:]
vals = vals.astype(np.int64)

keys = torch.from_numpy(keys)
vals = torch.from_numpy(vals)

print('Loading to memory took {} s'.format(time.time() - start))

keys_norm = keys.cuda().pow_(2).sum(dim=-1, keepdim=True)
keys = keys.cuda()
vals = vals.cuda()

batch_size = 100
num_batches = len(queries) // batch_size + 1

all_probs = []
all_knns = []

for batch_idx in tqdm.tqdm(range(num_batches)):
batch_queries = queries[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
tgt = tokens[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
current_batch_size = len(batch_queries)
distances = -my_cdist(batch_queries, keys, keys_norm)
res = distances.topk(topk, dim=1)
knns = res.indices
all_knns.append(knns.cpu().numpy())
dists = res.values
probs = F.log_softmax(dists, dim=-1)
index_mask = torch.eq(vals[knns].squeeze(-1), tgt.unsqueeze(-1)).float()
index_mask[index_mask == 0] = -10000 # for stability
index_mask[index_mask == 1] = 0
yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1)
all_probs.append(yhat_knn_prob.cpu().numpy())


np.save(dstore_filename.split('/')[-1] + '_real_mask.npy', np.concatenate(all_probs))
np.save(dstore_filename.split('/')[-1] + '_real_mask_knns.npy', np.concatenate(all_knns, axis=0))
Loading

0 comments on commit bff8d61

Please sign in to comment.