Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two Teaching Vignettes for Dirichlet Processes #37

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions analysis/gibbs_sampling_DP_process.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
---
title: "Application on Dirichlet Process Model"
author: "Kaiqian Zhang"
date: "3/13/2018"
output:
html_document:
code_folding: hide
runtime: shiny
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```

## Pre-requisites

You need to know how does Gibbs sampling work and what a Dirichlet process is.

## Application: Clustering Gaussian Mixture Data

Suppose we have data $X = \{x_1,\dots,x_n\}$ from a 10-dimensional Gaussian mixture model. The $k$th component has two parameters $\boldsymbol{\mu}_k$ and $\sum_k$ since each component is multi-normal distributed. We do not know the number of clusters. We will use Gibbs sampling (what we've learned before) to make inference on the number of clusters.

Since we do not know the number of mixtures, we use Dirichlet process as our prior model. Here is our prior setting:

* $\beta_k \sim \text{Beta}(1,\alpha_0)$

* $\boldsymbol{\pi} = \text{Stick-Breaking-Process}(\boldsymbol{\beta})$

* $z_i | \boldsymbol{\pi} \sim \text{Multi}(\boldsymbol{\pi})$

* $\boldsymbol{\mu}_k \sim \text{Normal}(\boldsymbol{\mu}_0,\Sigma_0)$

* $\sum_k \sim \text{Inverse-Wishart}(\nu_0, \boldsymbol{\psi}_0)$

where $i=1\dots n$ and $k=1\dots$.

The conditional posteriors are: for the $s+1$th iteration,

* $\beta^{(s+1)}_k \sim \text{Beta}(1+n_k, \alpha_0+n_{k+})$ where $n_k$ denotes the number of data in the $k$th cluster and $n_{k+}$ is the sum of observations in the $k+1$, ..., up to $K$ cluster.

* $z_i^{(s+1)}\mid \boldsymbol{\pi}_{i}^{(s+1)},\boldsymbol{\mu}_k^{(s)}, \Sigma_k^{(s)},x_i \sim \text{Multi}(\boldsymbol{\pi}_{i}^{(s+1)})$, where $\boldsymbol{\pi}_{i,k}^{(s+1)} = P(x_i |{\boldsymbol{\mu}_k}^{(s)}, \Sigma_k^{(s)})\cdot \boldsymbol{\pi}_{i,k}^{(s)}$. Here, $\boldsymbol{\pi}_{i,k}$ denotes the probability that $x_i$ is from the $k$th cluster.

* $\boldsymbol{\mu}_k^{(s+1)} \sim \text{Normal}((\Sigma_0^{-1}+n_k(\Sigma_k^{(s)})^{-1})^{-1} (\Sigma_0^{-1}\boldsymbol{\mu}_0 + n_k(\Sigma_k^{(s)})^{-1}\overline{x}),(\Sigma_0^{-1}+n_k(\Sigma_k^{(s)})^{-1})^{-1})$

* $\Sigma_k^{(s+1)} \sim \text{Inverse-Wishart}(n_k+\nu_0, \boldsymbol{\psi}_0+\sum_{i=1}^{n_k}(x_i-\boldsymbol{\mu}_k^{(s)})(x_i-\boldsymbol{\mu}_k^{(s)})^{T})$.

We write a Gibbs sampling algorithm based on the outline above, and we will simulate data and test our Gibbs sampling to see whether it can help find the number of mixtures.

We generate data from a 10-dimensional Gaussian mixture with six clusters. The proportion for this mixture is $50:25:100:50:50:25$. Here are our test data.

```{r echo=FALSE}
library(rmarkdown)
library(mvtnorm)
library(MASS)
library(MCMCpack)
```

```{r}
set.seed(12345)
mu1 = runif(10, -1, 0)
mu2 = runif(10, 5, 7)
mu3 = runif(10, 0, 5)
mu4 = runif(10, -5, 0)
mu5 = runif(10, 10, 15)
mu6 = runif(10, 15, 20)
Sigma1 = rWishart(1, 15, diag(1/15, 10))[,,1]
Sigma2 = rWishart(1, 15, diag(1/15, 10))[,,1]
Sigma3 = rWishart(1, 15, diag(1/15, 10))[,,1]
Sigma4 = rWishart(1, 15, diag(1/15, 10))[,,1]
Sigma5 = rWishart(1, 15, diag(1/15, 10))[,,1]
Sigma6 = rWishart(1, 15, diag(1/15, 10))[,,1]
Data = rbind(mvrnorm(50, mu1, Sigma1),
mvrnorm(25, mu2, Sigma2),
mvrnorm(100, mu3, Sigma3),
mvrnorm(50, mu4, Sigma4),
mvrnorm(50, mu5, Sigma5),
mvrnorm(25, mu6, Sigma6))
```

Now suppose we do not know the number of clusters in our data. We propose that the number of mixtures might be 20 and our Gibbs sampling function will help us find the true number.

```{r gibbs, echo=FALSE}
stick <- function(activeK, alpha, N){
pi = rep(0, activeK)
beta_vec = rep(0, activeK)
for (k in 1:activeK){
beta_vec[k] = rbeta(1,1+N[k], alpha+sum(N[k:activeK])-N[k])
}
pi[1] = beta_vec[1]
for (i in 2:activeK){
pi[i] = beta_vec[i]*prod(1-beta_vec[1:(i-1)])
}
pi
}

