Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Added Jensen Shannon metric and dendrogram visualization #1484

Merged
merged 21 commits into from
Aug 23, 2017
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
378 changes: 378 additions & 0 deletions docs/notebooks/Topic_dendrogram.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,378 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualizing Topic clusters\n",
"\n",
"In this notebook, we will see how to visualize the clusters of topics using dendrogram and also see the exact distances with the help of a heatmap. Let's first train a LDA model to get the topics in kaggle's fake news dataset which can be dowloaded from [here](https://www.kaggle.com/mrisdal/fake-news)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from gensim.models.ldamodel import LdaModel\n",
"from gensim.corpora import Dictionary, MmCorpus\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports: gensim.corpora.MmCorpus, scipy.spatial.distance.squareform, plotly.figure_factory

"import pandas as pd\n",
"import re\n",
"from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation\n",
"\n",
"import numpy as np\n",
"from scipy.spatial.distance import pdist, squareform\n",
"\n",
"import plotly.offline as py\n",
"import plotly.graph_objs as go\n",
"import plotly.figure_factory as FF\n",
"\n",
"py.init_notebook_mode()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train Model\n",
"\n",
"We'll preprocess the data before training. You can refer to this [notebook](https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/lda_training_tips.ipynb) also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"df_fake = pd.read_csv('fake.csv')\n",
"df_fake[['title', 'text', 'language']].head()\n",
"df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language == 'english')]\n",
"\n",
"# remove stopwords and punctuations\n",
"def preprocess(row):\n",
" return strip_punctuation(remove_stopwords(row.lower()))\n",
" \n",
"df_fake['text'] = df_fake['text'].apply(preprocess)\n",
"\n",
"# Convert data to required input format by LDA\n",
"texts = []\n",
"for line in df_fake.text:\n",
" lowered = line.lower()\n",
" words = re.findall(r'\\w+', lowered, flags = re.UNICODE | re.LOCALE)\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary spaces flags = re.UNICODE | re.LOCALE

" texts.append(words)\n",
"# Create a dictionary representation of the documents.\n",
"dictionary = Dictionary(texts)\n",
"\n",
"# Filter out words that occur less than 2 documents, or more than 30% of the documents.\n",
"dictionary.filter_extremes(no_below=2, no_above=0.4)\n",
"# Bag-of-words representation of the documents.\n",
"corpus_fake = [dictionary.doc2bow(text) for text in texts]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, passes=30, chunksize=1500, iterations=200, alpha='auto')\n",
"lda_fake.save('lda_35')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"lda_fake = LdaModel.load('lda_35')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basic Dendrogram\n",
"\n",
"Let's first look at the dendrogram only, to understand how the topic clusters are formed and what does the various values in the plot represents.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# plotly's code"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"from gensim.matutils import jensen_shannon\n",
"\n",
"from random import sample\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports: random.sample, scipy

"\n",
"import scipy as scp\n",
"from scipy.cluster import hierarchy as sch\n",
"from scipy import spatial as scs\n",
"\n",
"# get topic distributions\n",
"topic_dist = lda_fake.state.get_lambda()\n",
"\n",
"# get topic terms\n",
"num_words = 300\n",
"topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])]\n",
"\n",
"# no. of terms to display in annotation\n",
"n_ann_terms = 10\n",
"\n",
"# use Jenson-Shannon distance metric in dendrogram\n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more "Jenson".

"def js_dist(X):\n",
" return pdist(X, lambda u, v: jensen_shannon(u, v))\n",
"\n",
"# calculate text annotations\n",
"def text_annotation(topic_dist, topic_terms, n_ann_terms):\n",
" # get dendrogram hierarchy data\n",
" linkagefun = lambda x: sch.linkage(x, 'complete')\n",
" d = js_dist(topic_dist)\n",
" Z = linkagefun(d)\n",
" P = sch.dendrogram(Z, orientation=\"bottom\", no_plot=True)\n",
"\n",
" # store topic no.(leaves) corresponding to the x-ticks in dendrogram\n",
" x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)\n",
" x_topic = dict(zip(P['leaves'], x_ticks))\n",
"\n",
" # store {topic no.:topic terms}\n",
" topic_vals = dict()\n",
" for key, val in x_topic.items():\n",
" topic_vals[val] = (topic_terms[key], topic_terms[key])\n",
"\n",
" text_annotations = []\n",
" # loop through every trace (scatter plot) in dendrogram\n",
" for trace in P['icoord']:\n",
" fst_topic = topic_vals[trace[0]]\n",
" scnd_topic = topic_vals[trace[2]]\n",
" \n",
" # annotation for two ends of current trace\n",
" pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)]\n",
" neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)]\n",
"\n",
" pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)]\n",
" neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)]\n",
"\n",
" t1 = \"<br>\".join((\": \".join((\"+++\", str(pos_tokens_t1))), \": \".join((\"---\", str(neg_tokens_t1)))))\n",
" t2 = t3 = ()\n",
" t4 = \"<br>\".join((\": \".join((\"+++\", str(pos_tokens_t4))), \": \".join((\"---\", str(neg_tokens_t4)))))\n",
"\n",
" # show topic terms in leaves\n",
" if trace[0] in x_ticks:\n",
" t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms])\n",
" if trace[2] in x_ticks:\n",
" t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms])\n",
"\n",
" text_annotations.append([t1, t2, t3, t4])\n",
"\n",
" # calculate intersecting/diff for upper level\n",
" intersecting = fst_topic[0] & scnd_topic[0]\n",
" different = fst_topic[0].symmetric_difference(scnd_topic[0])\n",
"\n",
" center = (trace[0] + trace[2]) / 2\n",
" topic_vals[center] = (intersecting, different)\n",
"\n",
" # remove trace value after it is annotated\n",
" topic_vals.pop(trace[0], None)\n",
" topic_vals.pop(trace[2], None) \n",
" \n",
" return text_annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"# get text annotations\n",
"annotation = text_annotation(topic_dist, topic_terms, n_ann_terms)\n",
"\n",
"# Plot dendrogram\n",
"dendro = create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), annotation=annotation)\n",
"dendro['layout'].update({'width':1000, 'height':600})\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spaces after :

