forked from PaddlePaddle/PaddleHub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodule.py
193 lines (158 loc) · 6.61 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import paddlehub as hub
from paddlehub.module.module import moduleinfo
from paddlehub.common.logger import logger
from lda_news.inference_engine import InferenceEngine
from lda_news.document import LDADoc, SLDADoc
from lda_news.semantic_matching import SemanticMatching, WordAndDis
from lda_news.tokenizer import LACTokenizer, SimpleTokenizer
from lda_news.config import ModelType
from lda_news.vocab import Vocab, WordCount
@moduleinfo(
name="lda_news",
version="1.0.2",
summary=
"This is a PaddleHub Module for LDA topic model in news dataset, where we can calculate doc distance, calculate the similarity between query and document, etc",
author="DesmonDay",
author_email="",
type="nlp/semantic_model")
class TopicModel(hub.Module):
def _initialize(self):
"""
Initialize with the necessary elements.
"""
self.model_dir = os.path.join(self.directory, 'news')
self.conf_file = 'lda.conf'
self.__engine = InferenceEngine(self.model_dir, self.conf_file)
self.vocab_path = os.path.join(self.model_dir, 'vocab_info.txt')
lac = hub.Module(name="lac")
# self.__tokenizer = SimpleTokenizer(self.vocab_path)
self.__tokenizer = LACTokenizer(self.vocab_path, lac)
self.vocabulary = self.__engine.get_model().get_vocab()
self.config = self.__engine.get_config()
self.topic_words = self.__engine.get_model().topic_words()
self.topic_sum_table = self.__engine.get_model().topic_sum()
def take_elem(word_count):
return word_count.count
for i in range(self.config.num_topics):
self.topic_words[i].sort(key=take_elem, reverse=True)
logger.info("Finish initialization.")
def cal_doc_distance(self, doc_text1, doc_text2):
"""
This interface calculates the distance between documents.
Args:
doc_text1(str): the input document text 1.
doc_text2(str): the input document text 2.
Returns:
jsd(float): Jensen-Shannon Divergence distance of two documents.
hd(float): Hellinger Distance of two documents.
"""
doc1_tokens = self.__tokenizer.tokenize(doc_text1)
doc2_tokens = self.__tokenizer.tokenize(doc_text2)
# Document topic inference.
doc1, doc2 = LDADoc(), LDADoc()
self.__engine.infer(doc1_tokens, doc1)
self.__engine.infer(doc2_tokens, doc2)
# To calculate jsd, we need dense document topic distribution.
dense_dict1 = doc1.dense_topic_dist()
dense_dict2 = doc2.dense_topic_dist()
# Calculate the distance between distributions.
# The smaller the distance, the higher the document semantic similarity.
sm = SemanticMatching()
jsd = sm.jensen_shannon_divergence(dense_dict1, dense_dict2)
hd = sm.hellinger_distance(dense_dict1, dense_dict2)
return jsd, hd
def cal_doc_keywords_similarity(self, document, top_k=10):
"""
This interface can be used to find top k keywords of document.
Args:
document(str): the input document text.
top_k(int): top k keywords of this document.
Returns:
results(list): contains top_k keywords and their corresponding
similarity compared to document.
"""
d_tokens = self.__tokenizer.tokenize(document)
# Do topic inference on documents to obtain topic distribution.
doc = LDADoc()
self.__engine.infer(d_tokens, doc)
doc_topic_dist = doc.sparse_topic_dist()
items = []
words = set()
for word in d_tokens:
if word in words:
continue
words.add(word)
wd = WordAndDis()
wd.word = word
sm = SemanticMatching()
wd.distance = sm.likelihood_based_similarity(
terms=[word], doc_topic_dist=doc_topic_dist, model=self.__engine.get_model())
items.append(wd)
def take_elem(word_dis):
return word_dis.distance
items.sort(key=take_elem, reverse=True)
results = []
size = len(items)
for i in range(top_k):
if i >= size:
break
results.append({"word": items[i].word, "similarity": items[i].distance})
return results
def cal_query_doc_similarity(self, query, document):
"""
This interface calculates the similarity between query and document.
Args:
query(str): the input query text.
document(str): the input document text.
Returns:
lda_sim(float): likelihood based similarity between query and document
based on LDA.
"""
q_tokens = self.__tokenizer.tokenize(query)
d_tokens = self.__tokenizer.tokenize(document)
doc = LDADoc()
self.__engine.infer(d_tokens, doc)
doc_topic_dist = doc.sparse_topic_dist()
sm = SemanticMatching()
lda_sim = sm.likelihood_based_similarity(q_tokens, doc_topic_dist, self.__engine.get_model())
return lda_sim
def infer_doc_topic_distribution(self, document):
"""
This interface infers the topic distribution of document.
Args:
document(str): the input document text.
Returns:
results(list): returns the topic distribution of document.
"""
tokens = self.__tokenizer.tokenize(document)
if tokens == []:
return []
results = []
doc = LDADoc()
self.__engine.infer(tokens, doc)
topics = doc.sparse_topic_dist()
for topic in topics:
results.append({"topic id": topic.tid, "distribution": topic.prob})
return results
def show_topic_keywords(self, topic_id, k=10):
"""
This interface returns first k keywords under specific topic.
Args:
topic_id(int): topic information we want to know.
k(int): top k keywords.
Returns:
results(dict): contains specific topic's keywords and corresponding
probability.
"""
EPS = 1e-8
results = {}
if 0 <= topic_id < self.config.num_topics:
k = min(k, len(self.topic_words[topic_id]))
for i in range(k):
prob = self.topic_words[topic_id][i].count / \
(self.topic_sum_table[topic_id] + EPS)
results[self.vocabulary[self.topic_words[topic_id][i].word_id]] = prob
return results
else:
logger.error("%d is out of range!" % topic_id)