postProb <- function(my_data, pi, mu_list, sigma_list){
activeK = length(mu_list)
temp = pi
post_pi = rep(0, length(pi))
class_assignments = matrix(0, activeK, dim(my_data)[1])
for (i in 1:dim(my_data)[1]){
for (k in 1:activeK){
mu_temp = mu_list[[k]]
sigma_temp = sigma_list[[k]]
temp[k] = pi[k]*dmvnorm(my_data[i,], mu_temp, sigma_temp)
}
post_pi = temp/sum(temp)
Z = match(1, rmultinom(1, 1, post_pi))
class_assignments[,i] = 0
class_assignments[Z,i] = 1
}
class_assignments
}

updateMu <- function(p, subsetData, n_k, mu_0, lambda, sigma_k){
if(length(subsetData)==p){
xbar = subsetData
}else{
xbar = colMeans(subsetData)
}
mumean = (mu_0*lambda+n_k*xbar)/(lambda+n_k)
muvar = 1/(lambda+n_k)*sigma_k
res = mvrnorm(1, mumean, muvar)
res
}

updateSigma <- function(p, subsetData, n_k, nu, psi, lambda, mu_0){
first = nu + n_k
if(length(subsetData)==p){
xbar = subsetData
second = psi + (lambda*n_k)/(lambda+n_k)*(xbar-mu_0)%*%t(xbar-mu_0)
riwish(first, second)
}else{
xbar = colMeans(subsetData)
ret = 0
for (i in 1:n_k){
temp = subsetData[i,]-xbar
ret = ret + temp %*% t(temp)
}
second = psi + ret + (lambda*n_k)/(lambda+n_k)*(xbar-mu_0)%*%t(xbar-mu_0)
riwish(first, second)
}
}

