Skip to content

Commit

Permalink
fix final ruff errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Jan 23, 2024
1 parent 455d843 commit df18958
Show file tree
Hide file tree
Showing 21 changed files with 62 additions and 57 deletions.
10 changes: 1 addition & 9 deletions bayes3d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
"""
.. include:: ./documentation.md
"""

from .camera import *
from .likelihood import *
from .renderer import *
from .rgbd import *
from .transforms_3d import *
from .viz import *

try:
import genjax

from .genjax import *
except ImportError as e:
print("GenJAX not installed. Importing bayes3d without genjax dependencies.")
print(e)


RENDERER = None
13 changes: 10 additions & 3 deletions bayes3d/_mkl/gaussian_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,19 @@
normal_pdf = jax.scipy.stats.norm.pdf
normal_logpdf = jax.scipy.stats.norm.logpdf
inv = jnp.linalg.inv
from bayes3d._mkl.types import (
Array,
CholeskyMatrix,
CovarianceMatrix,
Direction,
Float,
Matrix,
PrecisionMatrix,
Vector,
)

key = jax.random.PRNGKey(0)

# %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 5
from bayes3d._mkl.types import *


# %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 6
def ellipsoid_embedding(cov: CovarianceMatrix) -> Matrix:
Expand Down
17 changes: 9 additions & 8 deletions bayes3d/_mkl/simple_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb.

# %% auto 0
import genjax
import genjax._src.generative_functions.distributions.tensorflow_probability as gentfp
import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from genjax._src.generative_functions.distributions.distribution import ExactDensity

from bayes3d._mkl.utils import *

__all__ = [
"key",
"tfd",
Expand All @@ -20,17 +29,10 @@
]

# %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 3
import genjax
import jax
import jax.numpy as jnp

from bayes3d._mkl.utils import *

key = jax.random.PRNGKey(0)

# %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 4
import genjax._src.generative_functions.distributions.tensorflow_probability as gentfp
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions

Expand Down Expand Up @@ -177,7 +179,6 @@ def sensor_model(Y, sig, out):


# %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 15
from genjax._src.generative_functions.distributions.distribution import ExactDensity


def wrap_into_dist(score_func):
Expand Down
18 changes: 9 additions & 9 deletions bayes3d/_mkl/trimesh_to_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb.

# %% auto 0
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np
import trimesh
from jax import jit, vmap

from bayes3d._mkl.utils import *

