Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Move API #178

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 205 additions & 32 deletions notebooks/bayes3d_paper/old_inference_algorithm.ipynb

Large diffs are not rendered by default.

75 changes: 48 additions & 27 deletions notebooks/bayes3d_paper/run_ycbv_evaluation.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python

import os

import b3d
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels as pixel_rgbd_kernels
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import fire
import genjax
Expand All @@ -20,10 +21,13 @@


def run_tracking(scene=None, object=None, debug=False):
import b3d

FRAME_RATE = 50

ycb_dir = os.path.join(b3d.get_assets_path(), "bop/ycbv")
b3d.rr_init("run_ycbv_evaluation")
b3d.utils.rr_init("run_ycbv_evaluation")

ycb_dir = os.path.join(b3d.utils.get_assets_path(), "bop/ycbv")

if scene is None:
scenes = range(48, 60)
Expand All @@ -33,25 +37,29 @@ def run_tracking(scene=None, object=None, debug=False):
scenes = scene

hyperparams = {
"pose_kernel": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),
"pose_kernel": transition_kernels.GaussianVMFPoseDriftKernel(
variance=0.02, concentration=1000.0
),
"color_kernel": transition_kernels.LaplaceNotTruncatedColorDriftKernel(
scale=0.15
scale=0.02
),
"visibility_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
resample_probability=0.05, support=jnp.array([1e-5, 1.0 - 1e-5])
),
"depth_nonreturn_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
resample_probability=0.05, support=jnp.array([1e-5, 1.0 - 1e-5])
),
"depth_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.0025, 0.01, 0.02])
resample_probability=0.05,
support=jnp.array([0.01, 0.005, 0.01, 0.02]),
),
"color_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.05, 0.1, 0.15])
resample_probability=0.05, support=jnp.array([0.001])
),
"image_kernel": image_kernel.NoOcclusionPerVertexImageKernel(
pixel_rgbd_kernels.OldOcclusionPixelRGBDDistribution()
),
"image_likelihood": image_kernel.SimpleNoRenderImageLikelihood(),
}
info_from_trace = hyperparams["image_likelihood"].info_from_trace

for scene_id in scenes:
print(f"Scene {scene_id}")
Expand Down Expand Up @@ -119,25 +127,30 @@ def run_tracking(scene=None, object=None, debug=False):
model_vertices = model_vertices[subset]
model_colors = model_colors[subset]

hyperparams["intrinsics"] = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy,
"image_height": Pytree.const(image_height),
"image_width": Pytree.const(image_width),
"near": 0.01,
"far": 3.0,
}
hyperparams["vertices"] = model_vertices

num_vertices = model_vertices.shape[0]
previous_state = {
"pose": template_pose,
"colors": model_colors,
"visibility_prob": jnp.ones(num_vertices)
* hyperparams["visibility_prob_kernel"].possible_values[-1],
* hyperparams["visibility_prob_kernel"].support[-1],
"depth_nonreturn_prob": jnp.ones(num_vertices)
* hyperparams["depth_nonreturn_prob_kernel"].possible_values[0],
"depth_scale": hyperparams["depth_scale_kernel"].possible_values[0],
"color_scale": hyperparams["color_scale_kernel"].possible_values[0],
* hyperparams["depth_nonreturn_prob_kernel"].support[0],
"depth_scale": hyperparams["depth_scale_kernel"].support[0],
"color_scale": hyperparams["color_scale_kernel"].support[0],
}