"py.iplot(dendro)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The x-axis or the leaves of hierarchy represent the topics of our LDA model, y-axis is a measure of closeness of either individual topics or their cluster. Essentially, the y-axis level at which the branches merge (relative to the \"root\" of the tree) is related to their similarity. For ex., topic 15 and 27 are more similar to each other than to topic 25. In addition, topic 8 and 30 are more similar to 28 than topic 15 and 27 are to topic 25 as the height on which they merge is lower than the merge height of 15/27 to 25.\n",
"\n",
"Text annotations visible on hovering over the cluster nodes show the intersecting/different terms of it's two child nodes. Cluster node on first hierarchy level uses the topics on leaves directly to calculate intersecting/different terms, and the upper nodes assume the intersection(+++) as the topic terms of it's child node."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dendrogram with a Heatmap\n",
"\n",
"Now lets append the distance matrix of the topics below the dendrogram in form of heatmap after which we would be able to see the exact distances between all pair of topics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# get text annotations\n",
"annotation = text_annotation(topic_dist, topic_terms, n_ann_terms)\n",
"\n",
"# Initialize figure by creating upper dendrogram\n",
"figure = create_dendrogram(topic_dist, distfun=js_dist, labels = range(1, 36), annotation=annotation)\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no needed spaces labels = range(1, 36)

"for i in range(len(figure['data'])):\n",
" figure['data'][i]['yaxis'] = 'y2'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# get distance matrix and it's topic annotations\n",
"mdiff, annotation = lda_fake.diff(lda_fake, distance=\"jensen_shannon\", normed=False)\n",
"\n",
"# get reordered topic list\n",
"dendro_leaves = figure['layout']['xaxis']['ticktext']\n",
"dendro_leaves = list(map(int, dendro_leaves-1))\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's type of dendro_leaves? Maybe you want to [int(dendro_leaves - 1)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a list of int values where every value needs to be subtracted by 1. Simplified this using list comprehension now

"\n",
"# reorder distance matrix\n",
"heat_data = mdiff[dendro_leaves,:]\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after , (and below)

"heat_data = heat_data[:,dendro_leaves]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# heatmap annotation\n",
"annotation_html = [[\"+++ {}<br>--- {}\".format(\", \".join(int_tokens), \", \".join(diff_tokens))\n",
" for (int_tokens, diff_tokens) in row]\n",
" for row in annotation]\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to previous line

"\n",
"# plot heatmap of distance matrix\n",
"heatmap = go.Data([\n",
" go.Heatmap(\n",
" z = heat_data,\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No needed spaces (and below)

" colorscale = 'YIGnBu',\n",
" text = annotation_html,\n",
" hoverinfo='x+y+z+text'\n",
" )\n",
"])\n",
"\n",
"heatmap[0]['x'] = figure['layout']['xaxis']['tickvals']\n",
"heatmap[0]['y'] = figure['layout']['xaxis']['tickvals']\n",
"\n",
"# Add Heatmap Data to Figure\n",
"figure['data'].extend(heatmap)\n",
"\n",
"dendro_leaves = [x+1 for x in dendro_leaves]\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing whitespace x+1

"\n",
"# Edit Layout\n",
"figure['layout'].update({'width':800, 'height':800,\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing spaces after : (and below)

" 'showlegend':False, 'hovermode': 'closest',\n",
" })\n",
"\n",
"# Edit xaxis\n",
"figure['layout']['xaxis'].update({'domain': [.25, 1],\n",
" 'mirror': False,\n",
" 'showgrid': False,\n",
" 'showline': False,\n",
" \"showticklabels\": True, \n",
" \"tickmode\": \"array\",\n",
" \"ticktext\" : dendro_leaves,\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary space (before :) and below

" \"tickvals\" : figure['layout']['xaxis']['tickvals'],\n",
" 'zeroline': False,\n",
" 'ticks':\"\"})\n",
"# Edit yaxis\n",
"figure['layout']['yaxis'].update({'domain': [0, 0.75],\n",
" 'mirror': False,\n",
" 'showgrid': False,\n",
" 'showline': False,\n",
" \"showticklabels\": True, \n",
" \"tickmode\": \"array\",\n",
" \"ticktext\" : dendro_leaves,\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary space (before :) and below

" \"tickvals\" : figure['layout']['xaxis']['tickvals'],\n",
" 'zeroline': False,\n",
" 'ticks': \"\"})\n",
"# Edit yaxis2\n",
"figure['layout'].update({'yaxis2':{'domain':[0.75, 1],\n",
" 'mirror': False,\n",
" 'showgrid': False,\n",
" 'showline': False,\n",
" 'zeroline': False,\n",
" 'showticklabels': False,\n",
" 'ticks':\"\"}})\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing space

"\n",
"py.iplot(figure)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now the distance matrix is reordered to match the topic cluster order of the upper dendrogram leaves. We can see the exact distance measure between any two topics in the z-value and also their intersecting or different terms in the +++/--- annotation."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading