-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added test suite for coherencemodel and aggregation.
- Loading branch information
1 parent
3482910
commit ba28c35
Showing
2 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Automated tests for checking transformation algorithms (the models package). | ||
""" | ||
|
||
import logging | ||
import unittest | ||
|
||
from gensim.topic_coherence import aggregation | ||
|
||
class TestAggregation(unittest.TestCase): | ||
def setUp(self): | ||
self.confirmed_measures = [1.1, 2.2, 3.3, 4.4] | ||
|
||
def testArithmeticMean(self): | ||
"""Test arithmetic_mean()""" | ||
obtained = aggregation.arithmetic_mean(self.confirmed_measures) | ||
expected = 2.75 | ||
self.assertEqual(obtained, expected) | ||
|
||
if __name__ == '__main__': | ||
logging.root.setLevel(logging.WARNING) | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Automated tests for checking transformation algorithms (the models package). | ||
""" | ||
|
||
import logging | ||
import unittest | ||
import os | ||
import os.path | ||
import tempfile | ||
|
||
import numpy as np | ||
|
||
from gensim.models.coherencemodel import CoherenceModel | ||
from gensim.models.ldamodel import LdaModel | ||
from gensim.models.wrappers import LdaMallet | ||
from gensim.models.wrappers import LdaVowpalWabbit | ||
from gensim.corpora.dictionary import Dictionary | ||
|
||
module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder | ||
datapath = lambda fname: os.path.join(module_path, 'test_data', fname) | ||
|
||
# set up vars used in testing ("Deerwester" from the web tutorial) | ||
texts = [['human', 'interface', 'computer'], | ||
['survey', 'user', 'computer', 'system', 'response', 'time'], | ||
['eps', 'user', 'interface', 'system'], | ||
['system', 'human', 'system', 'eps'], | ||
['user', 'response', 'time'], | ||
['trees'], | ||
['graph', 'trees'], | ||
['graph', 'minors', 'trees'], | ||
['graph', 'minors', 'survey']] | ||
dictionary = Dictionary(texts) | ||
corpus = [dictionary.doc2bow(text) for text in texts] | ||
|
||
|
||
def testfile(): | ||
# temporary data will be stored to this file | ||
return os.path.join(tempfile.gettempdir(), 'gensim_models.tst') | ||
|
||
class TestCoherenceModel(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(8) | ||
self.badLdaModel = LdaModel(corpus=corpus, num_topics=2, passes=1) # Bad lda model | ||
self.goodLdaModel = LdaModel(corpus=corpus, num_topics=2, passes=50) # Good lda model | ||
|
||
def testUMassLdaModel(self): | ||
"""Test U_Mass topic coherence algorithm on LDA Model""" | ||
cm1 = CoherenceModel(model=self.badLdaModel, corpus=corpus, dictionary=dictionary, coherence='u_mass') | ||
cm2 = CoherenceModel(model=self.goodLdaModel, corpus=corpus, dictionary=dictionary, coherence='u_mass') | ||
self.assertTrue(cm1.get_coherence() < cm2.get_coherence()) | ||
|
||
def testCvLdaModel(self): | ||
"""Test C_v topic coherence algorithm on LDA Model""" | ||
cm1 = CoherenceModel(model=self.badLdaModel, texts=texts, dictionary=dictionary, coherence='c_v') | ||
cm2 = CoherenceModel(model=self.goodLdaModel, texts=texts, dictionary=dictionary, coherence='c_v') | ||
self.assertTrue(cm1.get_coherence() < cm2.get_coherence()) | ||
|
||
def testErrors(self): | ||
"""Test if errors are raised on bad input""" | ||
# not providing dictionary | ||
self.assertRaises(ValueError, CoherenceModel, model=self.goodLdaModel, corpus=corpus, coherence='u_mass') | ||
# not providing texts for c_v and instead providing corpus | ||
self.assertRaises(ValueError, CoherenceModel, model=self.goodLdaModel, corpus=corpus, dictionary=dictionary, coherence='c_v') | ||
# not providing corpus or texts for u_mass | ||
self.assertRaises(ValueError, CoherenceModel, self.goodLdaModel, dictionary, 'u_mass') | ||
|
||
if __name__ == '__main__': | ||
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) | ||
unittest.main() |