Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first pass at ruff linter fixes #68

Merged
merged 9 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix ]

- id: ruff-format
args: [ --exclude, kitti.ipynb ]
types_or: [ python, pyi, jupyter ]
2 changes: 1 addition & 1 deletion b3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .pose import Pose, Rot
from . import camera, colors, pose, types, utils
from . import renderer, io, bayes3d, chisight
from .renderer import Renderer, RendererOriginal
from .renderer import RendererOriginal

__version__ = metadata.version("genjax")
__all__ = [
Expand Down
1 change: 0 additions & 1 deletion b3d/bayes3d/enumerative_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import jax.numpy as jnp
import b3d
from b3d.pose import Pose
import genjax
from genjax import Pytree


Expand Down
1 change: 0 additions & 1 deletion b3d/bayes3d/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import genjax
from genjax.generative_functions.distributions import ExactDensity
import jax.numpy as jnp
import b3d
from b3d.pose import Pose
Expand Down
2 changes: 0 additions & 2 deletions b3d/bucket_utils/b3d_pull.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# gcloud storage cp --recursive gs://hgps_data_bucket/shared .
import argparse
import json
import os
import subprocess
import b3d
from pathlib import Path


## Paths.
Expand Down
3 changes: 0 additions & 3 deletions b3d/bucket_utils/b3d_push.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import argparse
import json
import os
import subprocess
import b3d
from pathlib import Path


## Paths.
Expand Down
2 changes: 0 additions & 2 deletions b3d/chisight/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from importlib import metadata
from . import particle_system, dense, sparse
3 changes: 1 addition & 2 deletions b3d/chisight/dense/dense_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from b3d.modeling_utils import uniform_discrete, uniform_pose, gaussian_vmf
from b3d.modeling_utils import uniform_pose
import genjax
import b3d
from b3d import Pose, Mesh
import jax
import jax.numpy as jnp
import b3d.chisight.dense.likelihoods.image_likelihood

Expand Down
8 changes: 0 additions & 8 deletions b3d/chisight/dense/likelihoods/image_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
import genjax
from genjax.generative_functions.distributions import ExactDensity
import jax.numpy as jnp
import b3d
from b3d import Mesh, Pose
from collections import namedtuple
from b3d.modeling_utils import uniform_discrete, uniform_pose
import jax
import os

from genjax import Pytree

Expand Down
9 changes: 0 additions & 9 deletions b3d/chisight/dense/likelihoods/kray_outlier_volume.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
import genjax
from genjax.generative_functions.distributions import ExactDensity
import jax.numpy as jnp
import b3d
from b3d import Mesh, Pose
from collections import namedtuple
from b3d.modeling_utils import uniform_discrete, uniform_pose
import jax
import os

from genjax import Pytree


def get_rgb_depth_inliers_from_observed_rendered_args(
Expand Down
5 changes: 0 additions & 5 deletions b3d/chisight/dense/likelihoods/krays.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import genjax
from genjax.generative_functions.distributions import ExactDensity
import jax.numpy as jnp
import b3d
from b3d import Mesh, Pose
from collections import namedtuple
from b3d.modeling_utils import uniform_discrete, uniform_pose
import jax
import os

from genjax import Pytree

Expand Down
1 change: 0 additions & 1 deletion b3d/chisight/dense/likelihoods/simple_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
def simple_likelihood(observed_rgbd, rendered_rgbd, likelihood_args):
fx = likelihood_args["fx"]
fy = likelihood_args["fy"]
far = likelihood_args["far"]

rendered_rgb = rendered_rgbd[..., :3]
observed_rgb = observed_rgbd[..., :3]
Expand Down
4 changes: 1 addition & 3 deletions b3d/chisight/particle_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import b3d
from b3d import Pose
import jax
import jax.numpy as jnp
import genjax
from genjax import gen
from b3d import Pose, Mesh
from b3d import Mesh
from b3d.chisight.sparse.gps_utils import add_dummy_var
from b3d.pose import uniform_pose_in_ball

Expand Down Expand Up @@ -244,7 +243,6 @@ def visualize_particle_system(
camera_pose_prior_params,
) = latent_particle_model_args

colors = b3d.distinct_colors(num_clusters.const)
absolute_particle_poses = particle_dynamics_summary["absolute_particle_poses"]
object_poses = particle_dynamics_summary["object_poses"]
camera_pose = particle_dynamics_summary["camera_pose"]
Expand Down
1 change: 0 additions & 1 deletion b3d/chisight/sparse/gps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from b3d.pose.pose_utils import (
uniform_samples_from_disc,
)
from .dynamic_gps import DynamicGPS
from typing import TypeAlias


Expand Down
4 changes: 1 addition & 3 deletions b3d/chisight/sparse/sparse_gps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dynamic_gps import DynamicGPS
from typing import Any, TypeAlias
from b3d.camera import screen_from_world
from genjax import ChoiceMapBuilder as C


@genjax.gen
Expand Down Expand Up @@ -108,7 +109,6 @@ def kernel(carried_state, _):
)

new_state = (t + 1, new_object_poses, new_camera_pose, static_carries)
vector_output = None
return (new_state, new_state)

# TODO: What arguments should be passed to the model?
Expand Down Expand Up @@ -279,8 +279,6 @@ def get_dynamic_gps(tr: SparseGPSModelTrace):
# -----------
# Setters
# -----------
from genjax import ChoiceMapBuilder as C
from genjax._src.core.generative.choice_map import EmptyChm


def set_camera_choice(t, cam: Pose, ch=None):
Expand Down
7 changes: 2 additions & 5 deletions b3d/chisight/sparse/sparse_model_reality_check.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import matplotlib.pyplot as plt
import genjax
import jax
import jax.numpy as jnp
from b3d.utils import keysplit
from b3d.camera import Intrinsics, screen_from_camera
from b3d.pose import Pose, camera_from_position_and_target
from b3d.camera import Intrinsics
from b3d.pose import Pose
from b3d.pose.pose_utils import uniform_pose_in_ball
from b3d.chisight.sparse.gps_utils import add_dummy_var
from b3d.chisight.sparse.sparse_gps_model import minimal_observation_model
Expand Down
6 changes: 1 addition & 5 deletions b3d/io/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import cv2
import imageio

import jax.numpy as jnp
import numpy as np
from PIL import Image
from tqdm import tqdm
from b3d import Mesh, Pose
from b3d import Pose
import glob

YCB_MODEL_NAMES = [
Expand Down Expand Up @@ -70,7 +69,6 @@ def get_ycbv_test_images(ycb_dir, scene_id, images_indices, fields=[]):

scene_rgb_images_dir = os.path.join(scene_data_dir, "rgb")
scene_depth_images_dir = os.path.join(scene_data_dir, "depth")
mask_visib_dir = os.path.join(scene_data_dir, "mask_visib")

with open(os.path.join(scene_data_dir, "scene_camera.json")) as scene_cam_data_json:
scene_cam_data = json.load(scene_cam_data_json)
Expand Down Expand Up @@ -156,8 +154,6 @@ def get_ycb_mesh(ycb_dir, id):


def get_ycbineoat_images(ycbineaot_dir, video_name, images_indices):
id_strs = []

video_dir = os.path.join(ycbineaot_dir, f"{video_name}")

color_files = sorted(glob.glob(f"{video_dir}/rgb/*.png"))
Expand Down
11 changes: 0 additions & 11 deletions b3d/io/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import os
from .video_input import VideoInput
from b3d.utils import get_shared
from b3d.types import Array
from b3d.io import VideoInput, FeatureTrackData
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
import cv2
import numpy as np
from sklearn.utils import Bunch
from pathlib import Path
import argparse
import inspect
import sys


def add_argparse(f):
Expand Down Expand Up @@ -129,7 +122,6 @@ def load_video_to_numpy(
T = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

if times is None:
times = np.arange(T, step=step)
Expand Down Expand Up @@ -222,9 +214,6 @@ def plot_video_summary(
reverse_color_channel=reverse_color_channel,
)

w = vid.shape[2]
h = vid.shape[1]

# Create a plot with the summary
# TODO: Should we hand in an axis?
fig, ax = plt.subplots(1, 1, figsize=(15, 4))
Expand Down
1 change: 0 additions & 1 deletion b3d/mesh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import b3d
import jax.numpy as jnp
from b3d import Pose
import jax
import trimesh
from jax.tree_util import register_pytree_node_class
Expand Down
6 changes: 1 addition & 5 deletions b3d/renderer/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _rasterize_bwd(self, saved_tensors, diffs):
)
dy, ddb = diffs

grads = _rasterize_bwd_custom_call(
_rasterize_bwd_custom_call(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to stare at this one more closely

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nishadgothoskar can you please comment on this? -- Do we need to call _rasterize_bwd_custom_call here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(@sritchie I am not familiar with all the code in renderer.py, so we probably need Nishad's input here.)

self,
pose,
pos,
Expand Down Expand Up @@ -458,7 +458,6 @@ def _render_batch(args, axes):
if pose.ndim != 5:
raise NotImplementedError("Underlying primitive must operate on 4D poses.")

original_shape = pose.shape
poses = jnp.moveaxis(pose, axes[0], 0)
size_1 = poses.shape[0]
size_2 = poses.shape[1]
Expand Down Expand Up @@ -749,9 +748,6 @@ def _interpolate_fwd_lowering(
def _render_batch_interp(args, axes):
attributes, uvs, triangle_ids, faces = args

original_shape_uvs = uvs.shape
original_shape_triangle_ids = triangle_ids.shape

uvs = jnp.moveaxis(uvs, axes[1], 0)
size_1 = uvs.shape[0]
size_2 = uvs.shape[1]
Expand Down
13 changes: 4 additions & 9 deletions b3d/renderer/renderer_original.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
from jax.interpreters import batching, mlir, xla
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
import functools
import os
import b3d
import b3d.renderer.nvdiffrast_jax.nvdiffrast.jax as dr
import rerun as rr


def projection_matrix_from_intrinsics(w, h, fx, fy, cx, cy, near, far):
Expand Down Expand Up @@ -192,12 +189,12 @@ def _register_custom_calls():


# @functools.partial(jax.jit, static_argnums=(0,))
def _rasterize_fwd_custom_call(r: "Renderer", pos, tri, resolution):
def _rasterize_fwd_custom_call(r: "RendererOriginal", pos, tri, resolution):
return _build_rasterize_fwd_primitive(r).bind(pos, tri, resolution)


@functools.lru_cache(maxsize=None)
def _build_rasterize_fwd_primitive(r: "Renderer"):
def _build_rasterize_fwd_primitive(r: "RendererOriginal"):
_register_custom_calls()
# For JIT compilation we need a function to evaluate the shape and dtype of the
# outputs of our op for some given inputs
Expand Down Expand Up @@ -277,7 +274,6 @@ def _rasterize_fwd_lowering(ctx, pos, tri, resolution):
def _render_batch(args, axes):
pos, tri, resolution = args

original_shape = pos.shape
pos_reshaped = pos.reshape(-1, *pos.shape[-2:])

(rendered,) = _rasterize_fwd_custom_call(r, pos_reshaped, tri, resolution)
Expand All @@ -304,7 +300,7 @@ def _render_batch(args, axes):

# @functools.partial(jax.jit, static_argnums=(0,))
def _interpolate_fwd_custom_call(
r: "Renderer",
r: "RendererOriginal",
attributes,
rast,
faces,
Expand All @@ -317,7 +313,7 @@ def _interpolate_fwd_custom_call(


# @functools.lru_cache(maxsize=None)
def _build_interpolate_fwd_primitive(r: "Renderer"):
def _build_interpolate_fwd_primitive(r: "RendererOriginal"):
_register_custom_calls()
# For JIT compilation we need a function to evaluate the shape and dtype of the
# outputs of our op for some given inputs
Expand Down Expand Up @@ -390,7 +386,6 @@ def _interpolate_fwd_lowering(
def _interpolate_batch(args, axes):
attributes, rast, faces = args

original_shape = attributes.shape
attributes_reshaped = attributes.reshape(-1, *attributes.shape[-2:])
rast_reshaped = rast.reshape(-1, *rast.shape[-3:])

Expand Down
3 changes: 0 additions & 3 deletions b3d/renderer/torch/pose.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import torch
import pytorch3d.transforms
import pytorch3d.renderer
4 changes: 1 addition & 3 deletions b3d/renderer/torch/renderutils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
# its affiliates is strictly prohibited.

import os
import sys

import numpy as np
import torch
import torch.utils.cpp_extension

Expand Down Expand Up @@ -75,7 +73,7 @@ def find_cl_path():
)
if os.path.exists(lock_fn):
print("Warning: Lock file exists in build directory: '%s'" % lock_fn)
except:
except Exception:
pass

# Compile and load.
Expand Down
Loading