diff --git a/fairseq/knnlm.py b/fairseq/knnlm.py index 310f025..79e0345 100644 --- a/fairseq/knnlm.py +++ b/fairseq/knnlm.py @@ -1,5 +1,6 @@ import torch import faiss +import faiss.contrib.torch_utils import math import numpy as np from fairseq import utils @@ -27,16 +28,25 @@ def setup_faiss(self, args): print('Reading datastore took {} s'.format(time.time() - start)) index.nprobe = args.probe + if args.knnlm_gpu: + start = time.time() + co = faiss.GpuClonerOptions() + co.useFloat16 = True + index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index, co) + print('Moving index to GPU took {} s'.format(time.time() - start)) + + dstore_float_dtype = np.float32 + dstore_int_dtype = np.int32 if args.dstore_fp16: - print('Keys are fp16 and vals are int16') - if not args.no_load_keys: - self.keys = np.memmap(args.dstore_filename+'_keys.npy', dtype=np.float16, mode='r', shape=(self.dstore_size, self.dimension)) - self.vals = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int16, mode='r', shape=(self.dstore_size, 1)) - else: - print('Keys are fp32 and vals are int64') - if not args.no_load_keys: - self.keys = np.memmap(args.dstore_filename+'_keys.npy', dtype=np.float32, mode='r', shape=(self.dstore_size, self.dimension)) - self.vals = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int, mode='r', shape=(self.dstore_size, 1)) + print('Keys are fp16') + dstore_float_dtype = np.float16 + + if not args.no_load_keys: + self.keys = np.memmap(args.dstore_filename+'_keys.npy', dtype=dstore_float_dtype, mode='r', + shape=(self.dstore_size, self.dimension)) + self.vals = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int64, mode='r', + shape=(self.dstore_size, 1)) + # self.vals = torch.from_numpy(self.vals).to(self.device) # If you wish to load all the keys into memory # CAUTION: Only do this if your RAM can handle it! @@ -46,16 +56,17 @@ def setup_faiss(self, args): if not args.no_load_keys: del self.keys - self.keys_from_memmap = np.memmap(args.dstore_filename+'_keys.npy', dtype=np.float32, mode='r', shape=(self.dstore_size, self.dimension)) - self.keys = np.zeros((self.dstore_size, self.dimension), dtype=np.float16 if args.dstore_fp16 else np.float32) + self.keys_from_memmap = np.memmap(args.dstore_filename+'_keys.npy', dtype=dstore_float_dtype, mode='r', shape=(self.dstore_size, self.dimension)) + self.keys = np.zeros((self.dstore_size, self.dimension), dtype=dstore_float_dtype) self.keys = self.keys_from_memmap[:] - self.keys = self.keys.astype(np.float16 if args.dstore_fp16 else np.float32) + self.keys = self.keys.astype(dstore_float_dtype) del self.vals - self.vals_from_memmap = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int, mode='r', shape=(self.dstore_size, 1)) - self.vals = np.zeros((self.dstore_size, 1), dtype=np.int16 if args.dstore_fp16 else np.int) - self.vals = self.vals_from_memmap[:] - self.vals = self.vals.astype(np.int16 if args.dstore_fp16 else np.int) + vals_from_memmap = np.memmap(args.dstore_filename+'_vals.npy', dtype=np.int, mode='r', shape=(self.dstore_size, 1)) + self.vals = np.zeros((self.dstore_size, 1), dtype=np.int64) + self.vals = vals_from_memmap[:] + self.vals = self.vals.astype(dstore_int_dtype) + del vals_from_memmap print('Loading to memory took {} s'.format(time.time() - start)) return index @@ -110,7 +121,7 @@ def dist_func(d, k, q, function=None): # (T_reducedxB) yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1).clone() - full_yhat_knn_prob = torch.full([qshape[0]*qshape[1]], -10000).cuda() + full_yhat_knn_prob = torch.full([qshape[0]*qshape[1]], -10000.0).cuda() full_yhat_knn_prob[tgt != pad_idx] = yhat_knn_prob # TxBx1 diff --git a/fairseq/options.py b/fairseq/options.py index 186c195..69f5cca 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -487,6 +487,7 @@ def add_eval_lm_args(parser): help='helpful for certain ops that are only used during eval') group.add_argument('--knnlm', action='store_true', help='use the k-nearest neighbors language model') + group.add_argument('--knnlm-gpu', action='store_true') group.add_argument('--save-knnlm-dstore', action='store_true', help='save keys for the knnlm datastore') group.add_argument('--dstore-mmap', default=None, type=str,