Skip to content

Commit

Permalink
Inference Move API (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
nishadgothoskar authored Sep 16, 2024
1 parent e7775e3 commit 4b8e6fe
Show file tree
Hide file tree
Showing 14 changed files with 760 additions and 333 deletions.
237 changes: 205 additions & 32 deletions notebooks/bayes3d_paper/old_inference_algorithm.ipynb

Large diffs are not rendered by default.

75 changes: 48 additions & 27 deletions notebooks/bayes3d_paper/run_ycbv_evaluation.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python

import os

import b3d
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels as pixel_rgbd_kernels
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import fire
import genjax
Expand All @@ -20,10 +21,13 @@


def run_tracking(scene=None, object=None, debug=False):
import b3d

FRAME_RATE = 50

ycb_dir = os.path.join(b3d.get_assets_path(), "bop/ycbv")
b3d.rr_init("run_ycbv_evaluation")
b3d.utils.rr_init("run_ycbv_evaluation")

ycb_dir = os.path.join(b3d.utils.get_assets_path(), "bop/ycbv")

if scene is None:
scenes = range(48, 60)
Expand All @@ -33,25 +37,29 @@ def run_tracking(scene=None, object=None, debug=False):
scenes = scene

hyperparams = {
"pose_kernel": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),
"pose_kernel": transition_kernels.GaussianVMFPoseDriftKernel(
variance=0.02, concentration=1000.0
),
"color_kernel": transition_kernels.LaplaceNotTruncatedColorDriftKernel(
scale=0.15
scale=0.02
),
"visibility_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
resample_probability=0.05, support=jnp.array([1e-5, 1.0 - 1e-5])
),
"depth_nonreturn_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
resample_probability=0.05, support=jnp.array([1e-5, 1.0 - 1e-5])
),
"depth_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.0025, 0.01, 0.02])
resample_probability=0.05,
support=jnp.array([0.01, 0.005, 0.01, 0.02]),
),
"color_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.05, 0.1, 0.15])
resample_probability=0.05, support=jnp.array([0.001])
),
"image_kernel": image_kernel.NoOcclusionPerVertexImageKernel(
pixel_rgbd_kernels.OldOcclusionPixelRGBDDistribution()
),
"image_likelihood": image_kernel.SimpleNoRenderImageLikelihood(),
}
info_from_trace = hyperparams["image_likelihood"].info_from_trace

for scene_id in scenes:
print(f"Scene {scene_id}")
Expand Down Expand Up @@ -119,25 +127,30 @@ def run_tracking(scene=None, object=None, debug=False):
model_vertices = model_vertices[subset]
model_colors = model_colors[subset]

hyperparams["intrinsics"] = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy,
"image_height": Pytree.const(image_height),
"image_width": Pytree.const(image_width),
"near": 0.01,
"far": 3.0,
}
hyperparams["vertices"] = model_vertices

num_vertices = model_vertices.shape[0]
previous_state = {
"pose": template_pose,
"colors": model_colors,
"visibility_prob": jnp.ones(num_vertices)
* hyperparams["visibility_prob_kernel"].possible_values[-1],
* hyperparams["visibility_prob_kernel"].support[-1],
"depth_nonreturn_prob": jnp.ones(num_vertices)
* hyperparams["depth_nonreturn_prob_kernel"].possible_values[0],
"depth_scale": hyperparams["depth_scale_kernel"].possible_values[0],
"color_scale": hyperparams["color_scale_kernel"].possible_values[0],
* hyperparams["depth_nonreturn_prob_kernel"].support[0],
"depth_scale": hyperparams["depth_scale_kernel"].support[0],
"color_scale": hyperparams["color_scale_kernel"].support[0],
}