Gibbs <- function(my_data, activeK, rep, nu, psi, mu_0, sigma_0, lambda){
# Set-up
p = ncol(my_data)
pi = rep(1/activeK, activeK) # prior set-up
mu_list = vector('list', activeK) # likelihood set-up
sigma_list = vector('list', activeK)
# Initialization
kmeans_init = kmeans(my_data, activeK) # kmeans initialize mu and sigma
mu_init_Mat = kmeans_init$centers
for (i in 1:activeK){
cluster_index = which(kmeans_init$cluster == i)
cluster_index = as.vector(cluster_index)
kth_cluster_data = my_data[cluster_index,]
sigma_list[[i]] = cov(kth_cluster_data)
mu_list[[i]] = mu_init_Mat[i,]
}
# Initial posterior Z-label matrix
membership = postProb(my_data, pi, mu_list, sigma_list)
N = rowSums(membership) # record n_k in each cluster
print('Initial assignment:')
print(N)

# Identify active clusters so far (i.e: which clusters have more than one element, update only those)
activeK_index = which(N>1)

for (s in 1:rep){
pi = stick(activeK,1,N) # prior
membership = postProb(my_data, pi, mu_list, sigma_list)
N = rowSums(membership)
for (k in 1:activeK){
if(N[k]>0){
k_cluster_index = which(membership[k,] %in% 1)
subsetData = my_data[k_cluster_index,]
mu_k = updateMu(p, subsetData, N[k], mu_0, lambda, sigma_list[[k]])
sigma_k = updateSigma(p, subsetData, N[k], nu, psi, lambda, mu_0)
mu_list[[k]] = mu_k
sigma_list[[k]] = sigma_k
}else{
mu_list[[k]] = mu_0
sigma_list[[k]] = sigma_0
}
}
}
valid_group_index = which(N>0)
final_mu_list = mu_list[valid_group_index]
final_sigma_list = sigma_list[valid_group_index]
groupings = c(1:activeK)%*%membership
print('Final assignment: ')
print(N)
return(list('mu'=final_mu_list, 'sigma' = final_sigma_list, 'pi' = pi, 'membership' = groupings))
}
```


```{r}
result = Gibbs(Data, activeK=20, rep=300, nu=12, psi=diag(10,10), mu_0=rep(0,10), sigma_0=diag(10,10), lambda=1)
```

Our Gibbs sampling algorithm also makes inference on the $\boldsymbol{\mu}$ and $\Sigma$ for each cluster and they are amazingly close to the actual $\boldsymbol{\mu}$ and $\Sigma$. We show our estimates on $\boldsymbol{\mu}$ here.

```{r}
result$mu
```
179 changes: 179 additions & 0 deletions analysis/shiny_dirichlet_process.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
---
title: "Introduction to Dirichlet Process"
author: "Kaiqian Zhang"
date: "3/6/2018"
output:
html_document:
code_folding: hide
runtime: shiny
---

## Pre-requisites

You need to know what a Dirichlet distribution is, and some introductory knowledge about mixture models.

## Introduction

We have introduced one popular stochastic process in Bayesian nonparametric inference, which is the Gaussian process. Dirichlet process is another widely used stochastic process in this field. Like Gaussian process, which has Gaussian distributed finite-dimensional marginal distributions, Dirichlet process has Dirichlet distributed finite-dimensional marginal distributions. We always say that Dirichlet process is a distribution of distributions, i.e. each draw from a Dirichlet Process is a probability measure.

We have seen that when estimating proportion of a mixture model, Dirichlet distribution is the conjugate prior for the multinomial distribution. In a similar way, the Dirichlet process is the conjugate prior for infinite, nonparametric discrete distributions. We usually use Dirichlet process as a prior distribution in infinite mixture models.

## Motivation: A Clustering Problem

Previously, we talked about mixture models. Suppose we have a Gaussian mixture of $K$ components:
\[P(x_1,\dots,x_n\mid \boldsymbol{\pi},\boldsymbol{\mu},\boldsymbol{\sigma^2})=\prod_{i=1}^{n}\sum_{k=1}^{K}\pi_kN(x_i\mid \mu_k,\sigma_k),\]
where $\boldsymbol{\pi}=\{\pi_1,\dots,\pi_K\}, \boldsymbol{\mu}=\{\mu_1,\dots,\mu_K\},$ and $\boldsymbol{\sigma^2}=\{\sigma^2_1,\dots,\sigma^2_K\}.$

Suppose we are instereted in estimating the mixture weights $\boldsymbol{\pi}$. The posterior distribution is
\[P(\boldsymbol{\pi} \mid x_1,\dots,x_n, \boldsymbol{\mu}, \boldsymbol{\sigma^2}) \propto P(\boldsymbol{\pi}, x_1,\dots,x_n, \boldsymbol{\mu}, \boldsymbol{\sigma^2})\]
\[ = P(x_1,\dots,x_n\mid \boldsymbol{\pi},\boldsymbol{\mu},\boldsymbol{\sigma^2})P(\boldsymbol{\mu})P(\boldsymbol{\sigma^2})P(\boldsymbol{\pi})\]
\[\propto P(x_1,\dots,x_n\mid \boldsymbol{\pi},\boldsymbol{\mu},\boldsymbol{\sigma^2})P(\boldsymbol{\pi}).\]
To simplify posterior inference, we'd better choose conjugate prior distributions. The conjugate prior for mixture weights $\boldsymbol{\pi}$ is the Dirichlet distribution.

Now consider a mixture model again. This time, however, we do not know the number of mixtures. This is essentially a clustering problem. When defining a prior on mixture wieghts, we need a distribution that allows an infinite number of mixtures. Therefore, we would like a prior that has properties like that of the Dirichlet distribution, but of infinite dimensions.

Hence we have
\[(\pi_1,\dots,\pi_K) \sim \text{Dir}(\frac{\alpha}{K},\frac{\alpha}{K},\dots)\]
and we want to take the limit to get infinite dimensions:
\[\boldsymbol{\pi} \sim \lim_{K\to \infty}\text{Dir}(\frac{\alpha}{K},\frac{\alpha}{K},\dots).\]

For each $\pi_i$ drawn from the above distribution, we associate a draw $\theta_i$ from a base distribution:
\[\theta_i \sim H \text{ for i=1,2,...}\infty.\]

We define
\[G:= \sum_{k=1}^{\infty}\pi_k\delta_{\theta_k}\]
is a discrete probability measure drawn from a Dirichlet process, $G\sim \text{DP}(H,\alpha)$. We will illustrate this idea by two shiny apps afterwards.

## Definition

We give out the formal definition here:

The Dirichlet process has two parameters, a base distribution $H$ and a concentration scalar $\alpha$. We say that $G$ is Dirichlet process distributed with base distribution $H$ and concentration parameter $\alpha$, denote $G \sim \text{DP}(H,\alpha)$ if
\[(G(A_1),G(A_2),\dots,G(A_k)) \sim \text{Dir}(\alpha H(A_1), \alpha H(A_2), \dots, \alpha H(A_k))\]
for every finite measurable partition $A_1, A_2,\dots, A_k$ of the vector space.

Notice that each draw from a DP model $G$ is a discrete probability measure. $G$ is consist of point masses $\theta_i$, we call them atoms. The position of each atom is drawn from $H$ and its weight $\pi_i$ is determined by $\alpha$.

## Stick-breaking Process

In practice, we can use a stick-breaking process to construct a Dirichlet process and proceed simulations. Consider a stick of length $1$. We iteratively sample $\beta_i$ from Beta($1,\alpha$) and then break $\beta_i$ of the remaining stick. Every time the length cut from the stick is the weight $\pi_i$ for the atom $\theta_i$. A stick-breaking process is as follows:

* Draw $\beta_1 \sim \text{Beta}(1,\alpha)$

* Set $\pi_1 = \beta_1$

* Draw $\beta_2 \sim \text{Beta}(1,\alpha)$

* Set $\pi_2 = \beta_2(1-\beta_1)$

...

* Set $\pi_k = \beta_k\prod_{i=1}^{k-1}(1-\beta_i).$

Let's draw a probility meausre from DP$(H,\alpha)$.

```{r}
simulation.DP = function(K,alpha,H){
if (H=="Uniform(0,1)"){
theta_positions = runif(K,0,1)
}else if(H=="Normal(0,1)"){
theta_positions = rnorm(K,0,1)
}else{
theta_positions = rexp(K,1)
}
beta_vec = rbeta(K,1,alpha)
pi = rep(0,K)
#Initialize the first weight pi
pi[1]=beta_vec[1]
for (k in 2:K){
pi[k]=beta_vec[k]*prod(1-beta_vec[1:(k-1)])
}
plot(theta_positions, pi,type='h',main='Generating G from a Stick-breaking Process')
}
```

```{r eruptions, echo=FALSE}
inputPanel(
selectInput("K", label = "Number of breaks:",
choices = c(100, 500, 1000, 3000), selected=100),

selectInput("alpha", label="alpha:",
choices=c(1,5,10,50),selected=10),

selectInput("base_dist", label = "Base distribution H:",
choices = c("Uniform(0,1)", "Normal(0,1)", "Exponential(1)"))
)

