-
Notifications
You must be signed in to change notification settings - Fork 0
/
suggest_reviewers.py
58 lines (50 loc) · 2.03 KB
/
suggest_reviewers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import json
from model_utils import Example, unk_string
from sacremoses import MosesTokenizer
import argparse
from models import load_model
import numpy as np
import cvxpy as cp
import sys, math
from suggest_utils import calc_reviewer_db_mapping, print_text_report, print_progress
BATCH_SIZE = 128
entok = MosesTokenizer(lang='en')
def create_embeddings(model, examps):
"""Embed textual examples
:param examps: A list of text to embed
:return: A len(examps) by embedding size numpy matrix of embeddings
"""
# Preprocess examples
print(f'Preprocessing {len(examps)} examples (.={BATCH_SIZE} examples)', file=sys.stderr)
data = []
for i, line in enumerate(examps):
p1 = " ".join(entok.tokenize(line, escape=False)).lower()
if model.sp is not None:
p1 = model.sp.EncodeAsPieces(p1)
p1 = " ".join(p1)
wp1 = Example(p1)
wp1.populate_embeddings(model.vocab, model.zero_unk, model.args.ngrams)
if len(wp1.embeddings) == 0:
wp1.embeddings.append(model.vocab[unk_string])
data.append(wp1)
print_progress(i, BATCH_SIZE)
print("", file=sys.stderr)
# Create embeddings
print(f'Embedding {len(examps)} examples (.={BATCH_SIZE} examples)', file=sys.stderr)
embeddings = np.zeros( (len(examps), model.args.dim) )
for i in range(0, len(data), BATCH_SIZE):
max_idx = min(i+BATCH_SIZE,len(data))
curr_batch = data[i:max_idx]
wx1, wl1 = model.torchify_batch(curr_batch)
vecs = model.encode(wx1, wl1)
vecs = vecs.detach().cpu().numpy()
vecs = vecs / np.sqrt((vecs * vecs).sum(axis=1))[:, None] #normalize for NN search
embeddings[i:max_idx] = vecs
print_progress(i, BATCH_SIZE)
print("", file=sys.stderr)
return embeddings
def calc_similarity_matrix(model, db, quer):
db_emb = create_embeddings(model, db)
quer_emb = create_embeddings(model, quer)
print(f'Performing similarity calculation', file=sys.stderr)
return np.matmul(quer_emb, np.transpose(db_emb))