-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lda difference #1334
Lda difference #1334
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,11 +33,13 @@ | |
import logging | ||
import numpy as np | ||
import numbers | ||
from random import sample | ||
import os | ||
|
||
from gensim import interfaces, utils, matutils | ||
from gensim.matutils import dirichlet_expectation | ||
from gensim.models import basemodel | ||
from gensim.matutils import kullback_leibler, hellinger, jaccard_set | ||
|
||
from itertools import chain | ||
from scipy.special import gammaln, psi # gamma function utils | ||
|
@@ -965,6 +967,72 @@ def get_term_topics(self, word_id, minimum_probability=None): | |
|
||
return values | ||
|
||
def diff(self, other, distance="kulback_leibler", num_words=100, n_ann_terms=10, normed=True): | ||
""" | ||
Calculate difference topic2topic between two Lda models | ||
`other` instances of `LdaMulticore` or `LdaModel` | ||
`distance` is function that will be applied to calculate difference between any topic pair. | ||
Available values: `kulback_leibler`, `hellinger` and `jaccard` | ||
`num_words` is quantity of most relevant words that used if distance == `jaccard` (also used for annotation) | ||
`n_ann_terms` is max quantity of words in intersection/symmetric difference between topics (used for annotation) | ||
Returns a matrix Z with shape (m1.num_topics, m2.num_topics), where Z[i][j] - difference between topic_i and topic_j | ||
and matrix annotation with shape (m1.num_topics, m2.num_topics, 2, None), | ||
where | ||
annotation[i][j] = [[`int_1`, `int_2`, ...], [`diff_1`, `diff_2`, ...]] and | ||
`int_k` is word from intersection of `topic_i` and `topic_j` and | ||
`diff_l` is word from symmetric difference of `topic_i` and `topic_j` | ||
`normed` is a flag. If `true`, matrix Z will be normalized | ||
Example: | ||
>>> m1, m2 = LdaMulticore.load(path_1), LdaMulticore.load(path_2) | ||
>>> mdiff, annotation = m1.diff(m2) | ||
>>> print(mdiff) # get matrix with difference for each topic pair from `m1` and `m2` | ||
>>> print(annotation) # get array with positive/negative words for each topic pair from `m1` and `m2` | ||
""" | ||
|
||
distances = {"kulback_leibler": kullback_leibler, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hanging indent. @tmylk |
||
"hellinger": hellinger, | ||
"jaccard": jaccard_set} | ||
|
||
if distance not in distances: | ||
valid_keys = ", ".join("`{}`".format(x) for x in distances.keys()) | ||
raise ValueError("Incorrect distance, valid only {}".format(valid_keys)) | ||
|
||
if not isinstance(other, self.__class__): | ||
raise ValueError("The parameter `other` must be of type `{}`".format(self.__name__)) | ||
|
||
distance_func = distances[distance] | ||
d1, d2 = self.state.get_lambda(), other.state.get_lambda() | ||
t1_size, t2_size = d1.shape[0], d2.shape[0] | ||
|
||
fst_topics = [{w for (w, _) in self.show_topic(topic, topn=num_words)} for topic in xrange(t1_size)] | ||
snd_topics = [{w for (w, _) in other.show_topic(topic, topn=num_words)} for topic in xrange(t2_size)] | ||
|
||
if distance == "jaccard": | ||
d1, d2 = fst_topics, snd_topics | ||
|
||
z = np.zeros((t1_size, t2_size)) | ||
for topic1 in range(t1_size): | ||
for topic2 in range(t2_size): | ||
z[topic1][topic2] = distance_func(d1[topic1], d2[topic2]) | ||
|
||
if normed: | ||
if np.abs(np.max(z)) > 1e-8: | ||
z /= np.max(z) | ||
|
||
annotation = [[None for _ in range(t1_size)] for _ in range(t2_size)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can create lists using Although I don't see the point of this initialization. Why not just start empty and append, in the loop below? What's with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Initialization allows writing more readable code (only assignment to the cell in a cycle). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. If that's your worry, isn't creating the 2D matrix as a numpy matrix (2D array) simpler/more readable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Numpy matrix with complex object type of element |
||
|
||
for topic1 in range(t1_size): | ||
for topic2 in range(t2_size): | ||
pos_tokens = fst_topics[topic1] & snd_topics[topic2] | ||
neg_tokens = fst_topics[topic1].symmetric_difference(snd_topics[topic2]) | ||
|
||
pos_tokens = sample(pos_tokens, min(len(pos_tokens), n_ann_terms)) | ||
neg_tokens = sample(neg_tokens, min(len(neg_tokens), n_ann_terms)) | ||
|
||
annotation[topic1][topic2] = [pos_tokens, neg_tokens] | ||
|
||
return z, annotation | ||
|
||
def __getitem__(self, bow, eps=None): | ||
""" | ||
Return topic distribution for the given document `bow`, as a list of | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2016 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
import unittest | ||
import numpy as np | ||
|
||
from gensim.corpora import Dictionary | ||
from gensim.models import LdaModel | ||
|
||
|
||
class TestLdaDiff(unittest.TestCase): | ||
def setUp(self): | ||
texts = [['human', 'interface', 'computer'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hanging indent. |
||
['survey', 'user', 'computer', 'system', 'response', 'time'], | ||
['eps', 'user', 'interface', 'system'], | ||
['system', 'human', 'system', 'eps'], | ||
['user', 'response', 'time'], | ||
['trees'], | ||
['graph', 'trees'], | ||
['graph', 'minors', 'trees'], | ||
['graph', 'minors', 'survey']] | ||
self.dictionary = Dictionary(texts) | ||
self.corpus = [self.dictionary.doc2bow(text) for text in texts] | ||
self.num_topics = 5 | ||
self.n_ann_terms = 10 | ||
self.model = LdaModel(corpus=self.corpus, id2word=self.dictionary, num_topics=self.num_topics, passes=10) | ||
|
||
def testBasic(self): | ||
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms) | ||
|
||
self.assertEqual(mdiff.shape, (self.num_topics, self.num_topics)) | ||
self.assertEquals(len(annotation), self.num_topics) | ||
self.assertEquals(len(annotation[0]), self.num_topics) | ||
|
||
def testIdentity(self): | ||
for dist_name in ["hellinger", "kulback_leibler", "jaccard"]: | ||
mdiff, annotation = self.model.diff(self.model, n_ann_terms=self.n_ann_terms, distance=dist_name) | ||
|
||
for row in annotation: | ||
for (int_tokens, diff_tokens) in row: | ||
self.assertEquals(diff_tokens, []) | ||
self.assertEquals(len(int_tokens), self.n_ann_terms) | ||
|
||
self.assertTrue(np.allclose(np.diag(mdiff), np.zeros(mdiff.shape[0], dtype=mdiff.dtype))) | ||
|
||
if dist_name == "jaccard": | ||
self.assertTrue(np.allclose(mdiff, np.zeros(mdiff.shape, dtype=mdiff.dtype))) | ||
|
||
def testInput(self): | ||
self.assertRaises(ValueError, self.model.diff, self.model, n_ann_terms=self.n_ann_terms, distance='something') | ||
self.assertRaises(ValueError, self.model.diff, [], n_ann_terms=self.n_ann_terms, distance='something') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will throw an exception if both inputs empty -- is that desired?
Missing docstring.