Skip to content

Commit

Permalink
Merge pull request #323 from kostastsa/main
Browse files Browse the repository at this point in the history
Implementation of rao-blackwellised particle filter (rbpf) and rbpf with optimal resampling.
  • Loading branch information
murphyk authored Nov 1, 2023
2 parents 95a09ba + edbfb9c commit a1cb385
Show file tree
Hide file tree
Showing 10 changed files with 1,079 additions and 67 deletions.
84 changes: 19 additions & 65 deletions docs/notebooks/linear_gaussian_ssm/lgssm_learning.ipynb

Large diffs are not rendered by default.

299 changes: 299 additions & 0 deletions docs/notebooks/slds/rbpf_maneuver.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking a maneuvering target using the RBPF"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'dynamax.slds'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb Cell 2\u001b[0m line \u001b[0;36m8\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m sys\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mappend(\u001b[39m'\u001b[39m\u001b[39m/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/dynamax\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mjax\u001b[39;00m \u001b[39mimport\u001b[39;00m vmap, tree_map, jit\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdynamax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mslds\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39minference\u001b[39;00m \u001b[39mimport\u001b[39;00m ParamsSLDS, LGParamsSLDS, DiscreteParamsSLDS, rbpfilter, rbpfilter_optimal\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdynamax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mslds\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodels\u001b[39;00m \u001b[39mimport\u001b[39;00m SLDS\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39m# import MVN from tfd\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'dynamax.slds'"
]
}
],
"source": [
"import dynamax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"from functools import partial\n",
"import sys \n",
"sys.path.append('/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax')\n",
"from jax import vmap, tree_map, jit\n",
"from dynamax.slds.inference import ParamsSLDS, LGParamsSLDS, DiscreteParamsSLDS, rbpfilter, rbpfilter_optimal\n",
"from dynamax.slds.models import SLDS\n",
"# import MVN from tfd\n",
"from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN\n",
"\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from functools import partial\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from jax.scipy.special import logit\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import jax"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simulate Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_states = 3\n",
"num_particles = 1000\n",
"state_dim = 4\n",
"emission_dim = 4\n",
"\n",
"TT = 0.1\n",
"A = jnp.array([[1, TT, 0, 0],\n",
" [0, 1, 0, 0],\n",
" [0, 0, 1, TT],\n",
" [0, 0, 0, 1]])\n",
"\n",
"\n",
"B1 = jnp.array([0, 0, 0, 0])\n",
"B2 = jnp.array([-1.225, -0.35, 1.225, 0.35])\n",
"B3 = jnp.array([1.225, 0.35, -1.225, -0.35])\n",
"B = jnp.stack([B1, B2, B3], axis=0)\n",
"\n",
"Q = 0.2 * jnp.eye(4)\n",
"R = 10.0 * jnp.diag(jnp.array([2, 1, 2, 1]))\n",
"C = jnp.eye(4)\n",
"\n",
"transition_matrix = jnp.array([\n",
" [0.8, 0.1, 0.1],\n",
" [0.1, 0.8, 0.1],\n",
" [0.1, 0.1, 0.8]\n",
"])\n",
"\n",
"discr_params = DiscreteParamsSLDS(\n",
" initial_distribution=jnp.ones(num_states)/num_states,\n",
" transition_matrix=transition_matrix,\n",
" proposal_transition_matrix=transition_matrix\n",
")\n",
"\n",
"lg_params = LGParamsSLDS(\n",
" initial_mean=jnp.ones(state_dim),\n",
" initial_cov=jnp.eye(state_dim),\n",
" dynamics_weights=A,\n",
" dynamics_cov=Q,\n",
" dynamics_bias=jnp.array([B1, B2, B3]),\n",
" dynamics_input_weights=None,\n",
" emission_weights=C,\n",
" emission_cov=R,\n",
" emission_bias=None,\n",
" emission_input_weights=None\n",
")\n",
"\n",
"pre_params = ParamsSLDS(\n",
" discrete=discr_params,\n",
" linear_gaussian=lg_params\n",
")\n",
"\n",
"params = pre_params.initialize(num_states, state_dim, emission_dim)\n",
"\n",
"## Sample states and emissions\n",
"key, next_key = jr.split(jr.PRNGKey(1))\n",
"slds = SLDS(num_states, state_dim, emission_dim)\n",
"dstates, cstates, emissions = slds.sample(params, key, 100)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Filtering"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"out = rbpfilter_optimal(num_particles, params, emissions, next_key)\n",
"filtered_means = out['means']\n",
"weights = out['weights']\n",
"sampled_dstates = out['states']\n",
"post_mean = jnp.einsum(\"ts,tsm->tm\", weights, filtered_means)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f = lambda t: jnp.array([jnp.sum(weights[t, jnp.where(sampled_dstates[t]==st)]) for st in range(num_states)])\n",
"p_est = jnp.array(list(map(f, jnp.arange(100))))\n",
"est_dstates = jnp.argmax(p_est, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def plot_3d_belief_state(mu_hist, dim, ax, skip=3, npoints=2000, azimuth=-30, elevation=30, h=0.5):\n",
" nsteps = len(mu_hist)\n",
" xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max()\n",
" xrange = np.linspace(xmin, xmax, npoints).reshape(-1, 1)\n",
" res = np.apply_along_axis(lambda X: kdeg(xrange, X[..., None], h), 1, mu_hist)\n",
" densities = res[..., dim]\n",
" for t in range(0, nsteps, skip):\n",
" tloc = t * np.ones(npoints)\n",
" px = densities[t]\n",
" ax.plot(tloc, xrange, px, c=\"tab:blue\", linewidth=1)\n",
" ax.set_zlim(0, 1)\n",
" style3d(ax, 1.8, 1.2, 0.7, 0.8)\n",
" ax.view_init(elevation, azimuth)\n",
" ax.set_xlabel(r\"$t$\", fontsize=13)\n",
" ax.set_ylabel(r\"$x_{\"f\"d={dim}\"\",t}$\", fontsize=13)\n",
" ax.set_zlabel(r\"$p(x_{d, t} \\vert y_{1:t})$\", fontsize=13)\n",
"\n",
"def scale_3d(ax, x_scale, y_scale, z_scale, factor): \n",
" scale=np.diag(np.array([x_scale, y_scale, z_scale, 1.0]))\n",
" scale=scale*(1.0/scale.max()) \n",
" scale[3,3]=factor\n",
" def short_proj(): \n",
" return np.dot(Axes3D.get_proj(ax), scale) \n",
" return short_proj \n",
"\n",
"def style3d(ax, x_scale, y_scale, z_scale, factor=0.62):\n",
" plt.gca().patch.set_facecolor('white')\n",
" ax.w_xaxis.set_pane_color((0, 0, 0, 0))\n",
" ax.w_yaxis.set_pane_color((0, 0, 0, 0))\n",
" ax.w_zaxis.set_pane_color((0, 0, 0, 0))\n",
" ax.get_proj = scale_3d(ax, x_scale, y_scale, z_scale, factor)\n",
"\n",
"def kdeg(x, X, h):\n",
" \"\"\"\n",
" KDE under a gaussian kernel\n",
"\n",
" Parameters\n",
" ----------\n",
" x: array(eval, D)\n",
" X: array(obs, D)\n",
" h: float\n",
"\n",
" Returns\n",
" -------\n",
" array(eval):\n",
" KDE around the observed values\n",
" \"\"\"\n",
" N, D = X.shape\n",
" nden, _ = x.shape\n",
"\n",
" Xhat = X.reshape(D, 1, N)\n",
" xhat = x.reshape(D, nden, 1)\n",
" u = xhat - Xhat\n",
" u = np.linalg.norm(u, ord=2, axis=0) ** 2 / (2 * h ** 2)\n",
" px = np.exp(-u).sum(axis=1) / (N * h * np.sqrt(2 * np.pi))\n",
" return px\n",
"\n",
"\n",
" # Plot target dataset\n",
"dict_figures = {}\n",
"color_dict = {0: \"tab:green\", 1: \"tab:red\", 2: \"tab:blue\"}\n",
"fig, ax = plt.subplots()\n",
"color_states_org = [color_dict[int(state)] for state in dstates]\n",
"ax.scatter(*cstates[:, [0, 2]].T, c=\"none\", edgecolors=color_states_org, s=10)\n",
"ax.scatter(*emissions[:, [0, 2]].T, s=5, c=\"black\", alpha=0.6)\n",
"ax.set_title(\"Data\")\n",
"dict_figures[\"rbpf-maneuver-data\"] = fig\n",
"\n",
"# Plot filtered dataset\n",
"fig, ax = plt.subplots()\n",
"rbpf_mse = ((post_mean - cstates)[:, [0, 2]] ** 2).mean(axis=0).sum()\n",
"color_states_est = [color_dict[int(state)] for state in np.array(est_dstates)]\n",
"ax.scatter(*post_mean[:, [0, 2]].T, c=\"none\", edgecolors=color_states_est, s=10)\n",
"ax.set_title(f\"RBPF MSE: {rbpf_mse:.2f}\")\n",
"dict_figures[\"rbpf-maneuver-trace\"] = fig\n",
"\n",
"# Plot belief state of discrete system\n",
"rbpf_error_rate = (dstates != est_dstates).mean()\n",
"fig, ax = plt.subplots(figsize=(2.5, 5))\n",
"sns.heatmap(p_est, cmap=\"viridis\", cbar=False)\n",
"plt.title(f\"RBPF, error rate: {rbpf_error_rate:0.3}\")\n",
"dict_figures[\"rbpf-maneuver-discrete-belief\"] = fig\n",
"\n",
"# Plot ground truth and MAP estimate\n",
"ohe = OneHotEncoder(sparse=False)\n",
"latent_hmap = ohe.fit_transform(dstates[:, None])\n",
"latent_hmap_est = ohe.fit_transform(p_est.argmax(axis=1)[:, None])\n",
"\n",
"fig, ax = plt.subplots(figsize=(2.5, 5))\n",
"sns.heatmap(latent_hmap, cmap=\"viridis\", cbar=False, ax=ax)\n",
"ax.set_title(\"Data\")\n",
"dict_figures[\"rbpf-maneuver-discrete-ground-truth.pdf\"] = fig\n",
"\n",
"fig, ax = plt.subplots(figsize=(2.5, 5))\n",
"sns.heatmap(latent_hmap_est, cmap=\"viridis\", cbar=False, ax=ax)\n",
"ax.set_title(f\"MAP (error rate: {rbpf_error_rate:0.4f})\")\n",
"dict_figures[\"rbpf-maneuver-discrete-map\"] = fig\n",
"\n",
"# Plot belief for state space\n",
"dims = [0, 2]\n",
"for dim in dims:\n",
" fig = plt.figure()\n",
" ax = plt.axes(projection=\"3d\")\n",
" plot_3d_belief_state(filtered_means, dim, ax, h=1.1)\n",
" # pml.savefig(f\"rbpf-maneuver-belief-states-dim{dim}.pdf\", pad_inches=0, bbox_inches=\"tight\")\n",
" dict_figures[f\"rbpf-maneuver-belief-states-dim{dim}.pdf\"] = fig\n",
"\n",
"\n",
"plt.rcParams[\"axes.spines.right\"] = False\n",
"plt.rcParams[\"axes.spines.top\"] = False\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 0 additions & 1 deletion dynamax/linear_gaussian_ssm/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ class TestFilteringAndSmoothing():
# Posteriors from full joint distribution
joint_means, joint_covs = joint_posterior_mvn(params, emissions)


## For sampling tests
lgssm, params = build_lgssm_for_sampling()

Expand Down
1 change: 0 additions & 1 deletion dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def e_step(

return (init_stats, dynamics_stats, emission_stats), posterior.marginal_loglik


def initialize_m_step_state(
self,
params: ParamsLGSSM,
Expand Down
2 changes: 2 additions & 0 deletions dynamax/slds/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from dynamax.slds.inference import DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal
from dynamax.slds.models import SLDS
Loading

0 comments on commit a1cb385

Please sign in to comment.