-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] HDP #1055
Merged
Merged
[WIP] HDP #1055
Changes from 2 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
7bc3ddb
Added print methods, lda_model
bhargavvader 4d3af60
Added HDP tests
bhargavvader ae9de89
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
bhargavvader e639fda
Changelog
bhargavvader e03f67f
CHANGELOG
bhargavvader 565b39e
Removed duplicate code
bhargavvader 57762c9
Removed duplicate code
bhargavvader 76a982c
Added import
bhargavvader e101b30
Fixed Changelog
bhargavvader File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ | |
import scipy.special as sp | ||
|
||
from gensim import interfaces, utils, matutils | ||
from gensim.models import basemodel | ||
from gensim.models import basemodel, ldamodel | ||
from six.moves import xrange | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -56,6 +56,21 @@ def dirichlet_expectation(alpha): | |
return(sp.psi(alpha) - sp.psi(np.sum(alpha, 1))[:, np.newaxis]) | ||
|
||
|
||
def get_random_state(seed): | ||
""" Turn seed into a np.random.RandomState instance. | ||
|
||
Method originally from maciejkula/glove-python, and written by @joshloyal | ||
""" | ||
if seed is None or seed is np.random: | ||
return np.random.mtrand._rand | ||
if isinstance(seed, (numbers.Integral, np.integer)): | ||
return np.random.RandomState(seed) | ||
if isinstance(seed, np.random.RandomState): | ||
return seed | ||
raise ValueError('%r cannot be used to seed a np.random.RandomState' | ||
' instance' % seed) | ||
|
||
|
||
def expect_log_sticks(sticks): | ||
""" | ||
For stick-breaking hdp, return the E[log(sticks)] | ||
|
@@ -130,7 +145,7 @@ class HdpModel(interfaces.TransformationABC, basemodel.BaseTopicModel): | |
def __init__(self, corpus, id2word, max_chunks=None, max_time=None, | ||
chunksize=256, kappa=1.0, tau=64.0, K=15, T=150, alpha=1, | ||
gamma=1, eta=0.01, scale=1.0, var_converge=0.0001, | ||
outputdir=None): | ||
outputdir=None, random_state=None): | ||
""" | ||
`gamma`: first level concentration | ||
`alpha`: second level concentration | ||
|
@@ -151,6 +166,8 @@ def __init__(self, corpus, id2word, max_chunks=None, max_time=None, | |
self.max_time = max_time | ||
self.outputdir = outputdir | ||
|
||
self.random_state = get_random_state(random_state) | ||
|
||
self.lda_alpha = None | ||
self.lda_beta = None | ||
|
||
|
@@ -169,7 +186,7 @@ def __init__(self, corpus, id2word, max_chunks=None, max_time=None, | |
self.m_var_sticks[1] = range(T - 1, 0, -1) | ||
self.m_varphi_ss = np.zeros(T) | ||
|
||
self.m_lambda = np.random.gamma(1.0, 1.0, (T, self.m_W)) * self.m_D * 100 / (T * self.m_W) - eta | ||
self.m_lambda = self.random_state.gamma(1.0, 1.0, (T, self.m_W)) * self.m_D * 100 / (T * self.m_W) - eta | ||
self.m_eta = eta | ||
self.m_Elogbeta = dirichlet_expectation(self.m_eta + self.m_lambda) | ||
|
||
|
@@ -442,6 +459,21 @@ def update_expectations(self): | |
self.m_timestamp[:] = self.m_updatect | ||
self.m_status_up_to_date = True | ||
|
||
def show_topic(self, topic_id, num_words=20, log=False, formatted=False): | ||
""" | ||
Print the `num_words` most probable words for `topics` number of topics. | ||
Set `topics=-1` to print all topics. | ||
|
||
Set `formatted=True` to return the topics as a list of strings, or | ||
`False` as lists of (weight, word) pairs. | ||
|
||
""" | ||
if not self.m_status_up_to_date: | ||
self.update_expectations() | ||
betas = self.m_lambda + self.m_eta | ||
hdp_formatter = HdpTopicFormatter(self.id2word, betas) | ||
return hdp_formatter.show_topic(topic_id, num_words, log, formatted) | ||
|
||
def show_topics(self, num_topics=20, num_words=20, log=False, formatted=True): | ||
""" | ||
Print the `num_words` most probable words for `topics` number of topics. | ||
|
@@ -510,6 +542,16 @@ def hdp_to_lda(self): | |
|
||
return (alpha, beta) | ||
|
||
def suggested_lda_model(self): | ||
""" | ||
Returns closest corresponding ldamodel object corresponding to current hdp model. | ||
The num_topics is m_T (default is 150) so as to preserve the matrice shapes when we assign alpha and beta. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is it different from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed duplicate code, added comment. |
||
""" | ||
alpha, beta = self.hdp_to_lda() | ||
ldam = ldamodel.LdaModel(num_topics=self.m_T, alpha=alpha, id2word=self.id2word, random_state=self.random_state) | ||
ldam.expElogbeta[:] = beta | ||
return ldam | ||
|
||
def evaluate_test_corpus(self, corpus): | ||
logger.info('TEST: evaluating test corpus') | ||
if self.lda_alpha is None or self.lda_beta is None: | ||
|
@@ -589,6 +631,32 @@ def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True): | |
|
||
return shown | ||
|
||
def print_topic(self, topic_id, num_words): | ||
return self.show_topic(topic_id, num_words, formatted=True) | ||
|
||
def show_topic(self, topic_id, num_words, log=False, formatted=False): | ||
|
||
lambdak = list(self.data[topic_id, :]) | ||
lambdak = lambdak / sum(lambdak) | ||
|
||
temp = zip(lambdak, xrange(len(lambdak))) | ||
temp = sorted(temp, key=lambda x: x[0], reverse=True) | ||
|
||
topic_terms = self.show_topic_terms(temp, num_words) | ||
|
||
if formatted: | ||
topic = self.format_topic(topic_id, topic_terms) | ||
|
||
# assuming we only output formatted topics | ||
if log: | ||
logger.info(topic) | ||
else: | ||
topic = (topic_id, topic_terms) | ||
|
||
# we only return the topic_terms | ||
return topic[1] | ||
|
||
|
||
def show_topic_terms(self, topic_data, num_words): | ||
return [(self.dictionary[wid], weight) for (weight, wid) in topic_data[:num_words]] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why copy-paste from LdaModel? Should it be moved to utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll move the common
ldamodel
andhdpmodel
methods to the respectiveutils
andmatutils
files after this PR is merged.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to make a few other changes to
utils
andmatutils
as wellThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be done as a part of this PR. Duplicate code will not be merged.