diff --git a/demo.py b/demo.py index 7b32f93a..8c434c4a 100755 --- a/demo.py +++ b/demo.py @@ -5,17 +5,18 @@ def test_demo(): import os - import b3d - import b3d.bayes3d as bayes3d import genjax import jax import jax.numpy as jnp import numpy as np import rerun as rr - from b3d import Mesh, Pose from genjax import Pytree from tqdm import tqdm + import b3d + import b3d.bayes3d as bayes3d + from b3d import Mesh, Pose + rr.init("demo") rr.connect("127.0.0.1:8812") diff --git a/demos/detector_segmenter.py b/demos/detector_segmenter.py index 13d73363..2b71bcaa 100644 --- a/demos/detector_segmenter.py +++ b/demos/detector_segmenter.py @@ -1,7 +1,6 @@ import io import os -import b3d import jax import jax.numpy as jnp import numpy @@ -19,6 +18,8 @@ ) from transformers.models.detr.feature_extraction_detr import rgb_to_id +import b3d + processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") diff --git a/demos/differentiable_renderer/dense_fitting/standalone_nishad_demo.py b/demos/differentiable_renderer/dense_fitting/standalone_nishad_demo.py index 6479009b..d97a9d67 100644 --- a/demos/differentiable_renderer/dense_fitting/standalone_nishad_demo.py +++ b/demos/differentiable_renderer/dense_fitting/standalone_nishad_demo.py @@ -3,11 +3,12 @@ import functools import os -import b3d import jax import jax.numpy as jnp import numpy as np import rerun as rr + +import b3d from b3d import Pose rr.init("gradients") diff --git a/demos/differentiable_renderer/gradient_based_pose_estimation.py b/demos/differentiable_renderer/gradient_based_pose_estimation.py index 194985a4..151590bc 100644 --- a/demos/differentiable_renderer/gradient_based_pose_estimation.py +++ b/demos/differentiable_renderer/gradient_based_pose_estimation.py @@ -1,15 +1,16 @@ import os from functools import partial -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering import jax import jax.numpy as jnp import optax import rerun as rr -from b3d import Mesh, Pose from tqdm import tqdm +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering +from b3d import Mesh, Pose + rr.init("gradients") rr.connect("127.0.0.1:8812") diff --git a/demos/differentiable_renderer/gradients_for_mug.py b/demos/differentiable_renderer/gradients_for_mug.py index 39901041..2d4dfdda 100644 --- a/demos/differentiable_renderer/gradients_for_mug.py +++ b/demos/differentiable_renderer/gradients_for_mug.py @@ -1,16 +1,17 @@ import os from functools import partial -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering import jax import jax.numpy as jnp import optax import rerun as rr import trimesh -from b3d import Pose from tqdm import tqdm +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering +from b3d import Pose + rr.init("gradients") rr.connect("127.0.0.1:8812") diff --git a/demos/differentiable_renderer/patch_tracking/demo_utils.py b/demos/differentiable_renderer/patch_tracking/demo_utils.py index 81ed9d9a..fd53fd33 100644 --- a/demos/differentiable_renderer/patch_tracking/demo_utils.py +++ b/demos/differentiable_renderer/patch_tracking/demo_utils.py @@ -1,10 +1,11 @@ import os -import b3d -import b3d.utils as utils import jax import jax.numpy as jnp import trimesh + +import b3d +import b3d.utils as utils from b3d import Pose ### Utils ### diff --git a/demos/differentiable_renderer/patch_tracking/model.py b/demos/differentiable_renderer/patch_tracking/model.py index cb8db286..f1e5ef1c 100644 --- a/demos/differentiable_renderer/patch_tracking/model.py +++ b/demos/differentiable_renderer/patch_tracking/model.py @@ -1,12 +1,12 @@ -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering import genjax import jax import jax.numpy as jnp import rerun as rr -from b3d.modeling_utils import uniform_pose +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering import demos.differentiable_renderer.patch_tracking.demo_utils as utils +from b3d.modeling_utils import uniform_pose def normalize(v): diff --git a/demos/differentiable_renderer/patch_tracking/multiple_patch_tracker_2.py b/demos/differentiable_renderer/patch_tracking/multiple_patch_tracker_2.py index 4c03bf37..110af1c3 100644 --- a/demos/differentiable_renderer/patch_tracking/multiple_patch_tracker_2.py +++ b/demos/differentiable_renderer/patch_tracking/multiple_patch_tracker_2.py @@ -1,11 +1,11 @@ -import b3d -import b3d.chisight.dense.patch_tracking as tracking import numpy as np import rerun as rr -from b3d.chisight.dense.model import rr_log_uniformpose_meshes_to_image_model_trace from tqdm import tqdm +import b3d +import b3d.chisight.dense.patch_tracking as tracking import demos.differentiable_renderer.patch_tracking.demo_utils as du +from b3d.chisight.dense.model import rr_log_uniformpose_meshes_to_image_model_trace rr.init("multiple_patch_tracking_2") rr.connect("127.0.0.1:8812") diff --git a/demos/differentiable_renderer/patch_tracking/multiple_patch_tracking.py b/demos/differentiable_renderer/patch_tracking/multiple_patch_tracking.py index fc74137d..860d2524 100644 --- a/demos/differentiable_renderer/patch_tracking/multiple_patch_tracking.py +++ b/demos/differentiable_renderer/patch_tracking/multiple_patch_tracking.py @@ -2,9 +2,6 @@ import os -import b3d -import b3d.chisight.dense.differentiable_renderer as r -import b3d.chisight.dense.likelihoods as l import genjax import jax import jax.numpy as jnp @@ -12,11 +9,14 @@ import optax import rerun as rr import trimesh -from b3d import Pose from tqdm import tqdm +import b3d +import b3d.chisight.dense.differentiable_renderer as r +import b3d.chisight.dense.likelihoods as l import demos.differentiable_renderer.patch_tracking.demo_utils as du import demos.differentiable_renderer.patch_tracking.model as m +from b3d import Pose rr.init("multiple_patch_tracking") rr.connect("127.0.0.1:8812") diff --git a/demos/differentiable_renderer/patch_tracking/single_patch_tracking_adam.py b/demos/differentiable_renderer/patch_tracking/single_patch_tracking_adam.py index 9e716b88..38692d44 100644 --- a/demos/differentiable_renderer/patch_tracking/single_patch_tracking_adam.py +++ b/demos/differentiable_renderer/patch_tracking/single_patch_tracking_adam.py @@ -1,17 +1,17 @@ ### Preliminaries ### -import b3d.chisight.dense.differentiable_renderer as r -import b3d.chisight.dense.likelihoods as l import genjax import jax import jax.numpy as jnp import optax import rerun as rr -from b3d import Pose from tqdm import tqdm +import b3d.chisight.dense.differentiable_renderer as r +import b3d.chisight.dense.likelihoods as l import demos.differentiable_renderer.patch_tracking.demo_utils as du import demos.differentiable_renderer.patch_tracking.model as m +from b3d import Pose rr.init("single_patch_tracking") rr.connect("127.0.0.1:8812") diff --git a/demos/differentiable_renderer/patch_tracking/single_patch_tracking_mh.py b/demos/differentiable_renderer/patch_tracking/single_patch_tracking_mh.py index 2d2d81cd..ee846165 100644 --- a/demos/differentiable_renderer/patch_tracking/single_patch_tracking_mh.py +++ b/demos/differentiable_renderer/patch_tracking/single_patch_tracking_mh.py @@ -2,18 +2,18 @@ import time -import b3d -import b3d.chisight.dense.differentiable_renderer as r -import b3d.chisight.dense.likelihoods as l import genjax import jax import jax.numpy as jnp import rerun as rr -from b3d import Pose from tqdm import tqdm +import b3d +import b3d.chisight.dense.differentiable_renderer as r +import b3d.chisight.dense.likelihoods as l import demos.differentiable_renderer.patch_tracking.demo_utils as du import demos.differentiable_renderer.patch_tracking.model as m +from b3d import Pose rr.init("single_patch_tracking-mh") rr.connect("127.0.0.1:8812") diff --git a/demos/differentiable_renderer/test_barycentric_interp.py b/demos/differentiable_renderer/test_barycentric_interp.py index 0f0e264e..90d65d7a 100644 --- a/demos/differentiable_renderer/test_barycentric_interp.py +++ b/demos/differentiable_renderer/test_barycentric_interp.py @@ -1,9 +1,10 @@ -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering -import b3d.chisight.dense.likelihoods as likelihoods import jax import jax.numpy as jnp import rerun as rr + +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering +import b3d.chisight.dense.likelihoods as likelihoods from b3d import Pose # Set up OpenGL renderer diff --git a/demos/differentiable_renderer/test_basic_gd.py b/demos/differentiable_renderer/test_basic_gd.py index c1e1fa10..d4ef6e9f 100644 --- a/demos/differentiable_renderer/test_basic_gd.py +++ b/demos/differentiable_renderer/test_basic_gd.py @@ -1,13 +1,13 @@ import time -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering -import b3d.chisight.dense.likelihoods as likelihoods import genjax import jax import jax.numpy as jnp import rerun as rr +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering +import b3d.chisight.dense.likelihoods as likelihoods import demos.differentiable_renderer.utils as utils # Set up OpenGL renderer diff --git a/demos/differentiable_renderer/utils.py b/demos/differentiable_renderer/utils.py index 70731d9d..66b70956 100644 --- a/demos/differentiable_renderer/utils.py +++ b/demos/differentiable_renderer/utils.py @@ -1,6 +1,7 @@ -import b3d import jax import jax.numpy as jnp + +import b3d from b3d import Pose diff --git a/demos/fork_knife_smc_identity_pose.py b/demos/fork_knife_smc_identity_pose.py index 65d88599..00a33c99 100644 --- a/demos/fork_knife_smc_identity_pose.py +++ b/demos/fork_knife_smc_identity_pose.py @@ -1,15 +1,16 @@ import os -import b3d import genjax import jax import jax.numpy as jnp import numpy as np import rerun as rr import trimesh -from b3d import Pose from tqdm import tqdm +import b3d +from b3d import Pose + ### Choose experiment INPUT = "fork-visible" # TODO make one dataset with both objects? diff --git a/demos/graphics_edits_demo/demo_visualize.py b/demos/graphics_edits_demo/demo_visualize.py index ca0f5df8..36975adf 100644 --- a/demos/graphics_edits_demo/demo_visualize.py +++ b/demos/graphics_edits_demo/demo_visualize.py @@ -2,15 +2,16 @@ import os import pickle -import b3d import jax import jax.numpy as jnp import numpy as np import rerun as rr import trimesh -from b3d import Pose from tqdm import tqdm +import b3d +from b3d import Pose + rr.init("demo_visualize3") rr.connect("127.0.0.1:8812") diff --git a/demos/graphics_edits_demo/vkm_demo.py b/demos/graphics_edits_demo/vkm_demo.py index a8b8b8d4..db82411d 100644 --- a/demos/graphics_edits_demo/vkm_demo.py +++ b/demos/graphics_edits_demo/vkm_demo.py @@ -1,11 +1,12 @@ import os -import b3d import jax import jax.numpy as jnp import numpy as np import rerun as rr import trimesh + +import b3d from b3d import Pose rr.init("vkm_demo2") diff --git a/demos/mesh_fitting/demo.py b/demos/mesh_fitting/demo.py index afe8af1d..d441e2f0 100644 --- a/demos/mesh_fitting/demo.py +++ b/demos/mesh_fitting/demo.py @@ -1,6 +1,5 @@ import os -import b3d import genjax import jax import jax.numpy as jnp @@ -8,6 +7,7 @@ import rerun as rr from tqdm import tqdm +import b3d import demos.mesh_fitting.model as m import demos.mesh_fitting.tessellation as t diff --git a/demos/mesh_fitting/demo_depth_init.py b/demos/mesh_fitting/demo_depth_init.py index b3907273..e680c44f 100644 --- a/demos/mesh_fitting/demo_depth_init.py +++ b/demos/mesh_fitting/demo_depth_init.py @@ -1,10 +1,10 @@ import os -import b3d import genjax import jax import rerun as rr +import b3d import demos.mesh_fitting.model as m import demos.mesh_fitting.utils as u diff --git a/demos/mesh_fitting/model.py b/demos/mesh_fitting/model.py index bad24cfc..ad6f5f64 100644 --- a/demos/mesh_fitting/model.py +++ b/demos/mesh_fitting/model.py @@ -1,9 +1,10 @@ -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering import genjax import jax import jax.numpy as jnp import rerun as rr + +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering from b3d.modeling_utils import uniform_pose diff --git a/demos/mesh_fitting/utils.py b/demos/mesh_fitting/utils.py index 1eacd42a..b2c67a17 100644 --- a/demos/mesh_fitting/utils.py +++ b/demos/mesh_fitting/utils.py @@ -1,7 +1,7 @@ -import b3d import jax import jax.numpy as jnp +import b3d import demos.mesh_fitting.tessellation as t diff --git a/demos/posterior_datasets/identity_posterior_data_gen.py b/demos/posterior_datasets/identity_posterior_data_gen.py index bd0da975..78fcb1b3 100644 --- a/demos/posterior_datasets/identity_posterior_data_gen.py +++ b/demos/posterior_datasets/identity_posterior_data_gen.py @@ -1,12 +1,13 @@ import os -import b3d import jax import jax.numpy as jnp import numpy as np import rerun as rr import trimesh +import b3d + width = 200 height = 200 fx = 300.0 diff --git a/demos/posterior_datasets/pose_posterior_data_gen.py b/demos/posterior_datasets/pose_posterior_data_gen.py index cf678469..97848057 100644 --- a/demos/posterior_datasets/pose_posterior_data_gen.py +++ b/demos/posterior_datasets/pose_posterior_data_gen.py @@ -1,10 +1,11 @@ import os -import b3d import jax import jax.numpy as jnp import rerun as rr import trimesh + +import b3d from b3d import Pose rr.init("demo") diff --git a/demos/sparse_model/cotracker.py b/demos/sparse_model/cotracker.py index cf5baeb6..66d7beba 100644 --- a/demos/sparse_model/cotracker.py +++ b/demos/sparse_model/cotracker.py @@ -1,10 +1,11 @@ import argparse import time -import b3d import numpy as np import torch +import b3d + parser = argparse.ArgumentParser("r3d_to_video_input") parser.add_argument("input", help=".r3d File", type=str) args = parser.parse_args() diff --git a/demos/sparse_model/sparse_model.py b/demos/sparse_model/sparse_model.py index 836870c2..99fc181a 100644 --- a/demos/sparse_model/sparse_model.py +++ b/demos/sparse_model/sparse_model.py @@ -1,7 +1,6 @@ import os from functools import partial -import b3d import jax import jax.numpy as jnp import optax @@ -9,6 +8,8 @@ import trimesh from tqdm import tqdm +import b3d + def map_nested_fn(fn): """Recursively apply `fn` to the key-value pairs of a nested dict.""" diff --git a/demos/sparse_model/sparse_model_cotracker.py b/demos/sparse_model/sparse_model_cotracker.py index bb691512..dba3b049 100644 --- a/demos/sparse_model/sparse_model_cotracker.py +++ b/demos/sparse_model/sparse_model_cotracker.py @@ -1,7 +1,6 @@ import os from functools import partial -import b3d import jax import jax.numpy as jnp import numpy as np @@ -10,6 +9,8 @@ from matplotlib import colormaps from tqdm import tqdm +import b3d + def map_nested_fn(fn): """Recursively apply `fn` to the key-value pairs of a nested dict.""" diff --git a/demos/sparse_model/sparse_model_torch.py b/demos/sparse_model/sparse_model_torch.py index 17f57ebd..64d06d02 100644 --- a/demos/sparse_model/sparse_model_torch.py +++ b/demos/sparse_model/sparse_model_torch.py @@ -1,12 +1,13 @@ import os -import b3d import pytorch3d.transforms import rerun as rr import torch import trimesh from tqdm import tqdm +import b3d + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/demos/speed_comparisons/demo_splat.py b/demos/speed_comparisons/demo_splat.py index 8d6e69e5..87f33b09 100644 --- a/demos/speed_comparisons/demo_splat.py +++ b/demos/speed_comparisons/demo_splat.py @@ -1,14 +1,15 @@ import os -import b3d import jax import jax.numpy as jnp import numpy as np import rerun as rr -from b3d import Pose from diff_gaussian_rasterization import rasterize_with_depth from tqdm import tqdm +import b3d +from b3d import Pose + rr.init("demo") rr.connect("127.0.0.1:8812") diff --git a/demos/speed_comparisons/demo_torch.py b/demos/speed_comparisons/demo_torch.py index f0178ea9..44b50522 100644 --- a/demos/speed_comparisons/demo_torch.py +++ b/demos/speed_comparisons/demo_torch.py @@ -2,16 +2,17 @@ import os import time -import b3d -import b3d.nvdiffrast_original.torch as dr -import b3d.torch -import b3d.torch.renderutils import pytorch3d.transforms import rerun as rr import torch import trimesh from tqdm import tqdm +import b3d +import b3d.nvdiffrast_original.torch as dr +import b3d.torch +import b3d.torch.renderutils + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_default_device("cuda") rr.init("demo") diff --git a/demos/speed_comparisons/demo_torch_vmap.py b/demos/speed_comparisons/demo_torch_vmap.py index 0232943b..402a7dee 100644 --- a/demos/speed_comparisons/demo_torch_vmap.py +++ b/demos/speed_comparisons/demo_torch_vmap.py @@ -2,8 +2,6 @@ import os import time -import b3d -import b3d.nvdiffrast_original.torch as dr import pytorch3d.transforms import rerun as rr import torch @@ -11,6 +9,9 @@ import trimesh from tqdm import tqdm +import b3d +import b3d.nvdiffrast_original.torch as dr + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rr.init("demo") rr.connect("127.0.0.1:8812") diff --git a/demos/speed_comparisons/demo_tracking_jax.py b/demos/speed_comparisons/demo_tracking_jax.py index c862c138..c7961b51 100644 --- a/demos/speed_comparisons/demo_tracking_jax.py +++ b/demos/speed_comparisons/demo_tracking_jax.py @@ -1,13 +1,14 @@ import os import time -import b3d import jax import jax.numpy as jnp import trimesh -from b3d import Pose from scipy.spatial.transform import Rotation as R +import b3d +from b3d import Pose + height = 100 width = 100 fx = 200.0 diff --git a/demos/test_nvdiffrast_original.py b/demos/test_nvdiffrast_original.py index f7d8c2ad..185e78f4 100644 --- a/demos/test_nvdiffrast_original.py +++ b/demos/test_nvdiffrast_original.py @@ -1,10 +1,11 @@ import os import time -import b3d import jax import jax.numpy as jnp import trimesh + +import b3d from b3d.renderer_original import RendererOriginal width = 200 diff --git a/demos/test_object_search.py b/demos/test_object_search.py index a807341f..60b491de 100644 --- a/demos/test_object_search.py +++ b/demos/test_object_search.py @@ -1,15 +1,16 @@ #!/usr/bin/env python import os -import b3d import genjax import jax import jax.numpy as jnp import numpy as np import rerun as rr -from b3d import Pose from tqdm import tqdm +import b3d +from b3d import Pose + PORT = 8812 rr.init("mug sm2c inference") rr.connect(addr=f"127.0.0.1:{PORT}") diff --git a/demos/test_panda.py b/demos/test_panda.py index 4d72bc16..cfcdae47 100644 --- a/demos/test_panda.py +++ b/demos/test_panda.py @@ -1,15 +1,16 @@ import os import pickle -import b3d import genjax import jax import jax.numpy as jnp import rerun as rr import trimesh -from b3d import Pose from tqdm import tqdm +import b3d +from b3d import Pose + PORT = 8812 rr.init("mug sm2c inference") rr.connect(addr=f"127.0.0.1:{PORT}") diff --git a/demos/test_renderer_fps.py b/demos/test_renderer_fps.py index 672f4ab4..f4ccebe2 100644 --- a/demos/test_renderer_fps.py +++ b/demos/test_renderer_fps.py @@ -1,14 +1,15 @@ import os import time -import b3d -import b3d.nvdiffrast_original.torch as dr import jax import jax.numpy as jnp import numpy as np import rerun as rr import torch import trimesh + +import b3d +import b3d.nvdiffrast_original.torch as dr from b3d.renderer_original import Renderer as RendererOriginal device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/demos/test_torch/test_pose.py b/demos/test_torch/test_pose.py index 577440f8..909ad707 100644 --- a/demos/test_torch/test_pose.py +++ b/demos/test_torch/test_pose.py @@ -1,4 +1,5 @@ import torch + from b3d.renderer.torch.pose import Pose diff --git a/demos/tracking_online_learning.py b/demos/tracking_online_learning.py index 2bde0719..c455133f 100644 --- a/demos/tracking_online_learning.py +++ b/demos/tracking_online_learning.py @@ -1,15 +1,16 @@ import os from functools import partial -import b3d import genjax import jax import jax.numpy as jnp import numpy as np import rerun as rr -from b3d import Pose from tqdm import tqdm +import b3d +from b3d import Pose + # Rerun setup PORT = 8812 rr.init("online_learning") diff --git a/notebooks/bayes3d_paper/kitti_data.py b/notebooks/bayes3d_paper/kitti_data.py index 5aa3bea8..abe1c22b 100644 --- a/notebooks/bayes3d_paper/kitti_data.py +++ b/notebooks/bayes3d_paper/kitti_data.py @@ -1,12 +1,13 @@ import os -import b3d import jax.numpy as jnp import numpy as np import pykitti from segment_anything import SamAutomaticMaskGenerator, sam_model_registry from tqdm import tqdm +import b3d + b3d.rr_init("kitti") basedir = os.path.join(b3d.get_assets_path(), "kitti") diff --git a/notebooks/bayes3d_paper/run_ycbv_evaluation.py b/notebooks/bayes3d_paper/run_ycbv_evaluation.py index 89e6c7e3..99522f82 100644 --- a/notebooks/bayes3d_paper/run_ycbv_evaluation.py +++ b/notebooks/bayes3d_paper/run_ycbv_evaluation.py @@ -6,14 +6,15 @@ def run_tracking(scene=None, object=None, debug=False): import importlib import os - import b3d import genjax import jax import jax.numpy as jnp - from b3d import Mesh, Pose from genjax import Pytree from tqdm import tqdm + import b3d + from b3d import Mesh, Pose + importlib.reload(b3d.mesh) importlib.reload(b3d.io.data_loader) importlib.reload(b3d.utils) diff --git a/notebooks/bayes3d_paper/visualize_tracking.py b/notebooks/bayes3d_paper/visualize_tracking.py index 43c3cc8c..73c04820 100644 --- a/notebooks/bayes3d_paper/visualize_tracking.py +++ b/notebooks/bayes3d_paper/visualize_tracking.py @@ -6,11 +6,12 @@ def make_visual(scene=None, object=None, debug=False): import importlib import os - import b3d import jax.numpy as jnp - from b3d import Mesh, Pose from tqdm import tqdm + import b3d + from b3d import Mesh, Pose + importlib.reload(b3d.mesh) importlib.reload(b3d.io.data_loader) importlib.reload(b3d.utils) diff --git a/notebooks/integration_dense.py b/notebooks/integration_dense.py index 12b82ace..ff84ba3f 100644 --- a/notebooks/integration_dense.py +++ b/notebooks/integration_dense.py @@ -1,17 +1,18 @@ import importlib import os -import b3d -import b3d.chisight.particle_system as ps import jax import jax.numpy as jnp +from genjax import Pytree + +import b3d +import b3d.chisight.particle_system as ps from b3d import Mesh, Pose from b3d.chisight.dense.likelihoods import ( KRaysImageLikelihoodArgs, make_krays_image_observation_model, ) from b3d.renderer.renderer_original import RendererOriginal -from genjax import Pytree importlib.reload(ps) diff --git a/notebooks/multi_particle_visualization.py b/notebooks/multi_particle_visualization.py index 89f7d700..e8fd8bdb 100644 --- a/notebooks/multi_particle_visualization.py +++ b/notebooks/multi_particle_visualization.py @@ -1,6 +1,5 @@ import os -import b3d import genjax import jax import jax.numpy as jnp @@ -9,9 +8,11 @@ # from b3d.utils import unproject_depth import rerun as rr import trimesh -from b3d import Pose from tqdm import tqdm +import b3d +from b3d import Pose + rr.init("demo.py") rr.connect("127.0.0.1:8812") rr.save("multi_particle_visualization.rrd") diff --git a/scripts/acquire_object_model.py b/scripts/acquire_object_model.py index b9373644..69bab99b 100644 --- a/scripts/acquire_object_model.py +++ b/scripts/acquire_object_model.py @@ -1,12 +1,13 @@ import argparse import time -import b3d import jax import jax.numpy as jnp -from b3d import Mesh, Pose from tqdm import tqdm +import b3d +from b3d import Mesh, Pose + b3d.rr_init("acquire_object_model") # python scripts/acquire_object_model.py assets/shared_data_bucket/input_data/lysol_static.r3d diff --git a/scripts/cotracker.py b/scripts/cotracker.py index 684b203c..7262cc3f 100644 --- a/scripts/cotracker.py +++ b/scripts/cotracker.py @@ -1,9 +1,10 @@ import time from pathlib import Path -import b3d import numpy as np import torch + +import b3d from b3d.io import FeatureTrackData from b3d.io.utils import add_argparse, path_stem diff --git a/scripts/deformable_demo.py b/scripts/deformable_demo.py index 48356c3d..2e3d75dd 100644 --- a/scripts/deformable_demo.py +++ b/scripts/deformable_demo.py @@ -6,12 +6,13 @@ import numpy as np import optax import rerun as rr +from jax.scipy.spatial.transform import Rotation as Rot +from sklearn.utils import Bunch + from b3d.chisight.sparse.gps_utils import cov_from_dq_composition from b3d.io import MeshData from b3d.pose import Pose from b3d.utils import keysplit -from jax.scipy.spatial.transform import Rotation as Rot -from sklearn.utils import Bunch # ************************** diff --git a/scripts/high_quality_object_mesh_aquisition.py b/scripts/high_quality_object_mesh_aquisition.py index d9ece43a..39c1d0b9 100644 --- a/scripts/high_quality_object_mesh_aquisition.py +++ b/scripts/high_quality_object_mesh_aquisition.py @@ -1,6 +1,7 @@ -import b3d import jax.numpy as jnp +import b3d + b3d.rr_init("high_quality") # python scripts/acquire_object_model.py assets/shared_data_bucket/input_data/lysol_static.r3d diff --git a/scripts/r3d_to_frames_and_mask.py b/scripts/r3d_to_frames_and_mask.py index a5e40e6e..d80ad4ef 100644 --- a/scripts/r3d_to_frames_and_mask.py +++ b/scripts/r3d_to_frames_and_mask.py @@ -1,7 +1,6 @@ from pathlib import Path from typing import Optional -import b3d import cv2 import fire import jax @@ -10,6 +9,8 @@ import trimesh from r3d_to_video_input import load_r3d_video_input +import b3d + def get_masks(rgb_imgs: jax.Array) -> jax.Array: masks = [b3d.carvekit_get_foreground_mask(img) for img in rgb_imgs] diff --git a/scripts/r3d_to_video_input.py b/scripts/r3d_to_video_input.py index 86b68ffa..17a91de7 100644 --- a/scripts/r3d_to_video_input.py +++ b/scripts/r3d_to_video_input.py @@ -5,13 +5,14 @@ import subprocess from pathlib import Path -import b3d import cv2 import jax import jax.numpy as jnp import liblzfse # https://pypi.org/project/pyliblzfse/ from natsort import natsorted +import b3d + def load_depth(filepath): with open(filepath, "rb") as depth_fh: diff --git a/scripts/visualize.py b/scripts/visualize.py index be546770..464810fd 100644 --- a/scripts/visualize.py +++ b/scripts/visualize.py @@ -1,7 +1,8 @@ -import b3d import fire import jax import jax.numpy as jnp + +import b3d from b3d import Pose diff --git a/src/b3d/chisight/dense/dense_only_patch_tracking/model.py b/src/b3d/chisight/dense/dense_only_patch_tracking/model.py index 67587bc6..a34761bc 100644 --- a/src/b3d/chisight/dense/dense_only_patch_tracking/model.py +++ b/src/b3d/chisight/dense/dense_only_patch_tracking/model.py @@ -1,10 +1,11 @@ -import b3d.chisight.dense.differentiable_renderer as rendering -import b3d.utils as utils import genjax import jax import jax.numpy as jnp import numpy as np import rerun as rr + +import b3d.chisight.dense.differentiable_renderer as rendering +import b3d.utils as utils from b3d import Pose from b3d.modeling_utils import uniform_pose diff --git a/src/b3d/chisight/dense/dense_only_patch_tracking/patch_tracking.py b/src/b3d/chisight/dense/dense_only_patch_tracking/patch_tracking.py index 6dcc3010..a1294a09 100644 --- a/src/b3d/chisight/dense/dense_only_patch_tracking/patch_tracking.py +++ b/src/b3d/chisight/dense/dense_only_patch_tracking/patch_tracking.py @@ -1,14 +1,15 @@ +import genjax +import jax +import jax.numpy as jnp +import optax +from genjax import ChoiceMapBuilder as C + import b3d import b3d.chisight.dense.dense_only_patch_tracking.model as m import b3d.chisight.dense.differentiable_renderer import b3d.chisight.dense.differentiable_renderer as r import b3d.chisight.dense.likelihoods as likelihoods -import genjax -import jax -import jax.numpy as jnp -import optax from b3d import Pose -from genjax import ChoiceMapBuilder as C def get_patches(centers, rgbds, X_WC, fx, fy, cx, cy): diff --git a/tests/conftest.py b/tests/conftest.py index c133e9f4..2e231a87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ -import b3d import pytest +import b3d + # Arrange @pytest.fixture diff --git a/tests/dense_model_unit_tests/triangle_depth_posterior/solver/importance.py b/tests/dense_model_unit_tests/triangle_depth_posterior/solver/importance.py index 911b4e04..55fc987d 100644 --- a/tests/dense_model_unit_tests/triangle_depth_posterior/solver/importance.py +++ b/tests/dense_model_unit_tests/triangle_depth_posterior/solver/importance.py @@ -1,9 +1,10 @@ -import b3d import genjax import jax import jax.numpy as jnp from genjax import ChoiceMapBuilder as C +import b3d + from ....common.solver import Solver from .model import get_likelihood, model_factory, rr_log_trace diff --git a/tests/dense_model_unit_tests/triangle_depth_posterior/solver/model.py b/tests/dense_model_unit_tests/triangle_depth_posterior/solver/model.py index 035814a9..094203c4 100644 --- a/tests/dense_model_unit_tests/triangle_depth_posterior/solver/model.py +++ b/tests/dense_model_unit_tests/triangle_depth_posterior/solver/model.py @@ -1,11 +1,12 @@ -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering -import b3d.chisight.dense.likelihoods as likelihoods import genjax import jax import jax.numpy as jnp import rerun as rr +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering +import b3d.chisight.dense.likelihoods as likelihoods + def normalize(weights): return weights / jnp.sum(weights) diff --git a/tests/dense_model_unit_tests/triangle_depth_posterior/task.py b/tests/dense_model_unit_tests/triangle_depth_posterior/task.py index c5bb9de2..e6a93726 100644 --- a/tests/dense_model_unit_tests/triangle_depth_posterior/task.py +++ b/tests/dense_model_unit_tests/triangle_depth_posterior/task.py @@ -1,10 +1,11 @@ -import b3d -import b3d.chisight.dense.differentiable_renderer as differentiable_renderer import jax import jax.numpy as jnp import matplotlib.pyplot as plt import rerun as rr +import b3d +import b3d.chisight.dense.differentiable_renderer as differentiable_renderer + from ...common.task import Task diff --git a/tests/dlpack.py b/tests/dlpack.py index 9a812a07..54bcd38e 100644 --- a/tests/dlpack.py +++ b/tests/dlpack.py @@ -1,6 +1,5 @@ import os -import b3d import jax import jax.numpy as jnp import numpy as np @@ -8,6 +7,8 @@ import torch import trimesh +import b3d + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mesh_path = os.path.join( diff --git a/tests/dynamic_object_model/kfold_image_kernel_real_data.py b/tests/dynamic_object_model/kfold_image_kernel_real_data.py index d4154169..77943648 100644 --- a/tests/dynamic_object_model/kfold_image_kernel_real_data.py +++ b/tests/dynamic_object_model/kfold_image_kernel_real_data.py @@ -2,10 +2,11 @@ import os -import b3d -import b3d.chisight.dynamic_object_model.kfold_image_kernel as kik import jax import jax.numpy as jnp + +import b3d +import b3d.chisight.dynamic_object_model.kfold_image_kernel as kik from b3d import Mesh b3d.rr_init("kfold_image_kernel2") diff --git a/tests/dynamic_object_model/kfold_image_kernel_unit_test.py b/tests/dynamic_object_model/kfold_image_kernel_unit_test.py index f2465bec..ed0a2d89 100644 --- a/tests/dynamic_object_model/kfold_image_kernel_unit_test.py +++ b/tests/dynamic_object_model/kfold_image_kernel_unit_test.py @@ -1,10 +1,11 @@ import importlib -import b3d -import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kik import jax.numpy as jnp from jax.random import PRNGKey +import b3d +import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kik + importlib.reload(kik) diff --git a/tests/dynamic_object_model/test_dynamic_object_model.py b/tests/dynamic_object_model/test_dynamic_object_model.py index 92053858..ef4ba1bc 100644 --- a/tests/dynamic_object_model/test_dynamic_object_model.py +++ b/tests/dynamic_object_model/test_dynamic_object_model.py @@ -1,13 +1,14 @@ ### IMPORTS ### -import b3d import jax import jax.numpy as jnp -from b3d import Pose from genjax import ChoiceMapBuilder as C from genjax import Pytree +import b3d +from b3d import Pose + b3d.reload(b3d.chisight.dynamic_object_model) diff --git a/tests/dynamic_object_model/test_pixel_distribution.py b/tests/dynamic_object_model/test_pixel_distribution.py index a29828d9..bced9ca7 100644 --- a/tests/dynamic_object_model/test_pixel_distribution.py +++ b/tests/dynamic_object_model/test_pixel_distribution.py @@ -1,9 +1,10 @@ import importlib -import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kik import jax import jax.numpy as jnp +import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kik + importlib.reload(kik) diff --git a/tests/dynamic_object_model/test_raycast_nondeterministic.py b/tests/dynamic_object_model/test_raycast_nondeterministic.py index 52bccb4b..3e5bc361 100644 --- a/tests/dynamic_object_model/test_raycast_nondeterministic.py +++ b/tests/dynamic_object_model/test_raycast_nondeterministic.py @@ -1,10 +1,11 @@ import importlib -import b3d -import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kfk import jax.numpy as jnp from jax.random import PRNGKey, split +import b3d +import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kfk + importlib.reload(kfk) diff --git a/tests/dynamic_object_model/test_truncated_laplace.py b/tests/dynamic_object_model/test_truncated_laplace.py index 3c152680..f8a61ffb 100644 --- a/tests/dynamic_object_model/test_truncated_laplace.py +++ b/tests/dynamic_object_model/test_truncated_laplace.py @@ -1,7 +1,8 @@ -import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kik import jax import jax.numpy as jnp +import b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel as kik + # importlib.reload(kik) # loc, scale, low, high, uniform_window_size = 0.0, 0.01, 0.0, 1.0, 0.1 # n_grid_steps = 1000 diff --git a/tests/gen3d/test_pixel_color_kernels.py b/tests/gen3d/test_pixel_color_kernels.py index 3a968ec9..d0db725f 100644 --- a/tests/gen3d/test_pixel_color_kernels.py +++ b/tests/gen3d/test_pixel_color_kernels.py @@ -3,6 +3,8 @@ import jax import jax.numpy as jnp import pytest +from genjax.typing import FloatArray + from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import ( COLOR_MAX_VAL, COLOR_MIN_VAL, @@ -11,7 +13,6 @@ TruncatedLaplacePixelColorDistribution, UniformPixelColorDistribution, ) -from genjax.typing import FloatArray @partial(jax.jit, static_argnums=(0,)) diff --git a/tests/gen3d/test_pixel_depth_kernels.py b/tests/gen3d/test_pixel_depth_kernels.py index 45e06f9e..ec9415ff 100644 --- a/tests/gen3d/test_pixel_depth_kernels.py +++ b/tests/gen3d/test_pixel_depth_kernels.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp import pytest + from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( DEPTH_NONRETURN_VAL, UNEXPLAINED_DEPTH_NONRETURN_PROB, diff --git a/tests/gen3d/test_pixel_rgbd_kernels.py b/tests/gen3d/test_pixel_rgbd_kernels.py index f42da130..87895ff6 100644 --- a/tests/gen3d/test_pixel_rgbd_kernels.py +++ b/tests/gen3d/test_pixel_rgbd_kernels.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp import pytest + from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import ( FullPixelColorDistribution, ) diff --git a/tests/image_likelihood/image_likelihood_tests.py b/tests/image_likelihood/image_likelihood_tests.py index 06c40214..871688ca 100644 --- a/tests/image_likelihood/image_likelihood_tests.py +++ b/tests/image_likelihood/image_likelihood_tests.py @@ -1,11 +1,12 @@ import os -import b3d import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import trimesh + +import b3d from b3d import Pose diff --git a/tests/image_likelihood/run_tests.py b/tests/image_likelihood/run_tests.py index df55f8e4..c2766c60 100644 --- a/tests/image_likelihood/run_tests.py +++ b/tests/image_likelihood/run_tests.py @@ -1,13 +1,7 @@ from functools import partial -import b3d import jax import jax.numpy as jnp -from b3d.chisight.dense.likelihoods.image_likelihoods import ( - gaussian_iid_pix_likelihood, - kray_likelihood_intermediate, - threedp3_gmm_likelihood, -) from image_likelihood_tests import ( mug_posterior_peakiness_samples, test_distance_invariance, @@ -16,6 +10,13 @@ test_resolution_invariance, ) +import b3d +from b3d.chisight.dense.likelihoods.image_likelihoods import ( + gaussian_iid_pix_likelihood, + kray_likelihood_intermediate, + threedp3_gmm_likelihood, +) + # set up latent image as likelihood arg def rgbd_latent_likelihood(likelihood, observed_rgbd, rendered_rgbd, likelihood_args): diff --git a/tests/sama4d/data_curation.py b/tests/sama4d/data_curation.py index a1a56a73..09faa2c6 100644 --- a/tests/sama4d/data_curation.py +++ b/tests/sama4d/data_curation.py @@ -1,10 +1,11 @@ import os -import b3d import jax import jax.numpy as jnp import trimesh +import b3d + def get_loaders_for_all_curated_scenes(): """ diff --git a/tests/sama4d/tracks_to_segmentation/keypoints_to_segmentation_task.py b/tests/sama4d/tracks_to_segmentation/keypoints_to_segmentation_task.py index 524ed1d1..baf4dacd 100644 --- a/tests/sama4d/tracks_to_segmentation/keypoints_to_segmentation_task.py +++ b/tests/sama4d/tracks_to_segmentation/keypoints_to_segmentation_task.py @@ -1,8 +1,8 @@ from typing import Callable -import b3d import jax.numpy as jnp +import b3d from tests.common.task import Task diff --git a/tests/sama4d/video_to_tracks/from_initialization/keypoint_tracking_task.py b/tests/sama4d/video_to_tracks/from_initialization/keypoint_tracking_task.py index 41f6613a..2226edb5 100644 --- a/tests/sama4d/video_to_tracks/from_initialization/keypoint_tracking_task.py +++ b/tests/sama4d/video_to_tracks/from_initialization/keypoint_tracking_task.py @@ -1,11 +1,11 @@ from typing import Callable -import b3d import jax.numpy as jnp import numpy as np import rerun as rr import rerun.blueprint as rrb +import b3d from tests.common.task import Task diff --git a/tests/sama4d/video_to_tracks/from_initialization/registration.py b/tests/sama4d/video_to_tracks/from_initialization/registration.py index d5fb161e..f6d52eb1 100644 --- a/tests/sama4d/video_to_tracks/from_initialization/registration.py +++ b/tests/sama4d/video_to_tracks/from_initialization/registration.py @@ -2,9 +2,10 @@ This file registers a default set of tasks and solvers for the video to keypoint tracks task class. """ -import b3d import jax.numpy as jnp +import b3d + from ...data_curation import get_loaders_for_all_curated_scenes from .keypoint_tracking_task import KeypointTrackingTask from .solvers.particle_system_patch_tracking_solver import ( diff --git a/tests/sama4d/video_to_tracks/from_initialization/solvers/dense_only_patch_tracking_solver.py b/tests/sama4d/video_to_tracks/from_initialization/solvers/dense_only_patch_tracking_solver.py index d69d1eae..6b58ca49 100644 --- a/tests/sama4d/video_to_tracks/from_initialization/solvers/dense_only_patch_tracking_solver.py +++ b/tests/sama4d/video_to_tracks/from_initialization/solvers/dense_only_patch_tracking_solver.py @@ -1,12 +1,12 @@ -import b3d -import b3d.chisight.dense.dense_only_patch_tracking.patch_tracking as tracking import jax.numpy as jnp import rerun as rr + +import b3d +import b3d.chisight.dense.dense_only_patch_tracking.patch_tracking as tracking from b3d import Pose from b3d.chisight.dense.dense_only_patch_tracking.model import ( rr_log_uniformpose_meshes_to_image_model_trace, ) - from tests.common.solver import Solver diff --git a/tests/sama4d/video_to_tracks/from_initialization/solvers/particle_system_patch_tracking_solver.py b/tests/sama4d/video_to_tracks/from_initialization/solvers/particle_system_patch_tracking_solver.py index fadfab85..02895172 100644 --- a/tests/sama4d/video_to_tracks/from_initialization/solvers/particle_system_patch_tracking_solver.py +++ b/tests/sama4d/video_to_tracks/from_initialization/solvers/particle_system_patch_tracking_solver.py @@ -1,11 +1,11 @@ -import b3d -import b3d.chisight.dense.differentiable_renderer as diffrend -import b3d.chisight.patch_tracking as tracking import jax import jax.numpy as jnp import rerun as rr -from b3d import Pose +import b3d +import b3d.chisight.dense.differentiable_renderer as diffrend +import b3d.chisight.patch_tracking as tracking +from b3d import Pose from tests.common.solver import Solver diff --git a/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker.py b/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker.py index be06f6bb..52d70d18 100644 --- a/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker.py +++ b/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker.py @@ -1,4 +1,3 @@ -import b3d import jax import jax.numpy as jnp import matplotlib.patches as mpatches @@ -6,6 +5,7 @@ from matplotlib.animation import FFMpegWriter, FuncAnimation from matplotlib.gridspec import GridSpec +import b3d from tests.common.solver import Solver diff --git a/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker_with_reinitialization.py b/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker_with_reinitialization.py index 8d3441e5..efd03f6f 100644 --- a/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker_with_reinitialization.py +++ b/tests/sama4d/video_to_tracks/from_initialization/solvers/twod/single_patch_tracker_with_reinitialization.py @@ -1,4 +1,3 @@ -import b3d import jax import jax.numpy as jnp import matplotlib.patches as mpatches @@ -6,6 +5,7 @@ from matplotlib.animation import FFMpegWriter, FuncAnimation from matplotlib.gridspec import GridSpec +import b3d from tests.common.solver import Solver diff --git a/tests/sama4d/video_to_tracks/solvers/conv_with_reinstantiation.py b/tests/sama4d/video_to_tracks/solvers/conv_with_reinstantiation.py index 34741079..a07510cd 100644 --- a/tests/sama4d/video_to_tracks/solvers/conv_with_reinstantiation.py +++ b/tests/sama4d/video_to_tracks/solvers/conv_with_reinstantiation.py @@ -1,6 +1,6 @@ import jax -from b3d.chisight.patch_tracking_2d.patch_tracker import PatchTracker2D +from b3d.chisight.patch_tracking_2d.patch_tracker import PatchTracker2D from tests.common.solver import Solver ### Solver for VideoToTracksTask ### diff --git a/tests/sama4d/video_to_tracks/video_to_tracks_task.py b/tests/sama4d/video_to_tracks/video_to_tracks_task.py index 37729815..3c54e89b 100644 --- a/tests/sama4d/video_to_tracks/video_to_tracks_task.py +++ b/tests/sama4d/video_to_tracks/video_to_tracks_task.py @@ -1,13 +1,13 @@ import warnings from typing import Callable -import b3d import jax.numpy as jnp import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np import rerun as rr +import b3d from tests.common.task import Task diff --git a/tests/test_2d_patch_tracker.py b/tests/test_2d_patch_tracker.py index e7315797..58228767 100644 --- a/tests/test_2d_patch_tracker.py +++ b/tests/test_2d_patch_tracker.py @@ -1,5 +1,6 @@ -import b3d import jax + +import b3d from b3d.chisight.patch_tracking_2d.patch_tracker import PatchTracker2D diff --git a/tests/test_bayes3d_model.py b/tests/test_bayes3d_model.py index f3003525..a4f3d7a3 100644 --- a/tests/test_bayes3d_model.py +++ b/tests/test_bayes3d_model.py @@ -1,11 +1,12 @@ -import b3d -import b3d.bayes3d as bayes3d import genjax import jax import jax.numpy as jnp +from genjax import ChoiceMapBuilder as C + +import b3d +import b3d.bayes3d as bayes3d from b3d import Pose from b3d.bayes3d.model import model_multiobject_gl_factory -from genjax import ChoiceMapBuilder as C class TestGroup: diff --git a/tests/test_chisight_dense_gps.py b/tests/test_chisight_dense_gps.py index 418404b5..b5baee97 100644 --- a/tests/test_chisight_dense_gps.py +++ b/tests/test_chisight_dense_gps.py @@ -1,16 +1,17 @@ import importlib -import b3d -import b3d.chisight.particle_system as ps import jax import jax.numpy as jnp +from genjax import Pytree + +import b3d +import b3d.chisight.particle_system as ps from b3d import Mesh, Pose from b3d.chisight.dense.likelihoods import ( KRaysImageLikelihoodArgs, make_krays_image_observation_model, ) from b3d.renderer.renderer_original import RendererOriginal -from genjax import Pytree importlib.reload(ps) diff --git a/tests/test_chisight_sparse_gps.py b/tests/test_chisight_sparse_gps.py index a0fe9371..53fa54b3 100644 --- a/tests/test_chisight_sparse_gps.py +++ b/tests/test_chisight_sparse_gps.py @@ -1,13 +1,14 @@ import importlib -import b3d -import b3d.chisight.particle_system as ps import jax import jax.numpy as jnp -from b3d import Pose from genjax import ChoiceMapBuilder as C from genjax import Pytree +import b3d +import b3d.chisight.particle_system as ps +from b3d import Pose + importlib.reload(ps) diff --git a/tests/test_diff_renderer.py b/tests/test_diff_renderer.py index 423720d7..ad6bd160 100644 --- a/tests/test_diff_renderer.py +++ b/tests/test_diff_renderer.py @@ -1,13 +1,14 @@ from functools import partial -import b3d -import b3d.chisight.dense.differentiable_renderer as rendering import jax import jax.numpy as jnp import optax import rerun as rr from tqdm import tqdm +import b3d +import b3d.chisight.dense.differentiable_renderer as rendering + rr.init("gradients") rr.connect("127.0.0.1:8812") diff --git a/tests/test_image_posterior_resolution_invariance.py b/tests/test_image_posterior_resolution_invariance.py index b7802d94..ad1ae5e7 100644 --- a/tests/test_image_posterior_resolution_invariance.py +++ b/tests/test_image_posterior_resolution_invariance.py @@ -7,18 +7,19 @@ import os import unittest -import b3d -import b3d.bayes3d as bayes3d import genjax import jax import jax.numpy as jnp import matplotlib.pyplot as plt import rerun as rr import trimesh -from b3d import Pose from genjax import Pytree from tqdm import tqdm +import b3d +import b3d.bayes3d as bayes3d +from b3d import Pose + class UpsamplingRenderer(b3d.Renderer): """ diff --git a/tests/test_likelihood_invariances.py b/tests/test_likelihood_invariances.py index b7921820..403afd7f 100644 --- a/tests/test_likelihood_invariances.py +++ b/tests/test_likelihood_invariances.py @@ -1,12 +1,13 @@ import os -import b3d -import b3d.bayes3d as bayes3d import jax import jax.numpy as jnp import rerun as rr import trimesh +import b3d +import b3d.bayes3d as bayes3d + PORT = 8812 rr.init("real") rr.connect(addr=f"127.0.0.1:{PORT}") diff --git a/tests/test_mesh.py b/tests/test_mesh.py index 36c4afdd..3d593372 100644 --- a/tests/test_mesh.py +++ b/tests/test_mesh.py @@ -1,8 +1,9 @@ import unittest -import b3d import jax import jax.numpy as jnp + +import b3d from b3d import Mesh, Pose diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 90d0d072..048e0e9e 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1,6 +1,7 @@ import genjax import jax import jax.numpy as jnp + from b3d.modeling_utils import ( PythonMixtureDistribution, truncated_color_laplace, diff --git a/tests/test_mug_handle_posterior.py b/tests/test_mug_handle_posterior.py index 4d5a4e2a..68cb73e2 100644 --- a/tests/test_mug_handle_posterior.py +++ b/tests/test_mug_handle_posterior.py @@ -1,15 +1,16 @@ import os -import b3d -import b3d.bayes3d as bayes3d import genjax import jax import jax.numpy as jnp import rerun as rr import trimesh -from b3d import Pose from genjax import Pytree +import b3d +import b3d.bayes3d as bayes3d +from b3d import Pose + PORT = 8812 rr.init("233") rr.connect(addr=f"127.0.0.1:{PORT}") diff --git a/tests/test_mug_smc_pose_inference_synthetic_and_real.py b/tests/test_mug_smc_pose_inference_synthetic_and_real.py index e126c66d..52d47a44 100644 --- a/tests/test_mug_smc_pose_inference_synthetic_and_real.py +++ b/tests/test_mug_smc_pose_inference_synthetic_and_real.py @@ -1,16 +1,17 @@ import os -import b3d -import b3d.bayes3d as bayes3d import genjax import jax import jax.numpy as jnp import rerun as rr import trimesh -from b3d import Pose from genjax import Pytree from tqdm import tqdm +import b3d +import b3d.bayes3d as bayes3d +from b3d import Pose + def test_renderer_full(renderer): PORT = 8812 diff --git a/tests/test_pose.py b/tests/test_pose.py index 92cd0fe9..70565955 100644 --- a/tests/test_pose.py +++ b/tests/test_pose.py @@ -3,9 +3,10 @@ import jax import jax.numpy as jnp import numpy as np -from b3d.pose import Pose, camera_from_position_and_target from jax.scipy.spatial.transform import Rotation as Rot +from b3d.pose import Pose, camera_from_position_and_target + def keysplit(key, *ns): if len(ns) == 0: diff --git a/tests/test_render_ycb_model.py b/tests/test_render_ycb_model.py index 605f34b0..ad7fed55 100644 --- a/tests/test_render_ycb_model.py +++ b/tests/test_render_ycb_model.py @@ -1,10 +1,11 @@ import os -import b3d -import b3d.bayes3d as bayes3d import jax.numpy as jnp import trimesh +import b3d +import b3d.bayes3d as bayes3d + def test_renderer_full(renderer): mesh_path = os.path.join( diff --git a/tests/test_renderer.py b/tests/test_renderer.py index f00ee138..b2c0bac4 100644 --- a/tests/test_renderer.py +++ b/tests/test_renderer.py @@ -1,6 +1,7 @@ -import b3d import jax.numpy as jnp +import b3d + def test_renderer_full(renderer): vertices = jnp.array( diff --git a/tests/test_renderer_original.py b/tests/test_renderer_original.py index d4b09b82..3d52b444 100644 --- a/tests/test_renderer_original.py +++ b/tests/test_renderer_original.py @@ -1,9 +1,10 @@ import importlib -import b3d import jax import jax.numpy as jnp +import b3d + importlib.reload(b3d.renderer.renderer_original)