hyperparams["vertices"] = model_vertices
hyperparams["fx"] = fx
hyperparams["fy"] = fy
hyperparams["cx"] = cx
hyperparams["cy"] = cy
hyperparams["image_height"] = Pytree.const(image_height)
hyperparams["image_width"] = Pytree.const(image_width)
choicemap = (
genjax.ChoiceMap.d(
{
Expand All @@ -160,12 +173,19 @@ def run_tracking(scene=None, object=None, debug=False):
key, choicemap, (hyperparams, previous_state)
)[0]

from b3d.chisight.gen3d.inference import inference_step
import b3d.chisight.gen3d.inference as inference
import b3d.chisight.gen3d.inference_old as inference_old
import b3d.chisight.gen3d.settings

inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams

### Run inference ###
for T in tqdm(range(len(all_data))):
key = b3d.split_key(key)
trace = inference_step(trace, key, all_data[T]["rgbd"])
trace = inference.advance_time(key, trace, all_data[T]["rgbd"])
trace = inference_old.inference_step(trace, key, inference_hyperparams)[
0
]
tracking_results[T] = trace

if debug:
Expand All @@ -190,14 +210,15 @@ def run_tracking(scene=None, object=None, debug=False):
)

trace = tracking_results[len(all_data) - 1]
info = info_from_trace(trace)
rendered_rgbd = info["latent_rgbd"]
latent_rgb = b3d.chisight.gen3d.image_kernel.get_latent_rgb_image(
trace.get_retval()["new_state"], trace.get_args()[0]
)

a = b3d.viz_rgb(
trace.get_choices()["rgbd"][..., :3],
)
b = b3d.viz_rgb(
rendered_rgbd[..., :3],
latent_rgb[..., :3],
)
b3d.multi_panel([a, b, b3d.overlay_image(a, b)]).save(
f"photo_SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}_POSES.png"
Expand Down
78 changes: 40 additions & 38 deletions notebooks/gen3d/interactive_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"import b3d\n",
"import jax.numpy as jnp\n",
"import pytest\n",
"import matplotlib.pyplot as plt"
"import matplotlib.pyplot as plt\n",
"from genjax import Pytree"
]
},
{
Expand All @@ -23,56 +24,45 @@
"metadata": {},
"outputs": [],
"source": [
"near, far, image_height, image_width = 0.001, 1.0, 480, 640\n",
"img_model = image_kernel.NoOcclusionPerVertexImageKernel(\n",
" near, far, image_height, image_width\n",
")\n",
"\n",
"inference_hyperparams = inference.InferenceHyperparams(\n",
" n_poses=1500,\n",
" do_stochastic_color_proposals=True,\n",
" pose_proposal_std=0.04,\n",
" pose_proposal_conc=1000.,\n",
" prev_color_proposal_laplace_scale=0.001,\n",
" obs_color_proposal_laplace_scale=0.001,\n",
")\n",
"\n",
"hyperparams = {\n",
" \"pose_kernel\": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),\n",
" \"color_kernel\": transition_kernels.LaplaceNotTruncatedColorDriftKernel(\n",
" scale= 0.05\n",
" ),\n",
" \"visibility_prob_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1, support=jnp.array([0.01, 0.99])\n",
" ),\n",
" \"depth_nonreturn_prob_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1, support=jnp.array([0.01, 0.99])\n",
" ),\n",
" \"depth_scale_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1,\n",
" support=jnp.array([0.0025, 0.01, 0.02, 0.1, 0.4, 1.0]),\n",
" ),\n",
" \"color_scale_kernel\": transition_kernels.DiscreteFlipKernel(\n",
" resample_probability=0.1, support=jnp.array([0.002, 0.01, 0.025, 0.05, 0.1, 0.15, 0.3, 0.8])\n",
" ),\n",
" \"image_kernel\": img_model,\n",
"}"
"import b3d.chisight.gen3d.settings \n",
"b3d.reload(b3d.chisight.gen3d.settings)\n",
"import b3d.chisight.gen3d.settings \n",
"inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams\n",
"hyperparams = b3d.chisight.gen3d.settings.hyperparams"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"hyperparams[\"intrinsics\"] = {\n",
" \"fx\": 100.0,\n",
" \"fy\": 100.0,\n",
" \"cx\": 50.0,\n",
" \"cy\": 50.0,\n",
" \"near\": 0.01,\n",
" \"far\": 3.0,\n",
" \"image_width\": Pytree.const(100),\n",
" \"image_height\": Pytree.const(100),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e43892100b2c48e9898f4c3e89e22554",
"model_id": "1212c5d7004143ccbc04e120fe8f44f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(ToggleButtons(description='Prev Vis Prob:', options=('0.01', '0.99'), value='0.01'), Tog"
"interactive(children=(FloatSlider(value=0.10000000149011612, continuous_update=False, description='Observed R:"
]
},
"metadata": {},
Expand All @@ -81,12 +71,24 @@
],
"source": [
"from b3d.chisight.gen3d.visualization import create_interactive_visualization\n",
"from b3d.chisight.gen3d.inference_old import attribute_proposal\n",
"b3d.reload(b3d.chisight.gen3d.inference_old)\n",
"b3d.reload(b3d.chisight.gen3d.visualization)\n",
"observed_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 0.4])\n",
"\n",
"observed_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 0.0])\n",
"latent_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 1.0])\n",
"previous_color = jnp.array([0.1, 0.2, 0.3])\n",
"previous_visibility_prob = hyperparams[\"visibility_prob_kernel\"].support[-1]\n",
"previous_dnrp = hyperparams[\"depth_nonreturn_prob_kernel\"].support[0]\n",
"create_interactive_visualization(\n",
" observed_rgbd_for_point,\n",
" latent_rgbd_for_point,\n",
" hyperparams,\n",
" inference_hyperparams,\n",
" previous_color,\n",
" previous_visibility_prob,\n",
" previous_dnrp,\n",
" attribute_proposal\n",
")"
]
},
Expand Down
Loading

0 comments on commit 4b8e6fe

Please sign in to comment.