renderPlot({
simulation.DP(as.numeric(input$K),as.numeric(input$alpha),toString(input$base_dist))
}, height=600,width=1000)
```

## Chinese Restaurant Process
Another metaphor for Dirichlet Process is Chinese restaurant process. Imagine a Chinese resaturant with infinitely many tables, each with infinite capacity. As each customer come, one may sit at an occupied table with probability proportionate to how many customers are already seated there, or they may sit at a new table with probability proportionate to $\alpha$. Also, at each table, a dish is selected and shared by all customers in that table; this is analogous to draw a $\theta$ from base distribution $H$.

We simulate a Chinese restaurant process as follows:

* The first customer always chooses the first table.

* The $n$th customer chooses a new table with probability $\frac{\alpha}{n-1+\alpha}$, and an occupied table with probability $\frac{c}{n-1+\alpha}$, where c is the number of people sitting at that table.

```{r}
Chinese.restaurant.simulation = function(num_customers, alpha, H){
#assign each customer to tables
table=c()
table[1]=1
new_table_prob = alpha/(num_customers-1+alpha)
for (i in 2:num_customers){
occupied_prob = table/(num_customers-1+alpha)
table_assignment = match(1, rmultinom(1, 1, c(occupied_prob,new_table_prob)))
if(table_assignment>length(table)){
table[table_assignment]=1
}else{
table[table_assignment] = table[table_assignment]+1
}
}
pi=table/num_customers
#draw a dish on each table
if (H=="Uniform(0,1)"){
theta_positions = runif(length(table),0,1)
}else if(H=="Normal(0,1)"){
theta_positions = rnorm(length(table),0,1)
}else{
theta_positions = rexp(length(table),1)
}
plot(theta_positions, pi,type='h',xlab='dish on each table', ylab='proportion of customers on each table', main='Generating G from a Chinese Restaurant Process')
}
```

```{r hithere, echo=FALSE}
inputPanel(
selectInput("num_customers", label = "Number of customers:", choices = c(100, 500, 1000, 3000), selected=100),

selectInput("alpha_1", label="alpha:", choices=c(1,5,10,50,100),selected=100),

selectInput("base_dist_1", label = "Base distribution H:", choices = c('Uniform(0,1)', 'Normal(0,1)', 'Exponential(1)'))
)

