-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Added sklearn wrapper for LDASeq model #1405
[WIP] Added sklearn wrapper for LDASeq model #1405
Conversation
""" | ||
Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel | ||
""" | ||
self.corpus = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why you needed a field for a corpus?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@menshikh-iv In my opinion, the user might be interested to know about the corpus
used for training the model (using the get_params
function). Should we continue to store this value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chinmayapancholi13 No, sklearn does not store X, so we should not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@menshikh-iv Yes, that is true for sklearn. Removing corpus
attribute from all the wrappers then.
Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel | ||
""" | ||
self.corpus = None | ||
self.model = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do this field "private" (start with underscores)
initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10, | ||
random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100) | ||
""" | ||
self.corpus = X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need to save X.
""" | ||
Fit the model according to the given training data. | ||
Calls gensim.models.LdaSeqModel: | ||
>>> gensim.models.LdaSeqModel(corpus=None, time_slice=None, id2word=None, alphas=0.01, num_topics=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this block >>> ...
, this example does not help for a new user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@menshikh-iv Should we remove this >>> ....
statement in all the model wrappers? This line basically tells us how the associated Gensim model is actually called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You just need to specify the class that is used (you have already done above) and write where a user can read the documentation.
em_min_iter=self.em_min_iter, em_max_iter=self.em_max_iter, chunksize=self.chunksize) | ||
return self | ||
|
||
def transform(self, docs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chek case, when you create instance and call transform immediately (without fit), you need to raise exception like sklearn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, please add an example of docs
param in docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@menshikh-iv For checking if the model has been fitted, would it be a good idea to check if self.gensim_model
is None
or not? This approach would clearly give an error when fit
hasn't been called before calling transform
but this also allows the user to set the value of self.gensim_model
through set_params
function (or even as wrapper.gensim_model=...
) and then call transform
function, which makes sense for us to allow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely forgot about set_param
, so, I think if you disable gensim_model
in set_param, you can check model is None
(it does not cover all cases, but covers the most obvious)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate the meaning of "disabling" gensim_model
param from the function set_params
?
Actually, gensim_model
is a public attribute of the model so it can be set like ldaseq_wrapper.gensim_model = some_model
, which is almost the same as using set_params
function to set this value. So, checking whether self.gensim_model
is None
should be enough, right?
This would be like :
def transform(self, docs):
"""
Return the topic proportions for the documents passed.
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")
# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
..........................................................................
..........................................................................
..........................................................................
..........................................................................
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, as a temporary option.
return np.reshape(np.array(X), (len(docs), self.num_topics)) | ||
|
||
def partial_fit(self, X): | ||
raise NotImplementedError("'partial_fit' has not been implemented for the LDA Seq model") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LDA Seq model
-> SklLdaSeqModel
for key in param_dict.keys(): | ||
self.assertEqual(model_params[key], param_dict[key]) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add persistence test with pickle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And add test with pipeline
score = text_ldaseq.score(corpus, test_target) | ||
self.assertGreater(score, 0.50) | ||
|
||
def testPersistence(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's sanity check only.
For persistence, you need to compare current
and loaded
models. For this purpose, you need to compare current
and loaded
inner matrices OR get corpus, transform it with both variant and compare results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I have now added code for comparing the vectors transformed from original and loaded models, in addition to this sanity check. :)
text_ldaseq = Pipeline((('features', model,), ('classifier', clf))) | ||
text_ldaseq.fit(corpus, test_target) | ||
score = text_ldaseq.score(corpus, test_target) | ||
self.assertGreater(score, 0.50) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's will be correct every time? No needed to fix seeds for reproducibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now have a fixed seed which is set before the test testPipeline
to ensure that we get similar values.
Thank you @chinmayapancholi13 👍 |
This PR adds a scikit-learn wrapper for Gensim's LDASeq model.