-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlda_example.py
58 lines (49 loc) · 2.09 KB
/
lda_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import print_function
from collections import defaultdict
from lda import LatentDirichletAllocation
def load_data(path="./data/nips.txt", min_term_occ=5):
docs, vocabulary = [], []
term_occurrences = defaultdict(int)
try:
for line in open(path).readlines():
doc = line.split(', ')
docs.append(doc)
for term in doc:
term_occurrences[term] += 1
vocabulary = [term for term, occ in term_occurrences.items()
if occ >= min_term_occ]
inv_voc = {term: t for t, term in enumerate(vocabulary)}
for d, doc in enumerate(docs):
docs[d] = [inv_voc[t] for t in doc if t in inv_voc]
except Exception as e:
print("Failed to load dataset with exception: ", e)
return docs, vocabulary
def main():
print("Loading NIPS dataset.")
docs, vocabulary = load_data()
print("NIPS dataset loaded.")
# test training
lda = LatentDirichletAllocation(num_topics=50)
print("Initialized LDA. Estimating parameters with 50 iterations.")
lda.train(docs=docs, vocabulary=vocabulary, num_iterations=50)
lda.save_parameters()
print("Saved LDA Parameters.")
# test inference model after training
topics = lda.infer_doc(docs[0])
s = sorted(range(len(topics)), key=topics.__getitem__, reverse=True)
print(["(Top {} : {})".format(w, topics[w]) for w in s[:5]])
print("\n\nComparing to doc in training set..")
topics = lda.doc_topic_matrix[0]
s = sorted(range(len(topics)), key=topics.__getitem__, reverse=True)
print(["(Top {} : {})".format(w, topics[w]) for w in s[:5]])
# Save topic terms to file
lda.save_topic_terms()
print("Topic Terms file saved to ./data/topic_terms.txt")
# Load model from scratch from file
lda = LatentDirichletAllocation(model_path="./data/ttm.mat")
topics = lda.infer_doc(docs[0])
print("\n\nInferred model:")
s = sorted(range(len(topics)), key=topics.__getitem__, reverse=True)
print(["(Top {} : {})".format(w, topics[w]) for w in s[:5]])
if __name__ == "__main__":
main()