diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86ed131..8234bc7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,13 +25,12 @@ repos: hooks: - id: ruff types_or: [ python, pyi, jupyter ] - args: [ --fix, --force-exclude ] - exclude: 'spring2025-course/localization-tutorial\.py' + args: [ --fix ] + exclude: 'spring2025-course/localization-tutorial\.*' - id: ruff-format types_or: [ python, pyi, jupyter ] - args: [ --force-exclude ] - exclude: 'spring2025-course/localization-tutorial\.py' + exclude: 'spring2025-course/localization-tutorial\.*' - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.369 diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb index a9dfb1f..6faf058 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb @@ -60,7 +60,6 @@ "import jax\n", "import jax.numpy as jnp\n", "import genjax\n", - "from urllib.request import urlopen\n", "from genjax import SelectionBuilder as S\n", "from genjax import ChoiceMapBuilder as C\n", "from genjax.typing import FloatArray, PRNGKey\n", diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index 88d0347..8876d63 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -328,10 +328,10 @@ def pose_widget(label, initial_pose, **opts): key = jax.random.key(0) def random_pose(k): - p_hd = jax.random.uniform(k, shape=(3,), + p_array = jax.random.uniform(k, shape=(3,), minval=world["bounding_box"][:, 0], maxval=world["bounding_box"][:, 1]) - return Pose(p_hd[0:2], p_hd[2]) + return Pose(p_array[0:2], p_array[2]) some_poses = jax.vmap(random_pose)(jax.random.split(key, 20)) @@ -401,7 +401,7 @@ def make_sensor_angles(sensor_settings): sensor_angles = make_sensor_angles(sensor_settings) -def ideal_sensor(sensor_angles, pose): +def ideal_sensor(pose): return jax.vmap( lambda angle: sensor_distance(pose.rotate(angle), world["walls"], sensor_settings["box_size"]) )(sensor_angles) @@ -437,7 +437,7 @@ def pose_at(state, label): def update_ideal_sensors(widget, label): widget.state.update({ - (label + "_readings"): ideal_sensor(sensor_angles, pose_at(widget.state, label)) + (label + "_readings"): ideal_sensor(pose_at(widget.state, label)) }) def on_pose_change(widget, _): @@ -451,7 +451,7 @@ def on_pose_change(widget, _): ) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) | Plot.initialState({ - "pose_readings": ideal_sensor(sensor_angles, some_pose) + "pose_readings": ideal_sensor(some_pose) }) | Plot.onChange({"pose": on_pose_change}) ) @@ -460,7 +460,7 @@ def on_pose_change(widget, _): # Some pictures in case of limited interactivity: # %% -some_readings = jax.vmap(ideal_sensor, in_axes=(None, 0))(sensor_angles, some_poses) +some_readings = jax.vmap(ideal_sensor)(some_poses) Plot.Frames([ ( @@ -724,7 +724,7 @@ def on_target_pose_chage(widget, _): | Plot.initialState( { "k": jax.random.key_data(k1), - "guess_readings": ideal_sensor(sensor_angles, guess_pose), + "guess_readings": ideal_sensor(guess_pose), "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose, sensor_settings["s_noise"])), "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose, sensor_settings["s_noise"]), "show_target_pose": False, @@ -737,10 +737,137 @@ def on_target_pose_chage(widget, _): ) +# %% [markdown] +# ### Likelihood, prior, and posterior +# +# A common way to proceed is to consider the likelihood function—the probability density of the fixed observations, with varying pose parameter—as the basis of our reasoning. Reasoning about the likelihood can be recovered as a special case of reasoning about the *posterior distribution* over the pose parameter, where the *prior* distribution was uniform with respect to some reference volume over poses. +# +# We introduce here three choices of prior, to illustrate how inference depends on it. + +# %% +# Uniform prior over the whole map. +# (This is just a recapitulation of `random_pose` from above.) + +@genjax.gen +def uniform_pose(mins, maxes): + p_array = genjax.uniform(mins, maxes) @ "p_array" + return Pose(p_array[0:2], p_array[2]) + +whole_map_prior = uniform_pose.partial_apply( + world["bounding_box"][:, 0], + world["bounding_box"][:, 1] +) + +def whole_map_cm_builder(pose): + return C["p_array"].set(pose.as_array()) + + +# %% +key, sub_key = jax.random.split(key) +some_poses = jax.vmap(lambda k: whole_map_prior.simulate(k, ()))(jax.random.split(sub_key, 100)).get_retval() + +( + world_plot + + pose_plots(some_poses, color="green") + + {"title": "Some poses"} +) + + +# %% +# Even mixture of uniform priors over two rooms. +@genjax.gen +def two_room_prior(): + flip = genjax.flip(0.5) @ "flip" + mins, maxes = jnp.where( + flip, + jnp.array([[12.83, 11.19, -jnp.pi], [15.81, 15.26, +jnp.pi]]), + jnp.array([[15.73, 5.79, -jnp.pi], [18.9, 9.57, +jnp.pi]]) + ) + return uniform_pose(mins, maxes) @ "uniform" + +# The two rooms are disjoint so the flip is deterministic. +# The posterior density is incorrect by a *constant* factor of 1/2. +def two_room_cm_builder(pose): + return C["flip"].set(jnp.array(pose.p[1] > 10.0)) | C["uniform", "p_array"].set(pose.as_array()) + + +# %% +key, sub_key = jax.random.split(key) +some_poses = jax.vmap(lambda k: two_room_prior.simulate(k, ()))(jax.random.split(sub_key, 100)).get_retval() + +( + world_plot + + pose_plots(some_poses, color="green") + + {"title": "Some poses"} +) + +# %% +# Prior localized around a single pose +pose_for_localized_prior = Pose(jnp.array([2.0, 16]), jnp.array(0.0)) +spread_of_localized_prior = (0.1, 0.75) +@genjax.gen +def localized_prior(): + p = ( + genjax.mv_normal_diag( + pose_for_localized_prior.p, + spread_of_localized_prior[0] * jnp.ones(2) + ) + @ "p" + ) + hd = ( + genjax.normal( + pose_for_localized_prior.hd, + spread_of_localized_prior[1] + ) + @ "hd" + ) + return Pose(p, hd) + +def localized_cm_builder(pose): + return C["p"].set(pose.p) | C["hd"].set(pose.hd) + + +# %% +key, sub_key = jax.random.split(key) +some_poses = jax.vmap(lambda k: localized_prior.simulate(k, ()))(jax.random.split(sub_key, 100)).get_retval() + +( + world_plot + + pose_plots(some_poses, color="green") + + {"title": "Some poses"} +) + +# %% [markdown] +# We also introduce joint models, whose densities serve as unnormalized posterior densities. + +# %% +model_dispatch = { + "whole_map": (whole_map_prior, whole_map_cm_builder), + "two_room": (two_room_prior, two_room_cm_builder), + "localized": (localized_prior, localized_cm_builder), +} + +def make_posterior_density_fn(prior_label, readings, model_noise): + prior, cm_builder = model_dispatch[prior_label] + @genjax.gen + def joint_model(): + pose = prior() @ "pose" + sensor = sensor_model(pose, sensor_angles, model_noise) @ "sensor" + return jax.jit( + lambda pose: + joint_model.assess( + C["pose"].set(cm_builder(pose)) | C["sensor", "distance"].set(readings), + () + )[0] + ) + + # %% [markdown] # ### Visualization setup # -# The next few widgets all operate on the same principle. A manipulable "camera" pose is shown. Hitting the button at bottom fixes the "target" pose to the camera pose and samples a batch of sensor readings at the target. It then performs some computation (optimization or inference) on those sensor readings and displays everything: the target, its sensor readings, and the computation results. +# The next few widgets all operate on the same principle. A manipulable "camera" pose is shown. Hitting the button makes the following happen: +# * The "target" pose gets fixed to the camera pose and a batch of sensor readings is sampled at the target. (The first slider controls the noise in the taking of these readings.) +# * Then some computation (optimization or inference) is performed on these sensor readings and everything gets displayed: the target, its sensor readings, and the computation results. (The second slider controls the sensor noise assumed by the model in this computation.) # %% def on_camera_button(button_handler): @@ -775,18 +902,29 @@ def camera_widget( + pose_widget("camera", camera_pose, color="blue") ) | noise_slider("world_noise", "World/data noise = ", sensor_settings["s_noise"]) + | Plot.html([ + "p", + "Prior:", + [ + "select", + {"onChange": js("(e) => $state.prior = e.target.value")}, + ["option", {"value": "whole_map", "selected": "True"}, "whole map"], + ["option", {"value": "two_room"}, "two room"], + ["option", {"value": "localized"}, "localized"], + ] + ]) | noise_slider("model_noise", "Model/inference noise = ", sensor_settings["s_noise"]) | ( Plot.html([ "div", {"class": "flex flex-col gap-4"}, [ - "button", # Changed from 'input' to 'button' + "button", { "class": "w-24 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 active:bg-blue-700", "onClick": on_camera_button(button_handler) }, - button_label # Text as child element, not label + button_label ] ]) & Plot.html([ @@ -806,17 +944,18 @@ def camera_widget( "target_exists": False, "target": {"p": None, "hd": None}, "target_readings": [], + "prior": "whole_map" } | initial_state, - sync=({"k", "target", "camera_readings"} | sync)) + sync=({"k", "target", "camera_readings", "prior"} | sync)) ) # %% [markdown] # ### Doing optimization # -# A common response is to optimize the likelihood. The field of optimization is vast, but it is not the path for us here, so we just take a quick look. +# A common response is to optimize the posterior density. The field of optimization is vast, but it is not the path for us here, so we just take a quick look. # -# The idea here is just to search for the good poses by brute force, ranging over a suitable discretization grid of the map. So in this widget, the button fires a sweep over a grid of possible poses, computing the likelihood of each. The top `N_keep` are kept, and shown with opacity proportional to their position in that sublist. Moreover, the best fit is shown in purple. +# The idea here is just to search for the good poses by brute force, ranging over a suitable discretization grid of the map. So in this widget, the button fires a sweep over a grid of possible poses, computing the posterior density of each. The top `N_keep` are kept, and shown with opacity proportional to their position in that sublist. Moreover, the best fit is shown in purple. # %% def make_grid(bounds, ns): @@ -838,14 +977,13 @@ def make_poses_grid(bounds, ns): camera_pose = Pose(jnp.array([2.0, 16]), jnp.array(0.0)) -grid_poses = make_poses_grid(world["bounding_box"], N_grid) def grid_search_handler(widget, k, readings): model_noise = float(getattr(widget.state, "model_noise")) - jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) - ) - likelihoods = jax.vmap(jitted_likelihood)(grid_poses) - best = jnp.argsort(likelihoods, descending=True)[0:N_keep] + jitted_posterior = make_posterior_density_fn(widget.state.prior, readings, model_noise) + + grid_poses = make_poses_grid(world["bounding_box"], N_grid) + posterior_densities = jax.vmap(jitted_posterior)(grid_poses) + best = jnp.argsort(posterior_densities, descending=True)[0:N_keep] widget.state.update({ "grid_poses": grid_poses[best].as_dict(), "best": grid_poses[best[0]].as_dict() @@ -878,14 +1016,14 @@ def grid_search_handler(widget, k, readings): # Note how the purple pose is a long way from summarizing the ambiguity in the solution set. # %% [markdown] -# ### Doing inference +# ### Doing probabilistic inference # # We show some ways one might approach this problem probabilistically: our answers to "Where might the robot be?" will all be given in the form of probability distributions. Moreover, these distributions will be embodied as generative samplers. # %% [markdown] # #### Grid approximation sampler # -# Although computing the grid takes work, afterwards accessing its members is cheap. Instead of only taking the best fit, we can draw members from the grid with probability in proportion to their likelihood. The result is the following sampler. +# Although computing the grid takes work, afterwards accessing its members is cheap. Instead of only taking the best fit, we can draw members from the grid with probability in proportion to their posterior density. The result is the following sampler. # %% N_grid = jnp.array([50, 50, 20]) @@ -895,17 +1033,14 @@ def grid_search_handler(widget, k, readings): camera_pose = Pose(jnp.array([15.13, 14.16]), jnp.array(1.5)) -grid_poses = make_poses_grid(world["bounding_box"], N_grid) - def grid_approximation_handler(widget, k, readings): model_noise = float(getattr(widget.state, "model_noise")) - jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) - ) - likelihoods = jax.vmap(jitted_likelihood)(grid_poses) + jitted_posterior = make_posterior_density_fn(widget.state.prior, readings, model_noise) + grid_poses = make_poses_grid(world["bounding_box"], N_grid) + posterior_densities = jax.vmap(jitted_posterior)(grid_poses) def grid_sample_one(k): - return grid_poses[genjax.categorical.sample(k, likelihoods)] + return grid_poses[genjax.categorical.sample(k, posterior_densities)] grid_samples = jax.vmap(grid_sample_one)(jax.random.split(k, N_samples)) widget.state.update({ @@ -929,7 +1064,7 @@ def grid_sample_one(k): # # What if we need not be systematic—and instead we just try a bunch of points, uniformly over all poses, instead of constrained to a grid? # -# Here we first draw `N` pre-samples, assess them, and pick a single representative one in probability proportional to its likelihood, to obtain one sample. The samples obtained this way are then more closely distributed to the posterior. +# Here we first draw `N` pre-samples, assess them, and pick a single representative one in probability proportional to its posterior density, to obtain one sample. The samples obtained this way are then more closely distributed to the posterior. # %% N_presamples = 1000 @@ -941,15 +1076,13 @@ def grid_sample_one(k): def importance_resampling_handler(widget, k, readings): model_noise = float(getattr(widget.state, "model_noise")) - jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) - ) + jitted_posterior = make_posterior_density_fn(widget.state.prior, readings, model_noise) def importance_resample_one(k): k1, k2 = jax.random.split(k) presamples = jax.vmap(random_pose)(jax.random.split(k1, N_presamples)) - likelihoods = jax.vmap(jitted_likelihood)(presamples) - return presamples[genjax.categorical.sample(k2, likelihoods)] + posterior_densities = jax.vmap(jitted_posterior)(presamples) + return presamples[genjax.categorical.sample(k2, posterior_densities)] grid_samples = jax.vmap(importance_resample_one)(jax.random.split(k, N_samples)) widget.state.update({ @@ -980,12 +1113,10 @@ def importance_resample_one(k): def MCMC_handler(widget, k, readings): model_noise = float(getattr(widget.state, "model_noise")) - jitted_likelihood = jax.jit( - lambda pose: likelihood_function(C["distance"].set(readings), pose, model_noise) - ) + jitted_posterior = make_posterior_density_fn(widget.state.prior, readings, model_noise) - def do_MH_step(pose_likelihood, k): - pose, likelihood = pose_likelihood + def do_MH_step(pose_posterior_density, k): + pose, posterior_density = pose_posterior_density k1, k2 = jax.random.split(k) p_hd = pose.as_array() delta = jnp.array([0.5, 0.5, 0.1]) @@ -993,21 +1124,21 @@ def do_MH_step(pose_likelihood, k): maxs = jnp.minimum(p_hd + delta, world["bounding_box"][:, 1]) new_p_hd = jax.random.uniform(k1, shape=(3,), minval=mins, maxval=maxs) new_pose = Pose(new_p_hd[0:2], new_p_hd[2]) - new_likelihood = jitted_likelihood(new_pose) - accept = (jnp.log(genjax.uniform.sample(k2)) <= new_likelihood - likelihood) + new_posterior = jitted_posterior(new_pose) + accept = (jnp.log(genjax.uniform.sample(k2)) <= new_posterior - posterior_density) return ( jax.tree.map( lambda x, y: jnp.where(accept, x, y), - (new_pose, new_likelihood), - (pose, likelihood) + (new_pose, posterior), + (pose, posterior_density) ), None ) def sample_MH_one(k): k1, k2 = jax.random.split(k) start_pose = random_pose(k1) - start_likelihood = jitted_likelihood(start_pose) - return jax.lax.scan(do_MH_step, (start_pose, start_likelihood), jax.random.split(k2, N_MH_steps))[0][0] + start_posterior = jitted_posterior(start_pose) + return jax.lax.scan(do_MH_step, (start_pose, start_posterior), jax.random.split(k2, N_MH_steps))[0][0] grid_samples = jax.vmap(sample_MH_one)(jax.random.split(k, N_samples)) widget.state.update({ @@ -1735,10 +1866,10 @@ def constraint_from_path(path): constraints_path_integrated = constraint_from_path(path_integrated) constraints_path_integrated_observations_low_deviation = ( - constraints_path_integrated ^ constraints_low_deviation + constraints_path_integrated | constraints_low_deviation ) constraints_path_integrated_observations_high_deviation = ( - constraints_path_integrated ^ constraints_high_deviation + constraints_path_integrated | constraints_high_deviation ) key, sub_key = jax.random.split(key) @@ -1993,8 +2124,8 @@ def path_to_polyline(path, **options): # readings at once. The results were better than guesses, but not accurate, in the # high deviation setting. # -# The technique we will use here discards steps with low likelihood at each step, and -# reinforces steps with high likelihood, allowing better particles to proportionately +# The technique we will use here discards steps with low posterior density at each step, and +# reinforces steps with high posterior density, allowing better particles to proportionately # search more of the probability space while discarding unpromising particles. # # The following class attempts to generatlize this idea: