From 89a29d1ac6e8c1360637aa1bfe77a1be227e83cc Mon Sep 17 00:00:00 2001 From: urialon Date: Thu, 5 May 2022 14:47:25 -0400 Subject: [PATCH] main change: change TxB to BxT --- fairseq/options.py | 4 ++ fairseq/sequence_scorer.py | 94 ++++++++++++++++++++++++++++++++++++-- fairseq_cli/eval_lm.py | 5 +- 3 files changed, 98 insertions(+), 5 deletions(-) diff --git a/fairseq/options.py b/fairseq/options.py index 69f5cca..f0950cf 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -492,6 +492,10 @@ def add_eval_lm_args(parser): help='save keys for the knnlm datastore') group.add_argument('--dstore-mmap', default=None, type=str, help='If saving knnlm dstore, save keys and values to this file') + group.add_argument('--min-knns', default=1, type=int) + group.add_argument('--max-knns', default=None, type=int) + group.add_argument('--local', action='store_true') + group.add_argument('--no-pointer', action='store_true') # fmt: on diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index e09fcf4..6b1c74d 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -6,7 +6,8 @@ import torch import sys import numpy as np -import time +import scipy.sparse as sp +import pickle from fairseq import utils from fairseq.data import Dictionary @@ -22,6 +23,26 @@ def __init__(self, tgt_dict, softmax_batch=None, compute_alignment=False, args=N assert self.softmax_batch > 0 self.compute_alignment = compute_alignment self.args = args + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if args.cluster is None: + self.cluster = None + print('No clustering is used.') + else: + self.cluster = np.load(args.cluster) + + if args.members is None: + self.members = None + print('No cluster-members file is used.') + else: + with open(args.members, 'rb') as file: + self.members = pickle.load(file) + + if self.members is None or self.cluster is None: + self.extend_pointers_using_clusters = lambda pointers: pointers + self.max_knns = self.args.max_knns if self.args.max_knns is not None else self.args.dstore_size + + @torch.no_grad() def generate(self, models, sample, **kwargs): @@ -99,10 +120,10 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): raise ValueError('Only knn *log* probs are supported.') yhat_knn_prob = dstore.get_knn_log_prob( - queries, - orig_target.permute(1, 0), + queries.permute(1, 0, 2), + orig_target, pad_idx=self.pad) - yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1) + yhat_knn_prob = yhat_knn_prob.squeeze(-1) if self.args.fp16: yhat_knn_prob = yhat_knn_prob.half() probs = probs.half() @@ -159,3 +180,68 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff): 'dstore_keys': decoder_out[1][self.args.knn_keytype][start_idxs[i]:,i,:] if self.args.save_knnlm_dstore else None, }]) return hypos + + @torch.no_grad() + def score_with_knnlm(self, hypos, dstore): + # TxBxC + for hypo in hypos: + hypo = hypo[0] + queries = hypo['queries'] # (time, dim) + orig_target = hypo['tokens'] # (time, ) + lm_probs = hypo['positional_scores'] # (time, ) + if self.args.fp16: + lm_probs = lm_probs.half() + + cur_knns = np.array([], dtype=np.int64) + cur_dists = np.array([], dtype=np.float32) + + probs_per_timestep = [] + for i in range(queries.size(0)): + perform_search = False + extended_pointers = None + cur_knns = cur_knns[cur_dists.argsort()[::-1]] + pointers = cur_knns + 1 + + if self.args.no_pointer or cur_knns.size < self.args.min_knns: + perform_search = True + + extended_pointers = pointers + if pointers.size >= self.max_knns: + extended_pointers = extended_pointers[:self.max_knns] + elif pointers.size > 0 and not self.args.no_pointer: + extended_pointers = self.extend_pointers_using_clusters(pointers) + + cur_knn_log_prob, knns, correct_vals_mask, dists = dstore.get_knn_log_prob( + queries[i,:].unsqueeze(0), + orig_target[i].unsqueeze(0), + pointers=None if self.args.no_pointer else extended_pointers.reshape(1, -1), + perform_search=perform_search) + + if self.args.fp16: + cur_knn_log_prob = cur_knn_log_prob.half() + + if not self.args.no_pointer: + vals_are_correct_and_pointer_available = correct_vals_mask & (knns < self.args.dstore_size - 1) + cur_knns = knns[vals_are_correct_and_pointer_available] + cur_dists = dists[vals_are_correct_and_pointer_available] + + combined = self.combine_knn_and_vocab_probs( + cur_knn_log_prob, lm_probs[i].unsqueeze(0), self.args.lmbda, dstore) + probs_per_timestep.append(combined[0]) + + + hypo['positional_scores'] = torch.as_tensor(probs_per_timestep) + return hypos + + def extend_pointers_using_clusters(self, pointers): + # Don't take the same cluster twice + clusters, cluster_counts = np.unique(self.cluster[pointers], return_counts=True) + # Take smaller clusters first + clusters = clusters[np.argsort(-cluster_counts)] + members = np.nonzero(self.members[clusters])[1] + # Prefer datastore entries that were directly pointed to by the previous time step's + # datastore entries, over other members of their cluster + extended_pointers = np.concatenate([pointers, members]) + if len(extended_pointers) > self.max_knns: + extended_pointers = extended_pointers[:self.max_knns] + return extended_pointers \ No newline at end of file diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 200f870..13ab4a2 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -12,6 +12,9 @@ import math import os +if '--local' in os.sys.argv: + import faiss + import torch import numpy as np @@ -160,7 +163,7 @@ def main(parsed_args): if args.dstore_fp16: print('Saving fp16') dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float16, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) - dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int16, mode='w+', shape=(args.dstore_size, 1)) + dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1)) else: print('Saving fp32') dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim))