-
Notifications
You must be signed in to change notification settings - Fork 0
/
cache_paper_embeddings.py
28 lines (23 loc) · 822 Bytes
/
cache_paper_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys
import torch
import json
import pickle
import gzip
from suggest_utils import calc_reviewer_db_mapping, print_text_report, print_progress
from suggest_reviewers import create_embeddings, calc_similarity_matrix
from models import load_model
import torch
accepted_submissions = pickle.load(open("../cached_or.pkl", "br"))
# Load the model
abstracts = []
abstract_keys = list(accepted_submissions.keys())
for k, v in accepted_submissions.items():
abstracts.append(v.content["abstract"])
conf_abs = abstracts
print('Loading model', file=sys.stderr)
model, epoch = load_model(None, "scratch/similarity-model.pt")
model.eval()
assert not model.training
# Get recommendations within the conference
paper_embs = create_embeddings(model, conf_abs)
pickle.dump(paper_embs, open("paper_embeddings.pkl", "bw"))