-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding sklearn wrapper for LDA code #932
Changes from 17 commits
08f417c
61a6f8c
66be324
cffa95b
10badc6
62a4d2f
b7eff2d
2a193fd
a32f8dc
a048ddc
ac1d28e
0d6cc0a
5d8c1a6
894784c
7a5ca4b
b35baba
13a136d
682f045
9fda951
380ea5f
e2485d4
3015896
a76eda4
97c1530
20a63ac
c0b2c5c
bd656a8
d749ba0
21119c5
14f984b
a3895b5
f832737
bc352a0
7cc39da
0ba233c
e23a8a4
041a32e
e7120f0
8a0950d
bd8bced
bb5872b
777576e
e50c3f9
e521269
51931fa
7ba30d6
82d1fdc
4f3441e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using wrappers for Scikit learn API" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.base```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The wrapper available (as of now) are :\n", | ||
"* LdaModel (```gensim.sklearn_integration.base.LdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update ipynb with new names of .py file and of the class |
||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### LdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use LdaModel begin with importing LdaModel wrapper" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.sklearn_integration.base import LdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next we will create a dummy set of texts and convert it into a corpus" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add the examples to ipynb from https://gist.github.com/AadityaJ/c98da3d01f76f068242c17b5e1593973 |
||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.corpora import mmcorpus, Dictionary\n", | ||
"texts = [['human', 'interface', 'computer'],\n", | ||
" ['survey', 'user', 'computer', 'system', 'response', 'time'],\n", | ||
" ['eps', 'user', 'interface', 'system'],\n", | ||
" ['system', 'human', 'system', 'eps'],\n", | ||
" ['user', 'response', 'time'],\n", | ||
" ['trees'],\n", | ||
" ['graph', 'trees'],\n", | ||
" ['graph', 'minors', 'trees'],\n", | ||
" ['graph', 'minors', 'survey']]\n", | ||
"dictionary = Dictionary(texts)\n", | ||
"corpus = [dictionary.doc2bow(text) for text in texts]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Then to run the LdaModel on it" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[(0, u'0.271*system + 0.181*eps + 0.181*interface + 0.181*human + 0.091*computer + 0.091*user + 0.001*trees + 0.001*graph + 0.001*time + 0.001*minors'), (1, u'0.166*graph + 0.166*trees + 0.111*user + 0.111*survey + 0.111*response + 0.111*minors + 0.111*time + 0.056*computer + 0.056*system + 0.001*human')]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model=LdaModel(n_topics=2,id2word=dictionary,n_iter=20, random_state=1)\n", | ||
"model.fit(corpus)\n", | ||
"print model.print_topics(2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 2", | ||
"language": "python", | ||
"name": "python2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#!/usr/bin/env python | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a good filename; please use lower case, with underscores |
||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
# | ||
""" | ||
scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
follows on scikit learn API conventions | ||
""" | ||
from gensim import models | ||
|
||
|
||
class SklearnWrapperLdaModel(models.LdaModel,object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PEP8: space after comma. Actually, not relevant at all, because |
||
""" | ||
Base LDA module | ||
""" | ||
def __init__(self, corpus=None, num_topics=100, id2word=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code style: no vertical indent. |
||
distributed=False, chunksize=2000, passes=1, update_every=1, | ||
alpha='symmetric', eta=None, decay=0.5, offset=1.0, | ||
eval_every=10, iterations=50, gamma_threshold=0.001, | ||
minimum_probability=0.01, random_state=None): | ||
""" | ||
sklearn wrapper for LDA model. derived class for gensim.model.LdaModel | ||
""" | ||
self.corpus = corpus | ||
self.num_topics = num_topics | ||
self.id2word = id2word | ||
self.distributed = distributed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think stuff like |
||
self.chunksize = chunksize | ||
self.passes = passes | ||
self.update_every = update_every | ||
self.alpha = alpha | ||
self.eta = eta | ||
self.decay = decay | ||
self.offset = offset | ||
self.eval_every = eval_every | ||
self.iterations = iterations | ||
self.gamma_threshold = gamma_threshold | ||
self.minimum_probability = minimum_probability | ||
self.random_state = random_state | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use normal |
||
if no fit function is used , then corpus is given in init | ||
""" | ||
if self.corpus: | ||
models.LdaModel.__init__( | ||
self, corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No vertical indent. |
||
distributed=self.distributed, chunksize=self.chunksize, passes=self.passes, | ||
update_every=self.update_every,alpha=self.alpha, eta=self.eta, decay=self.decay, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Space after comma (here and everywhere else). |
||
offset=self.offset,eval_every=self.eval_every, iterations=self.iterations, | ||
gamma_threshold=self.gamma_threshold,minimum_probability=self.minimum_probability, | ||
random_state=self.random_state) | ||
|
||
def get_params(self, deep=True): | ||
""" | ||
returns all parameters as dictionary. | ||
Warnings: Must for sklearn API.Do not Remove. | ||
""" | ||
if deep: | ||
return {"corpus":self.corpus,"num_topics":self.num_topics,"id2word":self.id2word, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No vertical indent (here and everywhere else). |
||
"distributed":self.distributed,"chunksize":self.chunksize,"passes":self.passes, | ||
"update_every":self.update_every,"alpha":self.alpha," eta":self.eta," decay":self.decay, | ||
"offset":self.offset,"eval_every":self.eval_every," iterations":self.iterations, | ||
"gamma_threshold":self.gamma_threshold,"minimum_probability":self.minimum_probability, | ||
"random_state":self.random_state} | ||
|
||
def set_params(self, **parameters): | ||
""" | ||
set all parameters. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Capitalize sentences. |
||
Warnings: Must for sklearn API.Do not Remove. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Also, what are these "Warnings" for? Are they really necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I provided "Warnings" as a way to not remove the functions in the future(necessary for sklearn API). Sure I can scratch them. |
||
""" | ||
for parameter, value in parameters.items(): | ||
self.setattr(parameter, value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just |
||
return self | ||
|
||
def fit(self, X): | ||
""" | ||
For fitting corpus into the class object. | ||
calls gensim.model.LdaModel: | ||
>>>gensim.models.LdaModel(corpus=corpus,num_topics=num_topics,id2word=id2word,passes=passes,update_every=update_every,alpha=alpha,iterations=iterations,eta=eta,random_state=random_state) | ||
Warnings: Must for sklearn API.Do not Remove. | ||
""" | ||
self.corpus=X | ||
models.LdaModel.__init__( | ||
self, corpus=X, num_topics=self.num_topics, id2word=self.id2word, | ||
distributed=self.distributed, chunksize=self.chunksize, passes=self.passes, | ||
update_every=self.update_every,alpha=self.alpha, eta=self.eta, decay=self.decay, | ||
offset=self.offset,eval_every=self.eval_every, iterations=self.iterations, | ||
gamma_threshold=self.gamma_threshold,minimum_probability=self.minimum_probability, | ||
random_state=self.random_state) | ||
return self | ||
|
||
def transform(self, bow, minimum_probability=None, minimum_phi_value=None, per_word_topics=False): | ||
""" | ||
takes as an input a new document (bow) and | ||
Return topic distribution for the given document bow, as a list of (topic_id, topic_probability) 2-tuples. | ||
Warnings: Must for sklearn API.Do not Remove. | ||
""" | ||
return self.get_document_topics( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't look right -- |
||
bow, minimum_probability=minimum_probability, | ||
minimum_phi_value=minimum_phi_value, per_word_topics=per_word_topics) | ||
|
||
def partial_fit(self, X): | ||
""" | ||
train model over X. | ||
""" | ||
self.update(corpus=X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a transform as in line 85 above |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""scikit learn wrapper for gensim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing file preamble (encoding, author, license etc). |
||
Contains various gensim based implementations | ||
which match with scikit-learn standards . | ||
See [1] for complete set of conventions. | ||
[1] http://scikit-learn.org/stable/developers/ | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import six | ||
import unittest | ||
import numpy | ||
|
||
from gensim.sklearn_integration.SklearnWrapperGensimLdaModel import SklearnWrapperLdaModel | ||
from gensim.corpora import Dictionary | ||
from gensim import matutils | ||
|
||
texts = [['complier', 'system', 'computer'], | ||
['eulerian', 'node', 'cycle', 'graph', 'tree', 'path'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Incorrect indentation. |
||
['graph', 'flow', 'network', 'graph'], | ||
['loading', 'computer', 'system'], | ||
['user', 'server', 'system'], | ||
['tree','hamiltonian'], | ||
['graph', 'trees'], | ||
['computer', 'kernel', 'malfunction','computer'], | ||
['server','system','computer']] | ||
dictionary = Dictionary(texts) | ||
corpus = [dictionary.doc2bow(text) for text in texts] | ||
|
||
|
||
class TestLdaModel(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please rename the tests to TestSklearnLDAWrapper |
||
def setUp(self): | ||
self.model=SklearnWrapperLdaModel(id2word=dictionary,num_topics=2,passes=100,minimum_probability=0,random_state=numpy.random.seed(0)) | ||
self.model.fit(corpus) | ||
|
||
def testPrintTopic(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a partial_fit test |
||
topic = self.model.print_topics(2) | ||
|
||
for k, v in topic: | ||
self.assertTrue(isinstance(v, six.string_types)) | ||
self.assertTrue(isinstance(k, int)) | ||
|
||
def testTransform(self): | ||
texts_new=['graph','eulerian'] | ||
bow = self.model.id2word.doc2bow(texts_new) | ||
doc_topics, word_topics, phi_values = self.model.transform(bow,per_word_topics=True) | ||
|
||
for k,v in word_topics: | ||
self.assertTrue(isinstance(v, list)) | ||
self.assertTrue(isinstance(k, int)) | ||
for k,v in doc_topics: | ||
self.assertTrue(isinstance(v, float)) | ||
self.assertTrue(isinstance(k, int)) | ||
for k,v in phi_values: | ||
self.assertTrue(isinstance(v, list)) | ||
self.assertTrue(isinstance(k, int)) | ||
|
||
def testPartialFit(self): | ||
for i in range(10): | ||
self.model.partial_fit(X=corpus) # fit against the model again | ||
doc=list(corpus)[0] # transform only the first document | ||
transformed = self.model[doc] | ||
transformed_approx = matutils.sparse2full(transformed, 2) # better approximation | ||
expected=[0.13, 0.87] | ||
passed = numpy.allclose(sorted(transformed_approx), sorted(expected), atol=1e-1) | ||
self.assertTrue(passed) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please resolve merge conflicts. Only one line should be added to changelog. Remove extra 2 lines about other changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please merge in develop branch to remove merge conflicts