Skip to content

Commit 70d79f8

Browse files
committed
[gensim] adapted code to handle HDP model from gensim along with lda models.
[requirements] added gensim packages to test-requirements. [tests] added gensim tests to ensure prepare/save_html functions still works with lda and hdp models.
1 parent 3211c6c commit 70d79f8

File tree

6 files changed

+84
-7
lines changed

6 files changed

+84
-7
lines changed

pyLDAvis/_prepare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
def __num_dist_rows__(array, ndigits=2):
34-
return int(pd.DataFrame(array).sum(axis=1).map(lambda x: round(x, ndigits)).sum())
34+
return array.shape[0] - int((pd.DataFrame(array).sum(axis=1) < 0.999).sum())
3535

3636

3737
class ValidationError(ValueError):

pyLDAvis/gensim.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,26 @@ def _extract_data(topic_model, corpus, dictionary, doc_topic_dists=None):
3333
assert doc_lengths.shape[0] == len(corpus), 'Document lengths and corpus have different sizes {} != {}'.format(doc_lengths.shape[0], len(corpus))
3434

3535
if doc_topic_dists is None:
36-
gamma, _ = topic_model.inference(corpus)
36+
# If its an HDP model.
37+
if hasattr(topic_model, 'lda_beta'):
38+
gamma = topic_model.inference(corpus)
39+
else:
40+
gamma, _ = topic_model.inference(corpus)
3741
doc_topic_dists = gamma / gamma.sum(axis=1)[:, None]
3842

39-
assert doc_topic_dists.shape[1] == topic_model.num_topics, 'Document topics and number of topics do not match {} != {}'.format(doc_topic_dists.shape[0], topic_model.num_topics)
43+
if hasattr(topic_model, 'lda_alpha'):
44+
num_topics = len(topic_model.lda_alpha)
45+
else:
46+
num_topics = topic_model.num_topics
47+
48+
assert doc_topic_dists.shape[1] == num_topics, 'Document topics and number of topics do not match {} != {}'.format(doc_topic_dists.shape[0], num_topics)
4049

4150
# get the topic-term distribution straight from gensim without
4251
# iterating over tuples
43-
topic = topic_model.state.get_lambda()
52+
if hasattr(topic_model, 'lda_beta'):
53+
topic = topic_model.lda_beta
54+
else:
55+
topic = topic_model.state.get_lambda()
4456
topic = topic / topic.sum(axis=1)[:, None]
4557
topic_term_dists = topic[:, fnames_argsort]
4658

rtd_reqs.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ jinja2==2.7.2
22
numpydoc>=0.4
33
pytest
44
future
5+
gensim

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def run_tests(self):
5050

5151
test_requirements = [
5252
'pytest',
53-
'funcy'
53+
'funcy',
54+
'gensim'
5455
]
5556

5657
setup(

tests/pyLDAvis/test_gensim_models.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#! /usr/bin/venv python
2+
3+
4+
from gensim.models import LdaModel, HdpModel
5+
from gensim.corpora.dictionary import Dictionary
6+
import pyLDAvis.gensim
7+
import os
8+
9+
10+
def get_corpus_dictionary():
11+
"""Crafts a toy corpus and the dictionary associated."""
12+
# Toy corpus.
13+
corpus = [
14+
['carrot', 'salad', 'tomato'],
15+
['carrot', 'salad', 'dish'],
16+
['tomato', 'dish'],
17+
['tomato', 'salad'],
18+
19+
['car', 'break', 'highway'],
20+
['highway', 'accident', 'car'],
21+
['moto', 'break'],
22+
['accident', 'moto', 'car']
23+
]
24+
25+
dictionary = Dictionary(corpus)
26+
27+
# Transforming corpus with dictionary.
28+
corpus = [dictionary.doc2bow(doc) for doc in corpus]
29+
30+
# Building reverse index.
31+
for (token, uid) in dictionary.token2id.items():
32+
dictionary.id2token[uid] = token
33+
34+
return corpus, dictionary
35+
36+
def test_lda():
37+
"""Trains a LDA model and tests the html outputs."""
38+
corpus, dictionary = get_corpus_dictionary()
39+
40+
lda = LdaModel(corpus=corpus,
41+
num_topics=2)
42+
43+
data = pyLDAvis.gensim.prepare(lda, corpus, dictionary)
44+
pyLDAvis.save_html(data, 'index_lda.html')
45+
os.remove('index_lda.html')
46+
47+
48+
def test_hdp():
49+
"""Trains a HDP model and tests the html outputs."""
50+
corpus, dictionary = get_corpus_dictionary()
51+
52+
hdp = HdpModel(corpus, dictionary.id2token)
53+
54+
data = pyLDAvis.gensim.prepare(hdp, corpus, dictionary)
55+
pyLDAvis.save_html(data, 'index_hdp.html')
56+
os.remove('index_hdp.html')
57+
58+
59+
if __name__ == "__main__":
60+
test_lda()
61+
test_hdp()

tests/pyLDAvis/test_prepare.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def rounded_token_table(r):
7474
tt.Freq = tt.Freq.round(5)
7575
return tt
7676
ett, ott = both(rounded_token_table)
77-
joined = pd.merge(ott, ett, on=['Freq', 'Term'], suffixes=['_o','_e'], how='inner')
78-
most_likely_map = pd.DataFrame(joined.groupby('Topic_o')['Topic_e'].value_counts(), columns=['count']).query('count > 100')
77+
joined = pd.DataFrame(pd.merge(ott, ett, on=['Freq', 'Term'], suffixes=['_o','_e'], how='inner')\
78+
.groupby('Topic_o')['Topic_e'].value_counts())
79+
joined.columns = ['count']
80+
most_likely_map = joined.query('count > 100')
7981
most_likely_map.index.names = ['Topic_o', 'Topic_e']
8082
df = pd.DataFrame(most_likely_map).reset_index()
8183
assert_array_equal(df['Topic_o'].values, df['Topic_e'].values)

0 commit comments

Comments
 (0)