-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Added Jensen Shannon metric and dendrogram visualization #1484
Changes from 19 commits
91b7eec
7a5c0c3
a190f7d
2fce748
1614dfb
d6601e5
30ad840
21df1cb
f0ba371
a702843
4998dfb
b31badb
c1e2b9b
2bfdbed
b9e7ab5
b8be8db
ee2a8c1
70b474a
34308a7
cd8ced7
159edb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,378 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Visualizing Topic clusters\n", | ||
"\n", | ||
"In this notebook, we will see how to visualize the clusters of topics using dendrogram and also see the exact distances with the help of a heatmap. Let's first train a LDA model to get the topics in kaggle's fake news dataset which can be dowloaded from [here](https://www.kaggle.com/mrisdal/fake-news)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.models.ldamodel import LdaModel\n", | ||
"from gensim.corpora import Dictionary, MmCorpus\n", | ||
"import pandas as pd\n", | ||
"import re\n", | ||
"from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"from scipy.spatial.distance import pdist, squareform\n", | ||
"\n", | ||
"import plotly.offline as py\n", | ||
"import plotly.graph_objs as go\n", | ||
"import plotly.figure_factory as FF\n", | ||
"\n", | ||
"py.init_notebook_mode()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Train Model\n", | ||
"\n", | ||
"We'll preprocess the data before training. You can refer to this [notebook](https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/lda_training_tips.ipynb) also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"df_fake = pd.read_csv('fake.csv')\n", | ||
"df_fake[['title', 'text', 'language']].head()\n", | ||
"df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language == 'english')]\n", | ||
"\n", | ||
"# remove stopwords and punctuations\n", | ||
"def preprocess(row):\n", | ||
" return strip_punctuation(remove_stopwords(row.lower()))\n", | ||
" \n", | ||
"df_fake['text'] = df_fake['text'].apply(preprocess)\n", | ||
"\n", | ||
"# Convert data to required input format by LDA\n", | ||
"texts = []\n", | ||
"for line in df_fake.text:\n", | ||
" lowered = line.lower()\n", | ||
" words = re.findall(r'\\w+', lowered, flags = re.UNICODE | re.LOCALE)\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary spaces |
||
" texts.append(words)\n", | ||
"# Create a dictionary representation of the documents.\n", | ||
"dictionary = Dictionary(texts)\n", | ||
"\n", | ||
"# Filter out words that occur less than 2 documents, or more than 30% of the documents.\n", | ||
"dictionary.filter_extremes(no_below=2, no_above=0.4)\n", | ||
"# Bag-of-words representation of the documents.\n", | ||
"corpus_fake = [dictionary.doc2bow(text) for text in texts]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, passes=30, chunksize=1500, iterations=200, alpha='auto')\n", | ||
"lda_fake.save('lda_35')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"lda_fake = LdaModel.load('lda_35')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Basic Dendrogram\n", | ||
"\n", | ||
"Let's first look at the dendrogram only, to understand how the topic clusters are formed and what does the various values in the plot represents.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# plotly's code" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true, | ||
"scrolled": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.matutils import jensen_shannon\n", | ||
"\n", | ||
"from random import sample\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused imports: |
||
"\n", | ||
"import scipy as scp\n", | ||
"from scipy.cluster import hierarchy as sch\n", | ||
"from scipy import spatial as scs\n", | ||
"\n", | ||
"# get topic distributions\n", | ||
"topic_dist = lda_fake.state.get_lambda()\n", | ||
"\n", | ||
"# get topic terms\n", | ||
"num_words = 300\n", | ||
"topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]\n", | ||
"\n", | ||
"# no. of terms to display in annotation\n", | ||
"n_ann_terms = 10\n", | ||
"\n", | ||
"# use Jenson-Shannon distance metric in dendrogram\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One more "Jenson". |
||
"def js_dist(X):\n", | ||
" return pdist(X, lambda u, v: jensen_shannon(u, v))\n", | ||
"\n", | ||
"# calculate text annotations\n", | ||
"def text_annotation(topic_dist, topic_terms, n_ann_terms):\n", | ||
" # get dendrogram hierarchy data\n", | ||
" linkagefun = lambda x: sch.linkage(x, 'complete')\n", | ||
" d = js_dist(topic_dist)\n", | ||
" Z = linkagefun(d)\n", | ||
" P = sch.dendrogram(Z, orientation=\"bottom\", no_plot=True)\n", | ||
"\n", | ||
" # store topic no.(leaves) corresponding to the x-ticks in dendrogram\n", | ||
" x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)\n", | ||
" x_topic = dict(zip(P['leaves'], x_ticks))\n", | ||
"\n", | ||
" # store {topic no.:topic terms}\n", | ||
" topic_vals = dict()\n", | ||
" for key, val in x_topic.items():\n", | ||
" topic_vals[val] = (topic_terms[key], topic_terms[key])\n", | ||
"\n", | ||
" text_annotations = []\n", | ||
" # loop through every trace (scatter plot) in dendrogram\n", | ||
" for trace in P['icoord']:\n", | ||
" fst_topic = topic_vals[trace[0]]\n", | ||
" scnd_topic = topic_vals[trace[2]]\n", | ||
" \n", | ||
" # annotation for two ends of current trace\n", | ||
" pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)]\n", | ||
" neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)]\n", | ||
"\n", | ||
" pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)]\n", | ||
" neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)]\n", | ||
"\n", | ||
" t1 = \"<br>\".join((\": \".join((\"+++\", str(pos_tokens_t1))), \": \".join((\"---\", str(neg_tokens_t1)))))\n", | ||
" t2 = t3 = ()\n", | ||
" t4 = \"<br>\".join((\": \".join((\"+++\", str(pos_tokens_t4))), \": \".join((\"---\", str(neg_tokens_t4)))))\n", | ||
"\n", | ||
" # show topic terms in leaves\n", | ||
" if trace[0] in x_ticks:\n", | ||
" t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms])\n", | ||
" if trace[2] in x_ticks:\n", | ||
" t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms])\n", | ||
"\n", | ||
" text_annotations.append([t1, t2, t3, t4])\n", | ||
"\n", | ||
" # calculate intersecting/diff for upper level\n", | ||
" intersecting = fst_topic[0] & scnd_topic[0]\n", | ||
" different = fst_topic[0].symmetric_difference(scnd_topic[0])\n", | ||
"\n", | ||
" center = (trace[0] + trace[2]) / 2\n", | ||
" topic_vals[center] = (intersecting, different)\n", | ||
"\n", | ||
" # remove trace value after it is annotated\n", | ||
" topic_vals.pop(trace[0], None)\n", | ||
" topic_vals.pop(trace[2], None) \n", | ||
" \n", | ||
" return text_annotations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true, | ||
"scrolled": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# get text annotations\n", | ||
"annotation = text_annotation(topic_dist, topic_terms, n_ann_terms)\n", | ||
"\n", | ||
"# Plot dendrogram\n", | ||
"dendro = create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), annotation=annotation)\n", | ||
"dendro['layout'].update({'width':1000, 'height':600})\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spaces after |
||
"py.iplot(dendro)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The x-axis or the leaves of hierarchy represent the topics of our LDA model, y-axis is a measure of closeness of either individual topics or their cluster. Essentially, the y-axis level at which the branches merge (relative to the \"root\" of the tree) is related to their similarity. For ex., topic 15 and 27 are more similar to each other than to topic 25. In addition, topic 8 and 30 are more similar to 28 than topic 15 and 27 are to topic 25 as the height on which they merge is lower than the merge height of 15/27 to 25.\n", | ||
"\n", | ||
"Text annotations visible on hovering over the cluster nodes show the intersecting/different terms of it's two child nodes. Cluster node on first hierarchy level uses the topics on leaves directly to calculate intersecting/different terms, and the upper nodes assume the intersection(+++) as the topic terms of it's child node." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Dendrogram with a Heatmap\n", | ||
"\n", | ||
"Now lets append the distance matrix of the topics below the dendrogram in form of heatmap after which we would be able to see the exact distances between all pair of topics." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# get text annotations\n", | ||
"annotation = text_annotation(topic_dist, topic_terms, n_ann_terms)\n", | ||
"\n", | ||
"# Initialize figure by creating upper dendrogram\n", | ||
"figure = create_dendrogram(topic_dist, distfun=js_dist, labels = range(1, 36), annotation=annotation)\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no needed spaces |
||
"for i in range(len(figure['data'])):\n", | ||
" figure['data'][i]['yaxis'] = 'y2'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# get distance matrix and it's topic annotations\n", | ||
"mdiff, annotation = lda_fake.diff(lda_fake, distance=\"jensen_shannon\", normed=False)\n", | ||
"\n", | ||
"# get reordered topic list\n", | ||
"dendro_leaves = figure['layout']['xaxis']['ticktext']\n", | ||
"dendro_leaves = list(map(int, dendro_leaves-1))\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's type of dendro_leaves? Maybe you want to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a list of int values where every value needs to be subtracted by 1. Simplified this using list comprehension now |
||
"\n", | ||
"# reorder distance matrix\n", | ||
"heat_data = mdiff[dendro_leaves,:]\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space after |
||
"heat_data = heat_data[:,dendro_leaves]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# heatmap annotation\n", | ||
"annotation_html = [[\"+++ {}<br>--- {}\".format(\", \".join(int_tokens), \", \".join(diff_tokens))\n", | ||
" for (int_tokens, diff_tokens) in row]\n", | ||
" for row in annotation]\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move to previous line |
||
"\n", | ||
"# plot heatmap of distance matrix\n", | ||
"heatmap = go.Data([\n", | ||
" go.Heatmap(\n", | ||
" z = heat_data,\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No needed spaces (and below) |
||
" colorscale = 'YIGnBu',\n", | ||
" text = annotation_html,\n", | ||
" hoverinfo='x+y+z+text'\n", | ||
" )\n", | ||
"])\n", | ||
"\n", | ||
"heatmap[0]['x'] = figure['layout']['xaxis']['tickvals']\n", | ||
"heatmap[0]['y'] = figure['layout']['xaxis']['tickvals']\n", | ||
"\n", | ||
"# Add Heatmap Data to Figure\n", | ||
"figure['data'].extend(heatmap)\n", | ||
"\n", | ||
"dendro_leaves = [x+1 for x in dendro_leaves]\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing whitespace |
||
"\n", | ||
"# Edit Layout\n", | ||
"figure['layout'].update({'width':800, 'height':800,\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing spaces after |
||
" 'showlegend':False, 'hovermode': 'closest',\n", | ||
" })\n", | ||
"\n", | ||
"# Edit xaxis\n", | ||
"figure['layout']['xaxis'].update({'domain': [.25, 1],\n", | ||
" 'mirror': False,\n", | ||
" 'showgrid': False,\n", | ||
" 'showline': False,\n", | ||
" \"showticklabels\": True, \n", | ||
" \"tickmode\": \"array\",\n", | ||
" \"ticktext\" : dendro_leaves,\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary space (before |
||
" \"tickvals\" : figure['layout']['xaxis']['tickvals'],\n", | ||
" 'zeroline': False,\n", | ||
" 'ticks':\"\"})\n", | ||
"# Edit yaxis\n", | ||
"figure['layout']['yaxis'].update({'domain': [0, 0.75],\n", | ||
" 'mirror': False,\n", | ||
" 'showgrid': False,\n", | ||
" 'showline': False,\n", | ||
" \"showticklabels\": True, \n", | ||
" \"tickmode\": \"array\",\n", | ||
" \"ticktext\" : dendro_leaves,\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary space (before |
||
" \"tickvals\" : figure['layout']['xaxis']['tickvals'],\n", | ||
" 'zeroline': False,\n", | ||
" 'ticks': \"\"})\n", | ||
"# Edit yaxis2\n", | ||
"figure['layout'].update({'yaxis2':{'domain':[0.75, 1],\n", | ||
" 'mirror': False,\n", | ||
" 'showgrid': False,\n", | ||
" 'showline': False,\n", | ||
" 'zeroline': False,\n", | ||
" 'showticklabels': False,\n", | ||
" 'ticks':\"\"}})\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing space |
||
"\n", | ||
"py.iplot(figure)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now the distance matrix is reordered to match the topic cluster order of the upper dendrogram leaves. We can see the exact distance measure between any two topics in the z-value and also their intersecting or different terms in the +++/--- annotation." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.4.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused imports:
gensim.corpora.MmCorpus
,scipy.spatial.distance.squareform
,plotly.figure_factory