renderPlot({
Chinese.restaurant.simulation(as.numeric(input$num_customers),as.numeric(input$alpha_1),toString(input$base_dist_1))
}, height=600,width=1000)
```

## Intuitions on $H$ and $\alpha$

(1) Each draw from a DP model is a discrete distribution. By our stick-breaking process simulation, we observe that the concetration parameter $\alpha$ describes this discreteness. As $\alpha \to 0$, the distributions drawn are concentrated on a single value. As $\alpha \to \infty$, the distributions drawn become continuous and $G \to H$.

(2) The concentration parameter $\alpha$ can also be treated as an inverse variance. $\text{Var}[G(A)] = \frac{H(A)(1-H(A))}{\alpha+1}$. The larger $\alpha$ is, the smaller the variance is, and the DP will concentrate more of its mass around the mean.

(3) The base distribution $H$ is essentially the mean of DP(Dirichlet process). We can think that when a DP draws a random distribution around $H$, it is the same as a normal distribution draws a random number around its mean.

## Citations

https://www.stats.ox.ac.uk/~teh/research/npbayes/Teh2010a.pdf

https://en.wikipedia.org/wiki/Dirichlet_process

https://www.cs.cmu.edu/~epxing/Class/10708-14/scribe_notes/scribe_note_lecture19.pdf