diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 323f512..86ed131 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,10 +25,13 @@ repos: hooks: - id: ruff types_or: [ python, pyi, jupyter ] - args: [ --fix ] + args: [ --fix, --force-exclude ] + exclude: 'spring2025-course/localization-tutorial\.py' - id: ruff-format types_or: [ python, pyi, jupyter ] + args: [ --force-exclude ] + exclude: 'spring2025-course/localization-tutorial\.py' - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.369 diff --git a/pixi.lock b/pixi.lock index fb6c01d..610d259 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2038,7 +2038,7 @@ packages: - pypi: . name: localization-tutorial version: 0.1.0 - sha256: b4dd1acecea6f4a37ced5a989a1249649e6a3038686747caca30223a6a056f51 + sha256: 33e938ca55162ae9ce3a35f7ce2bb1cf5f91aabf0dad8d020256eb42b1fdcb28 requires_dist: - genstudio>=2025.2.2,<2026 - genjax==0.9.1 @@ -2354,7 +2354,7 @@ packages: - python >=3.9 license: ISC purls: - - pkg:pypi/pexpect?source=hash-mapping + - pkg:pypi/pexpect?source=compressed-mapping size: 53561 timestamp: 1733302019362 - conda: https://conda.anaconda.org/conda-forge/noarch/pickleshare-0.7.5-pyhd8ed1ab_1004.conda diff --git a/pyproject.toml b/pyproject.toml index 788775b..6bf8410 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,3 +48,6 @@ cudnn = ">=9.7.1.26,<10" [tool.pixi.feature.cuda.target.linux-64.pypi-dependencies] jax = { version = ">=0.4.35", extras = ["cuda12-local"] } + +[tool.ruff] +exclude = ["spring2025-course/localization-tutorial.py"] diff --git a/spring2025-course/localization-tutorial.py b/spring2025-course/localization-tutorial.py index 6274d0d..8f01c8a 100644 --- a/spring2025-course/localization-tutorial.py +++ b/spring2025-course/localization-tutorial.py @@ -15,7 +15,7 @@ # name: python3 # --- - +# %% # pyright: reportUnusedExpression=false # %% # import sys @@ -41,11 +41,12 @@ import jax import jax.numpy as jnp import genjax +import os from genjax import SelectionBuilder as S from genjax import ChoiceMapBuilder as C from genjax.typing import Array, FloatArray, PRNGKey, IntArray from penzai import pz -from typing import TypeVar, Generic, Callable +from typing import Any, Iterable, TypeVar, Generic, Callable from genstudio.plot import js html = Plot.Hiccup @@ -62,7 +63,6 @@ # %% # General code here - def create_segments(points): """ Given an array of points of shape (N, 2), return an array of @@ -90,7 +90,9 @@ def make_world(wall_verts, clutters_vec): clutters = jax.vmap(create_segments)(clutters_vec) # Combine all points for bounding box calculation - all_points = jnp.vstack((jnp.array(wall_verts), jnp.concatenate(clutters_vec))) + all_points = jnp.vstack( + (jnp.array(wall_verts), jnp.concatenate(clutters_vec)) + ) x_min, y_min = jnp.min(all_points, axis=0) x_max, y_max = jnp.max(all_points, axis=0) @@ -100,14 +102,13 @@ def make_world(wall_verts, clutters_vec): center_point = jnp.array([(x_min + x_max) / 2.0, (y_min + y_max) / 2.0]) return { - "walls": walls, - "wall_verts": wall_verts, - "clutters": clutters, - "bounding_box": bounding_box, - "box_size": box_size, - "center_point": center_point, - } - + "walls": walls, + "wall_verts": wall_verts, + "clutters": clutters, + "bounding_box": bounding_box, + "box_size": box_size, + "center_point": center_point, + } def load_file(file_name): # load from cwd or its parent @@ -119,7 +120,6 @@ def load_file(file_name): with open(f"../{file_name}") as f: return json.load(f) - def load_world(file_name): """ Loads the world configuration from a specified file and constructs the world. @@ -142,7 +142,7 @@ def load_world(file_name): # %% # Specific example code here -world = load_world("world.json") +world = load_world("world.json"); # %% [markdown] # ### Plotting @@ -169,13 +169,20 @@ def load_world(file_name): Plot.color_map({"clutters": "magenta"}), ) -(world_plot + clutters_plot + {"title": "Given data"}) +( + world_plot + + clutters_plot + + {"title": "Given data"} +) # %% [markdown] # Following this initial display of the given data, we *suppress the clutters* until much later in the notebook. # %% -(world_plot + {"title": "Given data"}) +( + world_plot + + {"title": "Given data"} +) # %% [markdown] @@ -185,7 +192,6 @@ def load_world(file_name): # # These will be visualized using arrows whose tip is at the position, and whose direction indicates the heading. - # %% @pz.pytree_dataclass class Pose(genjax.PythonicPytree): @@ -232,20 +238,17 @@ def rotate(self, a: float) -> "Pose": """ return Pose(self.p, self.hd + a) - # %% def pose_wings(pose, opts={}): - return Plot.line( - js( - """ + return Plot.line(js(f""" const pose = %1; let positions = pose.p; let angles = pose.hd; - if (typeof angles === 'number') { + if (typeof angles === 'number') {{ positions = [positions]; angles = [angles]; - } - return Array.from(positions).flatMap((p, i) => { + }} + return Array.from(positions).flatMap((p, i) => {{ const angle = angles[i] const wingAngle = Math.PI / 12 const wingLength = 0.6 @@ -261,21 +264,13 @@ def pose_wings(pose, opts={}): i ] return [wing1, center, wing2] - }) - """, - pose, - expression=False, - ), - z="2", - **opts, - ) - + }}) + """, pose, expression=False), + z="2", + **opts) def pose_body(pose, opts={}): - return Plot.dot( - js("typeof %1.hd === 'number' ? [%1.p] : %1.p", pose), {"r": 4} | opts - ) - + return Plot.dot(js(f"typeof %1.hd === 'number' ? [%1.p] : %1.p", pose), {"r": 4} | opts) def pose_plots(poses, wing_opts={}, body_opts={}, **opts): """ @@ -294,16 +289,16 @@ def pose_plots(poses, wing_opts={}, body_opts={}, **opts): if "color" in opts: wing_opts = wing_opts | {"stroke": opts["color"]} body_opts = body_opts | {"fill": opts["color"]} - return pose_wings(poses, opts | wing_opts) + pose_body(poses, opts | body_opts) + return ( + pose_wings(poses, opts | wing_opts) + pose_body(poses, opts | body_opts) + ) def pose_widget(label, initial_pose, **opts): - return pose_plots( - js(f"$state.{label}"), - render=Plot.renderChildEvents( - { - "onDrag": js( - f""" + return ( + pose_plots(js(f"$state.{label}"), + render=Plot.renderChildEvents({"onDrag": js( + f""" (e) => {{ if (e.shiftKey) {{ const dx = e.x - $state.{label}.p[0]; @@ -314,28 +309,17 @@ def pose_widget(label, initial_pose, **opts): $state.update({{{label}: {{hd: $state.{label}.hd, p: [e.x, e.y]}}}}) }} }} - """ - ) - } - ), - **opts, - ) | Plot.initialState({label: initial_pose.as_dict()}, sync=label) - + """)}), **opts) + | Plot.initialState({label: initial_pose.as_dict()}, sync=label) + ) # %% some_pose = Pose(jnp.array([6.0, 15.0]), jnp.array(0.0)) -( - Plot.html( - "Click-drag on pose to change location. Shift-click-drag on pose to change heading." - ) - | (world_plot + pose_widget("pose", some_pose, color="blue")) - | Plot.html( - js( - "`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`" - ) - ) -) +Plot.html("Click-drag on pose to change location. Shift-click-drag on pose to change heading.") | ( + world_plot + + pose_widget("pose", some_pose, color="blue") +) | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) # %% [markdown] # A static picture in case of limited interactivity: @@ -343,20 +327,19 @@ def pose_widget(label, initial_pose, **opts): # %% key = jax.random.key(0) - def random_pose(k): - p_hd = jax.random.uniform( - k, - shape=(3,), + p_hd = jax.random.uniform(k, shape=(3,), minval=world["bounding_box"][:, 0], - maxval=world["bounding_box"][:, 1], - ) + maxval=world["bounding_box"][:, 1]) return Pose(p_hd[0:2], p_hd[2]) - some_poses = jax.vmap(random_pose)(jax.random.split(key, 20)) -(world_plot + pose_plots(some_poses, color="green") + {"title": "Some poses"}) +( + world_plot + + pose_plots(some_poses, color="green") + + {"title": "Some poses"} +) # %% [markdown] @@ -366,7 +349,6 @@ def random_pose(k): # # An "ideal" sensor reports the exact distance cast to a wall. (It is capped off at a max value in case of error.) - # %% def distance(p, seg, PARALLEL_TOL=1.0e-6): """ @@ -387,14 +369,16 @@ def distance(p, seg, PARALLEL_TOL=1.0e-6): st = jnp.where( jnp.abs(det) < PARALLEL_TOL, jnp.array([jnp.nan, jnp.nan]), - jnp.array( - [ - (segdp[0] * pq[1] - segdp[1] * pq[0]) / det, - (pdp[0] * pq[1] - pdp[1] * pq[0]) / det, - ] - ), + jnp.array([ + (segdp[0] * pq[1] - segdp[1] * pq[0]) / det, + (pdp[0] * pq[1] - pdp[1] * pq[0]) / det + ]) + ) + return jnp.where( + (st[0] >= 0.0) & (st[1] >= 0.0) & (st[1] <= 1.0), + st[0], + jnp.inf ) - return jnp.where((st[0] >= 0.0) & (st[1] >= 0.0) & (st[1] <= 1.0), st[0], jnp.inf) # %% @@ -404,60 +388,45 @@ def distance(p, seg, PARALLEL_TOL=1.0e-6): "box_size": world["box_size"], } - def sensor_distance(pose, walls, box_size): d = jnp.min(jax.vmap(distance, in_axes=(None, 0))(pose, walls)) # Capping to a finite value avoids issues below. return jnp.where(jnp.isinf(d), 2.0 * box_size, d) - # This represents a "fan" of sensor angles, with given field of vision, centered at angle 0. - def make_sensor_angles(sensor_settings): na = sensor_settings["num_angles"] return sensor_settings["fov"] * (jnp.arange(na) - ((na - 1) / 2)) / (na - 1) - sensor_angles = make_sensor_angles(sensor_settings) - def ideal_sensor(sensor_angles, pose): return jax.vmap( - lambda angle: sensor_distance( - pose.rotate(angle), world["walls"], sensor_settings["box_size"] - ) + lambda angle: sensor_distance(pose.rotate(angle), world["walls"], sensor_settings["box_size"]) )(sensor_angles) # %% # Plot sensor data. - def plot_sensors(pose, readings, sensor_angles): - return Plot.Import( - """export const projections = (pose, readings, angles) => Array.from({length: readings.length}, (_, i) => { + return Plot.Import("""export const projections = (pose, readings, angles) => Array.from({length: readings.length}, (_, i) => { const angle = angles[i] + pose.hd const reading = readings[i] return [pose.p[0] + reading * Math.cos(angle), pose.p[1] + reading * Math.sin(angle)] })""", - refer=["projections"], - ) | ( + refer=["projections"]) | ( Plot.line( - js( - "projections(%1, %2, %3).flatMap((projection, i) => [%1.p, projection, i])", - pose, - readings, - sensor_angles, - ), + js("projections(%1, %2, %3).flatMap((projection, i) => [%1.p, projection, i])", pose, readings, sensor_angles), stroke=Plot.constantly("sensor rays"), - ) - + Plot.dot( + ) + + Plot.dot( js("projections(%1, %2, %3)", pose, readings, sensor_angles), r=2.75, fill=Plot.constantly("sensor readings"), - ) - + Plot.colorMap({"sensor rays": "rgba(0,0,0,0.1)", "sensor readings": "#f80"}) + ) + + Plot.colorMap({"sensor rays": "rgba(0,0,0,0.1)", "sensor readings": "#f80"}) ) @@ -466,16 +435,10 @@ def pose_at(state, label): pose_dict = getattr(state, label) return Pose(jnp.array(pose_dict["p"]), jnp.array(pose_dict["hd"])) - def update_ideal_sensors(widget, _, label="pose"): - widget.state.update( - { - (label + "_readings"): ideal_sensor( - sensor_angles, pose_at(widget.state, label) - ) - } - ) - + widget.state.update({ + (label + "_readings"): ideal_sensor(sensor_angles, pose_at(widget.state, label)) + }) ( ( @@ -483,12 +446,10 @@ def update_ideal_sensors(widget, _, label="pose"): + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + pose_widget("pose", some_pose, color="blue") ) - | Plot.html( - js( - "`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`" - ) - ) - | Plot.initialState({"pose_readings": ideal_sensor(sensor_angles, some_pose)}) + | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) + | Plot.initialState({ + "pose_readings": ideal_sensor(sensor_angles, some_pose) + }) | Plot.onChange({"pose": update_ideal_sensors}) ) @@ -498,17 +459,14 @@ def update_ideal_sensors(widget, _, label="pose"): # %% some_readings = jax.vmap(ideal_sensor, in_axes=(None, 0))(sensor_angles, some_poses) -Plot.Frames( - [ - ( - world_plot - + plot_sensors(pose, some_readings[i], sensor_angles) - + pose_plots(pose) - ) - for i, pose in enumerate(some_poses) - ], - fps=2, -) +Plot.Frames([ + ( + world_plot + + plot_sensors(pose, some_readings[i], sensor_angles) + + pose_plots(pose) + ) + for i, pose in enumerate(some_poses) +], fps=2) # %% [markdown] # ## First steps in modeling uncertainty using Gen @@ -524,15 +482,12 @@ def update_ideal_sensors(widget, _, label="pose"): # # Its declarative model in `Gen` starts with the case of just one sensor reading: - # %% @genjax.gen def sensor_model_one(pose, angle): return ( genjax.normal( - sensor_distance( - pose.rotate(angle), world["walls"], sensor_settings["box_size"] - ), + sensor_distance(pose.rotate(angle), world["walls"], sensor_settings["box_size"]), sensor_settings["s_noise"], ) @ "distance" @@ -580,7 +535,6 @@ def sensor_model_one(pose, angle): # # With a little wrapping, one gets a function of the same type as `ideal_sensor`, ignoring the PRNG key. - # %% def noisy_sensor(key, pose): return sensor_model.propose(key, (pose, sensor_angles))[2] @@ -590,10 +544,12 @@ def noisy_sensor(key, pose): def update_noisy_sensors(widget, _, label="pose"): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) readings = noisy_sensor(k1, pose_at(widget.state, label)) - widget.state.update({"k": jax.random.key_data(k2), (label + "_readings"): readings}) + widget.state.update({ + "k": jax.random.key_data(k2), + (label + "_readings"): readings + }) return readings - key, k1, k2 = jax.random.split(key, 3) ( ( @@ -601,15 +557,11 @@ def update_noisy_sensors(widget, _, label="pose"): + plot_sensors(js("$state.pose"), js("$state.pose_readings"), sensor_angles) + pose_widget("pose", some_pose, color="blue") ) - | Plot.html( - js( - "`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`" - ) - ) - | Plot.initialState( - {"k": jax.random.key_data(k1), "pose_readings": noisy_sensor(k2, some_pose)}, - sync={"k"}, - ) + | Plot.html(js("`pose = Pose([${$state.pose.p.map((x) => x.toFixed(2))}], ${$state.pose.hd.toFixed(2)})`")) + | Plot.initialState({ + "k": jax.random.key_data(k1), + "pose_readings": noisy_sensor(k2, some_pose) + }, sync={"k"}) | Plot.onChange({"pose": update_noisy_sensors}) ) @@ -658,52 +610,39 @@ def update_noisy_sensors(widget, _, label="pose"): guess_pose = Pose(jnp.array([2.0, 16]), jnp.array(0.0)) target_pose = Pose(jnp.array([15.0, 4.0]), jnp.array(-1.6)) - def likelihood_function(cm, pose): return sensor_model.assess(cm, (pose, sensor_angles))[0] - def on_guess_pose_chage(widget, _): update_ideal_sensors(widget, None, label="guess") - widget.state.update( - { - "likelihood": likelihood_function( - C["distance"].set(widget.state.target_readings), - pose_at(widget.state, "guess"), - ) - } - ) - + widget.state.update({"likelihood": + likelihood_function( + C["distance"].set(widget.state.target_readings), + pose_at(widget.state, "guess") + ) + }) def on_target_pose_chage(widget, _): update_noisy_sensors(widget, None, label="target") - widget.state.update( - { - "likelihood": likelihood_function( - C["distance"].set(widget.state.target_readings), - pose_at(widget.state, "guess"), - ) - } - ) - + widget.state.update({"likelihood": + likelihood_function( + C["distance"].set(widget.state.target_readings), + pose_at(widget.state, "guess") + ) + }) ( Plot.Grid( ( world_plot - + plot_sensors( - js("$state.guess"), js("$state.target_readings"), sensor_angles - ) + + plot_sensors(js("$state.guess"), js("$state.target_readings"), sensor_angles) + pose_widget("guess", guess_pose, color="blue") - + Plot.cond( - js("$state.show_target_pose"), - pose_widget("target", target_pose, color="gold"), - ) + + Plot.cond(js("$state.show_target_pose"), + pose_widget("target", target_pose, color="gold")) ), ( Plot.rectY( - Plot.js( - """ + Plot.js(""" const data = []; for (let i = 0; i < $state.guess_readings.length; i++) { data.push({ @@ -718,82 +657,61 @@ def on_target_pose_chage(widget, _): }); } return data; - """, - expression=False, - ), + """, expression=False), x="sensor index", y="distance", fill="group", - interval=0.5, + interval=0.5 ) + Plot.domainY([0, 15]) + {"height": 300, "marginBottom": 50} - + Plot.color_map( - { - "wall distances from guess pose": "blue", - "sensor readings from hidden pose": "gold", - } - ) + + Plot.color_map({ + "wall distances from guess pose": "blue", + "sensor readings from hidden pose": "gold" + }) + Plot.colorLegend() + {"legend": {"anchor": "middle", "x": 0.5, "y": 1.2}} | [ "div", {"class": "text-lg mt-2 text-center w-full"}, - Plot.js( - "'log likelihood (greater is better): ' + $state.likelihood.toFixed(2)" - ), + Plot.js("'log likelihood (greater is better): ' + $state.likelihood.toFixed(2)") ] ), - cols=2, + cols=2 ) | ( - Plot.html( + Plot.html([ + "div", + {"class": "flex flex-col gap-4"}, [ - "div", - {"class": "flex flex-col gap-4"}, + "label", + {"class": "flex items-center gap-2 cursor-pointer"}, [ - "label", - {"class": "flex items-center gap-2 cursor-pointer"}, - [ - "input", - { - "type": "checkbox", - "checked": js("$state.show_target_pose"), - "onChange": js( - "(e) => $state.show_target_pose = e.target.checked" - ), - }, - ], - "show target pose", + "input", + { + "type": "checkbox", + "checked": js("$state.show_target_pose"), + "onChange": js("(e) => $state.show_target_pose = e.target.checked") + } ], + "show target pose" ] - ) - & Plot.html( - js( - "`guess = Pose([${$state.guess.p.map((x) => x.toFixed(2))}], ${$state.guess.hd.toFixed(2)})`" - ) - ) - & Plot.html( - js( - "`target = Pose([${$state.target.p.map((x) => x.toFixed(2))}], ${$state.target.hd.toFixed(2)})`" - ) - ) + ]) + & Plot.html(js("`guess = Pose([${$state.guess.p.map((x) => x.toFixed(2))}], ${$state.guess.hd.toFixed(2)})`")) + & Plot.html(js("`target = Pose([${$state.target.p.map((x) => x.toFixed(2))}], ${$state.target.hd.toFixed(2)})`")) ) | Plot.initialState( { "k": jax.random.key_data(k1), "guess_readings": ideal_sensor(sensor_angles, guess_pose), - "target_readings": ( - initial_target_readings := noisy_sensor(k3, target_pose) - ), - "likelihood": likelihood_function( - C["distance"].set(initial_target_readings), guess_pose - ), + "target_readings": (initial_target_readings := noisy_sensor(k3, target_pose)), + "likelihood": likelihood_function(C["distance"].set(initial_target_readings), guess_pose), "show_target_pose": False, - }, - sync={"k", "target_readings"}, - ) - | Plot.onChange({"guess": on_guess_pose_chage, "target": on_target_pose_chage}) + }, sync={"k", "target_readings"}) + | Plot.onChange({ + "guess": on_guess_pose_chage, + "target": on_target_pose_chage + }) ) @@ -802,81 +720,60 @@ def on_target_pose_chage(widget, _): # # The next few widgets all operate on the same principle. A manipulable "camera" pose is shown. Hitting the button at bottom fixes the "target" pose to the camera pose and samples a batch of sensor readings at the target. It then performs some computation (optimization or inference) on those sensor readings and displays everything: the target, its sensor readings, and the computation results. - # %% def on_camera_button(button_handler): def handler(widget, _): k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) - widget.state.update( - { - "k": jax.random.key_data(k1), - "target": widget.state.camera, - } - ) + widget.state.update({ + "k": jax.random.key_data(k1), + "target": widget.state.camera, + }) readings = update_noisy_sensors(widget, None, label="target") button_handler(widget, k2, readings) - widget.state.update( - { - "target_exists": True, - } - ) - + widget.state.update({ + "target_exists": True, + }) return handler - def camera_widget( - k, - camera_pose, - button_label, - button_handler, - result_plots=Plot.dot([jnp.sum(world["bounding_box"], axis=1)[0:2]], opacity=1), - bottom_elements=(), - initial_state={}, - sync=set(), -): + k, camera_pose, + button_label, button_handler, + result_plots=Plot.dot([jnp.sum(world["bounding_box"], axis=1)[0:2]], opacity=1), + bottom_elements=(), + initial_state={}, + sync=set()): return ( ( world_plot - + Plot.cond( - js("$state.target_exists"), + + Plot.cond(js("$state.target_exists"), result_plots - + plot_sensors( - js("$state.target"), js("$state.target_readings"), sensor_angles - ) - + pose_plots(js("$state.target"), color="red"), + + plot_sensors(js("$state.target"), js("$state.target_readings"), sensor_angles) + + pose_plots(js("$state.target"), color="red") ) + pose_widget("camera", camera_pose, color="blue") ) | ( - Plot.html( - [ - "div", - {"class": "flex flex-col gap-4"}, - [ - "button", # Changed from 'input' to 'button' - { - "class": "w-24 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 active:bg-blue-700", - "onClick": on_camera_button(button_handler), - }, - button_label, # Text as child element, not label - ], - ] - ) - & Plot.html( - [ - "div", - Plot.js( - """`camera = Pose([${$state.camera.p.map((x) => x.toFixed(2))}], ${$state.camera.hd.toFixed(2)})`""" - ), - ] - ) - & Plot.html( + Plot.html([ + "div", + {"class": "flex flex-col gap-4"}, [ - "div", - Plot.js("""$state.target_exists ? - `target = Pose([${$state.target.p.map((x) => x.toFixed(2))}], ${$state.target.hd.toFixed(2)})` : ''"""), + "button", # Changed from 'input' to 'button' + { + "class": "w-24 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 active:bg-blue-700", + "onClick": on_camera_button(button_handler) + }, + button_label # Text as child element, not label ] - ) + ]) + & Plot.html([ + "div", + Plot.js("""`camera = Pose([${$state.camera.p.map((x) => x.toFixed(2))}], ${$state.camera.hd.toFixed(2)})`""") + ]) + & Plot.html([ + "div", + Plot.js("""$state.target_exists ? + `target = Pose([${$state.target.p.map((x) => x.toFixed(2))}], ${$state.target.hd.toFixed(2)})` : ''""") + ]) & bottom_elements ) | Plot.initialState( @@ -885,10 +782,8 @@ def camera_widget( "target_exists": False, "target": {"p": None, "hd": None}, "target_readings": [], - } - | initial_state, - sync=({"k", "target", "camera_readings"} | sync), - ) + } | initial_state, + sync=({"k", "target", "camera_readings"} | sync)) ) @@ -899,22 +794,14 @@ def camera_widget( # # The idea here is just to search for the good poses by brute force, ranging over a suitable discretization grid of the map. So in this widget, the button fires a sweep over a grid of possible poses, computing the likelihood of each. The top `N_keep` are kept, and shown with opacity proportional to their position in that sublist. Moreover, the best fit is shown in purple. - # %% def make_grid(bounds, ns): - return [ - dim.reshape(-1) - for dim in jnp.meshgrid( - *(jnp.linspace(*bound, num=n) for (bound, n) in zip(bounds, ns)) - ) - ] - + return [dim.reshape(-1) for dim in jnp.meshgrid(*(jnp.linspace(*bound, num=n) for (bound, n) in zip(bounds, ns)))] def make_poses_grid_array(bounds, ns): grid_xs, grid_ys, grid_hds = make_grid(bounds, ns) return jnp.array([grid_xs, grid_ys]).T, grid_hds - def make_poses_grid(bounds, ns): return Pose(*make_poses_grid_array(bounds, ns)) @@ -928,21 +815,16 @@ def make_poses_grid(bounds, ns): camera_pose = Pose(jnp.array([2.0, 16]), jnp.array(0.0)) grid_poses = make_poses_grid(world["bounding_box"], N_grid) - - def grid_search_handler(widget, k, readings): jitted_likelihood = jax.jit( lambda pose: likelihood_function(C["distance"].set(readings), pose) ) likelihoods = jax.vmap(jitted_likelihood)(grid_poses) best = jnp.argsort(likelihoods, descending=True)[0:N_keep] - widget.state.update( - { - "grid_poses": grid_poses[best].as_dict(), - "best": grid_poses[best[0]].as_dict(), - } - ) - + widget.state.update({ + "grid_poses": grid_poses[best].as_dict(), + "best": grid_poses[best[0]].as_dict() + }) camera_widget( sub_key, @@ -950,22 +832,16 @@ def grid_search_handler(widget, k, readings): "grid search", grid_search_handler, result_plots=( - pose_plots( - js("$state.grid_poses"), - color="green", - opacity=jnp.arange(1.0, 0.0, -1.0 / N_keep), - ) + pose_plots(js("$state.grid_poses"), color="green", opacity=jnp.arange(1.0, 0.0, -1.0/N_keep)) + pose_plots(js("$state.best"), color="purple") ), bottom_elements=( - Plot.html( - [ - "div", - # For some reason `toFixed` very stubbonrly malfunctions in the following line: - Plot.js("""$state.target_exists ? - `best = Pose([${$state.best.p.map((x) => x.toFixed(2))}], ${$state.best.hd.toFixed(2)})` : ''"""), - ] - ) + Plot.html([ + "div", + # For some reason `toFixed` very stubbonrly malfunctions in the following line: + Plot.js("""$state.target_exists ? + `best = Pose([${$state.best.p.map((x) => x.toFixed(2))}], ${$state.best.hd.toFixed(2)})` : ''""") + ]) ), initial_state={ "grid_poses": {"p": [], "hd": []}, @@ -996,7 +872,6 @@ def grid_search_handler(widget, k, readings): grid_poses = make_poses_grid(world["bounding_box"], N_grid) - def grid_approximation_handler(widget, k, readings): jitted_likelihood = jax.jit( lambda pose: likelihood_function(C["distance"].set(readings), pose) @@ -1007,12 +882,9 @@ def grid_sample_one(k): return grid_poses[genjax.categorical.sample(k, likelihoods)] grid_samples = jax.vmap(grid_sample_one)(jax.random.split(k, N_samples)) - widget.state.update( - { - "sample_poses": grid_samples, - } - ) - + widget.state.update({ + "sample_poses": grid_samples, + }) camera_widget( sub_key, @@ -1041,7 +913,6 @@ def grid_sample_one(k): camera_pose = Pose(jnp.array([15.13, 14.16]), jnp.array(1.5)) - def importance_resampling_handler(widget, k, readings): jitted_likelihood = jax.jit( lambda pose: likelihood_function(C["distance"].set(readings), pose) @@ -1054,12 +925,9 @@ def importance_resample_one(k): return presamples[genjax.categorical.sample(k2, likelihoods)] grid_samples = jax.vmap(importance_resample_one)(jax.random.split(k, N_samples)) - widget.state.update( - { - "sample_poses": grid_samples, - } - ) - + widget.state.update({ + "sample_poses": grid_samples, + }) camera_widget( sub_key, @@ -1083,7 +951,6 @@ def importance_resample_one(k): camera_pose = Pose(jnp.array([15.13, 14.16]), jnp.array(1.5)) - def MCMC_handler(widget, k, readings): jitted_likelihood = jax.jit( lambda pose: likelihood_function(C["distance"].set(readings), pose) @@ -1099,31 +966,25 @@ def do_MH_step(pose_likelihood, k): new_p_hd = jax.random.uniform(k1, shape=(3,), minval=mins, maxval=maxs) new_pose = Pose(new_p_hd[0:2], new_p_hd[2]) new_likelihood = jitted_likelihood(new_pose) - accept = jnp.log(genjax.uniform.sample(k2)) <= new_likelihood - likelihood + accept = (jnp.log(genjax.uniform.sample(k2)) <= new_likelihood - likelihood) return ( jax.tree.map( lambda x, y: jnp.where(accept, x, y), (new_pose, new_likelihood), - (pose, likelihood), + (pose, likelihood) ), - None, + None ) - def sample_MH_one(k): k1, k2 = jax.random.split(k) start_pose = random_pose(k1) start_likelihood = jitted_likelihood(start_pose) - return jax.lax.scan( - do_MH_step, (start_pose, start_likelihood), jax.random.split(k2, N_MH_steps) - )[0][0] + return jax.lax.scan(do_MH_step, (start_pose, start_likelihood), jax.random.split(k2, N_MH_steps))[0][0] grid_samples = jax.vmap(sample_MH_one)(jax.random.split(k, N_samples)) - widget.state.update( - { - "sample_poses": grid_samples, - } - ) - + widget.state.update({ + "sample_poses": grid_samples, + }) camera_widget( sub_key, @@ -1144,14 +1005,12 @@ def sample_MH_one(k): # * an estimated initial pose (= position + heading), and # * a program of controls (= advance distance, followed by rotate heading). - # %% @pz.pytree_dataclass class Control(genjax.PythonicPytree): ds: FloatArray dhd: FloatArray - def load_robot_program(file_name): """ Loads the robot program from a specified file. @@ -1200,11 +1059,8 @@ def load_robot_program(file_name): # # If the motion of the robot is determined in an ideal manner by the controls, then we may simply integrate to determine the resulting path. Naïvely, this results in the following. - # %% -def diag(x): - return (x, x) - +def diag(x): return (x, x) def integrate_controls_unphysical(robot_inputs): """ @@ -1229,34 +1085,21 @@ def integrate_controls_unphysical(robot_inputs): # %% def update_unphysical_path(widget, _): start = pose_at(widget.state, "start") - widget.state.update( - {"path": integrate_controls_unphysical(robot_inputs | {"start": start})} - ) - + widget.state.update({ + "path": integrate_controls_unphysical(robot_inputs | {"start": start}) + }) ( ( world_plot - + pose_plots( - js("$state.path"), - color=Plot.constantly("path from integrating controls (UNphysical)"), - ) - + pose_widget( - "start", robot_inputs["start"], color=Plot.constantly("start pose") - ) - + Plot.color_map( - { - "start pose": "blue", - "path from integrating controls (UNphysical)": "green", - } - ) + + pose_plots(js("$state.path"), color=Plot.constantly("path from integrating controls (UNphysical)")) + + pose_widget("start", robot_inputs["start"], color=Plot.constantly("start pose")) + + Plot.color_map({"start pose": "blue", "path from integrating controls (UNphysical)": "green"}) ) - | Plot.html( - js( - "`start = Pose([${$state.start.p.map((x) => x.toFixed(2))}], ${$state.start.hd.toFixed(2)})`" - ) - ) - | Plot.initialState({"path": integrate_controls_unphysical(robot_inputs)}) + | Plot.html(js("`start = Pose([${$state.start.p.map((x) => x.toFixed(2))}], ${$state.start.hd.toFixed(2)})`")) + | Plot.initialState({ + "path": integrate_controls_unphysical(robot_inputs) + }) | Plot.onChange({"start": update_unphysical_path}) ) @@ -1266,7 +1109,6 @@ def update_unphysical_path(widget, _): # # We employ the following simple physics: when the robot's forward step through a control comes into contact with a wall, that step is interrupted and the robot instead "bounces" a fixed distance from the point of contact in the normal direction to the wall. - # %% @jax.jit def physical_step(p1: FloatArray, p2: FloatArray, hd): @@ -1297,9 +1139,7 @@ def physical_step(p1: FloatArray, p2: FloatArray, hd): collision_point = p1 + closest_wall_distance * step_pose.dp() wall_direction = closest_wall[1] - closest_wall[0] normalized_wall_direction = wall_direction / jnp.linalg.norm(wall_direction) - wall_normal = jnp.array( - [-normalized_wall_direction[1], normalized_wall_direction[0]] - ) + wall_normal = jnp.array([-normalized_wall_direction[1], normalized_wall_direction[0]]) # Ensure wall_normal points away from the robot's direction wall_normal = jnp.where( @@ -1316,7 +1156,6 @@ def physical_step(p1: FloatArray, p2: FloatArray, hd): return Pose(final_position, hd) - def integrate_controls_physical(robot_inputs): """ Integrates controls to generate a path, taking into account physical interactions with walls. @@ -1328,11 +1167,9 @@ def integrate_controls_physical(robot_inputs): - Pose: A Pose object representing the path taken by applying the controls. """ return jax.lax.scan( - lambda pose, control: diag( - physical_step( + lambda pose, control: diag(physical_step( pose.p, pose.p + control.ds * pose.dp(), pose.hd + control.dhd - ) - ), + )), robot_inputs["start"], robot_inputs["controls"], )[1] @@ -1345,35 +1182,24 @@ def integrate_controls_physical(robot_inputs): # %% def update_physical_path(widget, _): start = pose_at(widget.state, "start") - widget.state.update( - {"path": integrate_controls_physical(robot_inputs | {"start": start})} - ) - + widget.state.update({ + "path": integrate_controls_physical(robot_inputs | {"start": start}) + }) ( ( world_plot - + pose_plots( - js("$state.path"), - color=Plot.constantly("path from integrating controls (physical)"), - ) - + pose_widget( - "start", robot_inputs["start"], color=Plot.constantly("start pose") - ) - + Plot.color_map( - {"start pose": "blue", "path from integrating controls (physical)": "green"} - ) + + pose_plots(js("$state.path"), color=Plot.constantly("path from integrating controls (physical)")) + + pose_widget("start", robot_inputs["start"], color=Plot.constantly("start pose")) + + Plot.color_map({"start pose": "blue", "path from integrating controls (physical)": "green"}) ) - | Plot.html( - js( - "`start = Pose([${$state.start.p.map((x) => x.toFixed(2))}], ${$state.start.hd.toFixed(2)})`" - ) - ) - | Plot.initialState({"path": integrate_controls_physical(robot_inputs)}) + | Plot.html(js("`start = Pose([${$state.start.p.map((x) => x.toFixed(2))}], ${$state.start.hd.toFixed(2)})`")) + | Plot.initialState({ + "path": integrate_controls_physical(robot_inputs) + }) | Plot.onChange({"start": update_physical_path}) ) - # %% [markdown] # ### Modeling taking steps # @@ -1402,7 +1228,6 @@ def step_model(motion_settings, start, control): key, k1, k2 = jax.random.split(key, 3) - def confidence_circle(p, p_noise): return Plot.ellipse( p, @@ -1410,7 +1235,6 @@ def confidence_circle(p, p_noise): fill=Plot.constantly("95% confidence region"), ) + Plot.color_map({"95% confidence region": "rgba(255,0,0,0.25)"}) - def update_confidence_circle(widget, _): step = pose_at(widget.state, "step") step_vector = step.p - robot_inputs["start"].p @@ -1421,62 +1245,46 @@ def update_confidence_circle(widget, _): ds = jnp.linalg.norm(step_vector) dhd = (step.hd - tilted_start_hd + jnp.pi) % (2.0 * jnp.pi) - jnp.pi - widget.state.update( - {"start": tilted_start.as_dict(), "control": {"ds": ds, "dhd": dhd}} - ) + widget.state.update({ + "start": tilted_start.as_dict(), + "control": {"ds": ds, "dhd": dhd} + }) k1, k2 = jax.random.split(jax.random.wrap_key_data(widget.state.k)) samples = jax.vmap(step_model.propose, in_axes=(0, None))( jax.random.split(k1, N_samples), (default_motion_settings, tilted_start, Control(ds, dhd)), )[2] - widget.state.update({"k": jax.random.key_data(k2), "samples": samples.as_dict()}) - + widget.state.update({ + "k": jax.random.key_data(k2), + "samples": samples.as_dict() + }) ( ( world_plot + confidence_circle(js("[$state.step.p]"), default_motion_settings["p_noise"]) - + pose_plots( - js("$state.samples"), color=Plot.constantly("samples from the step model") - ) + + pose_plots(js("$state.samples"), color=Plot.constantly("samples from the step model")) + pose_plots(js("$state.start"), color=Plot.constantly("start pose")) - + pose_widget( - "step", - robot_inputs["start"], - color=Plot.constantly("attempt to step to here"), - ) - + Plot.color_map( - { - "start pose": "black", - "attempt to step to here": "blue", - "samples from the step model": "green", - } - ) - ) - | Plot.html( - js( - "`control = Control(${$state.control.ds.toFixed(2)}, ${$state.control.dhd.toFixed(2)})`" - ) - ) - | Plot.initialState( - { - "start": robot_inputs["start"].as_dict(), - "control": {"ds": 0.0, "dhd": 0.0}, - "k": jax.random.key_data(k1), - "samples": ( - jax.vmap(step_model.propose, in_axes=(0, None))( - jax.random.split(k2, N_samples), - ( - default_motion_settings, - robot_inputs["start"], - robot_inputs["controls"][0], - ), - )[2].as_dict() - ), - }, - sync={"k"}, + + pose_widget("step", robot_inputs["start"], color=Plot.constantly("attempt to step to here")) + + Plot.color_map({ + "start pose": "black", + "attempt to step to here": "blue", + "samples from the step model": "green", + }) ) + | Plot.html(js("`control = Control(${$state.control.ds.toFixed(2)}, ${$state.control.dhd.toFixed(2)})`")) + | Plot.initialState({ + "start": robot_inputs["start"].as_dict(), + "control": {"ds": 0.0, "dhd": 0.0}, + "k": jax.random.key_data(k1), + "samples": ( + jax.vmap(step_model.propose, in_axes=(0, None))( + jax.random.split(k2, N_samples), + (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), + )[2].as_dict() + ), + }, sync={"k"}) | Plot.onChange({"step": update_confidence_circle}) ) @@ -1505,7 +1313,6 @@ def update_confidence_circle(widget, _): # %% [markdown] # Here is a single path with confidence circles on each step's draw. - # %% def plot_path_with_confidence(path, step): prev_step = robot_inputs["start"] if step == 0 else path[step - 1] @@ -1513,21 +1320,19 @@ def plot_path_with_confidence(path, step): world_plot + confidence_circle( [prev_step.apply_control(robot_inputs["controls"][step]).p], - default_motion_settings["p_noise"], + default_motion_settings["p_noise"] ) + [pose_plots(path[i]) for i in range(step)] + pose_plots(path[step], color=Plot.constantly("next pose")) + Plot.color_map({"previous poses": "black", "next pose": "green"}) ) - key, sample_key = jax.random.split(key) -path = path_model.propose( - sample_key, (robot_inputs["start"], robot_inputs["controls"]) -)[2][1] +path = path_model.propose(sample_key, (robot_inputs["start"], robot_inputs["controls"]))[2][1] Plot.Frames( [ - plot_path_with_confidence(path, step) + Plot.title("Motion model (samples)") + plot_path_with_confidence(path, step) + + Plot.title("Motion model (samples)") for step in range(len(path)) ], fps=2, @@ -1541,20 +1346,14 @@ def plot_path_with_confidence(path, step): key, sub_key = jax.random.split(key) sample_paths = jax.vmap( - lambda k: path_model.propose(k, (robot_inputs["start"], robot_inputs["controls"]))[ - 2 - ][1] + lambda k: + path_model.propose(k, (robot_inputs["start"], robot_inputs["controls"]))[2][1] )(jax.random.split(sub_key, N_samples)) -Plot.html( - [ - "div.grid.grid-cols-2.gap-4", - *[ - walls_plot + pose_plots(path) + {"maxWidth": 300, "aspectRatio": 1} - for path in sample_paths - ], - ] -) +Plot.html([ + "div.grid.grid-cols-2.gap-4", + *[walls_plot + pose_plots(path) + {"maxWidth": 300, "aspectRatio": 1} for path in sample_paths] +]) # %% [markdown] @@ -1562,7 +1361,6 @@ def plot_path_with_confidence(path, step): # # We fold the sensor model into the motion model to form a "full model", whose traces describe simulations of the entire robot situation as we have described it. - # %% @genjax.gen def full_model_kernel(motion_settings, state, control): @@ -1570,7 +1368,6 @@ def full_model_kernel(motion_settings, state, control): sensor_model(pose, sensor_angles) @ "sensor" return pose - @genjax.gen def full_model(motion_settings): return ( @@ -1603,23 +1400,15 @@ def full_model(motion_settings): # %% [markdown] # By this point, visualization is essential. We will just get a quick picture here, and turn toward a more principled approach immediately thereafter. - # %% def animate_path_and_sensors(path, readings, motion_settings, frame_key=None): - return Plot.Frames( - [ - plot_path_with_confidence(path, step) - + plot_sensors(pose, readings[step], sensor_angles) - for step, pose in enumerate(path) - ], - fps=2, - key=frame_key, - ) - + return Plot.Frames([ + plot_path_with_confidence(path, step) + + plot_sensors(pose, readings[step], sensor_angles) + for step, pose in enumerate(path) + ], fps=2, key=frame_key) -animate_path_and_sensors( - retval[1], cm["steps", "sensor", "distance"], default_motion_settings -) +animate_path_and_sensors(retval[1], cm["steps", "sensor", "distance"], default_motion_settings) # %% [markdown] # ## From choicemaps to traces @@ -1664,7 +1453,12 @@ def animate_path_and_sensors(path, readings, motion_settings, frame_key=None): # %% key, sub_key = jax.random.split(key) -selections = [genjax.Selection.none(), S["p"], S["hd"], S["p"] | S["hd"]] +selections = [ + genjax.Selection.none(), + S["p"], + S["hd"], + S["p"] | S["hd"] +] [ trace.project(k, sel) for k, sel in zip(jax.random.split(sub_key, len(selections)), selections) @@ -1742,16 +1536,13 @@ def animate_path_and_sensors(path, readings, motion_settings, frame_key=None): # %% [markdown] # In addition to the handy data structure inspection `pz.ts.display(trace)` shown above, it is important to develop, ongoing in one's work, visualization code *as a function of the trace*, since all the information is on one place here, and *in the context of the human meaning of the information*. - # %% def get_path(trace): return trace.get_retval()[1] - def get_sensors(trace): return trace.get_choices()["steps", "sensor", "distance"] - def animate_full_trace(trace, frame_key=None): path = get_path(trace) readings = get_sensors(trace) @@ -1760,7 +1551,6 @@ def animate_full_trace(trace, frame_key=None): path, readings, motion_settings, frame_key=frame_key ) - key, sub_key = jax.random.split(key) tr = full_model.simulate(sub_key, (default_motion_settings,)) @@ -1798,12 +1588,8 @@ def animate_full_trace(trace, frame_key=None): observations_low_deviation = get_sensors(trace_low_deviation) observations_high_deviation = get_sensors(trace_high_deviation) -constraints_low_deviation = C["steps", "sensor", "distance"].set( - observations_low_deviation -) -constraints_high_deviation = C["steps", "sensor", "distance"].set( - observations_high_deviation -) +constraints_low_deviation = C["steps", "sensor", "distance"].set(observations_low_deviation) +constraints_high_deviation = C["steps", "sensor", "distance"].set(observations_high_deviation) # %% [markdown] # We summarize the information available to the robot to determine its location: @@ -1823,15 +1609,10 @@ def plt(readings): return plt(readings1) & plt(readings2) - return Plot.Frames( - [ - frame(*scene) - for scene in zip( - path, observations_low_deviation, observations_high_deviation - ) - ], - fps=2, - ) + return Plot.Frames([ + frame(*scene) + for scene in zip(path, observations_low_deviation, observations_high_deviation) + ], fps=2) animate_bare_sensors(itertools.repeat(Pose(world["center_point"], 0.0))) @@ -1908,7 +1689,6 @@ def plt(readings): # # In words, the data are incongruously unlikely for the integrated path. The (log) density of the measurement data, given the integrated path... - # %% def constraint_from_path(path): c_ps = jax.vmap(lambda ix, p: C["steps", ix, "pose", "p"].set(p))( @@ -1920,7 +1700,6 @@ def constraint_from_path(path): ) return c_ps + c_hds - constraints_path_integrated = constraint_from_path(path_integrated) constraints_path_integrated_observations_low_deviation = ( constraints_path_integrated ^ constraints_low_deviation @@ -1942,29 +1721,27 @@ def constraint_from_path(path): (motion_settings_high_deviation,), ) -Plot.Row( - *[ - ( - html("div.f3.b.tc", title) - | animate_full_trace(trace, frame_key="frame") - | html("span.tc", f"score: {score:,.2f}") - ) - for (title, trace, motion_settings, score) in ( - [ - "Low deviation", - trace_path_integrated_observations_low_deviation, - motion_settings_low_deviation, - w_low, - ], - [ - "High deviation", - trace_path_integrated_observations_high_deviation, - motion_settings_high_deviation, - w_high, - ], - ) - ] -) | Plot.Slider("frame", 0, T, fps=2) +Plot.Row(*[ + ( + html("div.f3.b.tc", title) + | animate_full_trace(trace, frame_key="frame") + | html("span.tc", f"score: {score:,.2f}") + ) + for (title, trace, motion_settings, score) in ( + [ + "Low deviation", + trace_path_integrated_observations_low_deviation, + motion_settings_low_deviation, + w_low, + ], + [ + "High deviation", + trace_path_integrated_observations_high_deviation, + motion_settings_high_deviation, + w_high, + ], +) +]) | Plot.Slider("frame", 0, T, fps=2) # %% [markdown] # ...more closely resembles the density of these data back-fitted onto typical (random) paths of the model. @@ -1995,8 +1772,14 @@ def constraint_from_path(path): Plot.new( world_plot, - [pose_plots(pose, color="blue", opacity=0.1) for pose in high_deviation_paths[:20]], - [pose_plots(pose, color="green", opacity=0.1) for pose in low_deviation_paths[:20]], + [ + pose_plots(pose, color="blue", opacity=0.1) + for pose in high_deviation_paths[:20] + ], + [ + pose_plots(pose, color="green", opacity=0.1) + for pose in low_deviation_paths[:20] + ], ) # %% [markdown] @@ -2082,7 +1865,6 @@ def constraint_from_path(path): N_presamples = 2000 N_samples = 20 - def importance_sample( key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int ): @@ -2107,19 +1889,11 @@ def importance_sample( key, sub_key = jax.random.split(key) low_posterior = jit_resample( - sub_key, - constraints_low_deviation, - motion_settings_low_deviation, - N_presamples, - N_samples, + sub_key, constraints_low_deviation, motion_settings_low_deviation, N_presamples, N_samples ) key, sub_key = jax.random.split(key) high_posterior = jit_resample( - sub_key, - constraints_high_deviation, - motion_settings_high_deviation, - N_presamples, - N_samples, + sub_key, constraints_high_deviation, motion_settings_high_deviation, N_presamples, N_samples ) @@ -2132,7 +1906,6 @@ def path_to_polyline(path, **options): else: return Plot.dot([path.p], fill=options["stroke"], r=2, **options) - ( world_plot + [ @@ -2149,7 +1922,9 @@ def path_to_polyline(path, **options): + pose_plots( path_high_deviation, fill=Plot.constantly("high deviation path"), opacity=0.2 ) - + pose_plots(path_integrated, fill=Plot.constantly("integrated path"), opacity=0.2) + + pose_plots( + path_integrated, fill=Plot.constantly("integrated path"), opacity=0.2 + ) + Plot.color_map( { "low deviation path": "green", @@ -2195,7 +1970,6 @@ def path_to_polyline(path, **options): StateT = TypeVar("StateT") ControlT = TypeVar("ControlT") - class SISwithRejuvenation(Generic[StateT, ControlT]): """ Given: @@ -2234,6 +2008,7 @@ class SISwithRejuvenation(Generic[StateT, ControlT]): of shape (N, T). """ + def __init__( self, init: StateT, @@ -2243,10 +2018,8 @@ def __init__( [PRNGKey, StateT, ControlT, Array], tuple[genjax.Trace[StateT], float] ], rejuvenate: Callable[ - [PRNGKey, genjax.Trace[StateT], Array, StateT, ControlT], - tuple[genjax.Trace[StateT], float], - ] - | None = None, + [PRNGKey, genjax.Trace[StateT], Array, StateT, ControlT], tuple[genjax.Trace[StateT], float] + ] | None = None, ): self.importance = jax.jit(importance) self.rejuvenate = jax.jit(rejuvenate) if rejuvenate else None @@ -2254,6 +2027,7 @@ def __init__( self.controls = controls self.observations = observations + class Result(Generic[StateT]): """This object contains all of the information generated by the SIS scan, and offers some convenient methods to reconstruct the paths explored @@ -2294,24 +2068,27 @@ def backtrack(self) -> list[list[StateT]]: p.reverse() return paths + def run(self, key: PRNGKey, N: int) -> dict: def step(state, update): particles, log_weights = state key, control, observation = update ks = jax.random.split(key, (3, N)) - samples, log_weight_increments = jax.vmap( - self.importance, in_axes=(0, 0, None, None) - )(ks[0], particles, control, observation) + samples, log_weight_increments = jax.vmap(self.importance, in_axes=(0, 0, None, None))( + ks[0], particles, control, observation + ) indices = jax.vmap(genjax.categorical.sampler, in_axes=(0, None))( ks[1], log_weights + log_weight_increments ) - (resamples, antecedents) = jax.tree.map( - lambda v: v[indices], (samples, particles) - ) + (resamples, antecedents) = jax.tree.map(lambda v: v[indices], (samples, particles)) if self.rejuvenate: - rejuvenated, new_log_weights = jax.vmap( - self.rejuvenate, in_axes=(0, 0, 0, None, None) - )(ks[2], resamples, antecedents, control, observation) + rejuvenated, new_log_weights = jax.vmap(self.rejuvenate, in_axes=(0, 0, 0, None, None))( + ks[2], + resamples, + antecedents, + control, + observation + ) else: rejuvenated, new_log_weights = resamples, jnp.zeros(log_weights.shape) return (rejuvenated.get_retval(), new_log_weights), (samples, indices) @@ -2330,7 +2107,6 @@ def step(state, update): ) return SISwithRejuvenation.Result(N, end, samples, indices) - # %% def localization_sis(motion_settings, observations): return SISwithRejuvenation( @@ -2363,17 +2139,14 @@ def localization_sis(motion_settings, observations): # %% N_particles = 100 - def pose_list_to_plural_pose(pl: list[Pose]) -> Pose: return Pose(jnp.array([pose.p for pose in pl]), [pose.hd for pose in pl]) - key, sub_key = jax.random.split(key) smc_result = localization_sis( motion_settings_high_deviation, observations_high_deviation ).run(sub_key, N_particles) - def plot_sis_result(ground_truth, smc_result): return ( world_plot @@ -2384,7 +2157,6 @@ def plot_sis_result(ground_truth, smc_result): ] ) - plot_sis_result(path_high_deviation, smc_result) # %% N_particles = 20 @@ -2402,26 +2174,20 @@ def plot_sis_result(ground_truth, smc_result): # # While we perturb the particles, we update the weight using the *SMCP3* rule. - # %% # This is the general SMCP3 algorithm in the case where there is no Jacobian term. def run_SMCP3_step(fwd_proposal, bwd_proposal, key, sample, proposal_args): k1, k2 = jax.random.split(key, 2) - _, fwd_proposal_weight, (fwd_update, bwd_choices) = fwd_proposal.propose( - k1, (sample, proposal_args) - ) + _, fwd_proposal_weight, (fwd_update, bwd_choices) = fwd_proposal.propose(k1, (sample, proposal_args)) new_sample, model_weight_diff, _, _ = sample.update(k2, fwd_update) - bwd_proposal_weight, _ = bwd_proposal.assess( - bwd_choices, (new_sample, proposal_args) - ) + bwd_proposal_weight, _ = bwd_proposal.assess(bwd_choices, (new_sample, proposal_args)) new_log_weight = model_weight_diff + bwd_proposal_weight - fwd_proposal_weight return new_sample, new_log_weight - # Forward proposal searches a nearby grid around the sample, # and returns an importance-resampled member. # The joint density (= the density from the full model) serves as @@ -2432,12 +2198,13 @@ def grid_fwd_proposal(sample, args): observation_cm = C["sensor", "distance"].set(observation) log_weights = jax.vmap( - lambda p, hd: full_model_kernel.assess( - observation_cm - | C["pose", "p"].set(p + sample.get_retval().p) - | C["pose", "hd"].set(hd + sample.get_retval().hd), - model_args, - )[0] + lambda p, hd: + full_model_kernel.assess( + observation_cm + | C["pose", "p"].set(p + sample.get_retval().p) + | C["pose", "hd"].set(hd + sample.get_retval().hd), + model_args + )[0] )(*base_grid) fwd_index = genjax.categorical(log_weights) @ "fwd_index" @@ -2446,21 +2213,21 @@ def grid_fwd_proposal(sample, args): C["pose", "p"].set(base_grid[0][fwd_index] + sample.get_retval().p) | C["pose", "hd"].set(base_grid[1][fwd_index] + sample.get_retval().hd) ), - C["bwd_index"].set(len(log_weights) - 1 - fwd_index), + C["bwd_index"].set(len(log_weights) - 1 - fwd_index) ) - # Backwards proposal simply guesses according to the prior over steps, nothing fancier. @genjax.gen def grid_bwd_proposal(new_sample, args): base_grid, _, model_args = args log_weights = jax.vmap( - lambda p, hd: step_model.assess( - C["p"].set(p + new_sample.get_retval().p) - | C["hd"].set(hd + new_sample.get_retval().hd), - model_args, - )[0] + lambda p, hd: + step_model.assess( + C["p"].set(p + new_sample.get_retval().p) + | C["hd"].set(hd + new_sample.get_retval().hd), + model_args + )[0] )(*base_grid) _ = genjax.categorical(log_weights) @ "bwd_index" @@ -2470,7 +2237,10 @@ def grid_bwd_proposal(new_sample, args): # %% def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observations): - base_grid = make_poses_grid_array(jnp.array([M_grid / 2.0, M_grid / 2.0]).T, N_grid) + base_grid = make_poses_grid_array( + jnp.array([M_grid / 2.0, M_grid / 2.0]).T, + N_grid + ) return SISwithRejuvenation( robot_inputs["start"], robot_inputs["controls"], @@ -2485,7 +2255,7 @@ def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observatio grid_bwd_proposal, key, sample, - (base_grid, observation, (motion_settings, pose, control)), + (base_grid, observation, (motion_settings, pose, control)) ), ) @@ -2495,7 +2265,7 @@ def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observatio # %% N_particles = 100 -M_grid = jnp.array([0.5, 0.5, jnp.pi / 600.0]) +M_grid = jnp.array([0.5, 0.5, jnp.pi/600.0]) N_grid = jnp.array([15, 15, 15]) key, sub_key = jax.random.split(key) @@ -2506,6 +2276,4 @@ def localization_sis_plus_grid_rejuv(motion_settings, M_grid, N_grid, observatio motion_settings_high_deviation, observations_high_deviation ).run(sub_key, N_particles) -plot_sis_result(path_high_deviation, smc_result) | plot_sis_result( - path_high_deviation, imp_result -) +plot_sis_result(path_high_deviation, smc_result) | plot_sis_result(path_high_deviation, imp_result)