-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #323 from kostastsa/main
Implementation of rao-blackwellised particle filter (rbpf) and rbpf with optimal resampling.
- Loading branch information
Showing
10 changed files
with
1,079 additions
and
67 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,299 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Tracking a maneuvering target using the RBPF" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "ModuleNotFoundError", | ||
"evalue": "No module named 'dynamax.slds'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", | ||
"\u001b[1;32m/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb Cell 2\u001b[0m line \u001b[0;36m8\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m sys\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mappend(\u001b[39m'\u001b[39m\u001b[39m/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/dynamax\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mjax\u001b[39;00m \u001b[39mimport\u001b[39;00m vmap, tree_map, jit\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdynamax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mslds\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39minference\u001b[39;00m \u001b[39mimport\u001b[39;00m ParamsSLDS, LGParamsSLDS, DiscreteParamsSLDS, rbpfilter, rbpfilter_optimal\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdynamax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mslds\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodels\u001b[39;00m \u001b[39mimport\u001b[39;00m SLDS\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39m# import MVN from tfd\u001b[39;00m\n", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'dynamax.slds'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import dynamax\n", | ||
"import jax.numpy as jnp\n", | ||
"import jax.random as jr\n", | ||
"from functools import partial\n", | ||
"import sys \n", | ||
"sys.path.append('/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax')\n", | ||
"from jax import vmap, tree_map, jit\n", | ||
"from dynamax.slds.inference import ParamsSLDS, LGParamsSLDS, DiscreteParamsSLDS, rbpfilter, rbpfilter_optimal\n", | ||
"from dynamax.slds.models import SLDS\n", | ||
"# import MVN from tfd\n", | ||
"from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN\n", | ||
"\n", | ||
"import seaborn as sns\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from functools import partial\n", | ||
"from sklearn.preprocessing import OneHotEncoder\n", | ||
"from jax.scipy.special import logit\n", | ||
"from mpl_toolkits.mplot3d import Axes3D\n", | ||
"import jax" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Simulate Data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_states = 3\n", | ||
"num_particles = 1000\n", | ||
"state_dim = 4\n", | ||
"emission_dim = 4\n", | ||
"\n", | ||
"TT = 0.1\n", | ||
"A = jnp.array([[1, TT, 0, 0],\n", | ||
" [0, 1, 0, 0],\n", | ||
" [0, 0, 1, TT],\n", | ||
" [0, 0, 0, 1]])\n", | ||
"\n", | ||
"\n", | ||
"B1 = jnp.array([0, 0, 0, 0])\n", | ||
"B2 = jnp.array([-1.225, -0.35, 1.225, 0.35])\n", | ||
"B3 = jnp.array([1.225, 0.35, -1.225, -0.35])\n", | ||
"B = jnp.stack([B1, B2, B3], axis=0)\n", | ||
"\n", | ||
"Q = 0.2 * jnp.eye(4)\n", | ||
"R = 10.0 * jnp.diag(jnp.array([2, 1, 2, 1]))\n", | ||
"C = jnp.eye(4)\n", | ||
"\n", | ||
"transition_matrix = jnp.array([\n", | ||
" [0.8, 0.1, 0.1],\n", | ||
" [0.1, 0.8, 0.1],\n", | ||
" [0.1, 0.1, 0.8]\n", | ||
"])\n", | ||
"\n", | ||
"discr_params = DiscreteParamsSLDS(\n", | ||
" initial_distribution=jnp.ones(num_states)/num_states,\n", | ||
" transition_matrix=transition_matrix,\n", | ||
" proposal_transition_matrix=transition_matrix\n", | ||
")\n", | ||
"\n", | ||
"lg_params = LGParamsSLDS(\n", | ||
" initial_mean=jnp.ones(state_dim),\n", | ||
" initial_cov=jnp.eye(state_dim),\n", | ||
" dynamics_weights=A,\n", | ||
" dynamics_cov=Q,\n", | ||
" dynamics_bias=jnp.array([B1, B2, B3]),\n", | ||
" dynamics_input_weights=None,\n", | ||
" emission_weights=C,\n", | ||
" emission_cov=R,\n", | ||
" emission_bias=None,\n", | ||
" emission_input_weights=None\n", | ||
")\n", | ||
"\n", | ||
"pre_params = ParamsSLDS(\n", | ||
" discrete=discr_params,\n", | ||
" linear_gaussian=lg_params\n", | ||
")\n", | ||
"\n", | ||
"params = pre_params.initialize(num_states, state_dim, emission_dim)\n", | ||
"\n", | ||
"## Sample states and emissions\n", | ||
"key, next_key = jr.split(jr.PRNGKey(1))\n", | ||
"slds = SLDS(num_states, state_dim, emission_dim)\n", | ||
"dstates, cstates, emissions = slds.sample(params, key, 100)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Filtering" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"out = rbpfilter_optimal(num_particles, params, emissions, next_key)\n", | ||
"filtered_means = out['means']\n", | ||
"weights = out['weights']\n", | ||
"sampled_dstates = out['states']\n", | ||
"post_mean = jnp.einsum(\"ts,tsm->tm\", weights, filtered_means)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"f = lambda t: jnp.array([jnp.sum(weights[t, jnp.where(sampled_dstates[t]==st)]) for st in range(num_states)])\n", | ||
"p_est = jnp.array(list(map(f, jnp.arange(100))))\n", | ||
"est_dstates = jnp.argmax(p_est, axis=1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"def plot_3d_belief_state(mu_hist, dim, ax, skip=3, npoints=2000, azimuth=-30, elevation=30, h=0.5):\n", | ||
" nsteps = len(mu_hist)\n", | ||
" xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max()\n", | ||
" xrange = np.linspace(xmin, xmax, npoints).reshape(-1, 1)\n", | ||
" res = np.apply_along_axis(lambda X: kdeg(xrange, X[..., None], h), 1, mu_hist)\n", | ||
" densities = res[..., dim]\n", | ||
" for t in range(0, nsteps, skip):\n", | ||
" tloc = t * np.ones(npoints)\n", | ||
" px = densities[t]\n", | ||
" ax.plot(tloc, xrange, px, c=\"tab:blue\", linewidth=1)\n", | ||
" ax.set_zlim(0, 1)\n", | ||
" style3d(ax, 1.8, 1.2, 0.7, 0.8)\n", | ||
" ax.view_init(elevation, azimuth)\n", | ||
" ax.set_xlabel(r\"$t$\", fontsize=13)\n", | ||
" ax.set_ylabel(r\"$x_{\"f\"d={dim}\"\",t}$\", fontsize=13)\n", | ||
" ax.set_zlabel(r\"$p(x_{d, t} \\vert y_{1:t})$\", fontsize=13)\n", | ||
"\n", | ||
"def scale_3d(ax, x_scale, y_scale, z_scale, factor): \n", | ||
" scale=np.diag(np.array([x_scale, y_scale, z_scale, 1.0]))\n", | ||
" scale=scale*(1.0/scale.max()) \n", | ||
" scale[3,3]=factor\n", | ||
" def short_proj(): \n", | ||
" return np.dot(Axes3D.get_proj(ax), scale) \n", | ||
" return short_proj \n", | ||
"\n", | ||
"def style3d(ax, x_scale, y_scale, z_scale, factor=0.62):\n", | ||
" plt.gca().patch.set_facecolor('white')\n", | ||
" ax.w_xaxis.set_pane_color((0, 0, 0, 0))\n", | ||
" ax.w_yaxis.set_pane_color((0, 0, 0, 0))\n", | ||
" ax.w_zaxis.set_pane_color((0, 0, 0, 0))\n", | ||
" ax.get_proj = scale_3d(ax, x_scale, y_scale, z_scale, factor)\n", | ||
"\n", | ||
"def kdeg(x, X, h):\n", | ||
" \"\"\"\n", | ||
" KDE under a gaussian kernel\n", | ||
"\n", | ||
" Parameters\n", | ||
" ----------\n", | ||
" x: array(eval, D)\n", | ||
" X: array(obs, D)\n", | ||
" h: float\n", | ||
"\n", | ||
" Returns\n", | ||
" -------\n", | ||
" array(eval):\n", | ||
" KDE around the observed values\n", | ||
" \"\"\"\n", | ||
" N, D = X.shape\n", | ||
" nden, _ = x.shape\n", | ||
"\n", | ||
" Xhat = X.reshape(D, 1, N)\n", | ||
" xhat = x.reshape(D, nden, 1)\n", | ||
" u = xhat - Xhat\n", | ||
" u = np.linalg.norm(u, ord=2, axis=0) ** 2 / (2 * h ** 2)\n", | ||
" px = np.exp(-u).sum(axis=1) / (N * h * np.sqrt(2 * np.pi))\n", | ||
" return px\n", | ||
"\n", | ||
"\n", | ||
" # Plot target dataset\n", | ||
"dict_figures = {}\n", | ||
"color_dict = {0: \"tab:green\", 1: \"tab:red\", 2: \"tab:blue\"}\n", | ||
"fig, ax = plt.subplots()\n", | ||
"color_states_org = [color_dict[int(state)] for state in dstates]\n", | ||
"ax.scatter(*cstates[:, [0, 2]].T, c=\"none\", edgecolors=color_states_org, s=10)\n", | ||
"ax.scatter(*emissions[:, [0, 2]].T, s=5, c=\"black\", alpha=0.6)\n", | ||
"ax.set_title(\"Data\")\n", | ||
"dict_figures[\"rbpf-maneuver-data\"] = fig\n", | ||
"\n", | ||
"# Plot filtered dataset\n", | ||
"fig, ax = plt.subplots()\n", | ||
"rbpf_mse = ((post_mean - cstates)[:, [0, 2]] ** 2).mean(axis=0).sum()\n", | ||
"color_states_est = [color_dict[int(state)] for state in np.array(est_dstates)]\n", | ||
"ax.scatter(*post_mean[:, [0, 2]].T, c=\"none\", edgecolors=color_states_est, s=10)\n", | ||
"ax.set_title(f\"RBPF MSE: {rbpf_mse:.2f}\")\n", | ||
"dict_figures[\"rbpf-maneuver-trace\"] = fig\n", | ||
"\n", | ||
"# Plot belief state of discrete system\n", | ||
"rbpf_error_rate = (dstates != est_dstates).mean()\n", | ||
"fig, ax = plt.subplots(figsize=(2.5, 5))\n", | ||
"sns.heatmap(p_est, cmap=\"viridis\", cbar=False)\n", | ||
"plt.title(f\"RBPF, error rate: {rbpf_error_rate:0.3}\")\n", | ||
"dict_figures[\"rbpf-maneuver-discrete-belief\"] = fig\n", | ||
"\n", | ||
"# Plot ground truth and MAP estimate\n", | ||
"ohe = OneHotEncoder(sparse=False)\n", | ||
"latent_hmap = ohe.fit_transform(dstates[:, None])\n", | ||
"latent_hmap_est = ohe.fit_transform(p_est.argmax(axis=1)[:, None])\n", | ||
"\n", | ||
"fig, ax = plt.subplots(figsize=(2.5, 5))\n", | ||
"sns.heatmap(latent_hmap, cmap=\"viridis\", cbar=False, ax=ax)\n", | ||
"ax.set_title(\"Data\")\n", | ||
"dict_figures[\"rbpf-maneuver-discrete-ground-truth.pdf\"] = fig\n", | ||
"\n", | ||
"fig, ax = plt.subplots(figsize=(2.5, 5))\n", | ||
"sns.heatmap(latent_hmap_est, cmap=\"viridis\", cbar=False, ax=ax)\n", | ||
"ax.set_title(f\"MAP (error rate: {rbpf_error_rate:0.4f})\")\n", | ||
"dict_figures[\"rbpf-maneuver-discrete-map\"] = fig\n", | ||
"\n", | ||
"# Plot belief for state space\n", | ||
"dims = [0, 2]\n", | ||
"for dim in dims:\n", | ||
" fig = plt.figure()\n", | ||
" ax = plt.axes(projection=\"3d\")\n", | ||
" plot_3d_belief_state(filtered_means, dim, ax, h=1.1)\n", | ||
" # pml.savefig(f\"rbpf-maneuver-belief-states-dim{dim}.pdf\", pad_inches=0, bbox_inches=\"tight\")\n", | ||
" dict_figures[f\"rbpf-maneuver-belief-states-dim{dim}.pdf\"] = fig\n", | ||
"\n", | ||
"\n", | ||
"plt.rcParams[\"axes.spines.right\"] = False\n", | ||
"plt.rcParams[\"axes.spines.top\"] = False\n", | ||
"plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.3" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from dynamax.slds.inference import DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal | ||
from dynamax.slds.models import SLDS |
Oops, something went wrong.