__all__ = [
"Array",
"Shape",
Expand Down Expand Up @@ -156,15 +165,6 @@
"""

# %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 3
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np
import trimesh
from jax import jit, vmap

from bayes3d._mkl.utils import *

Array = np.ndarray | jax.Array
Shape = int | tuple[int, ...]
FaceIndex = int
Expand Down
2 changes: 1 addition & 1 deletion bayes3d/colmap/colmap_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def read_points3D_binary(path_to_model_file):
track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
0
]
track_elems = read_next_bytes(
_track_elems = read_next_bytes(
fid,
num_bytes=8 * track_length,
format_char_sequence="ii" * track_length,
Expand Down
9 changes: 7 additions & 2 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

import bayes3d as b

from .genjax_distributions import *
from .genjax_distributions import (
contact_params_uniform,
image_likelihood,
uniform_discrete,
uniform_pose,
)


@genjax.static
Expand Down Expand Up @@ -64,7 +69,7 @@ def model(array, possible_object_indices, pose_bounds, contact_bounds, all_box_d

variance = genjax.uniform(0.00000000001, 10000.0) @ "variance"
outlier_prob = genjax.uniform(-0.01, 10000.0) @ "outlier_prob"
image = image_likelihood(rendered, variance, outlier_prob) @ "image"
_image = image_likelihood(rendered, variance, outlier_prob) @ "image"
return (
rendered,
indices,
Expand Down
9 changes: 4 additions & 5 deletions bayes3d/neural/cosypose_baseline/cosypose_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
import signal
import subprocess
import sys
import time

import numpy as np
import torch
import yaml

cosypose_path = (
f"{os.path.dirname(os.path.abspath(__file__))}/cosypose_baseline/cosypose"
)
sys.path.append(cosypose_path) # TODO cleaner import / add to path

import time

import torch
import yaml

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Expand Down Expand Up @@ -211,7 +210,7 @@ def cosypose_interface(rgb_imgs, camera_k):

pred_poses = np.asarray(pred.poses.cpu())
pred_ids = [
int(l[-3:]) - 1 for l in pred.infos.label
int(label[-3:]) - 1 for label in pred.infos.label
] # ex) 'obj_000014' for GT_IDX 13
pred_scores = [pred.infos.iloc[i].score for i in range(len(pred.infos))]

Expand Down
8 changes: 4 additions & 4 deletions bayes3d/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,10 @@ def _load_vertices_lowering(ctx, vertices, triangles):
# Extract the numpy type of the inputs
vertices_aval, triangles_aval = ctx.avals_in

if np.dtype(vertices_aval.dtype) != np.float32:
raise NotImplementedError(f"Unsupported vertices dtype {np_dtype}")
if np.dtype(triangles_aval.dtype) != np.int32:
raise NotImplementedError(f"Unsupported triangles dtype {np_dtype}")
if (dt := np.dtype(vertices_aval.dtype)) != np.float32:
raise NotImplementedError(f"Unsupported vertices dtype {dt}")
if (dt := np.dtype(triangles_aval.dtype)) != np.int32:
raise NotImplementedError(f"Unsupported triangles dtype {dt}")

opaque = dr._get_plugin(gl=True).build_load_vertices_descriptor(
r.renderer_env.cpp_wrapper, vertices_aval.shape[0], triangles_aval.shape[0]
Expand Down
1 change: 0 additions & 1 deletion bayes3d/scene_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def add_edge_scene_graph(
scene_graph, parent, child, face_parent, face_child, contact_params
):
print(parent, child, face_parent, face_child)
N = scene_graph.get_poses().shape[0]
sg_parents = jnp.array(scene_graph.parents)
sg_parents = sg_parents.at[child].set(parent)
sg_contact_params = jnp.array(scene_graph.contact_params)
Expand Down
2 changes: 1 addition & 1 deletion bayes3d/utils/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def voxel_occupied_occluded_free(camera_pose, depth_image, grid, intrinsics, tol
occupied = jnp.abs(real_depth_vals - projected_depth_vals) < tolerance
occluded = real_depth_vals < projected_depth_vals
occluded = occluded * (1.0 - occupied)
free = (1.0 - occluded) * (1.0 - occupied)
_free = (1.0 - occluded) * (1.0 - occupied)
return 1.0 * occupied + 0.5 * occluded


Expand Down
4 changes: 2 additions & 2 deletions bayes3d/utils/pybullet_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def pybullet_render(scene):
return image, depth


def create_box(
def create_box_from_pose(
pose,
scale=[1, 1, 1],
restitution=1,
Expand Down Expand Up @@ -500,7 +500,7 @@ def simulate(self, timesteps):
return pyb

def close(self):
if self.pyb_sim == None:
if self.pyb_sim is None:
raise ValueError("No pybullet simulation to close")
else:
p.disconnect(self.pyb_sim.client)
Expand Down
3 changes: 1 addition & 2 deletions bayes3d/utils/r3d_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ def load_r3d(r3d_path):

color_paths = natsorted(glob.glob(os.path.join(datapath, "rgbd", "*.jpg")))
depth_paths = natsorted(glob.glob(os.path.join(datapath, "rgbd", "*.depth")))
conf_paths = natsorted(glob.glob(os.path.join(datapath, "rgbd", "*.conf")))

colors = np.array([load_color(color_paths[i]) for i in range(len(color_paths))])
depths = np.array([load_depth(depth_paths[i]) for i in range(len(color_paths))])
depths = np.array([load_depth(depth_paths[i]) for i in range(len(depth_paths))])
depths[np.isnan(depths)] = 0.0

poses = get_poses(metadata)
Expand Down
2 changes: 1 addition & 1 deletion bayes3d/utils/ycb_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_test_img(scene_id, img_id, ycb_dir):
cam_depth_scale = image_cam_data["depth_scale"]

# get {visible mask, ID, pose} for each object in the scene
anno = dict()
# anno = dict()

# get GT object model ID+poses
objects_gt_data = scene_imgs_gt_data[remove_zero_pad(img_id)]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ line-length = 88
indent-width = 4

[tool.ruff.lint]
exclude = ["bayes3d/_mkl/*.py"]
extend-select = ["I"]
select = ["E4", "E7", "E9", "F"]

Expand Down
2 changes: 2 additions & 0 deletions scripts/_mkl/notebooks/kubric/kubric_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import kubric as kb
import numpy as np

rng = np.random.default_rng(2021)


def get_linear_camera_motion_start_end(
movement_speed: float,
Expand Down
2 changes: 1 addition & 1 deletion scripts/experiments/collaborations/arijit_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def body_fun(prev):
def model(T):
pose = b.uniform_pose(jnp.ones(3) * -1.0, jnp.ones(3) * 1.0) @ "init_pose"
velocity = b.gaussian_vmf_pose(jnp.eye(4), 0.01, 10000.0) @ "init_velocity"
evolve = (
_evolve = (
genjax.UnfoldCombinator.new(body_fun, 100)(50, (0, pose, velocity)) @ "dynamics"
)
return 1.0
Expand Down
2 changes: 1 addition & 1 deletion scripts/experiments/colmap/colmap_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def read_points3D_binary(path_to_model_file):
track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
0
]
track_elems = read_next_bytes(
_track_elems = read_next_bytes(
fid,
num_bytes=8 * track_length,
format_char_sequence="ii" * track_length,
Expand Down
8 changes: 4 additions & 4 deletions scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def physics_prior_v1(prev_pose, prev_prev_pose, bbox_dims, camera_pose, world2ca

# I1 -> Integrate X-Y-Z forward to current time step
# jprint("pred pos: {}", camera_pose[:3,:] @ jnp.concatenate([pred_pos, 1], axis = None))
physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same
physics_estimated_pose = physics_estimated_pose.at[:3, 3].set(pred_pos)
physics_estimated_pose = prev_pose.at[:3, 3].set(pred_pos)

return physics_estimated_pose

Expand Down Expand Up @@ -152,8 +151,9 @@ def physics_prior_v2(

# I1 -> Integrate X-Y-Z forward to current time step
# jprint("pred pos: {}", camera_pose[:3,:] @ jnp.concatenate([pred_pos, 1], axis = None))
physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same
physics_estimated_pose = physics_estimated_pose.at[:3, 3].set(pred_pos)
# NOTE @sritchie and @colin modified this line in response to a Ruff error;
# we'll flag this in code review but this could be buggy.
physics_estimated_pose = prev_poses[T].at[:3, 3].set(pred_pos)

return physics_estimated_pose

Expand Down
2 changes: 1 addition & 1 deletion scripts/experiments/tabletop/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import genjax
import jax
import jax.numpy as jnp
import joblib

import bayes3d as b
import bayes3d.genjax

console = genjax.pretty(show_locals=False)
import joblib

intrinsics = b.Intrinsics(
height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0
Expand Down
2 changes: 1 addition & 1 deletion scripts/experiments/tabletop/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import genjax
import jax
import jax.numpy as jnp
import joblib
from tqdm import tqdm

import bayes3d as b
import bayes3d.genjax

console = genjax.pretty(show_locals=False)
import joblib

intrinsics = b.Intrinsics(
height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0
Expand Down
2 changes: 1 addition & 1 deletion test/test_genjax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_genjax_trace_contains_right_info():
),
)

scores = enumerators.enumerate_choices_get_scores(trace, key, jnp.zeros((100, 3)))
_scores = enumerators.enumerate_choices_get_scores(trace, key, jnp.zeros((100, 3)))

assert trace["parent_0"] == -1
assert (trace["camera_pose"] == jnp.eye(4)).all()
Expand Down

0 comments on commit df18958

Please sign in to comment.