Skip to content

Commit

Permalink
[WIP] Add sklearn wrapper for LDA code (#932)
Browse files Browse the repository at this point in the history
* adding basic sklearn wrapper for LDA code

* updating changelog

* adding test case,adding id2word,deleting showtopics

* adding relevant ipynb

* adding transfrom and other get methods and modifying print_topics

* stylizing code to follow conventions

* removing redundant default argumen values

* adding partial_fit

* adding a line in test_sklearn_integration

* using LDAModel as Parent Class

* adding docs, modifying getparam

* changing class name.Adding comments

* adding test case for update and transform

* adding init

* updating changes,fixed typo and changing file name

* deleted base.py

* adding better testPartialFit method and minor changes due to change in class

* change name of test class

* adding changes in classname to ipynb

* Updating CHANGELOG.md

* Updated Main Model. Added fit_predict to class for example

* added sklearn countvectorizer example to ipynb

* adding logistic regression example

* adding if condition for csr_matrix to ldamodel

* adding check for fit csrmatrix also stylizing code

* minor bug.solved, fit should convert X to corpus

* removing fit_predict.adding csr_matrix check for update

* minor updates in ipynb

* adding rst file

* removed "basic" , added rst update to log

* changing indentation in texts

* added file preamble, removed unnecessary space

* following more pep8 conventions

* removing unnecessary comments

* changing isinstance csr_matrix to issparse

* changed to hanging indentation

* changing main filename

* changing module name in test

* updating ipynb with main filename

* changed class name

* changed file name

* fixing filename typo

* adding html file

* deleting html file

* vertical indentation fixes

* adding file to apiref.rst
  • Loading branch information
AadityaJ authored and tmylk committed Jan 29, 2017
1 parent acc45bc commit 0e0c082
Show file tree
Hide file tree
Showing 7 changed files with 532 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Unreleased:

* FastText wrapper added, can be used for training FastText word representations and performing word2vec operations over it

0.13.5, 2016-12-31

* Added sklearn wrapper for LdaModel along with relevant test cases, ipynb draft and rst docs. (@AadityaJ,[#932](https://github.com/RaRe-Technologies/gensim/pull/932))

0.13.4.1, 2017-01-04

* Disable direct access warnings on save and load of Word2vec/Doc2vec (@tmylk, [#1072](https://github.com/RaRe-Technologies/gensim/pull/1072))
Expand Down
325 changes: 325 additions & 0 deletions docs/notebooks/sklearn_wrapper.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using wrappers for Scikit learn API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The wrapper available (as of now) are :\n",
"* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LdaModel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use LdaModel begin with importing LdaModel wrapper"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we will create a dummy set of texts and convert it into a corpus"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from gensim.corpora import Dictionary\n",
"texts = [['complier', 'system', 'computer'],\n",
" ['eulerian', 'node', 'cycle', 'graph', 'tree', 'path'],\n",
" ['graph', 'flow', 'network', 'graph'],\n",
" ['loading', 'computer', 'system'],\n",
" ['user', 'server', 'system'],\n",
" ['tree','hamiltonian'],\n",
" ['graph', 'trees'],\n",
" ['computer', 'kernel', 'malfunction','computer'],\n",
" ['server','system','computer']]\n",
"dictionary = Dictionary(texts)\n",
"corpus = [dictionary.doc2bow(text) for text in texts]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then to run the LdaModel on it"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n"
]
},
{
"data": {
"text/plain": [
"[(0,\n",
" u'0.164*\"computer\" + 0.117*\"system\" + 0.105*\"graph\" + 0.061*\"server\" + 0.057*\"tree\" + 0.046*\"malfunction\" + 0.045*\"kernel\" + 0.045*\"complier\" + 0.043*\"loading\" + 0.039*\"hamiltonian\"'),\n",
" (1,\n",
" u'0.102*\"graph\" + 0.083*\"system\" + 0.072*\"tree\" + 0.064*\"server\" + 0.059*\"user\" + 0.059*\"computer\" + 0.057*\"trees\" + 0.056*\"eulerian\" + 0.055*\"node\" + 0.052*\"flow\"')]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model=SklearnWrapperLdaModel(num_topics=2,id2word=dictionary,iterations=20, random_state=1)\n",
"model.fit(corpus)\n",
"model.print_topics(2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"### Integration with Sklearn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To provide a better example of how it can be used with Sklearn, Let's use CountVectorizer method of sklearn. For this example we will use [20 Newsgroups data set](http://qwone.com/~jason/20Newsgroups/). We will only use the categories rec.sport.baseball and sci.crypt and use it to generate topics."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from gensim import matutils\n",
"from gensim.models.ldamodel import LdaModel\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"rand = np.random.mtrand.RandomState(1) # set seed for getting same result\n",
"cats = ['rec.sport.baseball', 'sci.crypt']\n",
"data = fetch_20newsgroups(subset='train',\n",
" categories=cats,\n",
" shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we use countvectorizer to convert the collection of text documents to a matrix of token counts."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"vec = CountVectorizer(min_df=10, stop_words='english')\n",
"\n",
"X = vec.fit_transform(data.data)\n",
"vocab = vec.get_feature_names() #vocab to be converted to id2word \n",
"\n",
"id2word=dict([(i, s) for i, s in enumerate(vocab)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we just need to fit X and id2word to our Lda wrapper."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[(0,\n",
" u'0.018*\"cryptography\" + 0.018*\"face\" + 0.017*\"fierkelab\" + 0.008*\"abuse\" + 0.007*\"constitutional\" + 0.007*\"collection\" + 0.007*\"finish\" + 0.007*\"150\" + 0.007*\"fast\" + 0.006*\"difference\"'),\n",
" (1,\n",
" u'0.022*\"corporate\" + 0.022*\"accurate\" + 0.012*\"chance\" + 0.008*\"decipher\" + 0.008*\"example\" + 0.008*\"basically\" + 0.008*\"dawson\" + 0.008*\"cases\" + 0.008*\"consideration\" + 0.008*\"follow\"'),\n",
" (2,\n",
" u'0.034*\"argue\" + 0.031*\"456\" + 0.031*\"arithmetic\" + 0.024*\"courtesy\" + 0.020*\"beastmaster\" + 0.019*\"bitnet\" + 0.015*\"false\" + 0.015*\"classified\" + 0.014*\"cubs\" + 0.014*\"digex\"'),\n",
" (3,\n",
" u'0.108*\"abroad\" + 0.089*\"asking\" + 0.060*\"cryptography\" + 0.035*\"certain\" + 0.030*\"ciphertext\" + 0.030*\"book\" + 0.028*\"69\" + 0.028*\"demand\" + 0.028*\"87\" + 0.027*\"cracking\"'),\n",
" (4,\n",
" u'0.022*\"clark\" + 0.019*\"authentication\" + 0.017*\"candidates\" + 0.016*\"decryption\" + 0.015*\"attempt\" + 0.013*\"creation\" + 0.013*\"1993apr5\" + 0.013*\"acceptable\" + 0.013*\"algorithms\" + 0.013*\"employer\"')]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"obj=SklearnWrapperLdaModel(id2word=id2word,num_topics=5,passes=20)\n",
"lda=obj.fit(X)\n",
"lda.print_topics()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"#### Using together with Scikit learn's Logistic Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now lets try Sklearn's logistic classifier to classify the given categories into two types.Ideally we should get postive weights when cryptography is talked about and negative when baseball is talked about."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn import linear_model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def print_features(clf, vocab, n=10):\n",
" ''' Better printing for sorted list '''\n",
" coef = clf.coef_[0]\n",
" print 'Positive features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[::-1][:n] if coef[j] > 0]))\n",
" print 'Negative features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[:n] if coef[j] < 0]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Positive features: clipper:1.50 code:1.24 key:1.04 encryption:0.95 chip:0.37 nsa:0.37 government:0.36 uk:0.36 org:0.23 cryptography:0.23\n",
"Negative features: baseball:-1.32 game:-0.71 year:-0.61 team:-0.38 edu:-0.27 games:-0.26 players:-0.23 ball:-0.17 season:-0.14 phillies:-0.11\n"
]
}
],
"source": [
"clf=linear_model.LogisticRegression(penalty='l1', C=0.1) #l1 penalty used\n",
"clf.fit(X,data.target)\n",
"print_features(clf,vocab)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
4 changes: 1 addition & 3 deletions docs/src/apiref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,4 @@ Modules:
summarization/summariser
summarization/syntactic_unit
summarization/textcleaner



sklearn_integration/sklearn_wrapper_gensim_ldamodel
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
:mod:`sklearn_integration.sklearn_wrapper_gensim_ldamodel.SklearnWrapperLdaModel` -- Scikit learn wrapper for Latent Dirichlet Allocation
======================================================

.. automodule:: gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel.SklearnWrapperLdaModel
:synopsis: Scikit learn wrapper for LDA model
:members:
:inherited-members:
:undoc-members:
:show-inheritance:
10 changes: 10 additions & 0 deletions gensim/sklearn_integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""Scikit learn wrapper for gensim.
Contains various gensim based implementations which match with scikit-learn standards.
See [1] for complete set of conventions.
[1] http://scikit-learn.org/stable/developers/
"""
Loading

0 comments on commit 0e0c082

Please sign in to comment.