hyperparams["vertices"] = model_vertices
hyperparams["fx"] = fx
hyperparams["fy"] = fy
hyperparams["cx"] = cx
hyperparams["cy"] = cy
hyperparams["image_height"] = Pytree.const(image_height)
hyperparams["image_width"] = Pytree.const(image_width)
choicemap = (
genjax.ChoiceMap.d(
{
Expand All @@ -160,12 +173,19 @@ def run_tracking(scene=None, object=None, debug=False):
key, choicemap, (hyperparams, previous_state)
)[0]

from b3d.chisight.gen3d.inference import inference_step
import b3d.chisight.gen3d.inference as inference
import b3d.chisight.gen3d.inference_old as inference_old
import b3d.chisight.gen3d.settings

inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams

### Run inference ###
for T in tqdm(range(len(all_data))):
key = b3d.split_key(key)
trace = inference_step(trace, key, all_data[T]["rgbd"])
trace = inference.advance_time(key, trace, all_data[T]["rgbd"])
trace = inference_old.inference_step(trace, key, inference_hyperparams)[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nishadgothoskar for tonight, can we aim to launch at least 1 run that is using the new inference algorithm? This will be through a call to inference.inference_step_c2f. I can add this to the script if helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes definitely. that is the plan.

0
]
tracking_results[T] = trace

if debug:
Expand All @@ -190,14 +210,15 @@ def run_tracking(scene=None, object=None, debug=False):
)

trace = tracking_results[len(all_data) - 1]
info = info_from_trace(trace)
rendered_rgbd = info["latent_rgbd"]
latent_rgb = b3d.chisight.gen3d.image_kernel.get_latent_rgb_image(
trace.get_retval()["new_state"], trace.get_args()[0]
)

a = b3d.viz_rgb(
trace.get_choices()["rgbd"][..., :3],
)
b = b3d.viz_rgb(
rendered_rgbd[..., :3],
latent_rgb[..., :3],
)
b3d.multi_panel([a, b, b3d.overlay_image(a, b)]).save(
f"photo_SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}_POSES.png"
Expand Down
78 changes: 40 additions & 38 deletions notebooks/gen3d/interactive_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"import b3d\n",
"import jax.numpy as jnp\n",
"import pytest\n",
"import matplotlib.pyplot as plt"
"import matplotlib.pyplot as plt\n",
"from genjax import Pytree"
]
},
{
Expand All @@ -23,56 +24,45 @@
"metadata": {},
"outputs": [],
"source": [
"near, far, image_height, image_width = 0.001, 1.0, 480, 640\n",
"img_model = image_kernel.NoOcclusionPerVertexImageKernel(\n",
" near, far, image_height, image_width\n",
")\n",
"\n",
"inference_hyperparams = inference.InferenceHyperparams(\n",
" n_poses=1500,\n",
" do_stochastic_color_proposals=True,\n",
" pose_proposal_std=0.04,\n",
" pose_proposal_conc=1000.,\n",
" prev_color_proposal_laplace_scale=0.001,\n",
" obs_color_proposal_laplace_scale=0.001,\n",
")\n",
"\n",
"hyperparams = {\n",
" \"pose_kernel\": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),\n",
" \"color_kernel\": transition_kernels.LaplaceNotTruncatedColorDriftKernel(\n",
" scale= 0.05\n",
" ),\n",
" \"visibility_prob_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1, support=jnp.array([0.01, 0.99])\n",
" ),\n",
" \"depth_nonreturn_prob_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1, support=jnp.array([0.01, 0.99])\n",
" ),\n",
" \"depth_scale_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1,\n",
" support=jnp.array([0.0025, 0.01, 0.02, 0.1, 0.4, 1.0]),\n",
" ),\n",
" \"color_scale_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1, support=jnp.array([0.002, 0.01, 0.025, 0.05, 0.1, 0.15, 0.3, 0.8])\n",
" ),\n",
" \"image_kernel\": img_model,\n",
"}"
"import b3d.chisight.gen3d.settings \n",
"b3d.reload(b3d.chisight.gen3d.settings)\n",
"import b3d.chisight.gen3d.settings \n",
"inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams\n",
"hyperparams = b3d.chisight.gen3d.settings.hyperparams"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"hyperparams[\"intrinsics\"] = {\n",
" \"fx\": 100.0,\n",
" \"fy\": 100.0,\n",
" \"cx\": 50.0,\n",
" \"cy\": 50.0,\n",
" \"near\": 0.01,\n",
" \"far\": 3.0,\n",
" \"image_width\": Pytree.const(100),\n",
" \"image_height\": Pytree.const(100),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e43892100b2c48e9898f4c3e89e22554",
"model_id": "1212c5d7004143ccbc04e120fe8f44f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(ToggleButtons(description='Prev Vis Prob:', options=('0.01', '0.99'), value='0.01'), Tog…"
"interactive(children=(FloatSlider(value=0.10000000149011612, continuous_update=False, description='Observed R:…"
]
},
"metadata": {},
Expand All @@ -81,12 +71,24 @@
],
"source": [
"from b3d.chisight.gen3d.visualization import create_interactive_visualization\n",
"from b3d.chisight.gen3d.inference_old import attribute_proposal\n",
"b3d.reload(b3d.chisight.gen3d.inference_old)\n",
"b3d.reload(b3d.chisight.gen3d.visualization)\n",
"observed_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 0.4])\n",
"\n",
"observed_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 0.0])\n",
"latent_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 1.0])\n",
"previous_color = jnp.array([0.1, 0.2, 0.3])\n",
"previous_visibility_prob = hyperparams[\"visibility_prob_kernel\"].support[-1]\n",
"previous_dnrp = hyperparams[\"depth_nonreturn_prob_kernel\"].support[0]\n",
"create_interactive_visualization(\n",
" observed_rgbd_for_point,\n",
" latent_rgbd_for_point,\n",
" hyperparams,\n",
" inference_hyperparams,\n",
" previous_color,\n",
" previous_visibility_prob,\n",
" previous_dnrp,\n",
" attribute_proposal\n",
")"
]
},
Expand Down
Loading
Loading