Skip to content

Commit

Permalink
main change: change TxB to BxT
Browse files Browse the repository at this point in the history
  • Loading branch information
urialon committed May 5, 2022
1 parent 0c1d86c commit 89a29d1
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 5 deletions.
4 changes: 4 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ def add_eval_lm_args(parser):
help='save keys for the knnlm datastore')
group.add_argument('--dstore-mmap', default=None, type=str,
help='If saving knnlm dstore, save keys and values to this file')
group.add_argument('--min-knns', default=1, type=int)
group.add_argument('--max-knns', default=None, type=int)
group.add_argument('--local', action='store_true')
group.add_argument('--no-pointer', action='store_true')
# fmt: on


Expand Down
94 changes: 90 additions & 4 deletions fairseq/sequence_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch
import sys
import numpy as np
import time
import scipy.sparse as sp
import pickle

from fairseq import utils
from fairseq.data import Dictionary
Expand All @@ -22,6 +23,26 @@ def __init__(self, tgt_dict, softmax_batch=None, compute_alignment=False, args=N
assert self.softmax_batch > 0
self.compute_alignment = compute_alignment
self.args = args
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if args.cluster is None:
self.cluster = None
print('No clustering is used.')
else:
self.cluster = np.load(args.cluster)

if args.members is None:
self.members = None
print('No cluster-members file is used.')
else:
with open(args.members, 'rb') as file:
self.members = pickle.load(file)

if self.members is None or self.cluster is None:
self.extend_pointers_using_clusters = lambda pointers: pointers
self.max_knns = self.args.max_knns if self.args.max_knns is not None else self.args.dstore_size



@torch.no_grad()
def generate(self, models, sample, **kwargs):
Expand Down Expand Up @@ -99,10 +120,10 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff):
raise ValueError('Only knn *log* probs are supported.')

yhat_knn_prob = dstore.get_knn_log_prob(
queries,
orig_target.permute(1, 0),
queries.permute(1, 0, 2),
orig_target,
pad_idx=self.pad)
yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1)
yhat_knn_prob = yhat_knn_prob.squeeze(-1)
if self.args.fp16:
yhat_knn_prob = yhat_knn_prob.half()
probs = probs.half()
Expand Down Expand Up @@ -159,3 +180,68 @@ def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff):
'dstore_keys': decoder_out[1][self.args.knn_keytype][start_idxs[i]:,i,:] if self.args.save_knnlm_dstore else None,
}])
return hypos

@torch.no_grad()
def score_with_knnlm(self, hypos, dstore):
# TxBxC
for hypo in hypos:
hypo = hypo[0]
queries = hypo['queries'] # (time, dim)
orig_target = hypo['tokens'] # (time, )
lm_probs = hypo['positional_scores'] # (time, )
if self.args.fp16:
lm_probs = lm_probs.half()

cur_knns = np.array([], dtype=np.int64)
cur_dists = np.array([], dtype=np.float32)

probs_per_timestep = []
for i in range(queries.size(0)):
perform_search = False
extended_pointers = None
cur_knns = cur_knns[cur_dists.argsort()[::-1]]
pointers = cur_knns + 1

if self.args.no_pointer or cur_knns.size < self.args.min_knns:
perform_search = True

extended_pointers = pointers
if pointers.size >= self.max_knns:
extended_pointers = extended_pointers[:self.max_knns]
elif pointers.size > 0 and not self.args.no_pointer:
extended_pointers = self.extend_pointers_using_clusters(pointers)

cur_knn_log_prob, knns, correct_vals_mask, dists = dstore.get_knn_log_prob(
queries[i,:].unsqueeze(0),
orig_target[i].unsqueeze(0),
pointers=None if self.args.no_pointer else extended_pointers.reshape(1, -1),
perform_search=perform_search)

if self.args.fp16:
cur_knn_log_prob = cur_knn_log_prob.half()

if not self.args.no_pointer:
vals_are_correct_and_pointer_available = correct_vals_mask & (knns < self.args.dstore_size - 1)
cur_knns = knns[vals_are_correct_and_pointer_available]
cur_dists = dists[vals_are_correct_and_pointer_available]

combined = self.combine_knn_and_vocab_probs(
cur_knn_log_prob, lm_probs[i].unsqueeze(0), self.args.lmbda, dstore)
probs_per_timestep.append(combined[0])


hypo['positional_scores'] = torch.as_tensor(probs_per_timestep)
return hypos

def extend_pointers_using_clusters(self, pointers):
# Don't take the same cluster twice
clusters, cluster_counts = np.unique(self.cluster[pointers], return_counts=True)
# Take smaller clusters first
clusters = clusters[np.argsort(-cluster_counts)]
members = np.nonzero(self.members[clusters])[1]
# Prefer datastore entries that were directly pointed to by the previous time step's
# datastore entries, over other members of their cluster
extended_pointers = np.concatenate([pointers, members])
if len(extended_pointers) > self.max_knns:
extended_pointers = extended_pointers[:self.max_knns]
return extended_pointers
5 changes: 4 additions & 1 deletion fairseq_cli/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import math
import os

if '--local' in os.sys.argv:
import faiss

import torch
import numpy as np

Expand Down Expand Up @@ -160,7 +163,7 @@ def main(parsed_args):
if args.dstore_fp16:
print('Saving fp16')
dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float16, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim))
dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int16, mode='w+', shape=(args.dstore_size, 1))
dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1))
else:
print('Saving fp32')
dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim))
Expand Down

0 comments on commit 89a29d1

Please sign in to comment.