Skip to content

Commit

Permalink
add gpu flag
Browse files Browse the repository at this point in the history
  • Loading branch information
urialon committed Mar 18, 2022
1 parent 8afab92 commit 0c1d86c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
45 changes: 28 additions & 17 deletions fairseq/knnlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import faiss
import faiss.contrib.torch_utils
import math
import numpy as np
from fairseq import utils
Expand Down Expand Up @@ -27,16 +28,25 @@ def setup_faiss(self, args):
print('Reading datastore took {} s'.format(time.time() - start))
index.nprobe = args.probe

if args.knnlm_gpu:
start = time.time()
co = faiss.GpuClonerOptions()
co.useFloat16 = True
index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index, co)
print('Moving index to GPU took {} s'.format(time.time() - start))

dstore_float_dtype = np.float32
dstore_int_dtype = np.int32
if args.dstore_fp16:
print('Keys are fp16 and vals are int16')
if not args.no_load_keys:
self.keys = np.memmap(args.dstore_filename+'_keys.npy', dtype=np.float16, mode='r', shape=(self.dstore_size, self.dimension))
self.vals = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int16, mode='r', shape=(self.dstore_size, 1))
else:
print('Keys are fp32 and vals are int64')
if not args.no_load_keys:
self.keys = np.memmap(args.dstore_filename+'_keys.npy', dtype=np.float32, mode='r', shape=(self.dstore_size, self.dimension))
self.vals = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int, mode='r', shape=(self.dstore_size, 1))
print('Keys are fp16')
dstore_float_dtype = np.float16

if not args.no_load_keys:
self.keys = np.memmap(args.dstore_filename+'_keys.npy', dtype=dstore_float_dtype, mode='r',
shape=(self.dstore_size, self.dimension))
self.vals = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int64, mode='r',
shape=(self.dstore_size, 1))
# self.vals = torch.from_numpy(self.vals).to(self.device)

# If you wish to load all the keys into memory
# CAUTION: Only do this if your RAM can handle it!
Expand All @@ -46,16 +56,17 @@ def setup_faiss(self, args):

if not args.no_load_keys:
del self.keys
self.keys_from_memmap = np.memmap(args.dstore_filename+'_keys.npy', dtype=np.float32, mode='r', shape=(self.dstore_size, self.dimension))
self.keys = np.zeros((self.dstore_size, self.dimension), dtype=np.float16 if args.dstore_fp16 else np.float32)
self.keys_from_memmap = np.memmap(args.dstore_filename+'_keys.npy', dtype=dstore_float_dtype, mode='r', shape=(self.dstore_size, self.dimension))
self.keys = np.zeros((self.dstore_size, self.dimension), dtype=dstore_float_dtype)
self.keys = self.keys_from_memmap[:]
self.keys = self.keys.astype(np.float16 if args.dstore_fp16 else np.float32)
self.keys = self.keys.astype(dstore_float_dtype)

del self.vals
self.vals_from_memmap = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int, mode='r', shape=(self.dstore_size, 1))
self.vals = np.zeros((self.dstore_size, 1), dtype=np.int16 if args.dstore_fp16 else np.int)
self.vals = self.vals_from_memmap[:]
self.vals = self.vals.astype(np.int16 if args.dstore_fp16 else np.int)
vals_from_memmap = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int, mode='r', shape=(self.dstore_size, 1))
self.vals = np.zeros((self.dstore_size, 1), dtype=np.int64)
self.vals = vals_from_memmap[:]
self.vals = self.vals.astype(dstore_int_dtype)
del vals_from_memmap
print('Loading to memory took {} s'.format(time.time() - start))

return index
Expand Down Expand Up @@ -110,7 +121,7 @@ def dist_func(d, k, q, function=None):

# (T_reducedxB)
yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1).clone()
full_yhat_knn_prob = torch.full([qshape[0]*qshape[1]], -10000).cuda()
full_yhat_knn_prob = torch.full([qshape[0]*qshape[1]], -10000.0).cuda()
full_yhat_knn_prob[tgt != pad_idx] = yhat_knn_prob

# TxBx1
Expand Down
1 change: 1 addition & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def add_eval_lm_args(parser):
help='helpful for certain ops that are only used during eval')
group.add_argument('--knnlm', action='store_true',
help='use the k-nearest neighbors language model')
group.add_argument('--knnlm-gpu', action='store_true')
group.add_argument('--save-knnlm-dstore', action='store_true',
help='save keys for the knnlm datastore')
group.add_argument('--dstore-mmap', default=None, type=str,
Expand Down

0 comments on commit 0c1d86c

Please sign in to comment.