From 455d843c87a430f6b815e7c7baae8a3ce1d26038 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Tue, 23 Jan 2024 10:57:38 -0500 Subject: [PATCH 1/6] first pass on pre-commit --- .github/workflows/pre-commit.yaml | 19 + .pre-commit-config.yaml | 17 + DEVELOPING.md | 34 ++ bayes3d/__init__.py | 12 +- bayes3d/_mkl/gaussian_particle_system_v0.py | 41 +- bayes3d/_mkl/gaussian_renderer.py | 174 ++++--- bayes3d/_mkl/gaussian_sensor_model.py | 26 +- bayes3d/_mkl/generic.py | 92 ++-- bayes3d/_mkl/plotting.py | 92 ++-- bayes3d/_mkl/pose.py | 257 ++++++---- bayes3d/_mkl/simple_likelihood.py | 154 +++--- bayes3d/_mkl/table_scene_model.py | 118 ++--- bayes3d/_mkl/trimesh_to_gaussians.py | 218 ++++---- bayes3d/_mkl/types.py | 51 +- bayes3d/_mkl/utils.py | 130 +++-- bayes3d/camera.py | 97 ++-- bayes3d/colmap/colmap_loader.py | 183 ++++--- bayes3d/colmap/colmap_utils.py | 97 ++-- bayes3d/colmap/dataset_loader.py | 142 ++++-- bayes3d/distributions.py | 75 +-- bayes3d/genjax/genjax_distributions.py | 37 +- bayes3d/genjax/model.py | 266 +++++++--- bayes3d/likelihood.py | 83 ++- .../cosypose_baseline/cosypose_utils.py | 185 ++++--- bayes3d/neural/dino.py | 357 ++++++++----- bayes3d/neural/segmentation.py | 16 +- bayes3d/renderer.py | 222 ++++---- .../rendering/nvdiffrast/common/__init__.py | 3 +- bayes3d/rendering/nvdiffrast/common/ops.py | 168 ++++--- .../rendering/nvdiffrast_jax/jax_renderer.py | 474 +++++++++++++----- .../nvdiffrast/common/__init__.py | 3 +- .../nvdiffrast_jax/nvdiffrast/common/ops.py | 168 ++++--- .../nvdiffrast_jax/nvdiffrast/jax/__init__.py | 3 +- .../nvdiffrast_jax/nvdiffrast/jax/ops.py | 159 +++--- .../renderer_matching_pytorch.py | 297 +++++++---- bayes3d/rendering/nvdiffrast_jax/setup.py | 48 +- .../_kubric_exec_parallel.py | 36 +- .../kubric_interface.py | 99 ++-- bayes3d/rgbd.py | 45 +- bayes3d/scene_graph.py | 200 +++++--- bayes3d/transforms_3d.py | 163 +++--- bayes3d/utils/__init__.py | 10 +- bayes3d/utils/bbox.py | 27 +- bayes3d/utils/enumerations.py | 156 ++++-- bayes3d/utils/gaussian_splatting.py | 205 +++++--- bayes3d/utils/icp.py | 114 +++-- bayes3d/utils/mesh.py | 138 +++-- bayes3d/utils/occlusion.py | 31 +- bayes3d/utils/pybullet_sim.py | 459 ++++++++++++----- bayes3d/utils/r3d_loader.py | 42 +- bayes3d/utils/utils.py | 254 ++++++---- bayes3d/utils/ycb_loader.py | 138 ++--- bayes3d/viz/meshcatviz.py | 45 +- bayes3d/viz/open3dviz.py | 161 +++--- bayes3d/viz/viz.py | 230 ++++++--- demo.py | 79 +-- pyproject.toml | 43 ++ .../_mkl/notebooks/kubric/kubric_helper.py | 34 +- scripts/_mkl/notebooks/nbexporter.py | 23 +- .../collaborations/arijit_physics.py | 22 +- scripts/experiments/colmap/colmap_loader.py | 182 ++++--- scripts/experiments/colmap/dataset_loader.py | 160 ++++-- scripts/experiments/colmap/run.py | 107 ++-- .../kubric_dataset_gen/kubric_dataset_gen.py | 61 ++- scripts/experiments/deeplearning/sam/sam.py | 76 +-- .../gaussian_splatting/optimization.py | 18 +- .../gaussian_splatting/splatting_simple.ipynb | 2 +- .../icra/camera_pose_tracking/util.py | 90 ++-- .../mcs/cognitive-battery/model.py | 8 +- .../mcs/cognitive-battery/scene_graph.py | 4 +- .../mcs/otp_gen/otp_gen/physics_priors.py | 82 +-- scripts/experiments/tabletop/data_gen.py | 97 ++-- scripts/experiments/tabletop/inference.py | 109 ++-- scripts/run_colmap.py | 107 ++-- scripts/ssh.py | 16 +- test/test_bbox_intersect.py | 60 ++- test/test_colmap.py | 27 +- test/test_cosypose.py | 17 +- test/test_differentiable_rendering.py | 51 +- test/test_genjax_model.py | 78 +-- test/test_icp.py | 64 ++- test/test_kubric.py | 49 +- test/test_likelihood.py | 17 +- test/test_open3d.py | 51 +- test/test_renderer.py | 129 ++--- test/test_renderer_internals.py | 71 +-- test/test_renderer_memory.py | 49 +- test/test_scene_graph.py | 36 +- test/test_splatting.py | 72 ++- test/test_transforms_3d.py | 5 +- test/test_viz.py | 63 ++- test/test_ycb_loading.py | 116 +++-- 92 files changed, 5937 insertions(+), 3409 deletions(-) create mode 100644 .github/workflows/pre-commit.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 DEVELOPING.md create mode 100644 pyproject.toml diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000..41739123 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,19 @@ +name: pre-commit hooks + +on: + pull_request: + push: + branches: + - main + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v4 + with: + python-version: 3.11.5 + + - uses: pre-commit/action@v3.0.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..10b2a03d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + args: [--unsafe] + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.11 + hooks: + - id: ruff + # types_or: [ python, pyi, jupyter ] + + - id: ruff-format + # types_or: [ python, pyi, jupyter ] diff --git a/DEVELOPING.md b/DEVELOPING.md new file mode 100644 index 00000000..420b15b9 --- /dev/null +++ b/DEVELOPING.md @@ -0,0 +1,34 @@ +# Developer's Guide + +This guide describes how to complete various tasks you'll encounter when working +on the Bayes3D codebase. + +### Commit Hooks + +We use [pre-commit](https://pre-commit.com/) to manage a series of git +pre-commit hooks for the project; for example, each time you commit code, the +hooks will make sure that your python is formatted properly. If your code isn't, +the hook will format it, so when you try to commit the second time you'll get +past the hook. + +All hooks are defined in `.pre-commit-config.yaml`. To install these hooks, +install `pre-commit` if you don't yet have it. I prefer using +[pipx](https://github.com/pipxproject/pipx) so that `pre-commit` stays globally +available. + +```bash +pipx install pre-commit +``` + +Then install the hooks with this command: + +```bash +pre-commit install +``` + +Now they'll run on every commit. If you want to run them manually, run the +following command: + +```bash +pre-commit run --all-files +``` diff --git a/bayes3d/__init__.py b/bayes3d/__init__.py index 0414da88..2d18c3df 100644 --- a/bayes3d/__init__.py +++ b/bayes3d/__init__.py @@ -1,24 +1,20 @@ """ .. include:: ./documentation.md """ -from .transforms_3d import * +from .camera import * +from .likelihood import * from .renderer import * from .rgbd import * -from .likelihood import * -from .camera import * +from .transforms_3d import * from .viz import * -from . import utils -from . import distributions -from . import scene_graph -from . import colmap try: import genjax + from .genjax import * except ImportError as e: print("GenJAX not installed. Importing bayes3d without genjax dependencies.") print(e) - RENDERER = None diff --git a/bayes3d/_mkl/gaussian_particle_system_v0.py b/bayes3d/_mkl/gaussian_particle_system_v0.py index 543ce60d..880daec0 100644 --- a/bayes3d/_mkl/gaussian_particle_system_v0.py +++ b/bayes3d/_mkl/gaussian_particle_system_v0.py @@ -1,40 +1,39 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/07a - Gaussian particle system v0.ipynb. # %% auto 0 -__all__ = ['normal_cdf', 'normal_pdf', 'normal_logpdf', 'inv', 'key', 'Array', 'Shape', 'Bool', 'Float', 'Int', 'Pose'] +__all__ = [ + "normal_cdf", + "normal_pdf", + "normal_logpdf", + "inv", + "key", + "Array", + "Shape", + "Bool", + "Float", + "Int", + "Pose", +] # %% ../../scripts/_mkl/notebooks/07a - Gaussian particle system v0.ipynb 2 -import bayes3d as b3d -import trimesh -import os -from bayes3d._mkl.utils import * -import matplotlib.pyplot as plt -import numpy as np import jax -from jax import jit, vmap import jax.numpy as jnp -from jax.scipy.spatial.transform import Rotation as Rot -from functools import partial -import genjax -from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics -from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth -import tensorflow_probability as tfp -from tensorflow_probability.substrates.jax.math import lambertw - +import numpy as np +from bayes3d._mkl.utils import * -normal_cdf = jax.scipy.stats.norm.cdf -normal_pdf = jax.scipy.stats.norm.pdf +normal_cdf = jax.scipy.stats.norm.cdf +normal_pdf = jax.scipy.stats.norm.pdf normal_logpdf = jax.scipy.stats.norm.logpdf inv = jnp.linalg.inv key = jax.random.PRNGKey(0) # %% ../../scripts/_mkl/notebooks/07a - Gaussian particle system v0.ipynb 4 -from typing import Any, NamedTuple -import numpy as np +from typing import NamedTuple + import jax -import jaxlib +import numpy as np Array = np.ndarray | jax.Array Shape = int | tuple[int, ...] diff --git a/bayes3d/_mkl/gaussian_renderer.py b/bayes3d/_mkl/gaussian_renderer.py index a912664f..5e56441d 100644 --- a/bayes3d/_mkl/gaussian_renderer.py +++ b/bayes3d/_mkl/gaussian_renderer.py @@ -1,29 +1,37 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb. # %% auto 0 -__all__ = ['normal_cdf', 'normal_pdf', 'normal_logpdf', 'inv', 'key', 'cast_rays', 'ellipsoid_embedding', 'bilinear', - 'log_gaussian', 'gaussian', 'gaussian_normalizing_constant', 'gaussian_restriction_to_ray', - 'discrete_arrival_probabilities', 'gaussian_time_of_arrival', 'gaussian_most_likely_time_of_arrival', - 'weighted_arrival_intersection', 'argmax_intersection', 'weighted_argmax_intersection'] +__all__ = [ + "normal_cdf", + "normal_pdf", + "normal_logpdf", + "inv", + "key", + "cast_rays", + "ellipsoid_embedding", + "bilinear", + "log_gaussian", + "gaussian", + "gaussian_normalizing_constant", + "gaussian_restriction_to_ray", + "discrete_arrival_probabilities", + "gaussian_time_of_arrival", + "gaussian_most_likely_time_of_arrival", + "weighted_arrival_intersection", + "argmax_intersection", + "weighted_argmax_intersection", +] # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 3 -import bayes3d as b3d -import trimesh -import os -from bayes3d._mkl.utils import * -import matplotlib.pyplot as plt -import numpy as np import jax -from jax import jit, vmap import jax.numpy as jnp -from functools import partial -from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics -from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth -import tensorflow_probability as tfp +from jax import jit, vmap from tensorflow_probability.substrates.jax.math import lambertw -normal_cdf = jax.scipy.stats.norm.cdf -normal_pdf = jax.scipy.stats.norm.pdf +from bayes3d._mkl.utils import * + +normal_cdf = jax.scipy.stats.norm.cdf +normal_pdf = jax.scipy.stats.norm.pdf normal_logpdf = jax.scipy.stats.norm.logpdf inv = jnp.linalg.inv @@ -32,57 +40,64 @@ # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 5 from bayes3d._mkl.types import * + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 6 -def ellipsoid_embedding(cov:CovarianceMatrix) -> Matrix: +def ellipsoid_embedding(cov: CovarianceMatrix) -> Matrix: """Returns A with cov = A@A.T""" sigma, U = jnp.linalg.eigh(cov) D = jnp.diag(jnp.sqrt(sigma)) return U @ D @ jnp.linalg.inv(U) + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 7 -def bilinear(x:Array, y:Array, A:Matrix) -> Float: +def bilinear(x: Array, y: Array, A: Matrix) -> Float: return x.T @ A @ y + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 8 -def log_gaussian(x:Vector, mu:Vector, P:PrecisionMatrix) -> Float: +def log_gaussian(x: Vector, mu: Vector, P: PrecisionMatrix) -> Float: """Evaluate an **unnormalized** gaussian at a given point.""" - return -0.5 * bilinear(x-mu, x-mu, P) + return -0.5 * bilinear(x - mu, x - mu, P) -def gaussian(x:Vector, mu:Vector, P:PrecisionMatrix) -> Float: +def gaussian(x: Vector, mu: Vector, P: PrecisionMatrix) -> Float: """Evaluate an **unnormalized** gaussian at a given point.""" - return jnp.exp(-0.5 * bilinear(x-mu, x-mu, P)) + return jnp.exp(-0.5 * bilinear(x - mu, x - mu, P)) -def gaussian_normalizing_constant(P:PrecisionMatrix) -> Float: +def gaussian_normalizing_constant(P: PrecisionMatrix) -> Float: """Returns the normalizing constant of an unnormalized gaussian.""" n = P.shape[0] - return jnp.sqrt(jnp.linalg.det(P)/(2*jnp.pi)**n) + return jnp.sqrt(jnp.linalg.det(P) / (2 * jnp.pi) ** n) + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 9 -def gaussian_restriction_to_ray(loc:Vector, P:PrecisionMatrix, A:CholeskyMatrix, x:Vector, v:Direction): +def gaussian_restriction_to_ray( + loc: Vector, P: PrecisionMatrix, A: CholeskyMatrix, x: Vector, v: Direction +): """ - Restricts a gaussian to a ray and returns - the mean `mu` and standard deviation `std`, s.t. we have + Restricts a gaussian to a ray and returns + the mean `mu` and standard deviation `std`, s.t. we have $$ P(x + t*v | loc, cov) = P(x + mu*v | loc, cov) * N(t | mu, std) $$ """ - mu = bilinear(loc - x, v, P)/bilinear(v, v, P) - std = 1/jnp.linalg.norm(inv(A)@v) + mu = bilinear(loc - x, v, P) / bilinear(v, v, P) + std = 1 / jnp.linalg.norm(inv(A) @ v) return mu, std + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 11 -def discrete_arrival_probabilities(occupancy_probs:Vector): +def discrete_arrival_probabilities(occupancy_probs: Vector): """ - Given an vector of `n` occupancy probabilities of neighbouring pixels, - it returns a vector of length `n+1` containing the probabilities of stopping + Given an vector of `n` occupancy probabilities of neighbouring pixels, + it returns a vector of length `n+1` containing the probabilities of stopping at a each pixel (while traversing them left to right) or not stopping at all. The return array is given by: $$ q_i = p_i \cdot \prod_{j=0}^{i-1} (1 - p_j) - + $$ for $i=0,...,n-1$, and $$ @@ -95,10 +110,13 @@ def discrete_arrival_probabilities(occupancy_probs:Vector): X(T) = \sigma(T)*\exp(- \int_0^T \sigma(t) \ dt). $$ """ - transmittances = jnp.concatenate([jnp.array([1.0]), jnp.cumprod(1-occupancy_probs)]) + transmittances = jnp.concatenate( + [jnp.array([1.0]), jnp.cumprod(1 - occupancy_probs)] + ) extended_occupancies = jnp.concatenate([occupancy_probs, jnp.array([1.0])]) return extended_occupancies * transmittances + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 13 def gaussian_time_of_arrival(xs, mu, sig, w=1.0): """ @@ -108,81 +126,103 @@ def gaussian_time_of_arrival(xs, mu, sig, w=1.0): Y(T) = w*g(T | \mu, \sigma)*\exp(- \int_0^T w*g(t | \mu, \sigma) \ dt). $$ """ - ys = w*normal_pdf(xs, loc=mu, scale=sig) * jnp.exp( - - w*normal_cdf(xs, loc=mu, scale=sig) - + w*normal_cdf(0.0, loc=mu, scale=sig)) - return ys - - -def gaussian_most_likely_time_of_arrival(mu, sig, w=1.): + ys = ( + w + * normal_pdf(xs, loc=mu, scale=sig) + * jnp.exp( + -w * normal_cdf(xs, loc=mu, scale=sig) + + w * normal_cdf(0.0, loc=mu, scale=sig) + ) + ) + return ys + + +def gaussian_most_likely_time_of_arrival(mu, sig, w=1.0): """ Returns the most likely time of first arrival - for a single weighted 1-dimensional Gaussian, i.e. the argmax of + for a single weighted 1-dimensional Gaussian, i.e. the argmax of $$ Y(T) = w*g(T | \mu, \sigma)*\exp(- \int_0^T w*g(t | \mu, \sigma) \ dt). $$ """ # TODO: Check if this is correct, cf. my notes. - Z = jnp.sqrt(lambertw(1/(2*jnp.pi) * w**2)) - return mu - Z*sig + Z = jnp.sqrt(lambertw(1 / (2 * jnp.pi) * w**2)) + return mu - Z * sig + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 16 -def weighted_arrival_intersection(mu:Vector, P:PrecisionMatrix, A:CholeskyMatrix, w:Float, x:Vector, v:Direction): +def weighted_arrival_intersection( + mu: Vector, P: PrecisionMatrix, A: CholeskyMatrix, w: Float, x: Vector, v: Direction +): """ Returns the "intersection" of a ray with a gaussian which we define as the mode of the gaussian restricted to the ray. """ t0, sig0 = gaussian_restriction_to_ray(mu, P, A, x, v) - w0 = w*gaussian(t0*v, mu, P) - Z = w0/gaussian_normalizing_constant(P) + w0 = w * gaussian(t0 * v, mu, P) + Z = w0 / gaussian_normalizing_constant(P) t = gaussian_most_likely_time_of_arrival(t0, sig0, Z) return t, w0 + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 17 -def argmax_intersection(mu:Vector, P:PrecisionMatrix, x:Vector, v:Direction): +def argmax_intersection(mu: Vector, P: PrecisionMatrix, x: Vector, v: Direction): """ Returns the "intersection" of a ray with a gaussian which we define as the mode of the gaussian restricted to the ray. """ - t = bilinear(mu - x, v, P)/bilinear(v, v, P) + t = bilinear(mu - x, v, P) / bilinear(v, v, P) return t -#|export -def weighted_argmax_intersection(mu:Vector, P:PrecisionMatrix, w:Float, x:Vector, v:Direction): +# |export +def weighted_argmax_intersection( + mu: Vector, P: PrecisionMatrix, w: Float, x: Vector, v: Direction +): """ Returns the "intersection" of a ray with a gaussian which we define as the mode of the gaussian restricted to the ray. """ - t = bilinear(mu - x, v, P)/bilinear(v, v, P) - return t, w*gaussian(x + t*v, mu, P) + t = bilinear(mu - x, v, P) / bilinear(v, v, P) + return t, w * gaussian(x + t * v, mu, P) + # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 24 -def _cast_ray(v, mus, precisions, colors, weights, zmax=2.0, bg_color=jnp.array([1.,1.,1.,1.])): +def _cast_ray( + v, + mus, + precisions, + colors, + weights, + zmax=2.0, + bg_color=jnp.array([1.0, 1.0, 1.0, 1.0]), +): # TODO: Deal with negative intersections behind the camera # TODO: Maybe switch to log probs? - # Compute fuzzy intersections `xs` with Gaussians and + # Compute fuzzy intersections `xs` with Gaussians and # their function values `sigmas` - ts, sigmas = vmap(weighted_argmax_intersection, (0,0,0,None,None))( - mus, precisions, weights, jnp.zeros(3), v) - order = jnp.argsort(ts) - ts = ts[order] + ts, sigmas = vmap(weighted_argmax_intersection, (0, 0, 0, None, None))( + mus, precisions, weights, jnp.zeros(3), v + ) + order = jnp.argsort(ts) + ts = ts[order] sigmas = sigmas[order] - xs = ts[:,None]*v[None,:] + xs = ts[:, None] * v[None, :] # TODO: Ensure that alphas are in [0,1] # TODO: Should we reset the color opacity to `op`? # Alternatively we can set `alphas = (1 - jnp.exp(-sigmas*1.0))` -- cf. Fuzzy Metaballs paper alphas = sigmas * (ts > 0) arrival_probs = discrete_arrival_probabilities(alphas) - op = 1 - arrival_probs[-1] # Opacity - mean_depth = jnp.sum(arrival_probs[:-1]*xs[:,2]) \ - + arrival_probs[-1]*zmax - mean_color = jnp.sum(arrival_probs[:-1,None]*colors[order], axis=0) \ - + arrival_probs[-1]*bg_color + op = 1 - arrival_probs[-1] # Opacity + mean_depth = jnp.sum(arrival_probs[:-1] * xs[:, 2]) + arrival_probs[-1] * zmax + mean_color = ( + jnp.sum(arrival_probs[:-1, None] * colors[order], axis=0) + + arrival_probs[-1] * bg_color + ) return mean_depth, mean_color, op -cast_rays = jit(vmap(_cast_ray, (0,None,None,None,None,None,None))) +cast_rays = jit(vmap(_cast_ray, (0, None, None, None, None, None, None))) diff --git a/bayes3d/_mkl/gaussian_sensor_model.py b/bayes3d/_mkl/gaussian_sensor_model.py index fa73d765..dea06350 100644 --- a/bayes3d/_mkl/gaussian_sensor_model.py +++ b/bayes3d/_mkl/gaussian_sensor_model.py @@ -1,37 +1,19 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/06b - Gaussian Sensor Model.ipynb. # %% auto 0 -__all__ = ['normal_cdf', 'normal_pdf', 'normal_logpdf', 'inv', 'key'] +__all__ = ["normal_cdf", "normal_pdf", "normal_logpdf", "inv", "key"] # %% ../../scripts/_mkl/notebooks/06b - Gaussian Sensor Model.ipynb 3 -import bayes3d as b3d -import trimesh -import os -from bayes3d._mkl.utils import * -import matplotlib.pyplot as plt -import numpy as np import jax -from jax import jit, vmap import jax.numpy as jnp -from functools import partial -from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics -from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth -import tensorflow_probability as tfp -from tensorflow_probability.substrates.jax.math import lambertw - # %% ../../scripts/_mkl/notebooks/06b - Gaussian Sensor Model.ipynb 4 from bayes3d._mkl.types import * -from bayes3d._mkl.gaussian_renderer import ( - weighted_arrival_intersection, - weighted_argmax_intersection, - discrete_arrival_probabilities -) - +from bayes3d._mkl.utils import * # %% ../../scripts/_mkl/notebooks/06b - Gaussian Sensor Model.ipynb 5 -normal_cdf = jax.scipy.stats.norm.cdf -normal_pdf = jax.scipy.stats.norm.pdf +normal_cdf = jax.scipy.stats.norm.cdf +normal_pdf = jax.scipy.stats.norm.pdf normal_logpdf = jax.scipy.stats.norm.logpdf inv = jnp.linalg.inv diff --git a/bayes3d/_mkl/generic.py b/bayes3d/_mkl/generic.py index 078a83f7..7a0bc275 100644 --- a/bayes3d/_mkl/generic.py +++ b/bayes3d/_mkl/generic.py @@ -1,92 +1,101 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb. # %% auto 0 -__all__ = ['normal_logpdf', 'normal_pdf', 'truncnorm_logpdf', 'truncnorm_pdf', 'inv', 'logaddexp', 'logsumexp', - 'contact_from_grid', 'generic_viewpoint', 'generic_contact'] +__all__ = [ + "normal_logpdf", + "normal_pdf", + "truncnorm_logpdf", + "truncnorm_pdf", + "inv", + "logaddexp", + "logsumexp", + "contact_from_grid", + "generic_viewpoint", + "generic_contact", +] # %% ../../scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb 1 import jax -from bayes3d._mkl.utils import keysplit -from bayes3d._mkl.pose import pack_pose -from jax import jit, vmap from jax import numpy as jnp +from jax import vmap from jax.scipy.spatial.transform import Rotation -from scipy.stats import truncnorm as scipy_truncnormal -normal_logpdf = jax.scipy.stats.norm.logpdf -normal_pdf = jax.scipy.stats.norm.pdf +from bayes3d._mkl.pose import pack_pose +from bayes3d._mkl.utils import keysplit + +normal_logpdf = jax.scipy.stats.norm.logpdf +normal_pdf = jax.scipy.stats.norm.pdf truncnorm_logpdf = jax.scipy.stats.truncnorm.logpdf -truncnorm_pdf = jax.scipy.stats.truncnorm.pdf +truncnorm_pdf = jax.scipy.stats.truncnorm.pdf -inv = jnp.linalg.inv +inv = jnp.linalg.inv logaddexp = jnp.logaddexp logsumexp = jax.scipy.special.logsumexp + # %% ../../scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb 2 def generic_viewpoint(key, cam, n, sig_x, sig_hd): """Generates generix camera poses by varying its xy-coordinates and angle (in the xy-plane).""" - + # TODO: Make a version that varies rot and pitch and potentially roll. - - _, keys = keysplit(key,1,2) + + _, keys = keysplit(key, 1, 2) # Generic position - xs = sig_x*jax.random.normal(keys[1], (n,3)) - xs = xs.at[0,:].set(0.0) - xs = xs.at[:,2].set(0.0) + xs = sig_x * jax.random.normal(keys[1], (n, 3)) + xs = xs.at[0, :].set(0.0) + xs = xs.at[:, 2].set(0.0) # Generic rotation - hds = sig_hd*jax.random.normal(keys[0], (n,)) + hds = sig_hd * jax.random.normal(keys[0], (n,)) hds = hds.at[0].set(0.0) - rs = vmap(Rotation.from_euler, (None,0))("y", hds) + rs = vmap(Rotation.from_euler, (None, 0))("y", hds) rs = Rotation.as_matrix(rs) - + # Generic camera poses ps = vmap(pack_pose)(xs, rs) - ps = cam@ps + ps = cam @ ps # Generic weights logps_hd = normal_logpdf(hds, loc=0.0, scale=sig_hd) - logps_x = normal_logpdf( xs, loc=0.0, scale=sig_x).sum(-1) - logps = logps_hd + logps_x + logps_x = normal_logpdf(xs, loc=0.0, scale=sig_x).sum(-1) + logps = logps_hd + logps_x return ps, logps # %% ../../scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb 3 def generic_contact(key, p0, n, sig_x, sig_hd): - - _, keys = keysplit(key,1,2) + _, keys = keysplit(key, 1, 2) # Generic contact-pose vector - xs = sig_x*jax.random.normal(keys[1], (n,3)) - xs = xs.at[:,2].set(0.0) - xs = xs.at[0,:].set(0.0) + xs = sig_x * jax.random.normal(keys[1], (n, 3)) + xs = xs.at[:, 2].set(0.0) + xs = xs.at[0, :].set(0.0) - hds = sig_hd*jax.random.normal(keys[0], (n,1)) - hds = hds.at[0,:].set(0.0) - rs = vmap(Rotation.from_euler, (None,0))("z", hds) + hds = sig_hd * jax.random.normal(keys[0], (n, 1)) + hds = hds.at[0, :].set(0.0) + rs = vmap(Rotation.from_euler, (None, 0))("z", hds) rs = Rotation.as_matrix(rs) - + # Generic camera poses ps = vmap(pack_pose)(xs, rs) # vs = jnp.concatenate([xs, hds], axis=1) # Generic weights - logps_hd = normal_logpdf(hds[:,0], loc=0.0, scale=sig_hd) - logps_x = normal_logpdf (xs, loc=0.0, scale=sig_x).sum(-1) - logps = logps_hd + logps_x + logps_hd = normal_logpdf(hds[:, 0], loc=0.0, scale=sig_hd) + logps_x = normal_logpdf(xs, loc=0.0, scale=sig_x).sum(-1) + logps = logps_hd + logps_x # Generic object pose - generic_ps = p0@ps + generic_ps = p0 @ ps return generic_ps, logps - # %% ../../scripts/_mkl/notebooks/31 - Generic Viewpoint.ipynb 4 def _contact_from_grid(v, p0=jnp.eye(4), sig_x=1.0, sig_hd=1.0): - x = jnp.array([*v[:2],0.0]) + x = jnp.array([*v[:2], 0.0]) hd = v[2] r = Rotation.from_euler("z", hd) @@ -94,9 +103,10 @@ def _contact_from_grid(v, p0=jnp.eye(4), sig_x=1.0, sig_hd=1.0): p = pack_pose(x, r) logp_hd = normal_logpdf(hd, loc=0.0, scale=sig_hd) - logp_x = normal_logpdf (x, loc=0.0, scale=sig_x).sum(-1) - logp = logp_hd + logp_x + logp_x = normal_logpdf(x, loc=0.0, scale=sig_x).sum(-1) + logp = logp_hd + logp_x + + return p0 @ p, logp - return p0@p, logp -contact_from_grid = vmap(_contact_from_grid, (0,None,None,None)) +contact_from_grid = vmap(_contact_from_grid, (0, None, None, None)) diff --git a/bayes3d/_mkl/plotting.py b/bayes3d/_mkl/plotting.py index ea7d7d8d..1c7fe202 100644 --- a/bayes3d/_mkl/plotting.py +++ b/bayes3d/_mkl/plotting.py @@ -1,31 +1,43 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/01 - Plotting.ipynb. # %% auto 0 -__all__ = ['rgba_from_vals', 'line_collection', 'plot_segs', 'zoom_in', 'unit_vec', 'plot_poses', 'plot_pose'] +__all__ = [ + "rgba_from_vals", + "line_collection", + "plot_segs", + "zoom_in", + "unit_vec", + "plot_poses", + "plot_pose", +] # %% ../../scripts/_mkl/notebooks/01 - Plotting.ipynb 2 +import jax +import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np -import jax.numpy as jnp -import jax + # %% ../../scripts/_mkl/notebooks/01 - Plotting.ipynb 3 def rgba_from_vals(vs, q=0.0, cmap="viridis", vmin=None, vmax=None): - if isinstance(q,list): + if isinstance(q, list): v_min = np.quantile(vs, q[0]) v_max = np.quantile(vs, q[1]) else: v_min = np.quantile(vs, q) v_max = np.max(vs) - if vmax is not None: v_max = vmax - if vmin is not None: v_min = vmin + if vmax is not None: + v_max = vmax + if vmin is not None: + v_min = vmin - cm = getattr(plt.cm, cmap) + cm = getattr(plt.cm, cmap) vs_ = np.clip(vs, v_min, v_max) - cs = cm(plt.Normalize()(vs_)) + cs = cm(plt.Normalize()(vs_)) return cs + # %% ../../scripts/_mkl/notebooks/01 - Plotting.ipynb 5 from matplotlib.collections import LineCollection @@ -35,42 +47,61 @@ def line_collection(a, b, c=None, linewidth=1, **kwargs): lc = LineCollection(lines, colors=c, linewidths=linewidth, **kwargs) return lc + # %% ../../scripts/_mkl/notebooks/01 - Plotting.ipynb 7 -def plot_segs(segs, c="k", linewidth=1, ax=None, **kwargs): - if ax is None: ax = plt.gca() +def plot_segs(segs, c="k", linewidth=1, ax=None, **kwargs): + if ax is None: + ax = plt.gca() n = 10 - segs = segs.reshape(-1,2,2) - a = segs[:,0] - b = segs[:,1] + segs = segs.reshape(-1, 2, 2) + a = segs[:, 0] + b = segs[:, 1] lc = line_collection(a, b, linewidth=linewidth, **kwargs) lc.set_colors(c) ax.add_collection(lc) + # %% ../../scripts/_mkl/notebooks/01 - Plotting.ipynb 9 def zoom_in(x, pad, ax=None): - if ax is None: ax = plt.gca() - ax.set_xlim(np.min(x[...,0])-pad, np.max(x[...,0])+pad) - ax.set_ylim(np.min(x[...,1])-pad, np.max(x[...,1])+pad) + if ax is None: + ax = plt.gca() + ax.set_xlim(np.min(x[..., 0]) - pad, np.max(x[..., 0]) + pad) + ax.set_ylim(np.min(x[..., 1]) - pad, np.max(x[..., 1]) + pad) + # %% ../../scripts/_mkl/notebooks/01 - Plotting.ipynb 10 -def unit_vec(hd): +def unit_vec(hd): return jnp.array([jnp.cos(hd), jnp.sin(hd)]) -def plot_poses(ps, sc=None, r=0.5, clip=-1e12, cs=None, c="lightgray", cmap="viridis", ax=None, q=0.0, zorder=None, linewidth=2): - if ax is None: ax = plt.gca() + +def plot_poses( + ps, + sc=None, + r=0.5, + clip=-1e12, + cs=None, + c="lightgray", + cmap="viridis", + ax=None, + q=0.0, + zorder=None, + linewidth=2, +): + if ax is None: + ax = plt.gca() ax.set_aspect(1) - ps = ps.reshape(-1,3) + ps = ps.reshape(-1, 3) - a = ps[:,:2] - b = a + r * jax.vmap(unit_vec)(ps[:,2]) + a = ps[:, :2] + b = a + r * jax.vmap(unit_vec)(ps[:, 2]) if cs is None: if sc is None: cs = c else: sc = sc.reshape(-1) - sc = jnp.where(jnp==-jnp.inf, clip, sc) - sc = jnp.clip(sc, clip, jnp.max(sc)) + sc = jnp.where(jnp == -jnp.inf, clip, sc) + sc = jnp.clip(sc, clip, jnp.max(sc)) sc = jnp.clip(sc, jnp.quantile(sc, q), jnp.max(sc)) cs = getattr(plt.cm, cmap)(plt.Normalize()(sc)) @@ -79,15 +110,14 @@ def plot_poses(ps, sc=None, r=0.5, clip=-1e12, cs=None, c="lightgray", cmap="vi b = b[order] cs = cs[order] + ax.add_collection(line_collection(a, b, c=cs, zorder=zorder, linewidth=linewidth)) - ax.add_collection(line_collection(a,b, c=cs, zorder=zorder, linewidth=linewidth)); - # %% ../../scripts/_mkl/notebooks/01 - Plotting.ipynb 11 -def plot_pose(p, r=0.5, c="red", ax=None,zorder=None, linewidth=2): - if ax is None: ax = plt.gca() +def plot_pose(p, r=0.5, c="red", ax=None, zorder=None, linewidth=2): + if ax is None: + ax = plt.gca() ax.set_aspect(1) a = p[:2] - b = a + r*unit_vec(p[2]) - ax.plot([a[0],b[0]],[a[1],b[1]], c=c, zorder=zorder, linewidth=linewidth) - + b = a + r * unit_vec(p[2]) + ax.plot([a[0], b[0]], [a[1], b[1]], c=c, zorder=zorder, linewidth=linewidth) diff --git a/bayes3d/_mkl/pose.py b/bayes3d/_mkl/pose.py index 5f5041f8..c55ce583 100644 --- a/bayes3d/_mkl/pose.py +++ b/bayes3d/_mkl/pose.py @@ -1,77 +1,101 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/02 - Pose.ipynb. # %% auto 0 -__all__ = ['PI', 'TWOPI', 'CAM_ALONG_X', 'Rot', 'Pose', 'rot2d', 'pack_2dpose', 'apply_2dpose', 'unit_vec', 'adjust_angle', - 'rot_x', 'rot_y', 'rot_z', 'from_euler', 'look_at', 'ax_to_ind', 'Rotation', 'unpack_pose', 'pack_pose', - 'apply_pose', 'mpl_plot_pose', 'lift_pose'] +__all__ = [ + "PI", + "TWOPI", + "CAM_ALONG_X", + "Rot", + "Pose", + "rot2d", + "pack_2dpose", + "apply_2dpose", + "unit_vec", + "adjust_angle", + "rot_x", + "rot_y", + "rot_z", + "from_euler", + "look_at", + "ax_to_ind", + "Rotation", + "unpack_pose", + "pack_pose", + "apply_pose", + "mpl_plot_pose", + "lift_pose", +] # %% ../../scripts/_mkl/notebooks/02 - Pose.ipynb 2 -import jax -import jax.numpy as jnp -import genjax -from genjax.generative_functions.distributions import ExactDensity -from dataclasses import dataclass from collections import namedtuple + +import jax.numpy as jnp from plum import dispatch -PI = jnp.pi -TWOPI = 2*PI +PI = jnp.pi +TWOPI = 2 * PI + # %% ../../scripts/_mkl/notebooks/02 - Pose.ipynb 6 -def rot2d(hd): return jnp.array([ - [jnp.cos(hd), -jnp.sin(hd)], - [jnp.sin(hd), jnp.cos(hd)] - ]); +def rot2d(hd): + return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]]) + -def pack_2dpose(x,hd): - return jnp.concatenate([x,jnp.array([hd])]) +def pack_2dpose(x, hd): + return jnp.concatenate([x, jnp.array([hd])]) -def apply_2dpose(p, ys): - return ys@rot2d(p[2] - jnp.pi/2).T + p[:2] -def unit_vec(hd): +def apply_2dpose(p, ys): + return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2] + + +def unit_vec(hd): return jnp.array([jnp.cos(hd), jnp.sin(hd)]) -def adjust_angle(hd): - return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi + +def adjust_angle(hd): + return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi + # %% ../../scripts/_mkl/notebooks/02 - Pose.ipynb 8 -CAM_ALONG_X = jnp.array([ - [0, 0, 1], - [-1, 0, 0], - [0, -1, 0] -]) +CAM_ALONG_X = jnp.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) def rot_x(theta): - return jnp.array([ - [1, 0, 0], - [0, jnp.cos(theta), -jnp.sin(theta)], - [0, jnp.sin(theta), jnp.cos(theta)] - ]) + return jnp.array( + [ + [1, 0, 0], + [0, jnp.cos(theta), -jnp.sin(theta)], + [0, jnp.sin(theta), jnp.cos(theta)], + ] + ) def rot_y(theta): - return jnp.array([ - [jnp.cos(theta), 0, -jnp.sin(theta)], - [0, 1, 0], - [jnp.sin(theta), 0, jnp.cos(theta)] - ]) + return jnp.array( + [ + [jnp.cos(theta), 0, -jnp.sin(theta)], + [0, 1, 0], + [jnp.sin(theta), 0, jnp.cos(theta)], + ] + ) def rot_z(theta): - return jnp.array([ - [jnp.cos(theta), -jnp.sin(theta), 0], - [jnp.sin(theta), jnp.cos(theta), 0], - [0, 0, 1] - ]) + return jnp.array( + [ + [jnp.cos(theta), -jnp.sin(theta), 0], + [jnp.sin(theta), jnp.cos(theta), 0], + [0, 0, 1], + ] + ) def from_euler(rot, pitch=0.0, roll=0.0): """ Imagine you stand on xy-plane and rotate (z-axis), pitch (y'-axis), and roll (x''-axis). """ - return rot_z(rot)@rot_y(pitch)@rot_x(roll) + return rot_z(rot) @ rot_y(pitch) @ rot_x(roll) def look_at(v, roll=0.0, cam=True): @@ -79,47 +103,57 @@ def look_at(v, roll=0.0, cam=True): R = CAM_ALONG_X if cam else jnp.eye(3) n = jnp.linalg.norm(v) - rot = jnp.arctan2(v[1],v[0]) - pitch = jnp.arctan2(v[2],n) - return from_euler(rot, pitch, roll)@R + rot = jnp.arctan2(v[1], v[0]) + pitch = jnp.arctan2(v[2], n) + return from_euler(rot, pitch, roll) @ R + # %% ../../scripts/_mkl/notebooks/02 - Pose.ipynb 9 def ax_to_ind(c): - lookup = {"x":0, "y":1, "z":2} + lookup = {"x": 0, "y": 1, "z": 2} return lookup[c] class Rotation(object): @staticmethod def _x(theta): - return jnp.array([ - [1, 0, 0], - [0, jnp.cos(theta), -jnp.sin(theta)], - [0, jnp.sin(theta), jnp.cos(theta)] - ]) + return jnp.array( + [ + [1, 0, 0], + [0, jnp.cos(theta), -jnp.sin(theta)], + [0, jnp.sin(theta), jnp.cos(theta)], + ] + ) @staticmethod def _y(theta): - return jnp.array([ - [jnp.cos(theta), 0, -jnp.sin(theta)], - [0, 1, 0], - [jnp.sin(theta), 0, jnp.cos(theta)] - ]) + return jnp.array( + [ + [jnp.cos(theta), 0, -jnp.sin(theta)], + [0, 1, 0], + [jnp.sin(theta), 0, jnp.cos(theta)], + ] + ) @staticmethod def _z(theta): - return jnp.array([ - [jnp.cos(theta), -jnp.sin(theta), 0], - [jnp.sin(theta), jnp.cos(theta), 0], - [0, 0, 1] - ]) + return jnp.array( + [ + [jnp.cos(theta), -jnp.sin(theta), 0], + [jnp.sin(theta), jnp.cos(theta), 0], + [0, 0, 1], + ] + ) @staticmethod - def _ax(ax:str, theta): - if ax == "x": return Rotation._x(theta) - elif ax == "y": return Rotation._y(theta) - elif ax == "z": return Rotation._z(theta) - + def _ax(ax: str, theta): + if ax == "x": + return Rotation._x(theta) + elif ax == "y": + return Rotation._y(theta) + elif ax == "z": + return Rotation._z(theta) + @staticmethod def from_euler(order, angles): """ @@ -128,62 +162,91 @@ def from_euler(order, angles): angles : Array of length 3, e.g. [0, 0, 0] """ rot_ax = Rotation._ax - return rot_ax(order[0], angles[0])@rot_ax(order[1],angles[1])@rot_ax(order[2],angles[2]) - + return ( + rot_ax(order[0], angles[0]) + @ rot_ax(order[1], angles[1]) + @ rot_ax(order[2], angles[2]) + ) + @staticmethod def look_at(v, roll=0.0, order="zyx"): - """ - """ + """ """ n = jnp.linalg.norm(v) - rot = jnp.arctan2(v[1],v[0]) - pitch = jnp.arctan2(v[2],n) - return Rotation.from_euler(order, [rot, pitch, roll])@yzX[:3,:3] - + rot = jnp.arctan2(v[1], v[0]) + pitch = jnp.arctan2(v[2], n) + return Rotation.from_euler(order, [rot, pitch, roll]) @ yzX[:3, :3] + + Rot = Rotation # %% ../../scripts/_mkl/notebooks/02 - Pose.ipynb 12 -Pose = namedtuple("Pose", ["x", "r"]) +Pose = namedtuple("Pose", ["x", "r"]) + @dispatch -def unpack_pose(p:Pose): +def unpack_pose(p: Pose): return p.x, p.r + @dispatch -def unpack_pose(R:jnp.ndarray): - return R[:3,3], R[:3,:3] +def unpack_pose(R: jnp.ndarray): + return R[:3, 3], R[:3, :3] + + +def pack_pose(x, r): + return jnp.concatenate( + [jnp.concatenate([r, x[:, None]], axis=1), jnp.array([[0, 0, 0, 1]])], axis=0 + ) -def pack_pose(x, r): - return jnp.concatenate([ - jnp.concatenate([r, x[:,None]], axis=1), - jnp.array([[0,0,0,1]])], axis=0) def apply_pose(p, x): t, r = unpack_pose(p) - return x@r.T + t + return x @ r.T + t + # %% ../../scripts/_mkl/notebooks/02 - Pose.ipynb 13 -import matplotlib.pyplot as plt -from mpl_toolkits.mplot3d import Axes3D -def mpl_plot_pose(ax, p, s=0.1, length=0.1, normalize=True, **kwargs): - t,r = unpack_pose(p) +def mpl_plot_pose(ax, p, s=0.1, length=0.1, normalize=True, **kwargs): + t, r = unpack_pose(p) # Coordinate frame data origin = t - x_axis = s*r[:3,0] - y_axis = s*r[:3,1] - z_axis = s*r[:3,2] + x_axis = s * r[:3, 0] + y_axis = s * r[:3, 1] + z_axis = s * r[:3, 2] # Plotting the coordinate frame - ax.quiver(*origin, *x_axis, color='r', label='X-axis', length=length, normalize=normalize, **kwargs) - ax.quiver(*origin, *y_axis, color='g', label='Y-axis', length=length, normalize=normalize, **kwargs) - ax.quiver(*origin, *z_axis, color='b', label='Z-axis', length=length, normalize=normalize, **kwargs) + ax.quiver( + *origin, + *x_axis, + color="r", + label="X-axis", + length=length, + normalize=normalize, + **kwargs, + ) + ax.quiver( + *origin, + *y_axis, + color="g", + label="Y-axis", + length=length, + normalize=normalize, + **kwargs, + ) + ax.quiver( + *origin, + *z_axis, + color="b", + label="Z-axis", + length=length, + normalize=normalize, + **kwargs, + ) + # %% ../../scripts/_mkl/notebooks/02 - Pose.ipynb 15 def lift_pose(x, hd, z=0.0, pitch=0.0, roll=0.0): """Lifts a 2d pose (x,hd) to 3d""" - return pack_pose( - jnp.concatenate([x, jnp.array([z])]), - from_euler(hd) @ CAM_ALONG_X - ) + return pack_pose(jnp.concatenate([x, jnp.array([z])]), from_euler(hd) @ CAM_ALONG_X) diff --git a/bayes3d/_mkl/simple_likelihood.py b/bayes3d/_mkl/simple_likelihood.py index 51bcd220..67a01940 100644 --- a/bayes3d/_mkl/simple_likelihood.py +++ b/bayes3d/_mkl/simple_likelihood.py @@ -1,19 +1,29 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb. # %% auto 0 -__all__ = ['key', 'tfd', 'uniform', 'truncnormal', 'normal', 'diagnormal', 'mixture_of_diagnormals', 'mixture_of_normals', - 'mixture_of_truncnormals', 'normal_logpdf', 'truncnorm_logpdf', 'truncnorm_pdf', 'make_simple_sensor_model', - 'make_simple_step_sensor_model', 'wrap_into_dist'] +__all__ = [ + "key", + "tfd", + "uniform", + "truncnormal", + "normal", + "diagnormal", + "mixture_of_diagnormals", + "mixture_of_normals", + "mixture_of_truncnormals", + "normal_logpdf", + "truncnorm_logpdf", + "truncnorm_pdf", + "make_simple_sensor_model", + "make_simple_step_sensor_model", + "wrap_into_dist", +] # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 3 +import genjax import jax import jax.numpy as jnp -from jax import jit, vmap -import genjax -from genjax import gen, choice_map, vector_choice_map -import matplotlib.pyplot as plt -import numpy as np -import bayes3d + from bayes3d._mkl.utils import * key = jax.random.PRNGKey(0) @@ -21,141 +31,169 @@ # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 4 import genjax._src.generative_functions.distributions.tensorflow_probability as gentfp import tensorflow_probability.substrates.jax as tfp + tfd = tfp.distributions uniform = genjax.tfp_uniform truncnormal = gentfp.TFPDistribution( - lambda mu, sig, low, high: tfd.TruncatedNormal(mu, sig, low, high)); + lambda mu, sig, low, high: tfd.TruncatedNormal(mu, sig, low, high) +) -normal = gentfp.TFPDistribution( - lambda mu, sig: tfd.Normal(mu, sig)); +normal = gentfp.TFPDistribution(lambda mu, sig: tfd.Normal(mu, sig)) diagnormal = gentfp.TFPDistribution( - lambda mus, sigs: tfd.MultivariateNormalDiag(mus, sigs)); + lambda mus, sigs: tfd.MultivariateNormalDiag(mus, sigs) +) mixture_of_diagnormals = gentfp.TFPDistribution( lambda ws, mus, sig: tfd.MixtureSameFamily( - tfd.Categorical(ws), - tfd.MultivariateNormalDiag(mus, sig * jnp.ones_like(mus)))) + tfd.Categorical(ws), tfd.MultivariateNormalDiag(mus, sig * jnp.ones_like(mus)) + ) +) mixture_of_normals = gentfp.TFPDistribution( lambda ws, mus, sig: tfd.MixtureSameFamily( - tfd.Categorical(ws), - tfd.Normal(mus, sig * jnp.ones_like(mus)))) + tfd.Categorical(ws), tfd.Normal(mus, sig * jnp.ones_like(mus)) + ) +) mixture_of_truncnormals = gentfp.TFPDistribution( lambda ws, mus, sigs, lows, highs: tfd.MixtureSameFamily( - tfd.Categorical(ws), - tfd.TruncatedNormal(mus, sigs, lows, highs))) + tfd.Categorical(ws), tfd.TruncatedNormal(mus, sigs, lows, highs) + ) +) # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 5 -from scipy.stats import truncnorm as scipy_truncnormal -normal_logpdf = jax.scipy.stats.norm.logpdf +normal_logpdf = jax.scipy.stats.norm.logpdf truncnorm_logpdf = jax.scipy.stats.truncnorm.logpdf -truncnorm_pdf = jax.scipy.stats.truncnorm.pdf +truncnorm_pdf = jax.scipy.stats.truncnorm.pdf # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 8 -# TODO: The input Y should be an array only containing range measruements as well. +# TODO: The input Y should be an array only containing range measruements as well. # For this to work we need to have the pixel vectors (the rays through each pixel) + def make_simple_sensor_model(zmax): - """Returns an simple sensor model marginalized over outliers.""" + """Returns an simple sensor model marginalized over outliers.""" @genjax.drop_arguments @genjax.gen def _sensor_model(y, sig, outlier): - - # Compute max range along ray ending at far plane # and adding some wiggle room z_ = jnp.linalg.norm(y) - zmax_ = z_/y[2]*zmax + zmax_ = z_ / y[2] * zmax - inlier_outlier_mix = genjax.tfp_mixture(genjax.tfp_categorical, [truncnormal, genjax.tfp_uniform]) - z = inlier_outlier_mix([jnp.log(1.0-outlier), jnp.log(outlier)], ( - (z_, sig, 0.0, zmax_), - (0.0, zmax_ + 1e-6))) @ "measurement" + inlier_outlier_mix = genjax.tfp_mixture( + genjax.tfp_categorical, [truncnormal, genjax.tfp_uniform] + ) + z = ( + inlier_outlier_mix( + [jnp.log(1.0 - outlier), jnp.log(outlier)], + ((z_, sig, 0.0, zmax_), (0.0, zmax_ + 1e-6)), + ) + @ "measurement" + ) z = jnp.clip(z, 0.0, zmax_) - return z * y/z_ + return z * y / z_ - @genjax.gen - def sensor_model(Y, sig, out): + def sensor_model(Y, sig, out): """ - Simplest sensor model that returns a vector of range measurements conditioned on + Simplest sensor model that returns a vector of range measurements conditioned on an image, noise level, and outlier probability. """ - - X = genjax.Map(_sensor_model, (0,None,None))(Y[...,:3].reshape(-1,3), sig, out) @ "X" + + X = ( + genjax.Map(_sensor_model, (0, None, None))( + Y[..., :3].reshape(-1, 3), sig, out + ) + @ "X" + ) X = X.reshape(Y.shape) return X return sensor_model + # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 12 def make_simple_step_sensor_model(far): - """Returns an simple step function sensor model marginalized over outliers.""" + """Returns an simple step function sensor model marginalized over outliers.""" @genjax.drop_arguments @genjax.gen def _sensor_model_pixel(y, sig, out): - - # Compute max range along ray ending at far plane - r_ = jnp.linalg.norm(y) - rmax = r_/y[2]*far + r_ = jnp.linalg.norm(y) + rmax = r_ / y[2] * far inlier_outlier_mix = genjax.tfp_mixture( - genjax.tfp_categorical, - [genjax.tfp_uniform, genjax.tfp_uniform]) + genjax.tfp_categorical, [genjax.tfp_uniform, genjax.tfp_uniform] + ) # The `1e-4` term helps with numerical issues from computing rmax # at least that's what I think - r = inlier_outlier_mix( - [jnp.log(1 - out), jnp.log(out)], - ((jnp.maximum(r_-sig, 0.0) , jnp.minimum(r_+sig, rmax)), (0.0, rmax + 1e-4))) @ "measurement" + r = ( + inlier_outlier_mix( + [jnp.log(1 - out), jnp.log(out)], + ( + (jnp.maximum(r_ - sig, 0.0), jnp.minimum(r_ + sig, rmax)), + (0.0, rmax + 1e-4), + ), + ) + @ "measurement" + ) r = jnp.clip(r, 0.0, rmax) - return r * y/r_ + return r * y / r_ - @genjax.gen def sensor_model(Y, sig, out): """ - Simplest sensor model that returns a vector of range measurements conditioned on + Simplest sensor model that returns a vector of range measurements conditioned on an image, noise level, and outlier probability. """ - - X = genjax.Map(_sensor_model_pixel, (0,None,None))(Y[...,:3].reshape(-1,3), sig, out) @ "X" - X = X.reshape(Y[...,:3].shape) + + X = ( + genjax.Map(_sensor_model_pixel, (0, None, None))( + Y[..., :3].reshape(-1, 3), sig, out + ) + @ "X" + ) + X = X.reshape(Y[..., :3].shape) return X return sensor_model + # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 15 from genjax._src.generative_functions.distributions.distribution import ExactDensity + def wrap_into_dist(score_func): """ - Takes a scoring function + Takes a scoring function - `score_func(observed, latent, ...)` + `score_func(observed, latent, ...)` and wraps it into a genjax distribution. """ + class WrappedScoreFunc(ExactDensity): - def sample(self, key, latent, *args): return latent - def logpdf(self, observed, latent, *args): return score_func(observed, latent, *args) + def sample(self, key, latent, *args): + return latent - return WrappedScoreFunc() + def logpdf(self, observed, latent, *args): + return score_func(observed, latent, *args) + return WrappedScoreFunc() diff --git a/bayes3d/_mkl/table_scene_model.py b/bayes3d/_mkl/table_scene_model.py index 062f18a6..26a4ba66 100644 --- a/bayes3d/_mkl/table_scene_model.py +++ b/bayes3d/_mkl/table_scene_model.py @@ -1,45 +1,44 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/30 - Table Scene Model.ipynb. # %% auto 0 -__all__ = ['normal_logpdf', 'normal_pdf', 'truncnorm_logpdf', 'truncnorm_pdf', 'inv', 'logaddexp', 'logsumexp', 'key', - 'make_table_scene_model'] +__all__ = [ + "normal_logpdf", + "normal_pdf", + "truncnorm_logpdf", + "truncnorm_pdf", + "inv", + "logaddexp", + "logsumexp", + "key", + "make_table_scene_model", +] # %% ../../scripts/_mkl/notebooks/30 - Table Scene Model.ipynb 2 -import bayes3d as b3d -import bayes3d.genjax -import joblib -from tqdm import tqdm -import os -import jax.numpy as jnp -import jax -from jax import jit, vmap -import numpy as np import genjax -import trimesh -import matplotlib.pyplot as plt +import jax +import jax.numpy as jnp + from bayes3d.genjax.genjax_distributions import * # console = genjax.pretty(show_locals=False) # %% ../../scripts/_mkl/notebooks/30 - Table Scene Model.ipynb 3 -from jax.scipy.spatial.transform import Rotation -from scipy.stats import truncnorm as scipy_truncnormal -normal_logpdf = jax.scipy.stats.norm.logpdf -normal_pdf = jax.scipy.stats.norm.pdf +normal_logpdf = jax.scipy.stats.norm.logpdf +normal_pdf = jax.scipy.stats.norm.pdf truncnorm_logpdf = jax.scipy.stats.truncnorm.logpdf -truncnorm_pdf = jax.scipy.stats.truncnorm.pdf +truncnorm_pdf = jax.scipy.stats.truncnorm.pdf -inv = jnp.linalg.inv +inv = jnp.linalg.inv logaddexp = jnp.logaddexp logsumexp = jax.scipy.special.logsumexp key = jax.random.PRNGKey(0) # %% ../../scripts/_mkl/notebooks/30 - Table Scene Model.ipynb 4 -from bayes3d._mkl.utils import keysplit from bayes3d._mkl.plotting import * + # %% ../../scripts/_mkl/notebooks/30 - Table Scene Model.ipynb 11 def make_table_scene_model(): """ @@ -52,13 +51,13 @@ def make_table_scene_model(): table = jnp.eye(4) cam = b3d.transform_from_pos_target_up( - jnp.array([0.0, -.5, -.75]), - jnp.zeros(3), + jnp.array([0.0, -.5, -.75]), + jnp.zeros(3), jnp.array([0.0,-1.0,0.0])) args = ( - jnp.arange(3), - jnp.arange(22), + jnp.arange(3), + jnp.arange(22), jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]), jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]), b3d.RENDERER.model_box_dims @@ -89,43 +88,48 @@ def make_table_scene_model(): """ @genjax.gen - def model(nums, - possible_object_indices, - pose_bounds, - contact_bounds, - all_box_dims): - - num_objects = len(nums) # this is a hack, otherwise genajx is complaining - - indices = jnp.array([], dtype=jnp.int32) - root_poses = jnp.zeros((0,4,4)) - contact_params = jnp.zeros((0,3)) - faces_parents = jnp.array([], dtype=jnp.int32) - faces_child = jnp.array([], dtype=jnp.int32) - parents = jnp.array([], dtype=jnp.int32) + def model(nums, possible_object_indices, pose_bounds, contact_bounds, all_box_dims): + num_objects = len(nums) # this is a hack, otherwise genajx is complaining - for i in range(num_objects): - - index = uniform_discrete(possible_object_indices) @ f"id_{i}" - pose = uniform_pose(pose_bounds[0], pose_bounds[1]) @ f"root_pose_{i}" - params = contact_params_uniform(contact_bounds[0], contact_bounds[1]) @ f"contact_params_{i}" - - parent_obj = uniform_discrete(jnp.arange(-1, num_objects - 1)) @ f"parent_{i}" - parent_face = uniform_discrete(jnp.arange(0,6)) @ f"face_parent_{i}" - child_face = uniform_discrete(jnp.arange(0,6)) @ f"face_child_{i}" + indices = jnp.array([], dtype=jnp.int32) + root_poses = jnp.zeros((0, 4, 4)) + contact_params = jnp.zeros((0, 3)) + faces_parents = jnp.array([], dtype=jnp.int32) + faces_child = jnp.array([], dtype=jnp.int32) + parents = jnp.array([], dtype=jnp.int32) - indices = jnp.concatenate([indices, jnp.array([index])]) - root_poses = jnp.concatenate([root_poses, pose.reshape(1,4,4)]) - contact_params = jnp.concatenate([contact_params, params.reshape(1,-1)]) - parents = jnp.concatenate([parents, jnp.array([parent_obj])]) - faces_parents = jnp.concatenate([faces_parents, jnp.array([parent_face])]) - faces_child = jnp.concatenate([faces_child, jnp.array([child_face])]) - - - scene = (root_poses, all_box_dims[indices], parents, contact_params, faces_parents, faces_child) + for i in range(num_objects): + index = uniform_discrete(possible_object_indices) @ f"id_{i}" + pose = uniform_pose(pose_bounds[0], pose_bounds[1]) @ f"root_pose_{i}" + params = ( + contact_params_uniform(contact_bounds[0], contact_bounds[1]) + @ f"contact_params_{i}" + ) + + parent_obj = ( + uniform_discrete(jnp.arange(-1, num_objects - 1)) @ f"parent_{i}" + ) + parent_face = uniform_discrete(jnp.arange(0, 6)) @ f"face_parent_{i}" + child_face = uniform_discrete(jnp.arange(0, 6)) @ f"face_child_{i}" + + indices = jnp.concatenate([indices, jnp.array([index])]) + root_poses = jnp.concatenate([root_poses, pose.reshape(1, 4, 4)]) + contact_params = jnp.concatenate([contact_params, params.reshape(1, -1)]) + parents = jnp.concatenate([parents, jnp.array([parent_obj])]) + faces_parents = jnp.concatenate([faces_parents, jnp.array([parent_face])]) + faces_child = jnp.concatenate([faces_child, jnp.array([child_face])]) + + scene = ( + root_poses, + all_box_dims[indices], + parents, + contact_params, + faces_parents, + faces_child, + ) poses = b.scene_graph.poses_from_scene_graph(*scene) - camera_pose = uniform_pose(pose_bounds[0], pose_bounds[1]) @ f"camera_pose" + camera_pose = uniform_pose(pose_bounds[0], pose_bounds[1]) @ "camera_pose" return camera_pose, poses, indices diff --git a/bayes3d/_mkl/trimesh_to_gaussians.py b/bayes3d/_mkl/trimesh_to_gaussians.py index 391f6f8d..722e2284 100644 --- a/bayes3d/_mkl/trimesh_to_gaussians.py +++ b/bayes3d/_mkl/trimesh_to_gaussians.py @@ -5,10 +5,10 @@ **Example:** ```python from bayes3d._mkl.trimesh_to_gaussians import ( - patch_trimesh, - uniformly_sample_from_mesh, - ellipsoid_embedding, - get_mean_colors, + patch_trimesh, + uniformly_sample_from_mesh, + ellipsoid_embedding, + get_mean_colors, pack_transform, transform_from_gaussian ) @@ -38,15 +38,15 @@ # ---------- key = keysplit(key) n_components = 150 -noise = 0.0; +noise = 0.0; X = xs + np.random.randn(*xs.shape)*noise means_init = np.array(uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]); # FIT THE GMM # ----------- -gm = GaussianMixture(n_components=n_components, - tol=1e-3, max_iter=100, - covariance_type="full", +gm = GaussianMixture(n_components=n_components, + tol=1e-3, max_iter=100, + covariance_type="full", means_init=means_init).fit(X) mus = gm.means_ @@ -61,12 +61,38 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb. # %% auto 0 -__all__ = ['Array', 'Shape', 'FaceIndex', 'FaceIndices', 'Array3', 'Array2', 'ArrayNx2', 'ArrayNx3', 'Matrix', 'PrecisionMatrix', - 'CovarianceMatrix', 'SquareMatrix', 'Vector', 'compute_area_and_normals', 'area_of_triangle', - 'patch_trimesh', 'texture_uv_basis', 'uv_to_color', 'barycentric_to_mesh', 'sample_from_face', - 'sample_from_mesh', 'get_colors_from_mesh', 'uniformly_sample_from_mesh', 'get_cluster_counts', - 'get_cluster_colors', 'get_mean_colors', 'ellipsoid_embedding', 'pack_transform', 'transform_from_gaussian', - 'create_ellipsoid_trimesh'] +__all__ = [ + "Array", + "Shape", + "FaceIndex", + "FaceIndices", + "Array3", + "Array2", + "ArrayNx2", + "ArrayNx3", + "Matrix", + "PrecisionMatrix", + "CovarianceMatrix", + "SquareMatrix", + "Vector", + "compute_area_and_normals", + "area_of_triangle", + "patch_trimesh", + "texture_uv_basis", + "uv_to_color", + "barycentric_to_mesh", + "sample_from_face", + "sample_from_mesh", + "get_colors_from_mesh", + "uniformly_sample_from_mesh", + "get_cluster_counts", + "get_cluster_colors", + "get_mean_colors", + "ellipsoid_embedding", + "pack_transform", + "transform_from_gaussian", + "create_ellipsoid_trimesh", +] # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 2 _doc_ = """ @@ -76,10 +102,10 @@ **Example:** ```python from bayes3d._mkl.trimesh_to_gaussians import ( - patch_trimesh, - uniformly_sample_from_mesh, - ellipsoid_embedding, - get_mean_colors, + patch_trimesh, + uniformly_sample_from_mesh, + ellipsoid_embedding, + get_mean_colors, pack_transform, transform_from_gaussian ) @@ -109,15 +135,15 @@ # ---------- key = keysplit(key) n_components = 150 -noise = 0.0; +noise = 0.0; X = xs + np.random.randn(*xs.shape)*noise means_init = np.array(uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]); # FIT THE GMM # ----------- -gm = GaussianMixture(n_components=n_components, - tol=1e-3, max_iter=100, - covariance_type="full", +gm = GaussianMixture(n_components=n_components, + tol=1e-3, max_iter=100, + covariance_type="full", means_init=means_init).fit(X) mus = gm.means_ @@ -130,40 +156,38 @@ """ # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 3 -import bayes3d as b3d -import trimesh -from bayes3d._mkl.utils import * -import os -import matplotlib.pyplot as plt -import numpy as np import jax -from jax import jit, vmap import jax.numpy as jnp -from typing import Any, NamedTuple import jaxlib +import numpy as np +import trimesh +from jax import jit, vmap + +from bayes3d._mkl.utils import * Array = np.ndarray | jax.Array Shape = int | tuple[int, ...] FaceIndex = int FaceIndices = Array -Array3 = Array -Array2 = Array -ArrayNx2 = Array -ArrayNx3 = Array -Matrix = jaxlib.xla_extension.ArrayImpl -PrecisionMatrix = Matrix +Array3 = Array +Array2 = Array +ArrayNx2 = Array +ArrayNx3 = Array +Matrix = jaxlib.xla_extension.ArrayImpl +PrecisionMatrix = Matrix CovarianceMatrix = Matrix -SquareMatrix = Matrix +SquareMatrix = Matrix Vector = Array + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 5 -def area_of_triangle(a:Array3, b:Array3, c:Array3=jnp.zeros(3)): +def area_of_triangle(a: Array3, b: Array3, c: Array3 = jnp.zeros(3)): """Computes the area of a triangle spanned by a,b[,c].""" - x = a-c - y = b-c + x = a - c + y = b - c w = jnp.linalg.norm(x) - h = jnp.linalg.norm(y - jnp.dot(x, y)/w**2 * x) - area = w*h/2 + h = jnp.linalg.norm(y - jnp.dot(x, y) / w**2 * x) + area = w * h / 2 return area @@ -172,16 +196,17 @@ def _compute_area_and_normal(f, vertices): a = vertices[f[1]] - vertices[f[0]] b = vertices[f[2]] - vertices[f[0]] area = area_of_triangle(a, b) - normal = jnp.cross(a,b) + normal = jnp.cross(a, b) return area, normal -compute_area_and_normals = jit(vmap(_compute_area_and_normal, (0,None))) +compute_area_and_normals = jit(vmap(_compute_area_and_normal, (0, None))) + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 6 -def patch_trimesh(mesh:trimesh.base.Trimesh): +def patch_trimesh(mesh: trimesh.base.Trimesh): """ - Return a patched copy of a trimesh object, and + Return a patched copy of a trimesh object, and ensure it to have a texture and the following attributes: - `mesh.visual.uv` - `copy.visual.material.to_color` @@ -195,64 +220,66 @@ def patch_trimesh(mesh:trimesh.base.Trimesh): return patched_mesh -def texture_uv_basis(face_idx:Array, mesh): +def texture_uv_basis(face_idx: Array, mesh): """ - Takes a face index and returns the three uv-vectors + Takes a face index and returns the three uv-vectors spanning the face in texture space. """ return mesh.visual.uv[mesh.faces[face_idx]] -def uv_to_color(uv:ArrayNx2, mesh): +def uv_to_color(uv: ArrayNx2, mesh): """Takes texture-uv coordinates and returns the corresponding color.""" return mesh.visual.material.to_color(uv) + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 7 -def barycentric_to_mesh(p:Array3, i:FaceIndex, mesh): +def barycentric_to_mesh(p: Array3, i: FaceIndex, mesh): """Converts a point in barycentric coordinates `p` on a face `i` to a 3d point on the mesh.""" - x = jnp.sum(p[:,None]*mesh.vertices[mesh.faces[i]], axis=0) + x = jnp.sum(p[:, None] * mesh.vertices[mesh.faces[i]], axis=0) return x def sample_from_face(key, n, i, mesh): """ - Sample random points `xs`, barycentric coordinates `ps`, and + Sample random points `xs`, barycentric coordinates `ps`, and face indices `fs` from a mesh. """ - _, key = keysplit(key,1,1) - ps = jax.random.dirichlet(key, jnp.ones(3), (n,)).reshape((n,3,1)) - xs = jnp.sum(ps*mesh.vertices[mesh.faces[i]], axis=1) + _, key = keysplit(key, 1, 1) + ps = jax.random.dirichlet(key, jnp.ones(3), (n,)).reshape((n, 3, 1)) + xs = jnp.sum(ps * mesh.vertices[mesh.faces[i]], axis=1) return xs, ps def sample_from_mesh(key, n, mesh): """ - Returns random points `xs`, barycentric coordinates `ps`, and + Returns random points `xs`, barycentric coordinates `ps`, and face indices `fs` from a mesh. """ - _, keys = keysplit(key,1,2) + _, keys = keysplit(key, 1, 2) - # Sample `n` faces from the mesh with - # probability proportional to their area. + # Sample `n` faces from the mesh with + # probability proportional to their area. areas, _ = compute_area_and_normals(mesh.faces, mesh.vertices) fs = jax.random.categorical(keys[0], jnp.log(areas), shape=(n,)) # Sample barycentric coordinates `bs` for each sampled face # and compute the corresponding world coordinates `xs`. - ps = jax.random.dirichlet(keys[1], jnp.ones(3), (n,)).reshape((n,3,1)) - xs = jnp.sum(ps*mesh.vertices[mesh.faces[fs]], axis=1) + ps = jax.random.dirichlet(keys[1], jnp.ones(3), (n,)).reshape((n, 3, 1)) + xs = jnp.sum(ps * mesh.vertices[mesh.faces[fs]], axis=1) return xs, ps, fs - -def get_colors_from_mesh(ps:ArrayNx3, fs:FaceIndices, mesh): + +def get_colors_from_mesh(ps: ArrayNx3, fs: FaceIndices, mesh): """ - Returns the colors of the points on the mesh given + Returns the colors of the points on the mesh given their barycentric coordinates `ps` and face indices `fs`. """ uvs = jnp.sum(ps * texture_uv_basis(fs, mesh), axis=1) - cs = uv_to_color(uvs, mesh)/255 + cs = uv_to_color(uvs, mesh) / 255 return cs + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 9 def uniformly_sample_from_mesh(key, n, mesh, with_color=True): """Uniformly sample `n` points and optionally their color on the surface from a mesh.""" @@ -261,10 +288,11 @@ def uniformly_sample_from_mesh(key, n, mesh, with_color=True): if with_color: cs = get_colors_from_mesh(ps, fs, mesh) else: - cs = jnp.full((n,3), 0.5) + cs = jnp.full((n, 3), 0.5) return xs, cs + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 10 def get_cluster_counts(m, labels): nums = [] @@ -273,32 +301,33 @@ def get_cluster_counts(m, labels): return np.array(nums) -#|export +# |export def get_cluster_colors(cs, m, labels): colors = [] for label in range(m): colors.append(cs[labels == label]) return colors + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 11 def get_mean_colors(cs, n, labels): mean_colors = [] - nums = [] + nums = [] for label in range(n): idx = labels == label num = np.sum(idx) - if num == 0: + if num == 0: c = np.array([0.5, 0.5, 0.5, 0.0]) - else: + else: c = np.mean(cs[idx], axis=0) nums.append(num) mean_colors.append(c) return np.array(mean_colors), np.array(nums) - + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 12 -def ellipsoid_embedding(cov:CovarianceMatrix) -> Matrix: +def ellipsoid_embedding(cov: CovarianceMatrix) -> Matrix: """Returns A with cov = A@A.T""" sigma, U = jnp.linalg.eigh(cov) D = jnp.diag(jnp.sqrt(sigma)) @@ -307,25 +336,32 @@ def ellipsoid_embedding(cov:CovarianceMatrix) -> Matrix: # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 13 def pack_transform(x, A, scale=1.0): - B = scale*A - return jnp.array([ - [B[0,0], B[0,1], B[0,2], x[0]], - [B[1,0], B[1,1], B[1,2], x[1]], - [B[2,0], B[2,1], B[2,2], x[2]], - [0.0, 0.0, 0.0, 1.0] - ]).T - - -def transform_from_gaussian(mu:Vector, cov:CovarianceMatrix=jnp.eye(3), scale=1.0) -> Matrix: + B = scale * A + return jnp.array( + [ + [B[0, 0], B[0, 1], B[0, 2], x[0]], + [B[1, 0], B[1, 1], B[1, 2], x[1]], + [B[2, 0], B[2, 1], B[2, 2], x[2]], + [0.0, 0.0, 0.0, 1.0], + ] + ).T + + +def transform_from_gaussian( + mu: Vector, cov: CovarianceMatrix = jnp.eye(3), scale=1.0 +) -> Matrix: """Returns an affine linear transformation 4x4 matrix from a Gaussian.""" A = ellipsoid_embedding(cov) B = scale * A - return jnp.array([ - [B[0,0], B[0,1], B[0,2], mu[0]], - [B[1,0], B[1,1], B[1,2], mu[1]], - [B[2,0], B[2,1], B[2,2], mu[2]], - [0.0, 0.0, 0.0, 1.0] - ]).T + return jnp.array( + [ + [B[0, 0], B[0, 1], B[0, 2], mu[0]], + [B[1, 0], B[1, 1], B[1, 2], mu[1]], + [B[2, 0], B[2, 1], B[2, 2], mu[2]], + [0.0, 0.0, 0.0, 1.0], + ] + ).T + # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 14 def create_ellipsoid_trimesh(covariance_matrix, num_points=10, scale=0.02): @@ -339,13 +375,15 @@ def create_ellipsoid_trimesh(covariance_matrix, num_points=10, scale=0.02): # Transform the sphere to the ellipsoid sigma, U = np.linalg.eig(covariance_matrix) D = np.diag(np.sqrt(sigma)) - ellipsoid = U @ D @ np.linalg.inv(U) @ np.vstack([x.flatten(), y.flatten(), z.flatten()]) + ellipsoid = ( + U @ D @ np.linalg.inv(U) @ np.vstack([x.flatten(), y.flatten(), z.flatten()]) + ) # Reshape the ellipsoid to match the shape of the original sphere vertices ellipsoid = ellipsoid.T.reshape(num_points, num_points, 3) # Create mesh data - mesh_vertices = scale*ellipsoid.reshape(-1, 3) + mesh_vertices = scale * ellipsoid.reshape(-1, 3) mesh_faces = [] for i in range(num_points - 1): for j in range(num_points - 1): diff --git a/bayes3d/_mkl/types.py b/bayes3d/_mkl/types.py index 139626c3..47ce0a98 100644 --- a/bayes3d/_mkl/types.py +++ b/bayes3d/_mkl/types.py @@ -1,16 +1,33 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/00a - Types.ipynb. # %% auto 0 -__all__ = ['Array', 'Shape', 'Bool', 'Float', 'Int', 'FaceIndex', 'FaceIndices', 'ArrayN', 'Array3', 'Array2', 'ArrayNx2', - 'ArrayNx3', 'Matrix', 'PrecisionMatrix', 'CovarianceMatrix', 'CholeskyMatrix', 'SquareMatrix', 'Vector', - 'Direction', 'BaseVector'] +__all__ = [ + "Array", + "Shape", + "Bool", + "Float", + "Int", + "FaceIndex", + "FaceIndices", + "ArrayN", + "Array3", + "Array2", + "ArrayNx2", + "ArrayNx3", + "Matrix", + "PrecisionMatrix", + "CovarianceMatrix", + "CholeskyMatrix", + "SquareMatrix", + "Vector", + "Direction", + "BaseVector", +] # %% ../../scripts/_mkl/notebooks/00a - Types.ipynb 1 -from typing import Any, NamedTuple -import numpy as np import jax import jaxlib - +import numpy as np Array = np.ndarray | jax.Array Shape = int | tuple[int, ...] @@ -19,16 +36,16 @@ Int = Array FaceIndex = int FaceIndices = Array -ArrayN = Array -Array3 = Array -Array2 = Array -ArrayNx2 = Array -ArrayNx3 = Array -Matrix = jaxlib.xla_extension.ArrayImpl -PrecisionMatrix = Matrix +ArrayN = Array +Array3 = Array +Array2 = Array +ArrayNx2 = Array +ArrayNx3 = Array +Matrix = jaxlib.xla_extension.ArrayImpl +PrecisionMatrix = Matrix CovarianceMatrix = Matrix -CholeskyMatrix = Matrix -SquareMatrix = Matrix -Vector = Array -Direction = Vector +CholeskyMatrix = Matrix +SquareMatrix = Matrix +Vector = Array +Direction = Vector BaseVector = Vector diff --git a/bayes3d/_mkl/utils.py b/bayes3d/_mkl/utils.py index 279b3f65..85609306 100644 --- a/bayes3d/_mkl/utils.py +++ b/bayes3d/_mkl/utils.py @@ -1,88 +1,122 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/00b - Utils.ipynb. # %% auto 0 -__all__ = ['key', 'logsumexp', 'cls', 'keysplit', 'bounding_box', 'argmax_axes', 'cam_to_screen', 'screen_to_cam', 'rot2d', - 'pack_2dpose', 'apply_2dpose', 'unit_vec', 'adjust_angle', 'argdiffs', 'Args', 'genjax_sample', - 'deff_gen_func_call', 'deff_gen_func_logpdf'] +__all__ = [ + "key", + "logsumexp", + "cls", + "keysplit", + "bounding_box", + "argmax_axes", + "cam_to_screen", + "screen_to_cam", + "rot2d", + "pack_2dpose", + "apply_2dpose", + "unit_vec", + "adjust_angle", + "argdiffs", + "Args", + "genjax_sample", + "deff_gen_func_call", + "deff_gen_func_logpdf", +] # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 2 -import matplotlib.pyplot as plt -from matplotlib.collections import LineCollection -import numpy as np +import genjax import jax import jax.numpy as jnp -import genjax +import numpy as np # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 4 -key = jax.random.PRNGKey(0) +key = jax.random.PRNGKey(0) logsumexp = jax.scipy.special.logsumexp + # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 5 def keysplit(key, *ns): - if len(ns) == 0: + if len(ns) == 0: return jax.random.split(key, 1)[0] elif len(ns) == 1: - n, = ns - if n == 1: return keysplit(key) - else: return jax.random.split(key, ns[0]) + (n,) = ns + if n == 1: + return keysplit(key) + else: + return jax.random.split(key, ns[0]) else: keys = [] - for n in ns: keys.append(keysplit(key, n)) + for n in ns: + keys.append(keysplit(key, n)) return keys # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 7 def bounding_box(arr, pad=0): """Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box.""" - return jnp.array([ - [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad], - [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad] - ]) + return jnp.array( + [ + [jnp.min(arr[..., 0]) - pad, jnp.min(arr[..., 1]) - pad], + [jnp.max(arr[..., 0]) + pad, jnp.max(arr[..., 1]) + pad], + ] + ) + # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 8 def argmax_axes(a, axes=None): """Argmax along specified axes""" - if axes is None: return jnp.argmax(a) - - n = len(axes) - axes_ = set(range(a.ndim)) + if axes is None: + return jnp.argmax(a) + + n = len(axes) + axes_ = set(range(a.ndim)) axes_0 = axes - axes_1 = sorted(axes_ - set(axes_0)) - axes_ = axes_0 + axes_1 + axes_1 = sorted(axes_ - set(axes_0)) + axes_ = axes_0 + axes_1 b = jnp.transpose(a, axes=axes_) c = b.reshape(np.prod(b.shape[:n]), -1) I = jnp.argmax(c, axis=0) - I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,)) + I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape( + b.shape[n:] + (n,) + ) + + return I - return I # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 11 -def cam_to_screen(x): return jnp.array([x[0]/x[2], x[1]/x[2], jnp.linalg.norm(x)]) -def screen_to_cam(y): return y[2]*jnp.array([y[0], y[1], 1.0]) +def cam_to_screen(x): + return jnp.array([x[0] / x[2], x[1] / x[2], jnp.linalg.norm(x)]) + + +def screen_to_cam(y): + return y[2] * jnp.array([y[0], y[1], 1.0]) + # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 12 -def rot2d(hd): return jnp.array([ - [jnp.cos(hd), -jnp.sin(hd)], - [jnp.sin(hd), jnp.cos(hd)] - ]); +def rot2d(hd): + return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]]) + -def pack_2dpose(x,hd): - return jnp.concatenate([x,jnp.array([hd])]) +def pack_2dpose(x, hd): + return jnp.concatenate([x, jnp.array([hd])]) -def apply_2dpose(p, ys): - return ys@rot2d(p[2] - jnp.pi/2).T + p[:2] -def unit_vec(hd): +def apply_2dpose(p, ys): + return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2] + + +def unit_vec(hd): return jnp.array([jnp.cos(hd), jnp.sin(hd)]) + def adjust_angle(hd): """Adjusts angle to lie in the interval [-pi,pi).""" - return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi + return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi + # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 14 -from genjax.incremental import UnknownChange, NoChange, Diff +from genjax.incremental import Diff, UnknownChange def argdiffs(args, other=None): @@ -90,47 +124,53 @@ def argdiffs(args, other=None): # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 15 -from builtins import property as _property, tuple as _tuple +from builtins import tuple as _tuple from typing import Any class Args(tuple): def __new__(cls, *args, **kwargs): return _tuple.__new__(cls, list(args) + list(kwargs.values())) - + def __init__(self, *args, **kwargs): self._d = dict() - for k,v in kwargs.items(): + for k, v in kwargs.items(): self._d[k] = v setattr(self, k, v) def __getitem__(self, k: str) -> Any: return self._d[k] + # %% ../../scripts/_mkl/notebooks/00b - Utils.ipynb 17 -# +# # Monkey patching `sample` for `BuiltinGenerativeFunction` -# +# cls = genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction + def genjax_sample(self, key, *args, **kwargs): tr = self.simulate(key, args) return tr.get_retval() + setattr(cls, "sample", genjax_sample) -# +# # Monkey patching `sample` for `DeferredGenerativeFunctionCall` -# +# cls = genjax._src.generative_functions.supports_callees.SugaredGenerativeFunctionCall + def deff_gen_func_call(self, key, **kwargs): return self.gen_fn.sample(key, *self.args, **kwargs) + def deff_gen_func_logpdf(self, x, **kwargs): return self.gen_fn.logpdf(x, *self.args, **kwargs) + setattr(cls, "__call__", deff_gen_func_call) setattr(cls, "sample", deff_gen_func_call) setattr(cls, "logpdf", deff_gen_func_logpdf) diff --git a/bayes3d/camera.py b/bayes3d/camera.py index 0c7ea8e0..b71e2204 100644 --- a/bayes3d/camera.py +++ b/bayes3d/camera.py @@ -1,11 +1,17 @@ +from collections import namedtuple + import jax.numpy as jnp import numpy as np + import bayes3d as b + from .transforms_3d import add_homogenous_ones -from collections import namedtuple # Declaring namedtuple() -Intrinsics = namedtuple('Intrinsics', ['height', 'width', 'fx', 'fy', 'cx', 'cy', 'near', 'far']) +Intrinsics = namedtuple( + "Intrinsics", ["height", "width", "fx", "fy", "cx", "cy", "near", "far"] +) + def K_from_intrinsics(intrinsics): """Returns the camera matrix from the intrinsics. @@ -15,11 +21,14 @@ def K_from_intrinsics(intrinsics): Returns: (np.ndarray) The camera matrix K (3x3). """ - return np.array([ - [intrinsics.fx ,0.0, intrinsics.cx], - [0.0 , intrinsics.fy, intrinsics.cy], - [0.0 ,0.0, 1.0], - ]) + return np.array( + [ + [intrinsics.fx, 0.0, intrinsics.cx], + [0.0, intrinsics.fy, intrinsics.cy], + [0.0, 0.0, 1.0], + ] + ) + def scale_camera_parameters(intrinsics, scaling_factor): """Scale the camera parameters by a given factor. @@ -35,9 +44,12 @@ def scale_camera_parameters(intrinsics, scaling_factor): new_cx = intrinsics.cx * scaling_factor new_cy = intrinsics.cy * scaling_factor - new_h = int(np.round(intrinsics.height * scaling_factor)) + new_h = int(np.round(intrinsics.height * scaling_factor)) new_w = int(np.round(intrinsics.width * scaling_factor)) - return Intrinsics(new_h, new_w, new_fx, new_fy, new_cx, new_cy, intrinsics.near, intrinsics.far) + return Intrinsics( + new_h, new_w, new_fx, new_fy, new_cx, new_cy, intrinsics.near, intrinsics.far + ) + def camera_rays_from_intrinsics(intrinsics): """Returns the camera rays from the intrinsics. @@ -47,16 +59,20 @@ def camera_rays_from_intrinsics(intrinsics): Returns: (np.ndarray) The camera rays (height x width x 3). """ - rows, cols = jnp.meshgrid(jnp.arange(intrinsics.width), jnp.arange(intrinsics.height)) - pixel_coords = jnp.stack([rows,cols],axis=-1) - pixel_coords_dir = (pixel_coords - jnp.array([intrinsics.cx, intrinsics.cy])) / jnp.array([intrinsics.fx, intrinsics.fy]) + rows, cols = jnp.meshgrid( + jnp.arange(intrinsics.width), jnp.arange(intrinsics.height) + ) + pixel_coords = jnp.stack([rows, cols], axis=-1) + pixel_coords_dir = ( + pixel_coords - jnp.array([intrinsics.cx, intrinsics.cy]) + ) / jnp.array([intrinsics.fx, intrinsics.fy]) pixel_coords_dir_h = add_homogenous_ones(pixel_coords_dir) return pixel_coords_dir_h def project_cloud_to_pixels(point_cloud, intrinsics): """Project a point cloud to pixels. - + Args: point_cloud (jnp.ndarray): The point cloud. Shape (N, 3) intrinsics (bayes3d.camera.Intrinsics): The camera intrinsics. @@ -64,13 +80,14 @@ def project_cloud_to_pixels(point_cloud, intrinsics): jnp.ndarray: The pixels. Shape (N, 2) """ point_cloud_normalized = point_cloud / point_cloud[:, 2].reshape(-1, 1) - temp1 = point_cloud_normalized[:, :2] * jnp.array([intrinsics.fx,intrinsics.fy]) + temp1 = point_cloud_normalized[:, :2] * jnp.array([intrinsics.fx, intrinsics.fy]) pixels = temp1 + jnp.array([intrinsics.cx, intrinsics.cy]) return pixels + def render_point_cloud(point_cloud, intrinsics, pixel_smudge=1): """Render a point cloud to an image. - + Args: point_cloud (jnp.ndarray): The point cloud. Shape (N, 3) intrinsics (bayes3d.camera.Intrinsics): The camera intrinsics. @@ -81,14 +98,17 @@ def render_point_cloud(point_cloud, intrinsics, pixel_smudge=1): point_cloud = jnp.vstack([jnp.zeros((1, 3)), transformed_cloud]) pixels = project_cloud_to_pixels(point_cloud, intrinsics) x, y = jnp.meshgrid(jnp.arange(intrinsics.width), jnp.arange(intrinsics.height)) - matches = (jnp.abs(x[:, :, None] - pixels[:, 0]) <= pixel_smudge) & (jnp.abs(y[:, :, None] - pixels[:, 1]) <= pixel_smudge) - matches = matches * (intrinsics.far * 2.0 - point_cloud[:,-1][None, None, :]) - a = jnp.argmax(matches, axis=-1) + matches = (jnp.abs(x[:, :, None] - pixels[:, 0]) <= pixel_smudge) & ( + jnp.abs(y[:, :, None] - pixels[:, 1]) <= pixel_smudge + ) + matches = matches * (intrinsics.far * 2.0 - point_cloud[:, -1][None, None, :]) + a = jnp.argmax(matches, axis=-1) return point_cloud[a] + def render_point_cloud_batched(point_cloud, intrinsics, NUM_PER, pixel_smudge=1): """Render a point cloud to an image in batches. - + Args: point_cloud (jnp.ndarray): The point cloud. Shape (N, 3) intrinsics (bayes3d.camera.Intrinsics): The camera intrinsics. @@ -98,16 +118,20 @@ def render_point_cloud_batched(point_cloud, intrinsics, NUM_PER, pixel_smudge=1) all_images = [] num_iters = jnp.ceil(point_cloud.shape[0] / NUM_PER).astype(jnp.int32) for i in range(num_iters): - img = b.render_point_cloud(point_cloud[i*NUM_PER:i*NUM_PER+NUM_PER], intrinsics, pixel_smudge=pixel_smudge) - img = img.at[img[:,:,2] < intrinsics.near].set(intrinsics.far) + img = b.render_point_cloud( + point_cloud[i * NUM_PER : i * NUM_PER + NUM_PER], + intrinsics, + pixel_smudge=pixel_smudge, + ) + img = img.at[img[:, :, 2] < intrinsics.near].set(intrinsics.far) all_images.append(img) - all_images_stack = jnp.stack(all_images,axis=-2) - best = all_images_stack[:,:,:,2].argmin(-1) + all_images_stack = jnp.stack(all_images, axis=-2) + best = all_images_stack[:, :, :, 2].argmin(-1) img = all_images_stack[ np.arange(intrinsics.height)[:, None], np.arange(intrinsics.width)[None, :], best, - : + :, ] return img @@ -119,12 +143,14 @@ def _open_gl_projection_matrix(h, w, fx, fy, cx, cy, near, far): # see http://ksimek.github.io/2013/06/03/calibrated_cameras_in_opengl/ persp = np.zeros((4, 4)) - persp = jnp.array([ - [fx, 0.0, -cx, 0.0], - [0.0, -fy, -cy, 0.0], - [0.0, 0.0, -near+far, near*far], - [0.0, 0.0, -1, 0.0], - ]) + persp = jnp.array( + [ + [fx, 0.0, -cx, 0.0], + [0.0, -fy, -cy, 0.0], + [0.0, 0.0, -near + far, near * far], + [0.0, 0.0, -1, 0.0], + ] + ) # persp[0, 0] = fx # persp[1, 1] = fy # persp[0, 2] = cx @@ -149,6 +175,7 @@ def _open_gl_projection_matrix(h, w, fx, fy, cx, cy, near, far): ) return orth @ persp @ view + def getProjectionMatrix(intrinsics): top = intrinsics.near / intrinsics.fy * intrinsics.height / 2.0 bottom = -top @@ -160,7 +187,11 @@ def getProjectionMatrix(intrinsics): P = P.at[1, 1].set(2.0 * intrinsics.near / (top - bottom)) P = P.at[0, 2].set((right + left) / (right - left)) P = P.at[1, 2].set((top + bottom) / (top - bottom)) - P = P.at[2, 2].set(z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near)) - P = P.at[2, 3].set(-2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near)) + P = P.at[2, 2].set( + z_sign * (intrinsics.far + intrinsics.near) / (intrinsics.far - intrinsics.near) + ) + P = P.at[2, 3].set( + -2.0 * (intrinsics.far * intrinsics.near) / (intrinsics.far - intrinsics.near) + ) P = P.at[3, 2].set(z_sign) - return jnp.transpose(P) \ No newline at end of file + return jnp.transpose(P) diff --git a/bayes3d/colmap/colmap_loader.py b/bayes3d/colmap/colmap_loader.py index 0ee0b1fd..c9fcb6f1 100644 --- a/bayes3d/colmap/colmap_loader.py +++ b/bayes3d/colmap/colmap_loader.py @@ -3,24 +3,27 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # -import numpy as np import collections import struct +import numpy as np + CameraModel = collections.namedtuple( - "CameraModel", ["model_id", "model_name", "num_params"]) -Camera = collections.namedtuple( - "Camera", ["id", "model", "width", "height", "params"]) + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) BaseImage = collections.namedtuple( - "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) Point3D = collections.namedtuple( - "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) CAMERA_MODELS = { CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), CameraModel(model_id=1, model_name="PINHOLE", num_params=4), @@ -32,43 +35,63 @@ CameraModel(model_id=7, model_name="FOV", num_params=5), CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), - CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), } -CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) - for camera_model in CAMERA_MODELS]) -CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) - for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) def qvec2rotmat(qvec): - return np.array([ - [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, - 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], - 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], - [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], - 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, - 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], - [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], - 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], - 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + def rotmat2qvec(R): Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat - K = np.array([ - [Rxx - Ryy - Rzz, 0, 0, 0], - [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], - [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], - [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) eigvals, eigvecs = np.linalg.eigh(K) qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] if qvec[0] < 0: qvec *= -1 return qvec + class Image(BaseImage): def qvec2rotmat(self): return qvec2rotmat(self.qvec) + def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): """Read and unpack the next bytes from a binary file. :param fid: @@ -80,6 +103,7 @@ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): data = fid.read(num_bytes) return struct.unpack(endian_character + format_char_sequence, data) + def read_points3D_text(path): """ see: src/base/reconstruction.cc @@ -99,7 +123,6 @@ def read_points3D_text(path): if len(line) > 0 and line[0] != "#": num_points += 1 - xyzs = np.empty((num_points, 3)) rgbs = np.empty((num_points, 3)) errors = np.empty((num_points, 1)) @@ -122,6 +145,7 @@ def read_points3D_text(path): return xyzs, rgbs, errors + def read_points3D_binary(path_to_model_file): """ see: src/base/reconstruction.cc @@ -129,7 +153,6 @@ def read_points3D_binary(path_to_model_file): void Reconstruction::WritePoints3DBinary(const std::string& path) """ - with open(path_to_model_file, "rb") as fid: num_points = read_next_bytes(fid, 8, "Q")[0] @@ -139,20 +162,25 @@ def read_points3D_binary(path_to_model_file): for p_id in range(num_points): binary_point_line_properties = read_next_bytes( - fid, num_bytes=43, format_char_sequence="QdddBBBd") + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) xyz = np.array(binary_point_line_properties[1:4]) rgb = np.array(binary_point_line_properties[4:7]) error = np.array(binary_point_line_properties[7]) - track_length = read_next_bytes( - fid, num_bytes=8, format_char_sequence="Q")[0] + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] track_elems = read_next_bytes( - fid, num_bytes=8*track_length, - format_char_sequence="ii"*track_length) + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) xyzs[p_id] = xyz rgbs[p_id] = rgb errors[p_id] = error return xyzs, rgbs, errors + def read_intrinsics_text(path): """ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py @@ -168,15 +196,18 @@ def read_intrinsics_text(path): elems = line.split() camera_id = int(elems[0]) model = elems[1] - assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + assert ( + model == "PINHOLE" + ), "While the loader support other types, the rest of the code assumes PINHOLE" width = int(elems[2]) height = int(elems[3]) params = np.array(tuple(map(float, elems[4:]))) - cameras[camera_id] = Camera(id=camera_id, model=model, - width=width, height=height, - params=params) + cameras[camera_id] = Camera( + id=camera_id, model=model, width=width, height=height, params=params + ) return cameras + def read_extrinsics_binary(path_to_model_file): """ see: src/base/reconstruction.cc @@ -188,27 +219,38 @@ def read_extrinsics_binary(path_to_model_file): num_reg_images = read_next_bytes(fid, 8, "Q")[0] for _ in range(num_reg_images): binary_image_properties = read_next_bytes( - fid, num_bytes=64, format_char_sequence="idddddddi") + fid, num_bytes=64, format_char_sequence="idddddddi" + ) image_id = binary_image_properties[0] qvec = np.array(binary_image_properties[1:5]) tvec = np.array(binary_image_properties[5:8]) camera_id = binary_image_properties[8] image_name = "" current_char = read_next_bytes(fid, 1, "c")[0] - while current_char != b"\x00": # look for the ASCII 0 entry + while current_char != b"\x00": # look for the ASCII 0 entry image_name += current_char.decode("utf-8") current_char = read_next_bytes(fid, 1, "c")[0] - num_points2D = read_next_bytes(fid, num_bytes=8, - format_char_sequence="Q")[0] - x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, - format_char_sequence="ddq"*num_points2D) - xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), - tuple(map(float, x_y_id_s[1::3]))]) + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] + ) point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) images[image_id] = Image( - id=image_id, qvec=qvec, tvec=tvec, - camera_id=camera_id, name=image_name, - xys=xys, point3D_ids=point3D_ids) + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) return images @@ -223,20 +265,24 @@ def read_intrinsics_binary(path_to_model_file): num_cameras = read_next_bytes(fid, 8, "Q")[0] for _ in range(num_cameras): camera_properties = read_next_bytes( - fid, num_bytes=24, format_char_sequence="iiQQ") + fid, num_bytes=24, format_char_sequence="iiQQ" + ) camera_id = camera_properties[0] model_id = camera_properties[1] model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name width = camera_properties[2] height = camera_properties[3] num_params = CAMERA_MODEL_IDS[model_id].num_params - params = read_next_bytes(fid, num_bytes=8*num_params, - format_char_sequence="d"*num_params) - cameras[camera_id] = Camera(id=camera_id, - model=model_name, - width=width, - height=height, - params=np.array(params)) + params = read_next_bytes( + fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) assert len(cameras) == num_cameras return cameras @@ -260,13 +306,19 @@ def read_extrinsics_text(path): camera_id = int(elems[8]) image_name = elems[9] elems = fid.readline().split() - xys = np.column_stack([tuple(map(float, elems[0::3])), - tuple(map(float, elems[1::3]))]) + xys = np.column_stack( + [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] + ) point3D_ids = np.array(tuple(map(int, elems[2::3]))) images[image_id] = Image( - id=image_id, qvec=qvec, tvec=tvec, - camera_id=camera_id, name=image_name, - xys=xys, point3D_ids=point3D_ids) + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) return images @@ -278,8 +330,9 @@ def read_colmap_bin_array(path): :return: nd array with the floating point values in the value """ with open(path, "rb") as fid: - width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, - usecols=(0, 1, 2), dtype=int) + width, height, channels = np.genfromtxt( + fid, delimiter="&", max_rows=1, usecols=(0, 1, 2), dtype=int + ) fid.seek(0) num_delimiter = 0 byte = fid.read(1) @@ -291,4 +344,4 @@ def read_colmap_bin_array(path): byte = fid.read(1) array = np.fromfile(fid, np.float32) array = array.reshape((width, height, channels), order="F") - return np.transpose(array, (1, 0, 2)).squeeze() \ No newline at end of file + return np.transpose(array, (1, 0, 2)).squeeze() diff --git a/bayes3d/colmap/colmap_utils.py b/bayes3d/colmap/colmap_utils.py index 1ee9052c..b8027b8d 100644 --- a/bayes3d/colmap/colmap_utils.py +++ b/bayes3d/colmap/colmap_utils.py @@ -26,21 +26,24 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # -import torch import math -import numpy as np from typing import NamedTuple +import numpy as np +import torch + + class BasicPointCloud(NamedTuple): - points : np.array - colors : np.array - normals : np.array + points: np.array + colors: np.array + normals: np.array + def geom_transform_points(points, transf_matrix): P, _ = points.shape @@ -51,6 +54,7 @@ def geom_transform_points(points, transf_matrix): denom = points_out[..., 3:] + 0.0000001 return (points_out[..., :3] / denom).squeeze(dim=0) + def getWorld2View(R, t): Rt = np.zeros((4, 4)) Rt[:3, :3] = R.transpose() @@ -58,7 +62,8 @@ def getWorld2View(R, t): Rt[3, 3] = 1.0 return np.float32(Rt) -def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + +def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): Rt = np.zeros((4, 4)) Rt[:3, :3] = R.transpose() Rt[:3, 3] = t @@ -71,6 +76,7 @@ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): Rt = np.linalg.inv(C2W) return np.float32(Rt) + def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) @@ -93,13 +99,14 @@ def getProjectionMatrix(znear, zfar, fovX, fovY): P[2, 3] = -(zfar * znear) / (zfar - znear) return P + def fov2focal(fov, pixels): return pixels / (2 * math.tan(fov / 2)) + def focal2fov(focal, pixels): - return 2*math.atan(pixels/(2*focal)) + return 2 * math.atan(pixels / (2 * focal)) -import torch C0 = 0.28209479177387814 C1 = 0.4886025119029199 @@ -108,7 +115,7 @@ def focal2fov(focal, pixels): -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, - 0.5462742152960396 + 0.5462742152960396, ] C3 = [ -0.5900435899266435, @@ -117,7 +124,7 @@ def focal2fov(focal, pixels): 0.3731763325901154, -0.4570457994644658, 1.445305721320277, - -0.5900435899266435 + -0.5900435899266435, ] C4 = [ 2.5033429417967046, @@ -129,7 +136,7 @@ def focal2fov(focal, pixels): 0.47308734787878004, -1.7701307697799304, 0.6258357354491761, -] +] def eval_sh(deg, sh, dirs): @@ -152,45 +159,55 @@ def eval_sh(deg, sh, dirs): result = C0 * sh[..., 0] if deg > 0: x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] - result = (result - - C1 * y * sh[..., 1] + - C1 * z * sh[..., 2] - - C1 * x * sh[..., 3]) + result = ( + result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] + ) if deg > 1: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z - result = (result + - C2[0] * xy * sh[..., 4] + - C2[1] * yz * sh[..., 5] + - C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + - C2[3] * xz * sh[..., 7] + - C2[4] * (xx - yy) * sh[..., 8]) + result = ( + result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8] + ) if deg > 2: - result = (result + - C3[0] * y * (3 * xx - yy) * sh[..., 9] + - C3[1] * xy * z * sh[..., 10] + - C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + - C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + - C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + - C3[5] * z * (xx - yy) * sh[..., 14] + - C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + result = ( + result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15] + ) if deg > 3: - result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + - C4[1] * yz * (3 * xx - yy) * sh[..., 17] + - C4[2] * xy * (7 * zz - 1) * sh[..., 18] + - C4[3] * yz * (7 * zz - 3) * sh[..., 19] + - C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + - C4[5] * xz * (7 * zz - 3) * sh[..., 21] + - C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + - C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + - C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + result = ( + result + + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] + * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) + * sh[..., 24] + ) return result + def RGB2SH(rgb): return (rgb - 0.5) / C0 + def SH2RGB(sh): - return sh * C0 + 0.5 \ No newline at end of file + return sh * C0 + 0.5 diff --git a/bayes3d/colmap/dataset_loader.py b/bayes3d/colmap/dataset_loader.py index f33f9dfc..88f87a81 100644 --- a/bayes3d/colmap/dataset_loader.py +++ b/bayes3d/colmap/dataset_loader.py @@ -3,24 +3,34 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # +import json import os import sys -from PIL import Image +from pathlib import Path from typing import NamedTuple -from .colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ - read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text -from .colmap_utils import getWorld2View2, focal2fov, fov2focal + import numpy as np -import json -from pathlib import Path +from PIL import Image from plyfile import PlyData, PlyElement +from .colmap_loader import ( + qvec2rotmat, + read_extrinsics_binary, + read_extrinsics_text, + read_intrinsics_binary, + read_intrinsics_text, + read_points3D_binary, + read_points3D_text, +) +from .colmap_utils import focal2fov, fov2focal, getWorld2View2 + + class CameraInfo(NamedTuple): uid: int R: np.array @@ -33,6 +43,7 @@ class CameraInfo(NamedTuple): width: int height: int + def getNerfppNorm(cam_info): def get_center_and_diag(cam_centers): cam_centers = np.hstack(cam_centers) @@ -56,12 +67,13 @@ def get_center_and_diag(cam_centers): return {"translate": translate, "radius": radius} + def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): cam_infos = [] for idx, key in enumerate(cam_extrinsics): - sys.stdout.write('\r') + sys.stdout.write("\r") # the exact output you're looking for: - sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) + sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics))) sys.stdout.flush() extr = cam_extrinsics[key] @@ -73,11 +85,11 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): R = np.transpose(qvec2rotmat(extr.qvec)) T = np.array(extr.tvec) - if intr.model=="SIMPLE_PINHOLE": + if intr.model == "SIMPLE_PINHOLE": focal_length_x = intr.params[0] FovY = focal2fov(focal_length_x, height) FovX = focal2fov(focal_length_x, width) - elif intr.model=="PINHOLE": + elif intr.model == "PINHOLE": focal_length_x = intr.params[0] focal_length_y = intr.params[1] FovY = focal2fov(focal_length_y, height) @@ -89,18 +101,37 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): image_name = os.path.basename(image_path).split(".")[0] image = Image.open(image_path) - cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, - image_path=image_path, image_name=image_name, width=width, height=height) + cam_info = CameraInfo( + uid=uid, + R=R, + T=T, + FovY=FovY, + FovX=FovX, + image=image, + image_path=image_path, + image_name=image_name, + width=width, + height=height, + ) cam_infos.append(cam_info) - sys.stdout.write('\n') + sys.stdout.write("\n") return cam_infos + def storePly(path, xyz, rgb): # Define the dtype for the structured array - dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), - ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), - ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] - + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + normals = np.zeros_like(xyz) elements = np.empty(xyz.shape[0], dtype=dtype) @@ -108,55 +139,64 @@ def storePly(path, xyz, rgb): elements[:] = list(map(tuple, attributes)) # Create the PlyData object and write to file - vertex_element = PlyElement.describe(elements, 'vertex') + vertex_element = PlyElement.describe(elements, "vertex") ply_data = PlyData([vertex_element]) ply_data.write(path) + def readColmapSceneInfo(path, images, eval, llffhold=8): try: cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) - except: + except Exception: cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) - reading_dir = "images" if images == None else images - cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) - cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) + reading_dir = "images" if images is None else images + cam_infos_unsorted = readColmapCameras( + cam_extrinsics=cam_extrinsics, + cam_intrinsics=cam_intrinsics, + images_folder=os.path.join(path, reading_dir), + ) + cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) if eval: train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] - test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] + # test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] else: train_cam_infos = cam_infos - test_cam_infos = [] + # test_cam_infos = [] - nerf_normalization = getNerfppNorm(train_cam_infos) + # nerf_normalization = getNerfppNorm(train_cam_infos) ply_path = os.path.join(path, "sparse/0/points3D.ply") bin_path = os.path.join(path, "sparse/0/points3D.bin") txt_path = os.path.join(path, "sparse/0/points3D.txt") if not os.path.exists(ply_path): - print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") + print( + "Converting point3d.bin to .ply, will happen only the first time you open the scene." + ) try: xyz, rgb, _ = read_points3D_binary(bin_path) - except: + except Exception: xyz, rgb, _ = read_points3D_text(txt_path) storePly(ply_path, xyz, rgb) try: path = ply_path plydata = PlyData.read(path) - vertices = plydata['vertex'] - positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T - colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 - normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T - except: + vertices = plydata["vertex"] + positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T + colors = ( + np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0 + ) + normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T + except Exception: print("FetchPly failed") - pcd = None + # pcd = None return (positions, colors, normals), train_cam_infos @@ -186,7 +226,9 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= # get the world-to-camera transform and set R, T w2c = np.linalg.inv(c2w) - R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code + R = np.transpose( + w2c[:3, :3] + ) # R is stored transposed due to 'glm' in CUDA code T = w2c[:3, 3] image_path = os.path.join(path, cam_name) @@ -195,17 +237,31 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= im_data = np.array(image.convert("RGBA")) - bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) + bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0]) norm_data = im_data / 255.0 - arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) - image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") + arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * ( + 1 - norm_data[:, :, 3:4] + ) + image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) - FovY = fovy + FovY = fovy FovX = fovx - cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, - image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) - - return cam_infos \ No newline at end of file + cam_infos.append( + CameraInfo( + uid=idx, + R=R, + T=T, + FovY=FovY, + FovX=FovX, + image=image, + image_path=image_path, + image_name=image_name, + width=image.size[0], + height=image.size[1], + ) + ) + + return cam_infos diff --git a/bayes3d/distributions.py b/bayes3d/distributions.py index 5b16b4c5..1e09cad8 100644 --- a/bayes3d/distributions.py +++ b/bayes3d/distributions.py @@ -1,45 +1,63 @@ -from tensorflow_probability.substrates import jax as tfp import jax import jax.numpy as jnp -from .transforms_3d import ( - quaternion_to_rotation_matrix, - rotation_matrix_to_quaternion -) from jax.scipy.special import logsumexp +from tensorflow_probability.substrates import jax as tfp + +from .transforms_3d import quaternion_to_rotation_matrix, rotation_matrix_to_quaternion + def vmf(key, concentration): - translation =jnp.zeros(3) + translation = jnp.zeros(3) quat = tfp.distributions.VonMisesFisher( jnp.array([1.0, 0.0, 0.0, 0.0]), concentration ).sample(seed=key) - rot_matrix = quaternion_to_rotation_matrix(quat) + rot_matrix = quaternion_to_rotation_matrix(quat) return jnp.vstack( - [jnp.hstack([rot_matrix, translation.reshape(3,1) ]), jnp.array([0.0, 0.0, 0.0, 1.0])] + [ + jnp.hstack([rot_matrix, translation.reshape(3, 1)]), + jnp.array([0.0, 0.0, 0.0, 1.0]), + ] ) + + vmf_jit = jax.jit(vmf) + def gaussian_vmf_zero_mean(key, var, concentration): - translation = tfp.distributions.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3) * var).sample(seed=key) + translation = tfp.distributions.MultivariateNormalDiag( + jnp.zeros(3), jnp.ones(3) * var + ).sample(seed=key) quat = tfp.distributions.VonMisesFisher( jnp.array([1.0, 0.0, 0.0, 0.0]), concentration ).sample(seed=key) - rot_matrix = quaternion_to_rotation_matrix(quat) + rot_matrix = quaternion_to_rotation_matrix(quat) return jnp.vstack( - [jnp.hstack([rot_matrix, translation.reshape(3,1) ]), jnp.array([0.0, 0.0, 0.0, 1.0])] + [ + jnp.hstack([rot_matrix, translation.reshape(3, 1)]), + jnp.array([0.0, 0.0, 0.0, 1.0]), + ] ) + def gaussian_vmf(key, pose_mean, var, concentration): return pose_mean.dot(gaussian_vmf_zero_mean(key, var, concentration)) + + gaussian_vmf_jit = jax.jit(gaussian_vmf) + def gaussian_vmf_logpdf(pose, pose_mean, var, concentration): - translation_prob = tfp.distributions.MultivariateNormalDiag(pose_mean[:3,3], jnp.ones(3) * var).log_prob(pose[:3,3]) - quat_mean = rotation_matrix_to_quaternion(pose_mean[:3,:3]) - quat = rotation_matrix_to_quaternion(pose[:3,:3]) - quat_prob = tfp.distributions.VonMisesFisher( - quat_mean, concentration - ).log_prob(quat) + translation_prob = tfp.distributions.MultivariateNormalDiag( + pose_mean[:3, 3], jnp.ones(3) * var + ).log_prob(pose[:3, 3]) + quat_mean = rotation_matrix_to_quaternion(pose_mean[:3, :3]) + quat = rotation_matrix_to_quaternion(pose[:3, :3]) + quat_prob = tfp.distributions.VonMisesFisher(quat_mean, concentration).log_prob( + quat + ) return translation_prob + quat_prob + + gaussian_vmf_logpdf_jit = jax.jit(gaussian_vmf_logpdf) @@ -47,21 +65,22 @@ def gaussian_vmf_mixture_sample(key, pose_means, log_weights, var, concentration idx = tfp.distributions.Categorical(logits=log_weights).sample(seed=key) return gaussian_vmf(key, pose_means[idx], var, concentration) + def gaussian_vmf_mixture_logpdf(key, pose, pose_means, log_weights, var, concentration): - log_probs = jax.vmap(gaussian_vmf_logpdf, in_axes=(None, 0, None, None))(pose, pose_means, var, concentration) + log_probs = jax.vmap(gaussian_vmf_logpdf, in_axes=(None, 0, None, None))( + pose, pose_means, var, concentration + ) log_mixture_probabilites = log_probs + log_weights return logsumexp(log_mixture_probabilites) - - def gaussian_vmf_logpdf(pose, pose_mean, var, concentration): - translation_prob = tfp.distributions.MultivariateNormalDiag(pose_mean[:3,3], jnp.ones(3) * var).log_prob(pose[:3,3]) - quat_mean = rotation_matrix_to_quaternion(pose_mean[:3,:3]) - quat = rotation_matrix_to_quaternion(pose[:3,:3]) - quat_prob = tfp.distributions.VonMisesFisher( - quat_mean, concentration - ).log_prob(quat) + translation_prob = tfp.distributions.MultivariateNormalDiag( + pose_mean[:3, 3], jnp.ones(3) * var + ).log_prob(pose[:3, 3]) + quat_mean = rotation_matrix_to_quaternion(pose_mean[:3, :3]) + quat = rotation_matrix_to_quaternion(pose[:3, :3]) + quat_prob = tfp.distributions.VonMisesFisher(quat_mean, concentration).log_prob( + quat + ) return translation_prob + quat_prob - - diff --git a/bayes3d/genjax/genjax_distributions.py b/bayes3d/genjax/genjax_distributions.py index 73dd7211..0b7724fc 100644 --- a/bayes3d/genjax/genjax_distributions.py +++ b/bayes3d/genjax/genjax_distributions.py @@ -1,10 +1,13 @@ -import jax from dataclasses import dataclass + +import jax +import jax.numpy as jnp from genjax.core.datatypes import JAXGenerativeFunction from genjax.generative_functions.distributions import ExactDensity -import jax.numpy as jnp + import bayes3d as b + @dataclass class GaussianVMFPose(ExactDensity, JAXGenerativeFunction): def sample(self, key, pose_mean, var, concentration, **kwargs): @@ -13,19 +16,25 @@ def sample(self, key, pose_mean, var, concentration, **kwargs): def logpdf(self, pose, pose_mean, var, concentration, **kwargs): return b.distributions.gaussian_vmf_logpdf(pose, pose_mean, var, concentration) + @dataclass class UniformPose(ExactDensity, JAXGenerativeFunction): def sample(self, key, low, high, **kwargs): position = jax.random.uniform(key, shape=(3,)) * (high - low) + low - orientation = b.quaternion_to_rotation_matrix(jax.random.normal(key, shape=(4,))) + orientation = b.quaternion_to_rotation_matrix( + jax.random.normal(key, shape=(4,)) + ) return b.transform_from_rot_and_pos(orientation, position) def logpdf(self, pose, low, high, **kwargs): - position = pose[:3,3] - valid = ((low <= position) & (position <= high)) - position_score = jnp.log((valid * 1.0) * (jnp.ones_like(position) / (high-low))) + position = pose[:3, 3] + valid = (low <= position) & (position <= high) + position_score = jnp.log( + (valid * 1.0) * (jnp.ones_like(position) / (high - low)) + ) return position_score.sum() + jnp.pi**2 + @dataclass class ImageLikelihood(ExactDensity, JAXGenerativeFunction): def sample(self, key, img, variance, outlier_prob): @@ -33,25 +42,30 @@ def sample(self, key, img, variance, outlier_prob): def logpdf(self, observed_image, latent_image, variance, outlier_prob): return b.threedp3_likelihood( - observed_image, latent_image, variance, outlier_prob, + observed_image, + latent_image, + variance, + outlier_prob, ) + @dataclass class ContactParamsUniform(ExactDensity, JAXGenerativeFunction): def sample(self, key, low, high): return jax.random.uniform(key, shape=(3,)) * (high - low) + low def logpdf(self, sampled_val, low, high, **kwargs): - valid = ((low <= sampled_val) & (sampled_val <= high)) - log_probs = jnp.log((valid * 1.0) * (jnp.ones_like(sampled_val) / (high-low))) + valid = (low <= sampled_val) & (sampled_val <= high) + log_probs = jnp.log((valid * 1.0) * (jnp.ones_like(sampled_val) / (high - low))) return log_probs.sum() + @dataclass class UniformDiscreteArray(ExactDensity, JAXGenerativeFunction): def sample(self, key, vals, arr): return jax.random.choice(key, vals, shape=arr.shape) - def logpdf(self, sampled_val, vals, arr,**kwargs): + def logpdf(self, sampled_val, vals, arr, **kwargs): return jnp.log(1.0 / (vals.shape[0])) * arr.shape[0] @@ -60,9 +74,10 @@ class UniformDiscrete(ExactDensity, JAXGenerativeFunction): def sample(self, key, vals): return jax.random.choice(key, vals) - def logpdf(self, sampled_val, vals,**kwargs): + def logpdf(self, sampled_val, vals, **kwargs): return jnp.log(1.0 / (vals.shape[0])) + gaussian_vmf_pose = GaussianVMFPose() image_likelihood = ImageLikelihood() contact_params_uniform = ContactParamsUniform() diff --git a/bayes3d/genjax/model.py b/bayes3d/genjax/model.py index 5f31bdfb..7257f1b8 100644 --- a/bayes3d/genjax/model.py +++ b/bayes3d/genjax/model.py @@ -1,80 +1,125 @@ -from genjax.generative_functions.distributions import ExactDensity -import bayes3d as b -from dataclasses import dataclass -import jax -import jax.numpy as jnp +import inspect +from collections import namedtuple + import genjax import jax -import os -import jax.tree_util as jtu -from tqdm import tqdm -from genjax.incremental import UnknownChange, NoChange, Diff -from collections import namedtuple -import inspect +import jax.numpy as jnp +from genjax.incremental import Diff, NoChange, UnknownChange + +import bayes3d as b + from .genjax_distributions import * + @genjax.static def model(array, possible_object_indices, pose_bounds, contact_bounds, all_box_dims): indices = jnp.array([], dtype=jnp.int32) - root_poses = jnp.zeros((0,4,4)) - contact_params = jnp.zeros((0,3)) + root_poses = jnp.zeros((0, 4, 4)) + contact_params = jnp.zeros((0, 3)) faces_parents = jnp.array([], dtype=jnp.int32) faces_child = jnp.array([], dtype=jnp.int32) parents = jnp.array([], dtype=jnp.int32) for i in range(array.shape[0]): - parent_obj = uniform_discrete(jnp.arange(-1,array.shape[0] - 1)) @ f"parent_{i}" - parent_face = uniform_discrete(jnp.arange(0,6)) @ f"face_parent_{i}" - child_face = uniform_discrete(jnp.arange(0,6)) @ f"face_child_{i}" + parent_obj = ( + uniform_discrete(jnp.arange(-1, array.shape[0] - 1)) @ f"parent_{i}" + ) + parent_face = uniform_discrete(jnp.arange(0, 6)) @ f"face_parent_{i}" + child_face = uniform_discrete(jnp.arange(0, 6)) @ f"face_child_{i}" index = uniform_discrete(possible_object_indices) @ f"id_{i}" - pose = uniform_pose( - pose_bounds[0], - pose_bounds[1], - ) @ f"root_pose_{i}" - - params = contact_params_uniform( - contact_bounds[0], - contact_bounds[1] - ) @ f"contact_params_{i}" + pose = ( + uniform_pose( + pose_bounds[0], + pose_bounds[1], + ) + @ f"root_pose_{i}" + ) + params = ( + contact_params_uniform(contact_bounds[0], contact_bounds[1]) + @ f"contact_params_{i}" + ) indices = jnp.concatenate([indices, jnp.array([index])]) - root_poses = jnp.concatenate([root_poses, pose.reshape(1,4,4)]) - contact_params = jnp.concatenate([contact_params, params.reshape(1,-1)]) + root_poses = jnp.concatenate([root_poses, pose.reshape(1, 4, 4)]) + contact_params = jnp.concatenate([contact_params, params.reshape(1, -1)]) parents = jnp.concatenate([parents, jnp.array([parent_obj])]) faces_parents = jnp.concatenate([faces_parents, jnp.array([parent_face])]) faces_child = jnp.concatenate([faces_child, jnp.array([child_face])]) - + box_dims = all_box_dims[indices] poses = b.scene_graph.poses_from_scene_graph( - root_poses, box_dims, parents, contact_params, faces_parents, faces_child) + root_poses, box_dims, parents, contact_params, faces_parents, faces_child + ) - camera_pose = uniform_pose( - pose_bounds[0], - pose_bounds[1], - ) @ f"camera_pose" + camera_pose = ( + uniform_pose( + pose_bounds[0], + pose_bounds[1], + ) + @ "camera_pose" + ) - rendered = b.RENDERER.render( - jnp.linalg.inv(camera_pose) @ poses , indices - )[...,:3] + rendered = b.RENDERER.render(jnp.linalg.inv(camera_pose) @ poses, indices)[..., :3] variance = genjax.uniform(0.00000000001, 10000.0) @ "variance" - outlier_prob = genjax.uniform(-0.01, 10000.0) @ "outlier_prob" + outlier_prob = genjax.uniform(-0.01, 10000.0) @ "outlier_prob" image = image_likelihood(rendered, variance, outlier_prob) @ "image" - return rendered, indices, poses, parents, contact_params, faces_parents, faces_child, root_poses + return ( + rendered, + indices, + poses, + parents, + contact_params, + faces_parents, + faces_child, + root_poses, + ) + + +def get_rendered_image(trace): + return trace.get_retval()[0] + + +def get_indices(trace): + return trace.get_retval()[1] + + +def get_poses(trace): + return trace.get_retval()[2] + + +def get_parents(trace): + return trace.get_retval()[3] + + +def get_contact_params(trace): + return trace.get_retval()[4] + -get_rendered_image = lambda trace: trace.get_retval()[0] -get_indices = lambda trace: trace.get_retval()[1] -get_poses = lambda trace: trace.get_retval()[2] -get_parents = lambda trace: trace.get_retval()[3] -get_contact_params = lambda trace: trace.get_retval()[4] -get_faces_parents = lambda trace: trace.get_retval()[5] -get_faces_child = lambda trace: trace.get_retval()[6] -get_root_poses = lambda trace: trace.get_retval()[7] +def get_faces_parents(trace): + return trace.get_retval()[5] + + +def get_faces_child(trace): + return trace.get_retval()[6] + + +def get_root_poses(trace): + return trace.get_retval()[7] + + +def get_outlier_volume(trace): + return trace.get_args()[5] + + +def get_focal_length(trace): + return trace.get_args()[6] + + +def get_far_plane(trace): + return trace.get_args()[7] -get_outlier_volume = lambda trace: trace.get_args()[5] -get_focal_length = lambda trace: trace.get_args()[6] -get_far_plane = lambda trace: trace.get_args()[7] def add_object(trace, key, obj_id, parent, face_parent, face_child): N = b.get_indices(trace).shape[0] + 1 @@ -84,59 +129,90 @@ def add_object(trace, key, obj_id, parent, face_parent, face_child): choices[f"face_parent_{N-1}"] = face_parent choices[f"face_child_{N-1}"] = face_child choices[f"contact_params_{N-1}"] = jnp.zeros(3) - return model.importance(key, choices, - (jnp.arange(N), *trace.get_args()[1:]) - )[1] + return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1] + add_object_jit = jax.jit(add_object) + def print_trace(trace): - print(""" + print( + """ SCORE: {:0.7f} VARIANCE: {:0.7f} OUTLIER_PROB {:0.7f} - """.format(trace.get_score(), trace["variance"], trace["outlier_prob"])) + """.format(trace.get_score(), trace["variance"], trace["outlier_prob"]) + ) + def viz_trace_meshcat(trace, colors=None): b.clear() - b.show_cloud("1", b.apply_transform_jit(trace["image"].reshape(-1,3), trace["camera_pose"])) - b.show_cloud("2", b.apply_transform_jit(get_rendered_image(trace).reshape(-1,3), trace["camera_pose"]),color=b.RED) + b.show_cloud( + "1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"]) + ) + b.show_cloud( + "2", + b.apply_transform_jit( + get_rendered_image(trace).reshape(-1, 3), trace["camera_pose"] + ), + color=b.RED, + ) indices = trace.get_retval()[1] if colors is None: colors = b.viz.distinct_colors(max(10, len(indices))) for i in range(len(indices)): - b.show_trimesh(f"obj_{i}", b.RENDERER.meshes[indices[i]],color=colors[i]) + b.show_trimesh(f"obj_{i}", b.RENDERER.meshes[indices[i]], color=colors[i]) b.set_pose(f"obj_{i}", trace.get_retval()[2][i]) - b.show_pose(f"camera_pose", trace["camera_pose"]) - + b.show_pose("camera_pose", trace["camera_pose"]) def make_onehot(n, i, hot=1, cold=0): return tuple(cold if j != i else hot for j in range(n)) + def multivmap(f, args=None): if args is None: args = (True,) * len(inspect.signature(f).parameters) multivmapped = f - for (i, ismapped) in reversed(list(enumerate(args))): + for i, ismapped in reversed(list(enumerate(args))): if ismapped: - multivmapped = jax.vmap(multivmapped, in_axes=make_onehot(len(args), i, hot=0, cold=None)) + multivmapped = jax.vmap( + multivmapped, in_axes=make_onehot(len(args), i, hot=0, cold=None) + ) return multivmapped -Enumerator = namedtuple("Enumerator",["update_choices", "update_choices_with_weight","update_choices_get_score", "enumerate_choices", "enumerate_choices_with_weights", "enumerate_choices_get_scores"]) -def default_chm_builder(addresses, args, chm_args = None): - return genjax.choice_map({ - addr: c for (addr, c) in zip(addresses, args) - }) +Enumerator = namedtuple( + "Enumerator", + [ + "update_choices", + "update_choices_with_weight", + "update_choices_get_score", + "enumerate_choices", + "enumerate_choices_with_weights", + "enumerate_choices_get_scores", + ], +) + + +def default_chm_builder(addresses, args, chm_args=None): + return genjax.choice_map({addr: c for (addr, c) in zip(addresses, args)}) + def make_unknown_change_argdiffs(trace): return tuple(map(lambda v: Diff(v, UnknownChange), trace.args)) + def make_no_change_argdiffs(trace): return tuple(map(lambda v: Diff(v, NoChange), trace.args)) -def make_enumerator(addresses, chm_builder = default_chm_builder, argdiff_f = make_unknown_change_argdiffs, chm_args = None): + +def make_enumerator( + addresses, + chm_builder=default_chm_builder, + argdiff_f=make_unknown_change_argdiffs, + chm_args=None, +): def enumerator(trace, key, *args): return trace.update( key, @@ -150,20 +226,58 @@ def enumerator_with_weight(trace, key, *args): chm_builder(addresses, args, chm_args), argdiff_f(trace), )[1:3] - + def enumerator_score(trace, key, *args): return enumerator(trace, key, *args).get_score() - return Enumerator(jax.jit(enumerator), jax.jit(enumerator_with_weight), jax.jit(enumerator_score), jax.jit(multivmap(enumerator, (False, False,) + (True,) * len(addresses))), jax.jit(multivmap(enumerator_with_weight, (False, False,) + (True,) * len(addresses))), jax.jit(multivmap(enumerator_score, (False, False,) + (True,) * len(addresses)))) + return Enumerator( + jax.jit(enumerator), + jax.jit(enumerator_with_weight), + jax.jit(enumerator_score), + jax.jit( + multivmap( + enumerator, + ( + False, + False, + ) + + (True,) * len(addresses), + ) + ), + jax.jit( + multivmap( + enumerator_with_weight, + ( + False, + False, + ) + + (True,) * len(addresses), + ) + ), + jax.jit( + multivmap( + enumerator_score, + ( + False, + False, + ) + + (True,) * len(addresses), + ) + ), + ) + -def viz_trace_rendered_observed(trace, scale = 2): +def viz_trace_rendered_observed(trace, scale=2): return b.viz.hstack_images( [ - b.viz.scale_image(b.get_depth_image(get_rendered_image(trace)[...,2]), scale), - b.viz.scale_image(b.get_depth_image(trace["image"][...,2]), scale) + b.viz.scale_image( + b.get_depth_image(get_rendered_image(trace)[..., 2]), scale + ), + b.viz.scale_image(b.get_depth_image(trace["image"][..., 2]), scale), ] ) + def get_pixelwise_scores(trace, filter_size): log_scores_per_pixel = b.threedp3_likelihood_per_pixel_jit( trace["image"], @@ -172,16 +286,14 @@ def get_pixelwise_scores(trace, filter_size): trace["outlier_prob"], get_outlier_volume(trace), get_focal_length(trace), - filter_size + filter_size, ) return log_scores_per_pixel + def update_address(trace, key, address, value): return trace.update( key, - genjax.choice_map({ - address: value - }), + genjax.choice_map({address: value}), tuple(map(lambda v: Diff(v, UnknownChange), trace.args)), )[2] - diff --git a/bayes3d/likelihood.py b/bayes3d/likelihood.py index 5e6e6829..a4b17486 100644 --- a/bayes3d/likelihood.py +++ b/bayes3d/likelihood.py @@ -1,16 +1,21 @@ -import jax.numpy as jnp -import jax -import numpy as np import functools -from functools import partial -from jax.scipy.special import logsumexp + +import jax +import jax.numpy as jnp ########### @functools.partial( jnp.vectorize, - signature='(m)->()', - excluded=(1,2,3,4,5,6,), + signature="(m)->()", + excluded=( + 1, + 2, + 3, + 4, + 5, + 6, + ), ) def gausssian_mixture_vectorize_old( ij, @@ -21,16 +26,19 @@ def gausssian_mixture_vectorize_old( outlier_volume: float, filter_size: int, ): - distances = ( - observed_xyz[ij[0], ij[1], :3] - - jax.lax.dynamic_slice(rendered_xyz_padded, (ij[0], ij[1], 0), (2*filter_size + 1, 2*filter_size + 1, 3)) + distances = observed_xyz[ij[0], ij[1], :3] - jax.lax.dynamic_slice( + rendered_xyz_padded, + (ij[0], ij[1], 0), + (2 * filter_size + 1, 2 * filter_size + 1, 3), ) probabilities = jax.scipy.stats.norm.logpdf( - distances, - loc=0.0, - scale=jnp.sqrt(variance) + distances, loc=0.0, scale=jnp.sqrt(variance) ).sum(-1) - jnp.log(observed_xyz.shape[0] * observed_xyz.shape[1]) - return jnp.logaddexp(probabilities.max() + jnp.log(1.0 - outlier_prob), jnp.log(outlier_prob) - jnp.log(outlier_volume)) + return jnp.logaddexp( + probabilities.max() + jnp.log(1.0 - outlier_prob), + jnp.log(outlier_prob) - jnp.log(outlier_volume), + ) + def threedp3_likelihood_per_pixel_old( observed_xyz: jnp.ndarray, @@ -38,18 +46,45 @@ def threedp3_likelihood_per_pixel_old( variance, outlier_prob, outlier_volume, - filter_size + filter_size, ): - rendered_xyz_padded = jax.lax.pad(rendered_xyz, -100.0, ((filter_size,filter_size,0,),(filter_size,filter_size,0,),(0,0,0,))) - jj, ii = jnp.meshgrid(jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0])) - indices = jnp.stack([ii,jj],axis=-1) + rendered_xyz_padded = jax.lax.pad( + rendered_xyz, + -100.0, + ( + ( + filter_size, + filter_size, + 0, + ), + ( + filter_size, + filter_size, + 0, + ), + ( + 0, + 0, + 0, + ), + ), + ) + jj, ii = jnp.meshgrid( + jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0]) + ) + indices = jnp.stack([ii, jj], axis=-1) log_probabilities = gausssian_mixture_vectorize_old( - indices, observed_xyz, + indices, + observed_xyz, rendered_xyz_padded, - variance, outlier_prob, outlier_volume, filter_size + variance, + outlier_prob, + outlier_volume, + filter_size, ) return log_probabilities + def threedp3_likelihood_old( observed_xyz: jnp.ndarray, rendered_xyz: jnp.ndarray, @@ -57,14 +92,14 @@ def threedp3_likelihood_old( outlier_prob, outlier_volume, focal_length, - filter_size + filter_size, ): log_probabilities_per_pixel = threedp3_likelihood_per_pixel_old( - observed_xyz, rendered_xyz, variance, - outlier_prob, outlier_volume, filter_size + observed_xyz, rendered_xyz, variance, outlier_prob, outlier_volume, filter_size ) return log_probabilities_per_pixel.sum() + def threedp3_likelihood( observed_xyz: jnp.ndarray, rendered_xyz: jnp.ndarray, @@ -72,6 +107,6 @@ def threedp3_likelihood( outlier_prob, ): distances = jnp.linalg.norm(observed_xyz - rendered_xyz, axis=-1) - probabilities_per_pixel = (distances < variance/2) / variance + probabilities_per_pixel = (distances < variance / 2) / variance average_probability = probabilities_per_pixel.mean() return average_probability diff --git a/bayes3d/neural/cosypose_baseline/cosypose_utils.py b/bayes3d/neural/cosypose_baseline/cosypose_utils.py index 7c10c490..d9d1372a 100644 --- a/bayes3d/neural/cosypose_baseline/cosypose_utils.py +++ b/bayes3d/neural/cosypose_baseline/cosypose_utils.py @@ -2,34 +2,47 @@ import signal import subprocess import sys + import numpy as np -cosypose_path = f"{os.path.dirname(os.path.abspath(__file__))}/cosypose_baseline/cosypose" -sys.path.append(cosypose_path) # TODO cleaner import / add to path -import yaml -import torch +cosypose_path = ( + f"{os.path.dirname(os.path.abspath(__file__))}/cosypose_baseline/cosypose" +) +sys.path.append(cosypose_path) # TODO cleaner import / add to path + import time +import torch +import yaml + torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.cuda.synchronize() -COSYPOSE_CONDA_ENV_NAME = 'cosypose' +COSYPOSE_CONDA_ENV_NAME = "cosypose" COSYPOSE_MODEL = None + class CosyPose(object): - def __init__(self, detector_run_id='detector-bop-ycbv-synt+real--292971', coarse_run_id='coarse-bop-ycbv-synt+real--822463', refiner_run_id='refiner-bop-ycbv-synt+real--631598') -> None: - self.detector, self.pose_predictor = self.get_models(detector_run_id, coarse_run_id, refiner_run_id) + def __init__( + self, + detector_run_id="detector-bop-ycbv-synt+real--292971", + coarse_run_id="coarse-bop-ycbv-synt+real--822463", + refiner_run_id="refiner-bop-ycbv-synt+real--631598", + ) -> None: + self.detector, self.pose_predictor = self.get_models( + detector_run_id, coarse_run_id, refiner_run_id + ) def load_detector(self, run_id): print("EXPDIR=", EXP_DIR) run_dir = EXP_DIR / run_id - cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader) + cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.FullLoader) cfg = check_update_config_detector(cfg) label_to_category_id = cfg.label_to_category_id model = create_model_detector(cfg, len(label_to_category_id)) - ckpt = torch.load(run_dir / 'checkpoint.pth.tar') - ckpt = ckpt['state_dict'] + ckpt = torch.load(run_dir / "checkpoint.pth.tar") + ckpt = ckpt["state_dict"] model.load_state_dict(ckpt) model = model.cuda().eval() model.cfg = cfg @@ -39,9 +52,9 @@ def load_detector(self, run_id): def load_pose_models(self, coarse_run_id, refiner_run_id=None, n_workers=8): run_dir = EXP_DIR / coarse_run_id - cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader) + cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.FullLoader) cfg = check_update_config_pose(cfg) - #object_ds = BOPObjectDataset(BOP_DS_DIR / 'tless/models_cad') + # object_ds = BOPObjectDataset(BOP_DS_DIR / 'tless/models_cad') object_ds = make_object_dataset(cfg.object_ds_name) mesh_db = MeshDataBase.from_object_ds(object_ds) renderer = BulletBatchRenderer(object_set=cfg.urdf_ds_name, n_workers=n_workers) @@ -51,14 +64,20 @@ def load_model(run_id): if run_id is None: return run_dir = EXP_DIR / run_id - cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader) + cfg = yaml.load( + (run_dir / "config.yaml").read_text(), Loader=yaml.FullLoader + ) cfg = check_update_config_pose(cfg) if cfg.train_refiner: - model = create_model_refiner(cfg, renderer=renderer, mesh_db=mesh_db_batched) + model = create_model_refiner( + cfg, renderer=renderer, mesh_db=mesh_db_batched + ) else: - model = create_model_coarse(cfg, renderer=renderer, mesh_db=mesh_db_batched) - ckpt = torch.load(run_dir / 'checkpoint.pth.tar') - ckpt = ckpt['state_dict'] + model = create_model_coarse( + cfg, renderer=renderer, mesh_db=mesh_db_batched + ) + ckpt = torch.load(run_dir / "checkpoint.pth.tar") + ckpt = ckpt["state_dict"] model.load_state_dict(ckpt) model = model.cuda().eval() model.cfg = cfg @@ -67,98 +86,112 @@ def load_model(run_id): coarse_model = load_model(coarse_run_id) refiner_model = load_model(refiner_run_id) - model = CoarseRefinePosePredictor(coarse_model=coarse_model, - refiner_model=refiner_model) + model = CoarseRefinePosePredictor( + coarse_model=coarse_model, refiner_model=refiner_model + ) return model, mesh_db - def get_models(self, detector_run_id, coarse_run_id, refiner_run_id): - #load models + def get_models(self, detector_run_id, coarse_run_id, refiner_run_id): + # load models detector = self.load_detector(detector_run_id) - pose_predictor, mesh_db = self.load_pose_models(coarse_run_id=coarse_run_id,refiner_run_id=refiner_run_id,n_workers=4) - return detector,pose_predictor + pose_predictor, mesh_db = self.load_pose_models( + coarse_run_id=coarse_run_id, refiner_run_id=refiner_run_id, n_workers=4 + ) + return detector, pose_predictor def inference(self, image, camera_k): - #[1,540,720,3]->[1,3,540,720] + # [1,540,720,3]->[1,3,540,720] images = torch.from_numpy(image).cuda().float().unsqueeze_(0) images = images.permute(0, 3, 1, 2) / 255 - #[1,3,3] + # [1,3,3] cameras_k = torch.from_numpy(camera_k).cuda().float().unsqueeze_(0) - #2D detector - #print("start detect object.") - box_detections = self.detector.get_detections(images=images, one_instance_per_class=False, - detection_th=0.8,output_masks=False, mask_th=0.9) - #pose esitimition + # 2D detector + # print("start detect object.") + box_detections = self.detector.get_detections( + images=images, + one_instance_per_class=False, + detection_th=0.8, + output_masks=False, + mask_th=0.9, + ) + # pose esitimition if len(box_detections) == 0: return None - #print("start estimate pose.") - final_preds, all_preds= self.pose_predictor.get_predictions(images, cameras_k, detections=box_detections, - n_coarse_iterations=1,n_refiner_iterations=4) - - #result: this_batch_detections, final_preds + # print("start estimate pose.") + final_preds, all_preds = self.pose_predictor.get_predictions( + images, + cameras_k, + detections=box_detections, + n_coarse_iterations=1, + n_refiner_iterations=4, + ) + + # result: this_batch_detections, final_preds return final_preds - + def cosypose_interface(rgb_imgs, camera_k): - rgb_imgs = rgb_imgs[...,:3] + rgb_imgs = rgb_imgs[..., :3] if os.path.exists("/tmp/cosypose_output.npz"): os.remove("/tmp/cosypose_output.npz") else: print("The file does not exist") - - if len(rgb_imgs.shape) == 3: + + if len(rgb_imgs.shape) == 3: # unsqueeze into (1, H, W, 3) if single-image (H,W,3) rgb_imgs = rgb_imgs[None, :] - - np.savez("/tmp/cosypose_input.npz", - rgbs=rgb_imgs, - K=camera_k) + + np.savez("/tmp/cosypose_input.npz", rgbs=rgb_imgs, K=camera_k) print("Entering COSYPOSE") py = os.popen(f"conda run -n {COSYPOSE_CONDA_ENV_NAME} which python").read().strip() print(py) - cmd = f"{py} {os.path.abspath(__file__)}" - pro = subprocess.Popen(cmd, stdout=subprocess.PIPE, - shell=True, preexec_fn=os.setsid) - + cmd = f"{py} {os.path.abspath(__file__)}" + pro = subprocess.Popen( + cmd, stdout=subprocess.PIPE, shell=True, preexec_fn=os.setsid + ) + while not os.path.exists("/tmp/cosypose_output.npz"): # print("waiting...") time.sleep(1.0) - os.killpg(os.getpgid(pro.pid), signal.SIGTERM) # Send the signal to all the process groups + os.killpg( + os.getpgid(pro.pid), signal.SIGTERM + ) # Send the signal to all the process groups print("Finished COSYPOSE") data = np.load("/tmp/cosypose_output.npz") return data -if __name__=="__main__": + +if __name__ == "__main__": print("SUBPROCESS:", sys.version) # expect Python 3.7.6 # do imports here to bypass during imports from jax3dp3 __init__ - from cosypose.datasets.datasets_cfg import make_scene_dataset, make_object_dataset + from cosypose.config import EXP_DIR + from cosypose.datasets.datasets_cfg import make_object_dataset + from cosypose.integrated.detector import Detector + from cosypose.integrated.pose_predictor import CoarseRefinePosePredictor # Pose estimator from cosypose.lib3d.rigid_mesh_database import MeshDataBase - from cosypose.training.pose_models_cfg import create_model_refiner, create_model_coarse - from cosypose.training.pose_models_cfg import check_update_config as check_update_config_pose from cosypose.rendering.bullet_batch_renderer import BulletBatchRenderer - from cosypose.integrated.pose_predictor import CoarseRefinePosePredictor - from cosypose.integrated.multiview_predictor import MultiviewScenePredictor - from cosypose.datasets.wrappers.multiview_wrapper import MultiViewWrapper + from cosypose.training.detector_models_cfg import ( + check_update_config as check_update_config_detector, + ) # Detection from cosypose.training.detector_models_cfg import create_model_detector - from cosypose.training.detector_models_cfg import check_update_config as check_update_config_detector - from cosypose.integrated.detector import Detector + from cosypose.training.pose_models_cfg import ( + check_update_config as check_update_config_pose, + ) + from cosypose.training.pose_models_cfg import ( + create_model_coarse, + create_model_refiner, + ) - from cosypose.evaluation.pred_runner.bop_predictions import BopPredictionRunner - - from cosypose.utils.distributed import get_tmp_dir, get_rank - from cosypose.utils.distributed import init_distributed_mode - - from cosypose.config import EXP_DIR - - os.environ["CUDA_VISIBLE_DEVICES"]= '0' + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # load model print("instantiated") @@ -166,7 +199,7 @@ def cosypose_interface(rgb_imgs, camera_k): # load data data = np.load("/tmp/cosypose_input.npz") - rgb_imgs, camera_k = data['rgbs'], data['K'] + rgb_imgs, camera_k = data["rgbs"], data["K"] num_imgs = len(rgb_imgs) all_poses = [] @@ -177,17 +210,21 @@ def cosypose_interface(rgb_imgs, camera_k): print(f"{i+1}/{num_imgs} inference done") pred_poses = np.asarray(pred.poses.cpu()) - pred_ids = [int(l[-3:])-1 for l in pred.infos.label] # ex) 'obj_000014' for GT_IDX 13 + pred_ids = [ + int(l[-3:]) - 1 for l in pred.infos.label + ] # ex) 'obj_000014' for GT_IDX 13 pred_scores = [pred.infos.iloc[i].score for i in range(len(pred.infos))] - all_poses.append(pred_poses); all_ids.append(pred_ids); all_scores.append(pred_scores) + all_poses.append(pred_poses) + all_ids.append(pred_ids) + all_scores.append(pred_scores) print(pred_poses, pred_ids, pred_scores) - np.savez("/tmp/cosypose_output.npz", - pred_poses=np.asarray(all_poses), - pred_ids=np.asarray(all_ids), - pred_scores=np.asarray(all_scores)) + np.savez( + "/tmp/cosypose_output.npz", + pred_poses=np.asarray(all_poses), + pred_ids=np.asarray(all_ids), + pred_scores=np.asarray(all_scores), + ) print("saved results. exiting") sys.exit() - - diff --git a/bayes3d/neural/dino.py b/bayes3d/neural/dino.py index 136c7227..ab3c309b 100644 --- a/bayes3d/neural/dino.py +++ b/bayes3d/neural/dino.py @@ -1,32 +1,45 @@ import argparse -import torch -import torchvision.transforms -from torch import nn -from torchvision import transforms -import torch.nn.modules.utils as nn_utils import math -import timm import types +import warnings from pathlib import Path -from typing import Union, List, Tuple -from PIL import Image -import numpy as np +from typing import List, Tuple, Union + import jax.numpy as jnp +import numpy as np +import timm +import torch +import torch.nn.modules.utils as nn_utils +import torchvision +import torchvision.transforms as T +from PIL import Image +from torch import nn +from torchvision import transforms + +import bayes3d as b + def get_embeddings(dinov2_vitg14, rgb): - img = b.get_rgb_image(rgb).convert('RGB') + img = b.get_rgb_image(rgb).convert("RGB") patch_w, patch_h = np.array(img.size) // 14 - transform = T.Compose([ - T.GaussianBlur(9, sigma=(0.1, 2.0)), - T.Resize((patch_h * 14, patch_w * 14)), - T.CenterCrop((patch_h * 14, patch_w * 14)), - T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ]) + transform = T.Compose( + [ + T.GaussianBlur(9, sigma=(0.1, 2.0)), + T.Resize((patch_h * 14, patch_w * 14)), + T.CenterCrop((patch_h * 14, patch_w * 14)), + T.ToTensor(), + T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) tensor = transform(img)[:3].unsqueeze(0).to(device) with torch.no_grad(): features_dict = dinov2_vitg14.forward_features(tensor) - features = features_dict['x_norm_patchtokens'][0].reshape((patch_h, patch_w, 384)).permute(2, 0, 1).unsqueeze(0) + features = ( + features_dict["x_norm_patchtokens"][0] + .reshape((patch_h, patch_w, 384)) + .permute(2, 0, 1) + .unsqueeze(0) + ) img_feat_norm = torch.nn.functional.normalize(features, dim=1) output = jnp.array(img_feat_norm.cpu().detach().numpy())[0] del img_feat_norm @@ -34,11 +47,11 @@ def get_embeddings(dinov2_vitg14, rgb): del tensor del features_dict torch.cuda.empty_cache() - return jnp.transpose(output, (1,2,0)) + return jnp.transpose(output, (1, 2, 0)) class ViTExtractor: - """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. + """This class facilitates extraction of features, descriptors, and saliency maps from a ViT. We use the following notation in the documentation of the module's methods: B - batch size @@ -49,7 +62,13 @@ class ViTExtractor: d - the embedding dimension in the ViT. """ - def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): + def __init__( + self, + model_type: str = "dino_vits8", + stride: int = 4, + model: nn.Module = None, + device: str = "cuda", + ): """ :param model_type: A string specifying the type of model to extract from. [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | @@ -71,8 +90,12 @@ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Mo self.p = self.model.patch_embed.patch_size self.stride = self.model.patch_embed.proj.stride - self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) - self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) + self.mean = ( + (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) + ) + self.std = ( + (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) + ) self._feats = [] self.hook_handlers = [] @@ -87,20 +110,22 @@ def create_model(model_type: str) -> nn.Module: vit_base_patch16_224] :return: the model """ - if 'dino' in model_type: - model = torch.hub.load('facebookresearch/dino:main', model_type) + if "dino" in model_type: + model = torch.hub.load("facebookresearch/dino:main", model_type) else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). temp_model = timm.create_model(model_type, pretrained=True) model_type_dict = { - 'vit_small_patch16_224': 'dino_vits16', - 'vit_small_patch8_224': 'dino_vits8', - 'vit_base_patch16_224': 'dino_vitb16', - 'vit_base_patch8_224': 'dino_vitb8' + "vit_small_patch16_224": "dino_vits16", + "vit_small_patch8_224": "dino_vits8", + "vit_base_patch16_224": "dino_vitb16", + "vit_base_patch8_224": "dino_vitb8", } - model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type]) + model = torch.hub.load( + "facebookresearch/dino:main", model_type_dict[model_type] + ) temp_state_dict = temp_model.state_dict() - del temp_state_dict['head.weight'] - del temp_state_dict['head.bias'] + del temp_state_dict["head.weight"] + del temp_state_dict["head.bias"] model.load_state_dict(temp_state_dict) return model @@ -112,7 +137,10 @@ def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): :param stride_hw: A tuple containing the new height and width stride respectively. :return: the interpolation method """ - def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + + def interpolate_pos_encoding( + self, x: torch.Tensor, w: int, h: int + ) -> torch.Tensor: npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: @@ -123,18 +151,24 @@ def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Ten # compute number of tokens taking stride into account w0 = 1 + (w - patch_size) // stride_hw[1] h0 = 1 + (h - patch_size) // stride_hw[0] - assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and + assert w0 * h0 == npatch, f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + patch_pos_embed.reshape( + 1, int(math.sqrt(N)), int(math.sqrt(N)), dim + ).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), - mode='bicubic', - align_corners=False, recompute_scale_factor=False + mode="bicubic", + align_corners=False, + recompute_scale_factor=False, + ) + assert ( + int(w0) == patch_pos_embed.shape[-2] + and int(h0) == patch_pos_embed.shape[-1] ) - assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) @@ -153,17 +187,23 @@ def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: return model stride = nn_utils._pair(stride) - assert all([(patch_size // s_) * s_ == patch_size for s_ in - stride]), f'stride {stride} should divide patch_size {patch_size}' + assert all( + [(patch_size // s_) * s_ == patch_size for s_ in stride] + ), f"stride {stride} should divide patch_size {patch_size}" # fix the stride model.patch_embed.proj.stride = stride # fix the positional encoding code - model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) + model.interpolate_pos_encoding = types.MethodType( + ViTExtractor._fix_pos_enc(patch_size, stride), model + ) return model - def preprocess(self, image_path: Union[str, Path], - load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]: + def preprocess( + self, + image_path: Union[str, Path], + load_size: Union[int, Tuple[int, int]] = None, + ) -> Tuple[torch.Tensor, Image.Image]: """ Preprocesses an image before extraction. :param image_path: path to image to be extracted. @@ -172,13 +212,14 @@ def preprocess(self, image_path: Union[str, Path], (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. (2) the pil image in relevant dimensions """ - pil_image = Image.open(image_path).convert('RGB') + pil_image = Image.open(image_path).convert("RGB") if load_size is not None: - pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) - prep = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=self.mean, std=self.std) - ]) + pil_image = transforms.Resize( + load_size, interpolation=transforms.InterpolationMode.LANCZOS + )(pil_image) + prep = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std)] + ) prep_img = prep(pil_image)[None, ...] return prep_img, pil_image @@ -186,16 +227,18 @@ def _get_hook(self, facet: str): """ generate a hook method for a specific block and facet. """ - if facet in ['attn', 'token']: + if facet in ["attn", "token"]: + def _hook(model, input, output): self._feats.append(output) + return _hook - if facet == 'query': + if facet == "query": facet_idx = 0 - elif facet == 'key': + elif facet == "key": facet_idx = 1 - elif facet == 'value': + elif facet == "value": facet_idx = 2 else: raise TypeError(f"{facet} is not a supported facet.") @@ -203,8 +246,13 @@ def _hook(model, input, output): def _inner_hook(module, input, output): input = input[0] B, N, C = input.shape - qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) - self._feats.append(qkv[facet_idx]) #Bxhxtxd + qkv = ( + module.qkv(input) + .reshape(B, N, 3, module.num_heads, C // module.num_heads) + .permute(2, 0, 3, 1, 4) + ) + self._feats.append(qkv[facet_idx]) # Bxhxtxd + return _inner_hook def _register_hooks(self, layers: List[int], facet: str) -> None: @@ -215,12 +263,20 @@ def _register_hooks(self, layers: List[int], facet: str) -> None: """ for block_idx, block in enumerate(self.model.blocks): if block_idx in layers: - if facet == 'token': - self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) - elif facet == 'attn': - self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) - elif facet in ['key', 'query', 'value']: - self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) + if facet == "token": + self.hook_handlers.append( + block.register_forward_hook(self._get_hook(facet)) + ) + elif facet == "attn": + self.hook_handlers.append( + block.attn.attn_drop.register_forward_hook( + self._get_hook(facet) + ) + ) + elif facet in ["key", "query", "value"]: + self.hook_handlers.append( + block.attn.register_forward_hook(self._get_hook(facet)) + ) else: raise TypeError(f"{facet} is not a supported facet.") @@ -232,7 +288,9 @@ def _unregister_hooks(self) -> None: handle.remove() self.hook_handlers = [] - def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: + def _extract_features( + self, batch: torch.Tensor, layers: List[int] = 11, facet: str = "key" + ) -> List[torch.Tensor]: """ extract features from the model :param batch: batch to extract features for. Has shape BxCxHxW. @@ -249,7 +307,10 @@ def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: _ = self.model(batch) self._unregister_hooks() self.load_size = (H, W) - self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) + self.num_patches = ( + 1 + (H - self.p) // self.stride[0], + 1 + (W - self.p) // self.stride[1], + ) return self._feats def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: @@ -263,7 +324,9 @@ def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) bin_x = bin_x.permute(0, 2, 1) - bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1]) + bin_x = bin_x.reshape( + B, bin_x.shape[1], self.num_patches[0], self.num_patches[1] + ) # Bx(dxh)xnum_patches[0]xnum_patches[1] sub_desc_dim = bin_x.shape[1] @@ -271,37 +334,63 @@ def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: # compute bins of all sizes for all spatial locations. for k in range(0, hierarchy): # avg pooling with kernel 3**kx3**k - win_size = 3 ** k - avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) + win_size = 3**k + avg_pool = torch.nn.AvgPool2d( + win_size, stride=1, padding=win_size // 2, count_include_pad=False + ) avg_pools.append(avg_pool(bin_x)) - bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device) + bin_x = torch.zeros( + (B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1]) + ).to(self.device) for y in range(self.num_patches[0]): for x in range(self.num_patches[1]): part_idx = 0 # fill all bins for a spatial location (y, x) for k in range(0, hierarchy): - kernel_size = 3 ** k + kernel_size = 3**k for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): - for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): + for j in range( + x - kernel_size, x + kernel_size + 1, kernel_size + ): if i == y and j == x and k != 0: continue - if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]: - bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ - :, :, i, j] + if ( + 0 <= i < self.num_patches[0] + and 0 <= j < self.num_patches[1] + ): + bin_x[ + :, + part_idx * sub_desc_dim : (part_idx + 1) + * sub_desc_dim, + y, + x, + ] = avg_pools[k][:, :, i, j] else: # handle padding in a more delicate way than zero padding temp_i = max(0, min(i, self.num_patches[0] - 1)) temp_j = max(0, min(j, self.num_patches[1] - 1)) - bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ - :, :, temp_i, - temp_j] + bin_x[ + :, + part_idx * sub_desc_dim : (part_idx + 1) + * sub_desc_dim, + y, + x, + ] = avg_pools[k][:, :, temp_i, temp_j] part_idx += 1 - bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) + bin_x = ( + bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) + ) # Bx1x(t-1)x(dxh) return bin_x - def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', - bin: bool = False, include_cls: bool = False) -> torch.Tensor: + def extract_descriptors( + self, + batch: torch.Tensor, + layer: int = 11, + facet: str = "key", + bin: bool = False, + include_cls: bool = False, + ) -> torch.Tensor: """ extract descriptors from the model :param batch: batch to extract descriptors for. Has shape BxCxHxW. @@ -310,18 +399,20 @@ def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = :param bin: apply log binning to the descriptor. default is False. :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. """ - assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. + assert facet in ["key", "query", "value", "token"], f"""{facet} is not a supported facet for descriptors. choose from ['key' | 'query' | 'value' | 'token'] """ self._extract_features(batch, [layer], facet) x = self._feats[0] - if facet == 'token': - x.unsqueeze_(dim=1) #Bx1xtxd + if facet == "token": + x.unsqueeze_(dim=1) # Bx1xtxd if not include_cls: x = x[:, :, 1:, :] # remove cls token else: assert not bin, "bin = True and include_cls = True are not supported together, set one of them False." if not bin: - desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + desc = ( + x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) + ) # Bx1xtx(dxh) else: desc = self._log_bin(x) return desc @@ -333,50 +424,94 @@ def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor: :param batch: batch to extract saliency maps for. Has shape BxCxHxW. :return: a tensor of saliency maps. has shape Bxt-1 """ - assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type." - self._extract_features(batch, [11], 'attn') + assert ( + self.model_type == "dino_vits8" + ), "saliency maps are supported only for dino_vits model_type." + self._extract_features(batch, [11], "attn") head_idxs = [0, 2, 4, 5] - curr_feats = self._feats[0] #Bxhxtxt - cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1) + curr_feats = self._feats[0] # Bxhxtxt + cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) # Bx(t-1) temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0] - cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1] + cls_attn_maps = (cls_attn_map - temp_mins) / ( + temp_maxs - temp_mins + ) # normalize to range [0,1] return cls_attn_maps + """ taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse""" + + def str2bool(v): if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Facilitate ViT Descriptor extraction.') - parser.add_argument('--image_path', type=str, required=True, help='path of the extracted image.') - parser.add_argument('--output_path', type=str, required=True, help='path to file containing extracted descriptors.') - parser.add_argument('--load_size', default=224, type=int, help='load size of the input image.') - parser.add_argument('--stride', default=4, type=int, help="""stride of first convolution layer. - small stride -> higher resolution.""") - parser.add_argument('--model_type', default='dino_vits8', type=str, - help="""type of model to extract. - Choose from [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | - vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]""") - parser.add_argument('--facet', default='key', type=str, help="""facet to create descriptors from. - options: ['key' | 'query' | 'value' | 'token']""") - parser.add_argument('--layer', default=11, type=int, help="layer to create descriptors from.") - parser.add_argument('--bin', default='False', type=str2bool, help="create a binned descriptor if True.") + parser = argparse.ArgumentParser( + description="Facilitate ViT Descriptor extraction." + ) + parser.add_argument( + "--image_path", type=str, required=True, help="path of the extracted image." + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="path to file containing extracted descriptors.", + ) + parser.add_argument( + "--load_size", default=224, type=int, help="load size of the input image." + ) + parser.add_argument( + "--stride", + default=4, + type=int, + help="""stride of first convolution layer. + small stride -> higher resolution.""", + ) + parser.add_argument( + "--model_type", + default="dino_vits8", + type=str, + help="""type of model to extract. + Choose from [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | + vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]""", + ) + parser.add_argument( + "--facet", + default="key", + type=str, + help="""facet to create descriptors from. + options: ['key' | 'query' | 'value' | 'token']""", + ) + parser.add_argument( + "--layer", default=11, type=int, help="layer to create descriptors from." + ) + parser.add_argument( + "--bin", + default="False", + type=str2bool, + help="create a binned descriptor if True.", + ) args = parser.parse_args() with torch.no_grad(): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" extractor = ViTExtractor(args.model_type, args.stride, device=device) image_batch, image_pil = extractor.preprocess(args.image_path, args.load_size) - print(f"Image {args.image_path} is preprocessed to tensor of size {image_batch.shape}.") - descriptors = extractor.extract_descriptors(image_batch.to(device), args.layer, args.facet, args.bin) + print( + f"Image {args.image_path} is preprocessed to tensor of size {image_batch.shape}." + ) + descriptors = extractor.extract_descriptors( + image_batch.to(device), args.layer, args.facet, args.bin + ) print(f"Descriptors are of size: {descriptors.shape}") torch.save(descriptors, args.output_path) print(f"Descriptors saved to: {args.output_path}") @@ -398,12 +533,11 @@ def __init__( upsample=False, **kwargs, ): - super().__init__() self.extractor = ViTExtractor(model_type, stride, device=device) self.load_size = load_size self.input_image_transform = self.get_input_image_transform() - if upsample == True: + if upsample is True: if "desired_height" in kwargs.keys(): self.desired_height = kwargs["desired_height"] if "desired_width" in kwargs.keys(): @@ -460,10 +594,11 @@ def forward(self, img, apply_default_input_transform=True): feat = upsample_feat_vec(feat, [self.desired_height, self.desired_width]) return feat + class Dino(object): def __init__(self, height, width): self.model = VITFeatureExtractor( - upsample= True, + upsample=True, desired_height=height, desired_width=width, ) @@ -479,5 +614,3 @@ def get_embeddings(self, img): img_feat = self.model.forward(img, apply_default_input_transform=False) img_feat_norm = torch.nn.functional.normalize(img_feat, dim=1) return img_feat_norm.cpu().detach().numpy() - - diff --git a/bayes3d/neural/segmentation.py b/bayes3d/neural/segmentation.py index 85a31d01..3fed2685 100644 --- a/bayes3d/neural/segmentation.py +++ b/bayes3d/neural/segmentation.py @@ -1,24 +1,28 @@ -import bayes3d as b import jax.numpy as jnp +import bayes3d as b + HIINTERFACE = None + + def carvekit_get_foreground_mask(image: b.RGBD): global HIINTERFACE if HIINTERFACE is None: - from carvekit.api.high import HiInterface import torch + from carvekit.api.high import HiInterface + HIINTERFACE = HiInterface( object_type="object", # Can be "object" or "hairs-like". batch_size_seg=5, batch_size_matting=1, - device='cuda' if torch.cuda.is_available() else 'cpu', + device="cuda" if torch.cuda.is_available() else "cpu", seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, - trimap_prob_threshold=220,#231, + trimap_prob_threshold=220, # 231, trimap_dilation=15, trimap_erosion_iters=20, - fp16=False + fp16=False, ) imgs = HIINTERFACE([b.get_rgb_image(image.rgb)]) - mask = jnp.array(imgs[0])[...,-1] > 0.5 + mask = jnp.array(imgs[0])[..., -1] > 0.5 return mask diff --git a/bayes3d/renderer.py b/bayes3d/renderer.py index 1fa6e658..49837703 100644 --- a/bayes3d/renderer.py +++ b/bayes3d/renderer.py @@ -1,46 +1,47 @@ -from typing import Tuple -import gc import functools -import bayes3d.rendering.nvdiffrast.common as dr -import bayes3d.camera -import bayes3d as j -import bayes3d as b -import bayes3d.transforms_3d as t3d -import trimesh -import jax.numpy as jnp +import gc + import jax +import jax.numpy as jnp +import numpy as np +import trimesh from jax import core, dtypes from jax.core import ShapedArray from jax.interpreters import batching, mlir, xla from jax.lib import xla_client -import numpy as np from jaxlib.hlo_helpers import custom_call -from tqdm import tqdm + +import bayes3d as b +import bayes3d as j +import bayes3d.camera +import bayes3d.rendering.nvdiffrast.common as dr + def _transform_image_zeros(image_jnp, intrinsics): image_jnp_2 = jnp.concatenate( - [ - j.t3d.unproject_depth(image_jnp[:,:,2], intrinsics), - image_jnp[:,:,3:] - ], - axis=-1 + [j.t3d.unproject_depth(image_jnp[:, :, 2], intrinsics), image_jnp[:, :, 3:]], + axis=-1, ) return image_jnp_2 + + _transform_image_zeros_jit = jax.jit(_transform_image_zeros) -_transform_image_zeros_parallel = jax.vmap(_transform_image_zeros, in_axes=(0,None)) +_transform_image_zeros_parallel = jax.vmap(_transform_image_zeros, in_axes=(0, None)) _transform_image_zeros_parallel_jit = jax.jit(_transform_image_zeros_parallel) + def setup_renderer(intrinsics, num_layers=1024): """Setup the renderer. Args: - intrinsics (bayes3d.camera.Intrinsics): The camera intrinsics. + intrinsics (bayes3d.camera.Intrinsics): The camera intrinsics. """ b.RENDERER = Renderer(intrinsics, num_layers=num_layers) + class Renderer(object): def __init__(self, intrinsics, num_layers=1024): """A renderer for rendering meshes. - + Args: intrinsics (bayes3d.camera.Intrinsics): The camera intrinsics. num_layers (int, optional): The number of scenes to render in parallel. Defaults to 1024. @@ -50,18 +51,24 @@ def __init__(self, intrinsics, num_layers=1024): self.intrinsics = intrinsics self.proj_matrix = b.camera._open_gl_projection_matrix( - intrinsics.height, intrinsics.width, - intrinsics.fx, intrinsics.fy, - intrinsics.cx, intrinsics.cy, - intrinsics.near, intrinsics.far + intrinsics.height, + intrinsics.width, + intrinsics.fx, + intrinsics.fy, + intrinsics.cx, + intrinsics.cy, + intrinsics.near, + intrinsics.far, + ) + + self.renderer_env = dr.RasterizeGLContext( + self.height, self.width, output_db=False ) - - self.renderer_env = dr.RasterizeGLContext(self.height, self.width, output_db=False) build_setup_primitive(self, self.height, self.width, num_layers).bind() - self.meshes =[] - self.mesh_names =[] - self.model_box_dims = jnp.zeros((0,3)) + self.meshes = [] + self.mesh_names = [] + self.model_box_dims = jnp.zeros((0, 3)) def clear_gpu_meshmem(self): """ @@ -77,9 +84,16 @@ def clear_gpu_meshmem(self): # Force the garbage collector to run to reclaim memory gc.collect() - def add_mesh_from_file(self, mesh_filename, mesh_name=None, scaling_factor=1.0, force=None, center_mesh=True): + def add_mesh_from_file( + self, + mesh_filename, + mesh_name=None, + scaling_factor=1.0, + force=None, + center_mesh=True, + ): """Add a mesh to the renderer from a file. - + Args: mesh_filename (str): The filename of the mesh. mesh_name (str, optional): The name of the mesh. Defaults to None. @@ -88,11 +102,16 @@ def add_mesh_from_file(self, mesh_filename, mesh_name=None, scaling_factor=1.0, center_mesh (bool, optional): Whether to center the mesh. Defaults to True. """ mesh = trimesh.load(mesh_filename, force=force) - self.add_mesh(mesh, mesh_name=mesh_name, scaling_factor=scaling_factor, center_mesh=center_mesh) + self.add_mesh( + mesh, + mesh_name=mesh_name, + scaling_factor=scaling_factor, + center_mesh=center_mesh, + ) def add_mesh(self, mesh, mesh_name=None, scaling_factor=1.0, center_mesh=True): """Add a mesh to the renderer. - + Args: mesh (trimesh.Trimesh): The mesh to add. mesh_name (str, optional): The name of the mesh. Defaults to None. @@ -101,47 +120,45 @@ def add_mesh(self, mesh, mesh_name=None, scaling_factor=1.0, center_mesh=True): """ if mesh_name is None: mesh_name = f"object_{len(self.meshes)}" - + mesh.vertices = mesh.vertices * scaling_factor - + bounding_box_dims, bounding_box_pose = bayes3d.utils.aabb(mesh.vertices) if center_mesh: - if not jnp.isclose(bounding_box_pose[:3,3], 0.0).all(): + if not jnp.isclose(bounding_box_pose[:3, 3], 0.0).all(): print(f"Centering mesh with translation {bounding_box_pose[:3,3]}") - mesh.vertices = mesh.vertices - bounding_box_pose[:3,3] + mesh.vertices = mesh.vertices - bounding_box_pose[:3, 3] self.meshes.append(mesh) self.mesh_names.append(mesh_name) - self.model_box_dims = jnp.vstack( - [ - self.model_box_dims, - bounding_box_dims - ] - ) + self.model_box_dims = jnp.vstack([self.model_box_dims, bounding_box_dims]) vertices = np.array(mesh.vertices) - vertices = np.concatenate([vertices, np.ones((*vertices.shape[:-1],1))],axis=-1) + vertices = np.concatenate( + [vertices, np.ones((*vertices.shape[:-1], 1))], axis=-1 + ) triangles = np.array(mesh.faces) prim = build_load_vertices_primitive(self) prim.bind(jnp.float32(vertices), jnp.int32(triangles)) - - - def render_many_custom_intrinsics(self, poses, indices, intrinsics): proj_matrix = b.camera._open_gl_projection_matrix( - intrinsics.height, intrinsics.width, - intrinsics.fx, intrinsics.fy, - intrinsics.cx, intrinsics.cy, - intrinsics.near, intrinsics.far + intrinsics.height, + intrinsics.width, + intrinsics.fx, + intrinsics.fy, + intrinsics.cx, + intrinsics.cy, + intrinsics.near, + intrinsics.far, ) images_jnp = _render_custom_call(self, poses, indices, proj_matrix)[0] return _transform_image_zeros_parallel(images_jnp, intrinsics) def render_many(self, poses, indices): """Render many scenes in parallel. - + Args: poses (jnp.ndarray): The poses of the objects in the scene. Shape (N, M, 4, 4) where N is the number of scenes and M is the number of objects. @@ -155,19 +172,24 @@ def render_many(self, poses, indices): return self.render_many_custom_intrinsics(poses, indices, self.intrinsics) def render(self, poses, indices): - return self.render_many(poses[None,...], indices)[0] + return self.render_many(poses[None, ...], indices)[0] def render_custom_intrinsics(self, poses, indices, intrinsics): - return self.render_many_custom_intrinsics(poses[None,...], indices, intrinsics)[0] + return self.render_many_custom_intrinsics( + poses[None, ...], indices, intrinsics + )[0] + # Useful reference for understanding the custom calls setup: # https://github.com/dfm/extending-jax + @functools.lru_cache def _register_custom_calls(): for _name, _value in dr._get_plugin(gl=True).registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="gpu") + @functools.partial(jax.jit, static_argnums=(0,)) def _render_custom_call(r: "Renderer", poses, indices, intrinsics_matrix): return _build_render_primitive(r).bind(poses, indices, intrinsics_matrix) @@ -176,25 +198,33 @@ def _render_custom_call(r: "Renderer", poses, indices, intrinsics_matrix): @functools.lru_cache(maxsize=None) def _build_render_primitive(r: "Renderer"): _register_custom_calls() + # For JIT compilation we need a function to evaluate the shape and dtype of the # outputs of our op for some given inputs def _render_abstract(poses, indices, intrinsics_matrix): num_images = poses.shape[0] if poses.shape[1] != indices.shape[0]: - raise ValueError(f"Poses Shape: {poses.shape} Indices Shape: {indices.shape}") + raise ValueError( + f"Poses Shape: {poses.shape} Indices Shape: {indices.shape}" + ) dtype = dtypes.canonicalize_dtype(poses.dtype) - return [ShapedArray((num_images, r.height, r.width, 4), dtype), - ShapedArray((), dtype)] + return [ + ShapedArray((num_images, r.height, r.width, 4), dtype), + ShapedArray((), dtype), + ] # Provide an MLIR "lowering" of the render primitive. def _render_lowering(ctx, poses, indices, intrinsics_matrix): - # Extract the numpy type of the inputs poses_aval, indices_aval, intrinsics_matrix_aval = ctx.avals_in if poses_aval.ndim != 4: - raise NotImplementedError(f"Only 4D inputs supported: got {poses_aval.shape}") + raise NotImplementedError( + f"Only 4D inputs supported: got {poses_aval.shape}" + ) if indices_aval.ndim != 1: - raise NotImplementedError(f"Only 1D inputs supported: got {indices_aval.shape}") + raise NotImplementedError( + f"Only 1D inputs supported: got {indices_aval.shape}" + ) np_dtype = np.dtype(poses_aval.dtype) if np_dtype != np.float32: @@ -204,15 +234,20 @@ def _render_lowering(ctx, poses, indices, intrinsics_matrix): num_images, num_objects = poses_aval.shape[:2] out_shp_dtype = mlir.ir.RankedTensorType.get( - [num_images, r.height, r.width, 4], - mlir.dtype_to_ir_type(poses_aval.dtype)) + [num_images, r.height, r.width, 4], mlir.dtype_to_ir_type(poses_aval.dtype) + ) if num_objects != indices_aval.shape[0]: - raise ValueError(f"Poses Shape: {poses_aval.shape} Indices Shape: {indices_aval.shape}") - opaque = dr._get_plugin(gl=True).build_rasterize_descriptor(r.renderer_env.cpp_wrapper, - [num_objects, num_images]) + raise ValueError( + f"Poses Shape: {poses_aval.shape} Indices Shape: {indices_aval.shape}" + ) + opaque = dr._get_plugin(gl=True).build_rasterize_descriptor( + r.renderer_env.cpp_wrapper, [num_objects, num_images] + ) - scalar_dummy = mlir.ir.RankedTensorType.get([], mlir.dtype_to_ir_type(poses_aval.dtype)) + scalar_dummy = mlir.ir.RankedTensorType.get( + [], mlir.dtype_to_ir_type(poses_aval.dtype) + ) op_name = "jax_rasterize_fwd_gl" return custom_call( op_name, @@ -221,20 +256,26 @@ def _render_lowering(ctx, poses, indices, intrinsics_matrix): # The inputs: operands=[poses, indices, intrinsics_matrix], # Layout specification: - operand_layouts=[(3, 2, 0, 1), (0,), (1,0,)], + operand_layouts=[ + (3, 2, 0, 1), + (0,), + ( + 1, + 0, + ), + ], result_layouts=[(3, 2, 1, 0), ()], # GPU specific additional data - backend_config=opaque + backend_config=opaque, ).results - # ************************************ # * SUPPORT FOR BATCHING WITH VMAP * # ************************************ def _render_batch(args, axes): - poses, indices, intrinsics_matrix = args + poses, indices, intrinsics_matrix = args if poses.ndim != 5: - raise NotImplementedError("Underlying primitive must operate on 4D poses.") + raise NotImplementedError("Underlying primitive must operate on 4D poses.") original_shape = poses.shape poses = jnp.moveaxis(poses, axes[0], 0) @@ -244,16 +285,19 @@ def _render_batch(args, axes): poses = poses.reshape(size_1 * size_2, num_objects, 4, 4) if poses.shape[1] != indices.shape[0]: - raise ValueError(f"Poses Original Shape: {original_shape} Poses Shape: {poses.shape} Indices Shape: {indices.shape}") + raise ValueError( + f"Poses Original Shape: {original_shape} Poses Shape: {poses.shape} Indices Shape: {indices.shape}" + ) if poses.shape[-2:] != (4, 4): - raise ValueError(f"Poses Original Shape: {original_shape} Poses Shape: {poses.shape} Indices Shape: {indices.shape}") + raise ValueError( + f"Poses Original Shape: {original_shape} Poses Shape: {poses.shape} Indices Shape: {indices.shape}" + ) renders, dummy = _render_custom_call(r, poses, indices, intrinsics_matrix) renders = renders.reshape(size_1, size_2, *renders.shape[1:]) out_axes = 0, None return (renders, dummy), out_axes - # ********************************************* # * BOILERPLATE TO REGISTER THE OP WITH JAX * # ********************************************* @@ -265,17 +309,15 @@ def _render_batch(args, axes): # Connect the XLA translation rules for JIT compilation mlir.register_lowering(_render_prim, _render_lowering, platform="gpu") batching.primitive_batchers[_render_prim] = _render_batch - - return _render_prim - + return _render_prim @functools.lru_cache(maxsize=None) def build_setup_primitive(r: "Renderer", h, w, num_layers): _register_custom_calls() # print('build_setup_primitive') - + # For JIT compilation we need a function to evaluate the shape and dtype of the # outputs of our op for some given inputs def _setup_abstract(): @@ -288,9 +330,12 @@ def _setup_lowering(ctx): # print('lowering setup!') opaque = dr._get_plugin(gl=True).build_setup_descriptor( - r.renderer_env.cpp_wrapper, h, w, num_layers) + r.renderer_env.cpp_wrapper, h, w, num_layers + ) - scalar_dummy = mlir.ir.RankedTensorType.get([], mlir.dtype_to_ir_type(np.dtype(np.float32))) + scalar_dummy = mlir.ir.RankedTensorType.get( + [], mlir.dtype_to_ir_type(np.dtype(np.float32)) + ) op_name = "jax_setup" return custom_call( op_name, @@ -302,7 +347,7 @@ def _setup_lowering(ctx): operand_layouts=[], result_layouts=[(), ()], # GPU specific additional data - backend_config=opaque + backend_config=opaque, ).results # ********************************************* @@ -315,7 +360,7 @@ def _setup_lowering(ctx): # Connect the XLA translation rules for JIT compilation mlir.register_lowering(_prim, _setup_lowering, platform="gpu") - + return _prim @@ -323,7 +368,7 @@ def _setup_lowering(ctx): def build_load_vertices_primitive(r: "Renderer"): _register_custom_calls() # print('build_load_vertices_primitive') - + # For JIT compilation we need a function to evaluate the shape and dtype of the # outputs of our op for some given inputs def _load_vertices_abstract(vertices, triangles): @@ -343,9 +388,12 @@ def _load_vertices_lowering(ctx, vertices, triangles): raise NotImplementedError(f"Unsupported triangles dtype {np_dtype}") opaque = dr._get_plugin(gl=True).build_load_vertices_descriptor( - r.renderer_env.cpp_wrapper, vertices_aval.shape[0], triangles_aval.shape[0]) + r.renderer_env.cpp_wrapper, vertices_aval.shape[0], triangles_aval.shape[0] + ) - scalar_dummy = mlir.ir.RankedTensorType.get([], mlir.dtype_to_ir_type(np.dtype(np.float32))) + scalar_dummy = mlir.ir.RankedTensorType.get( + [], mlir.dtype_to_ir_type(np.dtype(np.float32)) + ) op_name = "jax_load_vertices" return custom_call( op_name, @@ -357,7 +405,7 @@ def _load_vertices_lowering(ctx, vertices, triangles): operand_layouts=[(1, 0), (1, 0)], result_layouts=[(), ()], # GPU specific additional data - backend_config=opaque + backend_config=opaque, ).results # ********************************************* @@ -370,7 +418,5 @@ def _load_vertices_lowering(ctx, vertices, triangles): # Connect the XLA translation rules for JIT compilation mlir.register_lowering(_prim, _load_vertices_lowering, platform="gpu") - - return _prim - + return _prim diff --git a/bayes3d/rendering/nvdiffrast/common/__init__.py b/bayes3d/rendering/nvdiffrast/common/__init__.py index 000b85ba..68394fff 100644 --- a/bayes3d/rendering/nvdiffrast/common/__init__.py +++ b/bayes3d/rendering/nvdiffrast/common/__init__.py @@ -6,5 +6,6 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -from .ops import RasterizeGLContext, get_log_level, set_log_level, _get_plugin +from .ops import RasterizeGLContext, _get_plugin, get_log_level, set_log_level + __all__ = ["RasterizeGLContext", "get_log_level", "set_log_level", "_get_plugin"] diff --git a/bayes3d/rendering/nvdiffrast/common/ops.py b/bayes3d/rendering/nvdiffrast/common/ops.py index de3eb1b8..fce6cb23 100644 --- a/bayes3d/rendering/nvdiffrast/common/ops.py +++ b/bayes3d/rendering/nvdiffrast/common/ops.py @@ -8,15 +8,17 @@ import importlib import logging -import numpy as np import os + import torch import torch.utils.cpp_extension -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # C++/Cuda plugin compiler/loader. _cached_plugin = {} + + def _get_plugin(gl=False): assert isinstance(gl, bool) @@ -25,14 +27,24 @@ def _get_plugin(gl=False): return _cached_plugin[gl] # Make sure we can find the necessary compiler and libary binaries. - if os.name == 'nt': + if os.name == "nt": lib_dir = os.path.dirname(__file__) + r"\..\lib" + def find_cl_path(): import glob - for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: - vs_relative_path = r"\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition - paths = sorted(glob.glob(r"C:\Program Files" + vs_relative_path), reverse=True) - paths += sorted(glob.glob(r"C:\Program Files (x86)" + vs_relative_path), reverse=True) + + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + vs_relative_path = ( + r"\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" + % edition + ) + paths = sorted( + glob.glob(r"C:\Program Files" + vs_relative_path), reverse=True + ) + paths += sorted( + glob.glob(r"C:\Program Files (x86)" + vs_relative_path), + reverse=True, + ) if paths: return paths[0] @@ -40,95 +52,117 @@ def find_cl_path(): if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: - raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") - os.environ['PATH'] += ';' + cl_path + raise RuntimeError( + "Could not locate a supported Microsoft Visual C++ installation" + ) + os.environ["PATH"] += ";" + cl_path # Compiler options. - opts = ['-DNVDR_TORCH'] + opts = ["-DNVDR_TORCH"] # Linker options for the GL-interfacing plugin. ldflags = [] if gl: - if os.name == 'posix': - ldflags = ['-lGL', '-lEGL'] - elif os.name == 'nt': - libs = ['gdi32', 'opengl32', 'user32', 'setgpu'] - ldflags = ['/LIBPATH:' + lib_dir] + ['/DEFAULTLIB:' + x for x in libs] + if os.name == "posix": + ldflags = ["-lGL", "-lEGL"] + elif os.name == "nt": + libs = ["gdi32", "opengl32", "user32", "setgpu"] + ldflags = ["/LIBPATH:" + lib_dir] + ["/DEFAULTLIB:" + x for x in libs] # List of source files. if gl: source_files = [ - 'common.cpp', - 'glutil.cpp', - 'rasterize_gl.cpp', + "common.cpp", + "glutil.cpp", + "rasterize_gl.cpp", ] else: source_files = [ - '../common/common.cpp', - '../common/rasterize.cu', - '../common/interpolate.cu', - '../common/texture.cu', - '../common/texture.cpp', - '../common/antialias.cu', - 'torch_bindings.cpp', - 'torch_rasterize.cpp', - 'torch_interpolate.cpp', - 'torch_texture.cpp', - 'torch_antialias.cpp', + "../common/common.cpp", + "../common/rasterize.cu", + "../common/interpolate.cu", + "../common/texture.cu", + "../common/texture.cpp", + "../common/antialias.cu", + "torch_bindings.cpp", + "torch_rasterize.cpp", + "torch_interpolate.cpp", + "torch_texture.cpp", + "torch_antialias.cpp", ] # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. - os.environ['TORCH_CUDA_ARCH_LIST'] = '' + os.environ["TORCH_CUDA_ARCH_LIST"] = "" # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin. - if gl and (os.name == 'posix') and ('libGLEW' in os.environ.get('LD_PRELOAD', '')): - logging.getLogger('nvdiffrast').warning("Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin") + if gl and (os.name == "posix") and ("libGLEW" in os.environ.get("LD_PRELOAD", "")): + logging.getLogger("nvdiffrast").warning( + "Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin" + ) # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. - plugin_name = 'nvdiffrast_plugin' + ('_gl' if gl else '') + plugin_name = "nvdiffrast_plugin" + ("_gl" if gl else "") try: - lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory(plugin_name, False), 'lock') + lock_fn = os.path.join( + torch.utils.cpp_extension._get_build_directory(plugin_name, False), "lock" + ) if os.path.exists(lock_fn): - logging.getLogger('nvdiffrast').warning("Lock file exists in build directory: '%s'" % lock_fn) - except: + logging.getLogger("nvdiffrast").warning( + "Lock file exists in build directory: '%s'" % lock_fn + ) + except Exception: pass # Speed up compilation on Windows. - if os.name == 'nt': + if os.name == "nt": # Skip telemetry sending step in vcvarsall.bat - os.environ['VSCMD_SKIP_SENDTELEMETRY'] = '1' + os.environ["VSCMD_SKIP_SENDTELEMETRY"] = "1" # Opportunistically patch distutils to cache MSVC environments. try: import distutils._msvccompiler import functools - if not hasattr(distutils._msvccompiler._get_vc_env, '__wrapped__'): - distutils._msvccompiler._get_vc_env = functools.lru_cache()(distutils._msvccompiler._get_vc_env) - except: + + if not hasattr(distutils._msvccompiler._get_vc_env, "__wrapped__"): + distutils._msvccompiler._get_vc_env = functools.lru_cache()( + distutils._msvccompiler._get_vc_env + ) + except Exception: pass # Compile and load. source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] - torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=opts, extra_cuda_cflags=opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False) + torch.utils.cpp_extension.load( + name=plugin_name, + sources=source_paths, + extra_cflags=opts, + extra_cuda_cflags=opts + ["-lineinfo"], + extra_ldflags=ldflags, + with_cuda=True, + verbose=False, + ) # Import, cache, and return the compiled module. _cached_plugin[gl] = importlib.import_module(plugin_name) return _cached_plugin[gl] -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Log level. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def get_log_level(): - '''Get current log level. + """Get current log level. Returns: Current log level in nvdiffrast. See `set_log_level()` for possible values. - ''' + """ return _get_plugin().get_log_level() + def set_log_level(level): - '''Set log level. + """Set log level. Log levels follow the convention on the C++ side of Torch: 0 = Info, @@ -138,19 +172,21 @@ def set_log_level(level): The default log level is 1. Args: - level: New log level as integer. Internal nvdiffrast messages of this + level: New log level as integer. Internal nvdiffrast messages of this severity or higher will be printed, while messages of lower severity will be silent. - ''' + """ _get_plugin().set_log_level(level) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # GL state wrapper. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class RasterizeGLContext: - def __init__(self, height, width, output_db=False, mode='automatic', device=None): - '''Create a new OpenGL rasterizer context. + def __init__(self, height, width, output_db=False, mode="automatic", device=None): + """Create a new OpenGL rasterizer context. Creating an OpenGL context is a slow operation so you should usually reuse the same context in all calls to `rasterize()` on the same CPU thread. The OpenGL context @@ -174,9 +210,9 @@ def __init__(self, height, width, output_db=False, mode='automatic', device=None device. Returns: The newly created OpenGL rasterizer context. - ''' + """ assert output_db is True or output_db is False - assert mode in ['automatic', 'manual'] + assert mode in ["automatic", "manual"] self.output_db = output_db self.mode = mode if device is None: @@ -184,19 +220,21 @@ def __init__(self, height, width, output_db=False, mode='automatic', device=None else: with torch.cuda.device(device): cuda_device_idx = torch.cuda.current_device() - self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(output_db, mode == 'automatic', cuda_device_idx) - self.active_depth_peeler = None # For error checking only. + self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper( + output_db, mode == "automatic", cuda_device_idx + ) + self.active_depth_peeler = None # For error checking only. def set_context(self): - '''Set (activate) OpenGL context in the current CPU thread. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Set (activate) OpenGL context in the current CPU thread. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.set_context() def release_context(self): - '''Release (deactivate) currently active OpenGL context. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Release (deactivate) currently active OpenGL context. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.release_context() diff --git a/bayes3d/rendering/nvdiffrast_jax/jax_renderer.py b/bayes3d/rendering/nvdiffrast_jax/jax_renderer.py index a96252e1..b2357276 100644 --- a/bayes3d/rendering/nvdiffrast_jax/jax_renderer.py +++ b/bayes3d/rendering/nvdiffrast_jax/jax_renderer.py @@ -1,28 +1,25 @@ -from typing import Tuple - import functools -# import bayes3d._rendering.nvdiffrast.common as dr -import bayes3d.rendering.nvdiffrast_jax.nvdiffrast.jax as dr -import bayes3d.camera -import bayes3d as j -import bayes3d as b -import bayes3d.transforms_3d as t3d -import trimesh -import jax.numpy as jnp + import jax +import jax.numpy as jnp +import numpy as np from jax import core, dtypes from jax.core import ShapedArray -from jax.interpreters import batching, mlir, xla +from jax.interpreters import mlir, xla from jax.lib import xla_client -import numpy as np from jaxlib.hlo_helpers import custom_call -from tqdm import tqdm -from jax import custom_vjp + +import bayes3d as b +import bayes3d.camera + +# import bayes3d._rendering.nvdiffrast.common as dr +import bayes3d.rendering.nvdiffrast_jax.nvdiffrast.jax as dr + class Renderer(object): def __init__(self, intrinsics, num_layers=1024): """A renderer for rendering meshes. - + Args: intrinsics (bayes3d.camera.Intrinsics): The camera intrinsics. num_layers (int, optional): The number of scenes to render in parallel. Defaults to 1024. @@ -32,9 +29,9 @@ def __init__(self, intrinsics, num_layers=1024): self.rasterize = jax.tree_util.Partial(self._rasterize, self) self.interpolate = jax.tree_util.Partial(self._interpolate, self) - #------------------ + # ------------------ # Rasterization - #------------------ + # ------------------ @functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) def _rasterize(self, pos, tri, resolution): @@ -54,103 +51,144 @@ def _rasterize_bwd(self, saved_tensors, diffs): _rasterize.defvjp(_rasterize_fwd, _rasterize_bwd) - - #------------------ + # ------------------ # Interpolation - #------------------ + # ------------------ @functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) def _interpolate(self, attr, rast, tri, rast_db, diff_attrs): num_total_attrs = attr.shape[-1] - diff_attrs_all = jax.lax.cond(diff_attrs.shape[0] == num_total_attrs, - lambda:True, - lambda:False) - return _interpolate_fwd_custom_call(self, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs) + diff_attrs_all = jax.lax.cond( + diff_attrs.shape[0] == num_total_attrs, lambda: True, lambda: False + ) + return _interpolate_fwd_custom_call( + self, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs + ) def _interpolate_fwd(self, attr, rast, tri, rast_db, diff_attrs): num_total_attrs = attr.shape[-1] - diff_attrs_all = jax.lax.cond(diff_attrs.shape[0] == num_total_attrs, - lambda:True, - lambda:False) - out, out_da = _interpolate_fwd_custom_call(self, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs) + diff_attrs_all = jax.lax.cond( + diff_attrs.shape[0] == num_total_attrs, lambda: True, lambda: False + ) + out, out_da = _interpolate_fwd_custom_call( + self, attr, rast, tri, rast_db, diff_attrs_all, diff_attrs + ) saved_tensors = (attr, rast, tri, rast_db, diff_attrs_all, diff_attrs) return (out, out_da), saved_tensors - + def _interpolate_bwd(self, saved_tensors, diffs): attr, rast, tri, rast_db, diff_attrs_all, diff_attrs_list = saved_tensors - dy, dda = diffs - g_attr, g_rast, g_rast_db = _interpolate_bwd_custom_call(self, attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) + dy, dda = diffs + g_attr, g_rast, g_rast_db = _interpolate_bwd_custom_call( + self, attr, rast, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list + ) return g_attr, g_rast, None, g_rast_db, None - + _interpolate.defvjp(_interpolate_fwd, _interpolate_bwd) def render_many(self, vertices, faces, poses, intrinsics): jax_renderer = self projection_matrix = b.camera._open_gl_projection_matrix( - intrinsics.height, intrinsics.width, - intrinsics.fx, intrinsics.fy, - intrinsics.cx, intrinsics.cy, - intrinsics.near, intrinsics.far + intrinsics.height, + intrinsics.width, + intrinsics.fx, + intrinsics.fy, + intrinsics.cx, + intrinsics.cy, + intrinsics.near, + intrinsics.far, ) composed_projection = projection_matrix @ poses - vertices_homogenous = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1) - clip_spaces_projected_vertices = jnp.einsum("nij,mj->nmi", composed_projection, vertices_homogenous) - rast_out, rast_out_db = jax_renderer.rasterize(clip_spaces_projected_vertices, faces, jnp.array([intrinsics.height, intrinsics.width])) - interpolated_collided_vertices_clip, _ = jax_renderer.interpolate(jnp.tile(vertices_homogenous[None,...],(poses.shape[0],1,1)), rast_out, faces, rast_out_db, jnp.array([0,1,2,3])) - interpolated_collided_vertices = jnp.einsum("a...ij,a...j->a...i", poses, interpolated_collided_vertices_clip) - mask = rast_out[...,-1] > 0 - depth = interpolated_collided_vertices[...,2] * mask + vertices_homogenous = jnp.concatenate( + [vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1 + ) + clip_spaces_projected_vertices = jnp.einsum( + "nij,mj->nmi", composed_projection, vertices_homogenous + ) + rast_out, rast_out_db = jax_renderer.rasterize( + clip_spaces_projected_vertices, + faces, + jnp.array([intrinsics.height, intrinsics.width]), + ) + interpolated_collided_vertices_clip, _ = jax_renderer.interpolate( + jnp.tile(vertices_homogenous[None, ...], (poses.shape[0], 1, 1)), + rast_out, + faces, + rast_out_db, + jnp.array([0, 1, 2, 3]), + ) + interpolated_collided_vertices = jnp.einsum( + "a...ij,a...j->a...i", poses, interpolated_collided_vertices_clip + ) + mask = rast_out[..., -1] > 0 + depth = interpolated_collided_vertices[..., 2] * mask return depth def render(self, vertices, faces, object_pose, intrinsics): jax_renderer = self projection_matrix = b.camera._open_gl_projection_matrix( - intrinsics.height, intrinsics.width, - intrinsics.fx, intrinsics.fy, - intrinsics.cx, intrinsics.cy, - intrinsics.near, intrinsics.far + intrinsics.height, + intrinsics.width, + intrinsics.fx, + intrinsics.fy, + intrinsics.cx, + intrinsics.cy, + intrinsics.near, + intrinsics.far, ) final_mtx_proj = projection_matrix @ object_pose - posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1],1))], axis=-1) + posw = jnp.concatenate([vertices, jnp.ones((*vertices.shape[:-1], 1))], axis=-1) pos_clip_ja = xfm_points(vertices, final_mtx_proj) - rast_out, rast_out_db = jax_renderer.rasterize(pos_clip_ja[None,...], faces, jnp.array([intrinsics.height, intrinsics.width])) - gb_pos,_ = jax_renderer.interpolate(posw[None,...], rast_out, faces, rast_out_db, jnp.array([0,1,2,3])) + rast_out, rast_out_db = jax_renderer.rasterize( + pos_clip_ja[None, ...], + faces, + jnp.array([intrinsics.height, intrinsics.width]), + ) + gb_pos, _ = jax_renderer.interpolate( + posw[None, ...], rast_out, faces, rast_out_db, jnp.array([0, 1, 2, 3]) + ) mask = rast_out[..., -1] > 0 shape_keep = gb_pos.shape gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1]) gb_pos = gb_pos[..., :3] depth = xfm_points(gb_pos, object_pose) depth = depth.reshape(shape_keep)[..., 2] * -1 - return - (depth * mask), mask + return -(depth * mask), mask + # ================================================================================================ # Register custom call targets helpers # ================================================================================================ def xfm_points(points, matrix): - points2 = jnp.concatenate([points, jnp.ones((*points.shape[:-1],1))], axis=-1) + points2 = jnp.concatenate([points, jnp.ones((*points.shape[:-1], 1))], axis=-1) return jnp.matmul(points2, matrix.T) + # XLA array layout in memory def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] + # Register custom call targets @functools.lru_cache def _register_custom_calls(): for _name, _value in dr._get_plugin(gl=True).registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="gpu") + # ================================================================================================ # Rasterize # ================================================================================================ #### FORWARD #### + # @functools.partial(jax.jit, static_argnums=(0,)) def _rasterize_fwd_custom_call(r: "Renderer", pos, tri, resolution): return _build_rasterize_fwd_primitive(r).bind(pos, tri, resolution) + @functools.lru_cache(maxsize=None) def _build_rasterize_fwd_primitive(r: "Renderer"): _register_custom_calls() @@ -158,30 +196,44 @@ def _build_rasterize_fwd_primitive(r: "Renderer"): # outputs of our op for some given inputs def _rasterize_fwd_abstract(pos, tri, resolution): - if (len(pos.shape) != 3 or pos.shape[-1] != 4): - raise ValueError(f"Pass in a [num_images, num_vertices, 4] sized first input") - num_images= pos.shape[0] + if len(pos.shape) != 3 or pos.shape[-1] != 4: + raise ValueError( + "Pass in a [num_images, num_vertices, 4] sized first input" + ) + num_images = pos.shape[0] dtype = dtypes.canonicalize_dtype(pos.dtype) - return [ShapedArray((num_images, r.intrinsics.height, r.intrinsics.width, 4), dtype), - ShapedArray((num_images, r.intrinsics.height, r.intrinsics.width, 4), dtype)] + return [ + ShapedArray( + (num_images, r.intrinsics.height, r.intrinsics.width, 4), dtype + ), + ShapedArray( + (num_images, r.intrinsics.height, r.intrinsics.width, 4), dtype + ), + ] # Provide an MLIR "lowering" of the rasterize primitive. def _rasterize_fwd_lowering(ctx, pos, tri, resolution): """ - Single-object (one obj represented by tri) rasterization with + Single-object (one obj represented by tri) rasterization with multiple poses (first dimension fo pos) dr.rasterize(glctx, pos, tri, resolution=resolution) """ # Extract the numpy type of the inputs poses_aval, tri_aval, resolution_aval = ctx.avals_in if poses_aval.ndim != 3: - raise NotImplementedError(f"Only 3D vtx position inputs supported: got {poses_aval.shape}") + raise NotImplementedError( + f"Only 3D vtx position inputs supported: got {poses_aval.shape}" + ) if tri_aval.ndim != 2: - raise NotImplementedError(f"Only 2D triangle inputs supported: got {tri_aval.shape}") + raise NotImplementedError( + f"Only 2D triangle inputs supported: got {tri_aval.shape}" + ) if resolution_aval.shape[0] != 2: - raise NotImplementedError(f"Only 2D resolutions supported: got {resolution_aval.shape}") + raise NotImplementedError( + f"Only 2D resolutions supported: got {resolution_aval.shape}" + ) np_dtype = np.dtype(poses_aval.dtype) if np_dtype != np.float32: @@ -193,10 +245,12 @@ def _rasterize_fwd_lowering(ctx, pos, tri, resolution): num_triangles = tri_aval.shape[0] out_shp_dtype = mlir.ir.RankedTensorType.get( [num_images, r.intrinsics.height, r.intrinsics.width, 4], - mlir.dtype_to_ir_type(np_dtype)) + mlir.dtype_to_ir_type(np_dtype), + ) - opaque = dr._get_plugin(gl=True).build_diff_rasterize_fwd_descriptor(r.renderer_env.cpp_wrapper, - [num_images, num_vertices, num_triangles]) + opaque = dr._get_plugin(gl=True).build_diff_rasterize_fwd_descriptor( + r.renderer_env.cpp_wrapper, [num_images, num_vertices, num_triangles] + ) op_name = "jax_rasterize_fwd_gl" @@ -207,8 +261,23 @@ def _rasterize_fwd_lowering(ctx, pos, tri, resolution): # The inputs: operands=[pos, tri, resolution], backend_config=opaque, - operand_layouts=default_layouts(poses_aval.shape, tri_aval.shape, resolution_aval.shape), - result_layouts=default_layouts((num_images, r.intrinsics.height, r.intrinsics.width, 4,), (num_images, r.intrinsics.height, r.intrinsics.width, 4,)), + operand_layouts=default_layouts( + poses_aval.shape, tri_aval.shape, resolution_aval.shape + ), + result_layouts=default_layouts( + ( + num_images, + r.intrinsics.height, + r.intrinsics.width, + 4, + ), + ( + num_images, + r.intrinsics.height, + r.intrinsics.width, + 4, + ), + ), ).results # ********************************************* @@ -225,13 +294,14 @@ def _rasterize_fwd_lowering(ctx, pos, tri, resolution): return _rasterize_prim - #### BACKWARD #### + # @functools.partial(jax.jit, static_argnums=(0,)) def _rasterize_bwd_custom_call(r: "Renderer", pos, tri, rast_out, dy, ddb): return _build_rasterize_bwd_primitive(r).bind(pos, tri, rast_out, dy, ddb) + @functools.lru_cache(maxsize=None) def _build_rasterize_bwd_primitive(r: "Renderer"): _register_custom_calls() @@ -239,8 +309,10 @@ def _build_rasterize_bwd_primitive(r: "Renderer"): # outputs of our op for some given inputs def _rasterize_bwd_abstract(pos, tri, rast_out, dy, ddb): - if (len(pos.shape) != 3): - raise ValueError(f"Pass in a [num_images, num_vertices, 4] sized first input") + if len(pos.shape) != 3: + raise ValueError( + "Pass in a [num_images, num_vertices, 4] sized first input" + ) out_shp = pos.shape dtype = dtypes.canonicalize_dtype(pos.dtype) @@ -256,21 +328,25 @@ def _rasterize_bwd_lowering(ctx, pos, tri, rast_out, dy, ddb): depth, height, width = rast_aval.shape[:3] if rast_aval.ndim != 4: - raise NotImplementedError(f"Rasterization output should be 4D: got {rast_aval.shape}") + raise NotImplementedError( + f"Rasterization output should be 4D: got {rast_aval.shape}" + ) if dy_aval.ndim != 4 or ddb_aval.ndim != 4: - raise NotImplementedError(f"Grad outputs from rasterize should be 4D: got dy={dy_aval.shape} and ddb={ddb_aval.shape}") + raise NotImplementedError( + f"Grad outputs from rasterize should be 4D: got dy={dy_aval.shape} and ddb={ddb_aval.shape}" + ) np_dtype = np.dtype(rast_aval.dtype) if np_dtype != np.float32: raise NotImplementedError(f"Unsupported dtype {np_dtype}") out_shp_dtype = mlir.ir.RankedTensorType.get( - [num_images, num_vertices, 4], - mlir.dtype_to_ir_type(np_dtype)) # gradients have same size as the positions + [num_images, num_vertices, 4], mlir.dtype_to_ir_type(np_dtype) + ) # gradients have same size as the positions - opaque = dr._get_plugin(gl=True).build_diff_rasterize_bwd_descriptor([num_images, num_vertices], - [num_triangles], - [depth, height, width]) + opaque = dr._get_plugin(gl=True).build_diff_rasterize_bwd_descriptor( + [num_images, num_vertices], [num_triangles], [depth, height, width] + ) op_name = "jax_rasterize_bwd" @@ -281,8 +357,20 @@ def _rasterize_bwd_lowering(ctx, pos, tri, rast_out, dy, ddb): # The inputs: operands=[pos, tri, rast_out, dy, ddb], backend_config=opaque, - operand_layouts=default_layouts(pos_aval.shape, tri_aval.shape, rast_aval.shape, dy_aval.shape, ddb_aval.shape), - result_layouts=default_layouts((num_images, num_vertices, 4,)), + operand_layouts=default_layouts( + pos_aval.shape, + tri_aval.shape, + rast_aval.shape, + dy_aval.shape, + ddb_aval.shape, + ), + result_layouts=default_layouts( + ( + num_images, + num_vertices, + 4, + ) + ), ).results # ********************************************* @@ -305,9 +393,15 @@ def _rasterize_bwd_lowering(ctx, pos, tri, rast_out, dy, ddb): #### FORWARD #### + # @functools.partial(jax.jit, static_argnums=(0,)) -def _interpolate_fwd_custom_call(r: "Renderer", attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs): - return _build_interpolate_fwd_primitive(r).bind(attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs) +def _interpolate_fwd_custom_call( + r: "Renderer", attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs +): + return _build_interpolate_fwd_primitive(r).bind( + attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs + ) + # @functools.lru_cache(maxsize=None) def _build_interpolate_fwd_primitive(r: "Renderer"): @@ -315,9 +409,13 @@ def _build_interpolate_fwd_primitive(r: "Renderer"): # For JIT compilation we need a function to evaluate the shape and dtype of the # outputs of our op for some given inputs - def _interpolate_fwd_abstract(attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs): - if (len(attr.shape) != 3): - raise ValueError(f"Pass in a [num_images, num_vertices, num_attributes] sized first input") + def _interpolate_fwd_abstract( + attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs + ): + if len(attr.shape) != 3: + raise ValueError( + "Pass in a [num_images, num_vertices, num_attributes] sized first input" + ) num_images, num_vertices, num_attributes = attr.shape _, height, width, _ = rast_out.shape num_tri, _ = tri.shape @@ -326,20 +424,37 @@ def _interpolate_fwd_abstract(attr, rast_out, tri, rast_db, diff_attrs_all, diff dtype = dtypes.canonicalize_dtype(attr.dtype) out_abstract = ShapedArray((num_images, height, width, num_attributes), dtype) - out_db_abstract = ShapedArray((num_images, height, width, 2*num_diff_attrs), dtype) # empty tensor - return [out_abstract, out_db_abstract] - + out_db_abstract = ShapedArray( + (num_images, height, width, 2 * num_diff_attrs), dtype + ) # empty tensor + return [out_abstract, out_db_abstract] + # Provide an MLIR "lowering" of the interpolate primitive. - def _interpolate_fwd_lowering(ctx, attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs): + def _interpolate_fwd_lowering( + ctx, attr, rast_out, tri, rast_db, diff_attrs_all, diff_attrs + ): # Extract the numpy type of the inputs - attr_aval, rast_out_aval, tri_aval, rast_db_aval, _, diff_attr_aval = ctx.avals_in + ( + attr_aval, + rast_out_aval, + tri_aval, + rast_db_aval, + _, + diff_attr_aval, + ) = ctx.avals_in if attr_aval.ndim != 3: - raise NotImplementedError(f"Only 3D attribute inputs supported: got {attr_aval.shape}") + raise NotImplementedError( + f"Only 3D attribute inputs supported: got {attr_aval.shape}" + ) if rast_out_aval.ndim != 4: - raise NotImplementedError(f"Only 4D rast inputs supported: got {rast_out_aval.shape}") + raise NotImplementedError( + f"Only 4D rast inputs supported: got {rast_out_aval.shape}" + ) if tri_aval.ndim != 2: - raise NotImplementedError(f"Only 2D triangle tensors supported: got {tri_aval.shape}") + raise NotImplementedError( + f"Only 2D triangle tensors supported: got {tri_aval.shape}" + ) np_dtype = np.dtype(attr_aval.dtype) if np_dtype != np.float32: @@ -347,7 +462,9 @@ def _interpolate_fwd_lowering(ctx, attr, rast_out, tri, rast_db, diff_attrs_all, if np.dtype(tri_aval.dtype) != np.int32: raise NotImplementedError(f"Unsupported triangle dtype {tri_aval.dtype}") if np.dtype(diff_attr_aval.dtype) != np.int32: - raise NotImplementedError(f"Unsupported diff attribute dtype {diff_attr_aval.dtype}") + raise NotImplementedError( + f"Unsupported diff attribute dtype {diff_attr_aval.dtype}" + ) num_images, num_vertices, num_attributes = attr_aval.shape depth, height, width = rast_out_aval.shape[:3] @@ -355,21 +472,24 @@ def _interpolate_fwd_lowering(ctx, attr, rast_out, tri, rast_db, diff_attrs_all, num_diff_attrs = diff_attr_aval.shape[0] if num_diff_attrs > 0 and rast_db_aval.shape[-1] < num_diff_attrs: - raise NotImplementedError(f"Attempt to propagate bary gradients through {num_diff_attrs} attributes: got {rast_db_aval.shape}") + raise NotImplementedError( + f"Attempt to propagate bary gradients through {num_diff_attrs} attributes: got {rast_db_aval.shape}" + ) out_shp_dtype = mlir.ir.RankedTensorType.get( - [num_images, height, width, num_attributes], - mlir.dtype_to_ir_type(np_dtype)) + [num_images, height, width, num_attributes], mlir.dtype_to_ir_type(np_dtype) + ) out_db_shp_dtype = mlir.ir.RankedTensorType.get( - [num_images, height, width, 2*num_diff_attrs], - mlir.dtype_to_ir_type(np_dtype)) + [num_images, height, width, 2 * num_diff_attrs], + mlir.dtype_to_ir_type(np_dtype), + ) opaque = dr._get_plugin(gl=True).build_diff_interpolate_descriptor( - [num_images, num_vertices, num_attributes], - [depth, height, width], - [num_triangles], - num_diff_attrs # diff wrt all attributes (TODO) - ) + [num_images, num_vertices, num_attributes], + [depth, height, width], + [num_triangles], + num_diff_attrs, # diff wrt all attributes (TODO) + ) op_name = "jax_interpolate_fwd" @@ -380,8 +500,27 @@ def _interpolate_fwd_lowering(ctx, attr, rast_out, tri, rast_db, diff_attrs_all, # The inputs: operands=[attr, rast_out, tri, rast_db, diff_attrs], backend_config=opaque, - operand_layouts=default_layouts(attr_aval.shape, rast_out_aval.shape, tri_aval.shape, rast_db_aval.shape, diff_attr_aval.shape), - result_layouts=default_layouts((num_images, height, width, num_attributes,), (num_images, height, width, num_attributes,)), + operand_layouts=default_layouts( + attr_aval.shape, + rast_out_aval.shape, + tri_aval.shape, + rast_db_aval.shape, + diff_attr_aval.shape, + ), + result_layouts=default_layouts( + ( + num_images, + height, + width, + num_attributes, + ), + ( + num_images, + height, + width, + num_attributes, + ), + ), ).results # ********************************************* @@ -389,7 +528,9 @@ def _interpolate_fwd_lowering(ctx, attr, rast_out, tri, rast_db, diff_attrs_all, # ********************************************* _interpolate_prim = core.Primitive(f"interpolate_multiple_fwd_{id(r)}") _interpolate_prim.multiple_results = True - _interpolate_prim.def_impl(functools.partial(xla.apply_primitive, _interpolate_prim)) + _interpolate_prim.def_impl( + functools.partial(xla.apply_primitive, _interpolate_prim) + ) _interpolate_prim.def_abstract_eval(_interpolate_fwd_abstract) # # Connect the XLA translation rules for JIT compilation @@ -398,12 +539,25 @@ def _interpolate_fwd_lowering(ctx, attr, rast_out, tri, rast_db, diff_attrs_all, return _interpolate_prim - #### BACKWARD #### + # @functools.partial(jax.jit, static_argnums=(0,)) -def _interpolate_bwd_custom_call(r: "Renderer", attr, rast_out, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list): - return _build_interpolate_bwd_primitive(r).bind(attr, rast_out, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list) +def _interpolate_bwd_custom_call( + r: "Renderer", + attr, + rast_out, + tri, + dy, + rast_db, + dda, + diff_attrs_all, + diff_attrs_list, +): + return _build_interpolate_bwd_primitive(r).bind( + attr, rast_out, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list + ) + # @functools.lru_cache(maxsize=None) def _build_interpolate_bwd_primitive(r: "Renderer"): @@ -411,9 +565,13 @@ def _build_interpolate_bwd_primitive(r: "Renderer"): # For JIT compilation we need a function to evaluate the shape and dtype of the # outputs of our op for some given inputs - def _interpolate_bwd_abstract(attr, rast_out, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list): - if (len(attr.shape) != 3): - raise ValueError(f"Pass in a [num_images, num_vertices, num_attributes] sized first input") + def _interpolate_bwd_abstract( + attr, rast_out, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list + ): + if len(attr.shape) != 3: + raise ValueError( + "Pass in a [num_images, num_vertices, num_attributes] sized first input" + ) num_images, num_vertices, num_attributes = attr.shape depth, height, width, rast_channels = rast_out.shape depth_db, height_db, width_db, rast_channels_db = rast_db.shape @@ -421,21 +579,40 @@ def _interpolate_bwd_abstract(attr, rast_out, tri, dy, rast_db, dda, diff_attrs_ dtype = dtypes.canonicalize_dtype(attr.dtype) g_attr_abstract = ShapedArray((num_images, num_vertices, num_attributes), dtype) - g_rast_abstract = ShapedArray((depth, height, width, rast_channels), dtype) - g_rast_db_abstract = ShapedArray((depth_db, height_db, width_db, rast_channels_db), dtype) - return [g_attr_abstract, g_rast_abstract, g_rast_db_abstract] - + g_rast_abstract = ShapedArray((depth, height, width, rast_channels), dtype) + g_rast_db_abstract = ShapedArray( + (depth_db, height_db, width_db, rast_channels_db), dtype + ) + return [g_attr_abstract, g_rast_abstract, g_rast_db_abstract] + # Provide an MLIR "lowering" of the interpolate primitive. - def _interpolate_bwd_lowering(ctx, attr, rast_out, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list): + def _interpolate_bwd_lowering( + ctx, attr, rast_out, tri, dy, rast_db, dda, diff_attrs_all, diff_attrs_list + ): # Extract the numpy type of the inputs - attr_aval, rast_out_aval, tri_aval, dy_aval, rast_db_aval, dda_aval, _, diff_attr_aval = ctx.avals_in + ( + attr_aval, + rast_out_aval, + tri_aval, + dy_aval, + rast_db_aval, + dda_aval, + _, + diff_attr_aval, + ) = ctx.avals_in if attr_aval.ndim != 3: - raise NotImplementedError(f"Only 3D attribute inputs supported: got {attr_aval.shape}") + raise NotImplementedError( + f"Only 3D attribute inputs supported: got {attr_aval.shape}" + ) if rast_out_aval.ndim != 4: - raise NotImplementedError(f"Only 4D rast inputs supported: got {rast_out_aval.shape}") + raise NotImplementedError( + f"Only 4D rast inputs supported: got {rast_out_aval.shape}" + ) if tri_aval.ndim != 2: - raise NotImplementedError(f"Only 2D triangle tensors supported: got {tri_aval.shape}") + raise NotImplementedError( + f"Only 2D triangle tensors supported: got {tri_aval.shape}" + ) np_dtype = np.dtype(attr_aval.dtype) if np_dtype != np.float32: @@ -450,21 +627,22 @@ def _interpolate_bwd_lowering(ctx, attr, rast_out, tri, dy, rast_db, dda, diff_a num_diff_attrs = diff_attr_aval.shape[0] g_attr_shp_dtype = mlir.ir.RankedTensorType.get( - [num_images, num_vertices, num_attributes], - mlir.dtype_to_ir_type(np_dtype)) + [num_images, num_vertices, num_attributes], mlir.dtype_to_ir_type(np_dtype) + ) g_rast_shp_dtype = mlir.ir.RankedTensorType.get( - [depth, height, width, rast_channels], - mlir.dtype_to_ir_type(np_dtype)) + [depth, height, width, rast_channels], mlir.dtype_to_ir_type(np_dtype) + ) g_rast_db_shp_dtype = mlir.ir.RankedTensorType.get( [depth_db, height_db, width_db, rast_channels_db], - mlir.dtype_to_ir_type(np_dtype)) + mlir.dtype_to_ir_type(np_dtype), + ) opaque = dr._get_plugin(gl=True).build_diff_interpolate_descriptor( - [num_images, num_vertices, num_attributes], - [depth, height, width], - [num_triangles], - num_diff_attrs - ) + [num_images, num_vertices, num_attributes], + [depth, height, width], + [num_triangles], + num_diff_attrs, + ) op_name = "jax_interpolate_bwd" @@ -475,8 +653,34 @@ def _interpolate_bwd_lowering(ctx, attr, rast_out, tri, dy, rast_db, dda, diff_a # The inputs: operands=[attr, rast_out, tri, dy, rast_db, dda, diff_attrs_list], backend_config=opaque, - operand_layouts=default_layouts(attr_aval.shape, rast_out_aval.shape, tri_aval.shape, dy_aval.shape, rast_db_aval.shape, dda_aval.shape, diff_attr_aval.shape), - result_layouts=default_layouts((num_images, num_vertices, num_attributes,), (depth, height, width, rast_channels,), (depth_db, height_db, width_db, rast_channels_db,)), + operand_layouts=default_layouts( + attr_aval.shape, + rast_out_aval.shape, + tri_aval.shape, + dy_aval.shape, + rast_db_aval.shape, + dda_aval.shape, + diff_attr_aval.shape, + ), + result_layouts=default_layouts( + ( + num_images, + num_vertices, + num_attributes, + ), + ( + depth, + height, + width, + rast_channels, + ), + ( + depth_db, + height_db, + width_db, + rast_channels_db, + ), + ), ).results # ********************************************* @@ -484,10 +688,12 @@ def _interpolate_bwd_lowering(ctx, attr, rast_out, tri, dy, rast_db, dda, diff_a # ********************************************* _interpolate_prim = core.Primitive(f"interpolate_multiple_bwd_{id(r)}") _interpolate_prim.multiple_results = True - _interpolate_prim.def_impl(functools.partial(xla.apply_primitive, _interpolate_prim)) + _interpolate_prim.def_impl( + functools.partial(xla.apply_primitive, _interpolate_prim) + ) _interpolate_prim.def_abstract_eval(_interpolate_bwd_abstract) # # Connect the XLA translation rules for JIT compilation mlir.register_lowering(_interpolate_prim, _interpolate_bwd_lowering, platform="gpu") - return _interpolate_prim \ No newline at end of file + return _interpolate_prim diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/__init__.py b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/__init__.py index 000b85ba..68394fff 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/__init__.py +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/__init__.py @@ -6,5 +6,6 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -from .ops import RasterizeGLContext, get_log_level, set_log_level, _get_plugin +from .ops import RasterizeGLContext, _get_plugin, get_log_level, set_log_level + __all__ = ["RasterizeGLContext", "get_log_level", "set_log_level", "_get_plugin"] diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/ops.py b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/ops.py index de3eb1b8..fce6cb23 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/ops.py +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/ops.py @@ -8,15 +8,17 @@ import importlib import logging -import numpy as np import os + import torch import torch.utils.cpp_extension -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # C++/Cuda plugin compiler/loader. _cached_plugin = {} + + def _get_plugin(gl=False): assert isinstance(gl, bool) @@ -25,14 +27,24 @@ def _get_plugin(gl=False): return _cached_plugin[gl] # Make sure we can find the necessary compiler and libary binaries. - if os.name == 'nt': + if os.name == "nt": lib_dir = os.path.dirname(__file__) + r"\..\lib" + def find_cl_path(): import glob - for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: - vs_relative_path = r"\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition - paths = sorted(glob.glob(r"C:\Program Files" + vs_relative_path), reverse=True) - paths += sorted(glob.glob(r"C:\Program Files (x86)" + vs_relative_path), reverse=True) + + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + vs_relative_path = ( + r"\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" + % edition + ) + paths = sorted( + glob.glob(r"C:\Program Files" + vs_relative_path), reverse=True + ) + paths += sorted( + glob.glob(r"C:\Program Files (x86)" + vs_relative_path), + reverse=True, + ) if paths: return paths[0] @@ -40,95 +52,117 @@ def find_cl_path(): if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: - raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") - os.environ['PATH'] += ';' + cl_path + raise RuntimeError( + "Could not locate a supported Microsoft Visual C++ installation" + ) + os.environ["PATH"] += ";" + cl_path # Compiler options. - opts = ['-DNVDR_TORCH'] + opts = ["-DNVDR_TORCH"] # Linker options for the GL-interfacing plugin. ldflags = [] if gl: - if os.name == 'posix': - ldflags = ['-lGL', '-lEGL'] - elif os.name == 'nt': - libs = ['gdi32', 'opengl32', 'user32', 'setgpu'] - ldflags = ['/LIBPATH:' + lib_dir] + ['/DEFAULTLIB:' + x for x in libs] + if os.name == "posix": + ldflags = ["-lGL", "-lEGL"] + elif os.name == "nt": + libs = ["gdi32", "opengl32", "user32", "setgpu"] + ldflags = ["/LIBPATH:" + lib_dir] + ["/DEFAULTLIB:" + x for x in libs] # List of source files. if gl: source_files = [ - 'common.cpp', - 'glutil.cpp', - 'rasterize_gl.cpp', + "common.cpp", + "glutil.cpp", + "rasterize_gl.cpp", ] else: source_files = [ - '../common/common.cpp', - '../common/rasterize.cu', - '../common/interpolate.cu', - '../common/texture.cu', - '../common/texture.cpp', - '../common/antialias.cu', - 'torch_bindings.cpp', - 'torch_rasterize.cpp', - 'torch_interpolate.cpp', - 'torch_texture.cpp', - 'torch_antialias.cpp', + "../common/common.cpp", + "../common/rasterize.cu", + "../common/interpolate.cu", + "../common/texture.cu", + "../common/texture.cpp", + "../common/antialias.cu", + "torch_bindings.cpp", + "torch_rasterize.cpp", + "torch_interpolate.cpp", + "torch_texture.cpp", + "torch_antialias.cpp", ] # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. - os.environ['TORCH_CUDA_ARCH_LIST'] = '' + os.environ["TORCH_CUDA_ARCH_LIST"] = "" # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin. - if gl and (os.name == 'posix') and ('libGLEW' in os.environ.get('LD_PRELOAD', '')): - logging.getLogger('nvdiffrast').warning("Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin") + if gl and (os.name == "posix") and ("libGLEW" in os.environ.get("LD_PRELOAD", "")): + logging.getLogger("nvdiffrast").warning( + "Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin" + ) # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. - plugin_name = 'nvdiffrast_plugin' + ('_gl' if gl else '') + plugin_name = "nvdiffrast_plugin" + ("_gl" if gl else "") try: - lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory(plugin_name, False), 'lock') + lock_fn = os.path.join( + torch.utils.cpp_extension._get_build_directory(plugin_name, False), "lock" + ) if os.path.exists(lock_fn): - logging.getLogger('nvdiffrast').warning("Lock file exists in build directory: '%s'" % lock_fn) - except: + logging.getLogger("nvdiffrast").warning( + "Lock file exists in build directory: '%s'" % lock_fn + ) + except Exception: pass # Speed up compilation on Windows. - if os.name == 'nt': + if os.name == "nt": # Skip telemetry sending step in vcvarsall.bat - os.environ['VSCMD_SKIP_SENDTELEMETRY'] = '1' + os.environ["VSCMD_SKIP_SENDTELEMETRY"] = "1" # Opportunistically patch distutils to cache MSVC environments. try: import distutils._msvccompiler import functools - if not hasattr(distutils._msvccompiler._get_vc_env, '__wrapped__'): - distutils._msvccompiler._get_vc_env = functools.lru_cache()(distutils._msvccompiler._get_vc_env) - except: + + if not hasattr(distutils._msvccompiler._get_vc_env, "__wrapped__"): + distutils._msvccompiler._get_vc_env = functools.lru_cache()( + distutils._msvccompiler._get_vc_env + ) + except Exception: pass # Compile and load. source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] - torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=opts, extra_cuda_cflags=opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False) + torch.utils.cpp_extension.load( + name=plugin_name, + sources=source_paths, + extra_cflags=opts, + extra_cuda_cflags=opts + ["-lineinfo"], + extra_ldflags=ldflags, + with_cuda=True, + verbose=False, + ) # Import, cache, and return the compiled module. _cached_plugin[gl] = importlib.import_module(plugin_name) return _cached_plugin[gl] -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Log level. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def get_log_level(): - '''Get current log level. + """Get current log level. Returns: Current log level in nvdiffrast. See `set_log_level()` for possible values. - ''' + """ return _get_plugin().get_log_level() + def set_log_level(level): - '''Set log level. + """Set log level. Log levels follow the convention on the C++ side of Torch: 0 = Info, @@ -138,19 +172,21 @@ def set_log_level(level): The default log level is 1. Args: - level: New log level as integer. Internal nvdiffrast messages of this + level: New log level as integer. Internal nvdiffrast messages of this severity or higher will be printed, while messages of lower severity will be silent. - ''' + """ _get_plugin().set_log_level(level) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # GL state wrapper. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class RasterizeGLContext: - def __init__(self, height, width, output_db=False, mode='automatic', device=None): - '''Create a new OpenGL rasterizer context. + def __init__(self, height, width, output_db=False, mode="automatic", device=None): + """Create a new OpenGL rasterizer context. Creating an OpenGL context is a slow operation so you should usually reuse the same context in all calls to `rasterize()` on the same CPU thread. The OpenGL context @@ -174,9 +210,9 @@ def __init__(self, height, width, output_db=False, mode='automatic', device=None device. Returns: The newly created OpenGL rasterizer context. - ''' + """ assert output_db is True or output_db is False - assert mode in ['automatic', 'manual'] + assert mode in ["automatic", "manual"] self.output_db = output_db self.mode = mode if device is None: @@ -184,19 +220,21 @@ def __init__(self, height, width, output_db=False, mode='automatic', device=None else: with torch.cuda.device(device): cuda_device_idx = torch.cuda.current_device() - self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(output_db, mode == 'automatic', cuda_device_idx) - self.active_depth_peeler = None # For error checking only. + self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper( + output_db, mode == "automatic", cuda_device_idx + ) + self.active_depth_peeler = None # For error checking only. def set_context(self): - '''Set (activate) OpenGL context in the current CPU thread. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Set (activate) OpenGL context in the current CPU thread. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.set_context() def release_context(self): - '''Release (deactivate) currently active OpenGL context. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Release (deactivate) currently active OpenGL context. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.release_context() diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/__init__.py b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/__init__.py index 000b85ba..68394fff 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/__init__.py +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/__init__.py @@ -6,5 +6,6 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -from .ops import RasterizeGLContext, get_log_level, set_log_level, _get_plugin +from .ops import RasterizeGLContext, _get_plugin, get_log_level, set_log_level + __all__ = ["RasterizeGLContext", "get_log_level", "set_log_level", "_get_plugin"] diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/ops.py b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/ops.py index 170849e2..f57923be 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/ops.py +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/ops.py @@ -8,15 +8,17 @@ import importlib import logging -import numpy as np import os + import torch import torch.utils.cpp_extension -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # C++/Cuda plugin compiler/loader. _cached_plugin = {} + + def _get_plugin(gl=False): assert isinstance(gl, bool) @@ -25,14 +27,24 @@ def _get_plugin(gl=False): return _cached_plugin[gl] # Make sure we can find the necessary compiler and libary binaries. - if os.name == 'nt': + if os.name == "nt": lib_dir = os.path.dirname(__file__) + r"\..\lib" + def find_cl_path(): import glob - for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: - vs_relative_path = r"\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition - paths = sorted(glob.glob(r"C:\Program Files" + vs_relative_path), reverse=True) - paths += sorted(glob.glob(r"C:\Program Files (x86)" + vs_relative_path), reverse=True) + + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + vs_relative_path = ( + r"\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" + % edition + ) + paths = sorted( + glob.glob(r"C:\Program Files" + vs_relative_path), reverse=True + ) + paths += sorted( + glob.glob(r"C:\Program Files (x86)" + vs_relative_path), + reverse=True, + ) if paths: return paths[0] @@ -40,88 +52,110 @@ def find_cl_path(): if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: - raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") - os.environ['PATH'] += ';' + cl_path + raise RuntimeError( + "Could not locate a supported Microsoft Visual C++ installation" + ) + os.environ["PATH"] += ";" + cl_path # Compiler options. - opts = ['-DNVDR_TORCH'] + opts = ["-DNVDR_TORCH"] # Linker options for the GL-interfacing plugin. ldflags = [] if gl: - if os.name == 'posix': - ldflags = ['-lGL', '-lEGL'] - elif os.name == 'nt': - libs = ['gdi32', 'opengl32', 'user32', 'setgpu'] - ldflags = ['/LIBPATH:' + lib_dir] + ['/DEFAULTLIB:' + x for x in libs] + if os.name == "posix": + ldflags = ["-lGL", "-lEGL"] + elif os.name == "nt": + libs = ["gdi32", "opengl32", "user32", "setgpu"] + ldflags = ["/LIBPATH:" + lib_dir] + ["/DEFAULTLIB:" + x for x in libs] # List of source files. if gl: source_files = [ - '../common/common.cpp', - '../common/rasterize.cu', - '../common/interpolate.cu', - '../common/glutil.cpp', - '../common/rasterize_gl.cpp', - 'jax_bindings.cpp', - 'jax_rasterize_gl.cpp', - 'jax_interpolate.cpp', + "../common/common.cpp", + "../common/rasterize.cu", + "../common/interpolate.cu", + "../common/glutil.cpp", + "../common/rasterize_gl.cpp", + "jax_bindings.cpp", + "jax_rasterize_gl.cpp", + "jax_interpolate.cpp", ] else: source_files = [] # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. - os.environ['TORCH_CUDA_ARCH_LIST'] = '' + os.environ["TORCH_CUDA_ARCH_LIST"] = "" # On Linux, show a warning if GLEW is being forcibly loaded when compiling the GL plugin. - if gl and (os.name == 'posix') and ('libGLEW' in os.environ.get('LD_PRELOAD', '')): - logging.getLogger('nvdiffrast').warning("Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin") + if gl and (os.name == "posix") and ("libGLEW" in os.environ.get("LD_PRELOAD", "")): + logging.getLogger("nvdiffrast").warning( + "Warning: libGLEW is being loaded via LD_PRELOAD, and will probably conflict with the OpenGL plugin" + ) # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. - plugin_name = 'nvdiffrast_plugin_differentiable' + ('_gl' if gl else '') + plugin_name = "nvdiffrast_plugin_differentiable" + ("_gl" if gl else "") try: - lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory(plugin_name, False), 'lock') + lock_fn = os.path.join( + torch.utils.cpp_extension._get_build_directory(plugin_name, False), "lock" + ) if os.path.exists(lock_fn): - logging.getLogger('nvdiffrast').warning("Lock file exists in build directory: '%s'" % lock_fn) - except: + logging.getLogger("nvdiffrast").warning( + "Lock file exists in build directory: '%s'" % lock_fn + ) + except Exception: pass # Speed up compilation on Windows. - if os.name == 'nt': + if os.name == "nt": # Skip telemetry sending step in vcvarsall.bat - os.environ['VSCMD_SKIP_SENDTELEMETRY'] = '1' + os.environ["VSCMD_SKIP_SENDTELEMETRY"] = "1" # Opportunistically patch distutils to cache MSVC environments. try: import distutils._msvccompiler import functools - if not hasattr(distutils._msvccompiler._get_vc_env, '__wrapped__'): - distutils._msvccompiler._get_vc_env = functools.lru_cache()(distutils._msvccompiler._get_vc_env) - except: + + if not hasattr(distutils._msvccompiler._get_vc_env, "__wrapped__"): + distutils._msvccompiler._get_vc_env = functools.lru_cache()( + distutils._msvccompiler._get_vc_env + ) + except Exception: pass # Compile and load. source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] - torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=opts, extra_cuda_cflags=opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False) + torch.utils.cpp_extension.load( + name=plugin_name, + sources=source_paths, + extra_cflags=opts, + extra_cuda_cflags=opts + ["-lineinfo"], + extra_ldflags=ldflags, + with_cuda=True, + verbose=False, + ) # Import, cache, and return the compiled module. _cached_plugin[gl] = importlib.import_module(plugin_name) return _cached_plugin[gl] -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Log level. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def get_log_level(): - '''Get current log level. + """Get current log level. Returns: Current log level in nvdiffrast. See `set_log_level()` for possible values. - ''' + """ return _get_plugin().get_log_level() + def set_log_level(level): - '''Set log level. + """Set log level. Log levels follow the convention on the C++ side of Torch: 0 = Info, @@ -131,21 +165,21 @@ def set_log_level(level): The default log level is 1. Args: - level: New log level as integer. Internal nvdiffrast messages of this + level: New log level as integer. Internal nvdiffrast messages of this severity or higher will be printed, while messages of lower severity will be silent. - ''' + """ _get_plugin().set_log_level(level) - -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # GL state wrapper. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + class RasterizeGLContext: - def __init__(self, output_db=True, mode='automatic', device=None): - '''Create a new OpenGL rasterizer context. + def __init__(self, output_db=True, mode="automatic", device=None): + """Create a new OpenGL rasterizer context. Creating an OpenGL context is a slow operation so you should usually reuse the same context in all calls to `rasterize()` on the same CPU thread. The OpenGL context @@ -169,9 +203,9 @@ def __init__(self, output_db=True, mode='automatic', device=None): device. Returns: The newly created OpenGL rasterizer context. - ''' + """ assert output_db is True or output_db is False - assert mode in ['automatic', 'manual'] + assert mode in ["automatic", "manual"] self.output_db = output_db self.mode = mode if device is None: @@ -179,22 +213,21 @@ def __init__(self, output_db=True, mode='automatic', device=None): else: with torch.cuda.device(device): cuda_device_idx = torch.cuda.current_device() - self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper(output_db, mode == 'automatic', cuda_device_idx) - self.active_depth_peeler = None # For error checking only. + self.cpp_wrapper = _get_plugin(gl=True).RasterizeGLStateWrapper( + output_db, mode == "automatic", cuda_device_idx + ) + self.active_depth_peeler = None # For error checking only. def set_context(self): - '''Set (activate) OpenGL context in the current CPU thread. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Set (activate) OpenGL context in the current CPU thread. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.set_context() def release_context(self): - '''Release (deactivate) currently active OpenGL context. - Only available if context was created in manual mode. - ''' - assert self.mode == 'manual' + """Release (deactivate) currently active OpenGL context. + Only available if context was created in manual mode. + """ + assert self.mode == "manual" self.cpp_wrapper.release_context() - - - \ No newline at end of file diff --git a/bayes3d/rendering/nvdiffrast_jax/renderer_matching_pytorch.py b/bayes3d/rendering/nvdiffrast_jax/renderer_matching_pytorch.py index bf3c03c2..ce9fd5ca 100644 --- a/bayes3d/rendering/nvdiffrast_jax/renderer_matching_pytorch.py +++ b/bayes3d/rendering/nvdiffrast_jax/renderer_matching_pytorch.py @@ -1,42 +1,49 @@ +import argparse +import os +import time from collections import namedtuple + import jax import jax.numpy as jnp import numpy as np -import os, argparse -import time import torch -import bayes3d as b from jax_renderer import Renderer as JaxRenderer -# import nvdiffrast.torch as dr -import bayes3d.rendering.nvdiffrast_full.nvdiffrast.torch as dr # modified nvdiffrast to expose backward fn call api +import bayes3d as b + +# import nvdiffrast.torch as dr +import bayes3d.rendering.nvdiffrast_full.nvdiffrast.torch as dr # modified nvdiffrast to expose backward fn call api # -------------- # Which renderer to test # -------------- -parser=argparse.ArgumentParser() -parser.add_argument('TEST_NAME', type=str, help="jax or torch", default="jax") +parser = argparse.ArgumentParser() +parser.add_argument("TEST_NAME", type=str, help="jax or torch", default="jax") args = parser.parse_args() -JAX_RENDERER = args.TEST_NAME == 'jax' +JAX_RENDERER = args.TEST_NAME == "jax" print(f"Testing JAX: {JAX_RENDERER}") -#-------------------- -# Setup renderers -#-------------------- +# -------------------- +# Setup renderers +# -------------------- intrinsics = b.Intrinsics( - height=200, - width=200, - fx=200.0, fy=200.0, - cx=100.0, cy=100.0, - near=0.01, far=5.5 + height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.01, far=5.5 ) -proj_cam = torch.from_numpy(np.array(b.camera._open_gl_projection_matrix( - intrinsics.height, intrinsics.width, - intrinsics.fx, intrinsics.fy, - intrinsics.cx, intrinsics.cy, - intrinsics.near, intrinsics.far -))).cuda() +proj_cam = torch.from_numpy( + np.array( + b.camera._open_gl_projection_matrix( + intrinsics.height, + intrinsics.width, + intrinsics.fx, + intrinsics.fy, + intrinsics.cx, + intrinsics.cy, + intrinsics.near, + intrinsics.far, + ) + ) +).cuda() if JAX_RENDERER: # setup Jax renderer @@ -46,25 +53,25 @@ torch_glctx = dr.RasterizeGLContext() -#--------------------- +# --------------------- # Load object -#--------------------- -model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models") +# --------------------- +model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") idx = 14 -mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply") +mesh_path = os.path.join(model_dir, "obj_" + "{}".format(idx).rjust(6, "0") + ".ply") m = b.utils.load_mesh(mesh_path) -m = b.utils.scale_mesh(m, 1.0/100.0) +m = b.utils.scale_mesh(m, 1.0 / 100.0) vtx_pos = torch.from_numpy(m.vertices.astype(np.float32)).cuda() pos_idx = torch.from_numpy(m.faces.astype(np.int32)).cuda() -col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0],3)).astype(np.int32)).cuda() -vtx_col = torch.from_numpy(np.ones((1,3)).astype(np.float32)).cuda() +col_idx = torch.from_numpy(np.zeros((vtx_pos.shape[0], 3)).astype(np.int32)).cuda() +vtx_col = torch.from_numpy(np.ones((1, 3)).astype(np.float32)).cuda() print("Mesh has %d triangles and %d vertices." % (pos_idx.shape[0], vtx_pos.shape[0])) -#-------------------- +# -------------------- # transform points op -#-------------------- +# -------------------- def xfm_points(points, matrix): """Transform points. Args: @@ -80,39 +87,48 @@ def xfm_points(points, matrix): ) return out -#---------------------- + +# ---------------------- # Get clip-space poses (torch) -#---------------------- -rot_mtx_44 = torch.tensor([[-0.9513, 0.1699, 0.2573, 0.0000], - [-0.2436, 0.0976, -0.9650, 0.0000], - [-0.1890, -0.9806, -0.0514, 2.5000], - [ 0.0000, 0.0000, 0.0000, 1.0000]],).cuda() -pos = vtx_pos[None,...] -posw = torch.cat([pos, torch.ones([pos.shape[0], pos.shape[1], 1]).cuda()], axis=2) # (xyz) -> (xyz1) -transform_mtx = torch.matmul(proj_cam, rot_mtx_44) # transform = projection + pose rotation -pos_clip_ja = xfm_points(pos, transform_mtx[None,...]) # transform points - -resolution = [200,200] +# ---------------------- +rot_mtx_44 = torch.tensor( + [ + [-0.9513, 0.1699, 0.2573, 0.0000], + [-0.2436, 0.0976, -0.9650, 0.0000], + [-0.1890, -0.9806, -0.0514, 2.5000], + [0.0000, 0.0000, 0.0000, 1.0000], + ], +).cuda() +pos = vtx_pos[None, ...] +posw = torch.cat( + [pos, torch.ones([pos.shape[0], pos.shape[1], 1]).cuda()], axis=2 +) # (xyz) -> (xyz1) +transform_mtx = torch.matmul( + proj_cam, rot_mtx_44 +) # transform = projection + pose rotation +pos_clip_ja = xfm_points(pos, transform_mtx[None, ...]) # transform points + +resolution = [200, 200] rast_out_shp = (len(pos_clip_ja), resolution[0], resolution[1], 4) -#--------------------- +# --------------------- # Test setup -#--------------------- +# --------------------- # randomly create manual gradient inputs KEY1, KEY2 = jax.random.split(jax.random.PRNGKey(0), num=2) -dummy_dy = jax.random.uniform(KEY1, rast_out_shp) -dummy_ddb = jax.random.uniform(KEY2, rast_out_shp) -if not JAX_RENDERER: +dummy_dy = jax.random.uniform(KEY1, rast_out_shp) +dummy_ddb = jax.random.uniform(KEY2, rast_out_shp) +if not JAX_RENDERER: dummy_dy = torch.from_numpy(np.array(dummy_dy)).cuda() dummy_ddb = torch.from_numpy(np.array(dummy_ddb)).cuda() # context variable for torch autograd testing with manual gradient input -TorchCtx = namedtuple('TorchCtx', ['saved_tensors', 'saved_grad_db', 'saved_misc']) +TorchCtx = namedtuple("TorchCtx", ["saved_tensors", "saved_grad_db", "saved_misc"]) -#---------------------- +# ---------------------- # TEST 1 : Rasterize -#---------------------- +# ---------------------- if JAX_RENDERER: print("\n\n---------------------TESTING JAX RASTERIZE---------------------\n\n") @@ -121,30 +137,51 @@ def xfm_points(points, matrix): resolution = jnp.array(resolution) # jit with dummy input - (rast_out, rast_out_db), rasterize_vjp = jax.vjp(jax_renderer.rasterize, - jnp.zeros_like(pos_clip_ja_jax), - jnp.ones_like(pos_idx_jax), - resolution) + (rast_out, rast_out_db), rasterize_vjp = jax.vjp( + jax_renderer.rasterize, + jnp.zeros_like(pos_clip_ja_jax), + jnp.ones_like(pos_idx_jax), + resolution, + ) # evaluate and time. start_time = time.time() - (rast_out, rast_out_db), rasterize_vjp = jax.vjp(jax_renderer.rasterize, - pos_clip_ja_jax, - pos_idx_jax, - resolution) - pos_grads = rasterize_vjp((dummy_dy, - dummy_ddb))[0] + (rast_out, rast_out_db), rasterize_vjp = jax.vjp( + jax_renderer.rasterize, pos_clip_ja_jax, pos_idx_jax, resolution + ) + pos_grads = rasterize_vjp((dummy_dy, dummy_ddb))[0] end_time = time.time() # print results - print("JAX FWD (sum, min, max):", rast_out.sum().item(), rast_out.min().item(), rast_out.max().item()) - print("JAX FWD (sum for channels):", rast_out[...,0].sum().item(), rast_out[...,1].sum().item(), rast_out[...,2].sum().item(),rast_out[...,3].sum().item(),) - print("JAX FWD grads (sum, min, max):", rast_out_db.sum().item(), rast_out_db.min().item(), rast_out_db.max().item()) - print("JAX BWD (sum, min, max):", pos_grads.sum().item(), pos_grads.min().item(), pos_grads.max().item()) + print( + "JAX FWD (sum, min, max):", + rast_out.sum().item(), + rast_out.min().item(), + rast_out.max().item(), + ) + print( + "JAX FWD (sum for channels):", + rast_out[..., 0].sum().item(), + rast_out[..., 1].sum().item(), + rast_out[..., 2].sum().item(), + rast_out[..., 3].sum().item(), + ) + print( + "JAX FWD grads (sum, min, max):", + rast_out_db.sum().item(), + rast_out_db.min().item(), + rast_out_db.max().item(), + ) + print( + "JAX BWD (sum, min, max):", + pos_grads.sum().item(), + pos_grads.min().item(), + pos_grads.max().item(), + ) print(f"JAX rasterization (eval + grad): {(end_time - start_time)*1000} ms") # save viz - b.viz.get_depth_image(rast_out[0][:,:,2]).save("img_jax.png") + b.viz.get_depth_image(rast_out[0][:, :, 2]).save("img_jax.png") else: print("\n\n---------------------TESTING TORCH RASTERIZE---------------------\n\n") @@ -153,70 +190,136 @@ def xfm_points(points, matrix): # evaluate and time. start_time = time.time() - rast_out, rast_out_db = dr.rasterize(torch_glctx, pos_clip_ja, pos_idx, resolution=resolution) - ctx = TorchCtx(saved_tensors=(pos_clip_ja, pos_idx, rast_out), saved_grad_db=True, saved_misc=None) - pos_grads = dr._rasterize_func.backward(ctx, dummy_dy, dummy_ddb)[1] # 7 outputs; all are None except pos input + rast_out, rast_out_db = dr.rasterize( + torch_glctx, pos_clip_ja, pos_idx, resolution=resolution + ) + ctx = TorchCtx( + saved_tensors=(pos_clip_ja, pos_idx, rast_out), + saved_grad_db=True, + saved_misc=None, + ) + pos_grads = dr._rasterize_func.backward(ctx, dummy_dy, dummy_ddb)[ + 1 + ] # 7 outputs; all are None except pos input end_time = time.time() # print results - print("TORCH FWD (sum, min, max):", rast_out.sum().item(), rast_out.min().item(), rast_out.max().item()) - print("TORCH FWD (sum for channels):", rast_out[...,0].sum().item(), rast_out[...,1].sum().item(), rast_out[...,2].sum().item(),rast_out[...,3].sum().item(),) - print("TORCH FWD grads (sum, min, max):", rast_out_db.sum().item(), rast_out_db.min().item(), rast_out_db.max().item()) - print("TORCH BWD (sum, min, max):", pos_grads.sum().item(), pos_grads.min().item(), pos_grads.max().item()) + print( + "TORCH FWD (sum, min, max):", + rast_out.sum().item(), + rast_out.min().item(), + rast_out.max().item(), + ) + print( + "TORCH FWD (sum for channels):", + rast_out[..., 0].sum().item(), + rast_out[..., 1].sum().item(), + rast_out[..., 2].sum().item(), + rast_out[..., 3].sum().item(), + ) + print( + "TORCH FWD grads (sum, min, max):", + rast_out_db.sum().item(), + rast_out_db.min().item(), + rast_out_db.max().item(), + ) + print( + "TORCH BWD (sum, min, max):", + pos_grads.sum().item(), + pos_grads.min().item(), + pos_grads.max().item(), + ) print(f"Torch rasterization (eval + grad): {(end_time - start_time)*1000} ms") # save viz - b.viz.get_depth_image(jnp.array(rast_out[0][:,:,2].cpu())).save("img_torch.png") + b.viz.get_depth_image(jnp.array(rast_out[0][:, :, 2].cpu())).save("img_torch.png") -#---------------------- +# ---------------------- # TEST: Interpolate -#---------------------- +# ---------------------- if JAX_RENDERER: print("\n\n---------------------TESTING JAX INTERPOLATE---------------------\n\n") posw_jax = jnp.array(posw.cpu()) - all_attributes = jnp.array([0,1,2,3]) + all_attributes = jnp.array([0, 1, 2, 3]) # jit with dummy input - jax_renderer.interpolate(jnp.zeros_like(posw_jax), jnp.zeros_like(rast_out), jnp.ones_like(pos_idx_jax), jnp.ones_like(rast_out_db), all_attributes) + jax_renderer.interpolate( + jnp.zeros_like(posw_jax), + jnp.zeros_like(rast_out), + jnp.ones_like(pos_idx_jax), + jnp.ones_like(rast_out_db), + all_attributes, + ) start_time = time.time() - (gb_pos, dummy), interpolate_vjp = jax.vjp(jax_renderer.interpolate, - posw_jax, - rast_out, - pos_idx_jax, - rast_out_db, - all_attributes) - g_attr, g_rast, _, _, _ = interpolate_vjp((dummy_dy, - dummy)) + (gb_pos, dummy), interpolate_vjp = jax.vjp( + jax_renderer.interpolate, + posw_jax, + rast_out, + pos_idx_jax, + rast_out_db, + all_attributes, + ) + g_attr, g_rast, _, _, _ = interpolate_vjp((dummy_dy, dummy)) end_time = time.time() # print results - print(f"JAX FWD (sum, min, max): {gb_pos.sum().item(), gb_pos.min().item(), gb_pos.max().item()}") - print("JAX FWD (sum for channels):", gb_pos[...,0].sum().item(), gb_pos[...,1].sum().item(), gb_pos[...,2].sum().item(),gb_pos[...,3].sum().item(),) - print(f"JAX BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}") + print( + f"JAX FWD (sum, min, max): {gb_pos.sum().item(), gb_pos.min().item(), gb_pos.max().item()}" + ) + print( + "JAX FWD (sum for channels):", + gb_pos[..., 0].sum().item(), + gb_pos[..., 1].sum().item(), + gb_pos[..., 2].sum().item(), + gb_pos[..., 3].sum().item(), + ) + print( + f"JAX BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}" + ) print(f"JAX interpolation: {(end_time - start_time)*1000} ms") # save viz - b.viz.get_depth_image(gb_pos[0][:,:,2]).save("interpolate_jax.png") + b.viz.get_depth_image(gb_pos[0][:, :, 2]).save("interpolate_jax.png") print("---------------------------------------------------------------\n\n") else: print("\n\n---------------------TESTING TORCH INTERPOLATE---------------------\n\n") start_time = time.time() - gb_pos, dummy = dr.interpolate(posw, rast_out, pos_idx, rast_db = rast_out_db, diff_attrs = "all") - ctx = TorchCtx(saved_tensors=(posw, rast_out, pos_idx, rast_out_db), saved_grad_db=None, saved_misc=(1, [])) - grads = dr._interpolate_func_da.backward(ctx, dummy_dy, dummy) # 6 outputs; all are None except pos input + gb_pos, dummy = dr.interpolate( + posw, rast_out, pos_idx, rast_db=rast_out_db, diff_attrs="all" + ) + ctx = TorchCtx( + saved_tensors=(posw, rast_out, pos_idx, rast_out_db), + saved_grad_db=None, + saved_misc=(1, []), + ) + grads = dr._interpolate_func_da.backward( + ctx, dummy_dy, dummy + ) # 6 outputs; all are None except pos input g_attr, g_rast = grads[0], grads[1] end_time = time.time() # print results - print(f"TORCH FWD (sum, min, max): {gb_pos.sum().item(), gb_pos.min().item(), gb_pos.max().item()}") - print("TORCH FWD (sum for channels):", gb_pos[...,0].sum().item(), gb_pos[...,1].sum().item(), gb_pos[...,2].sum().item(),gb_pos[...,3].sum().item(),) - print(f"TORCH BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}") + print( + f"TORCH FWD (sum, min, max): {gb_pos.sum().item(), gb_pos.min().item(), gb_pos.max().item()}" + ) + print( + "TORCH FWD (sum for channels):", + gb_pos[..., 0].sum().item(), + gb_pos[..., 1].sum().item(), + gb_pos[..., 2].sum().item(), + gb_pos[..., 3].sum().item(), + ) + print( + f"TORCH BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}" + ) print(f"Torch interpolation: {(end_time - start_time)*1000} ms") # save viz - b.viz.get_depth_image(jnp.array(gb_pos[0][:,:,2].cpu())).save("interpolate_torch.png") + b.viz.get_depth_image(jnp.array(gb_pos[0][:, :, 2].cpu())).save( + "interpolate_torch.png" + ) print("---------------------------------------------------------------\n\n") diff --git a/bayes3d/rendering/nvdiffrast_jax/setup.py b/bayes3d/rendering/nvdiffrast_jax/setup.py index 2a19555e..a24aa2f8 100755 --- a/bayes3d/rendering/nvdiffrast_jax/setup.py +++ b/bayes3d/rendering/nvdiffrast_jax/setup.py @@ -6,9 +6,10 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. +import os + import nvdiffrast import setuptools -import os with open("README.md", "r") as fh: long_description = fh.read() @@ -24,31 +25,34 @@ url="https://github.com/NVlabs/nvdiffrast", packages=setuptools.find_packages(), package_data={ - 'nvdiffrast': [ - 'common/*.h', - 'common/*.inl', - 'common/*.cu', - 'common/*.cpp', - 'common/cudaraster/*.hpp', - 'common/cudaraster/impl/*.cpp', - 'common/cudaraster/impl/*.hpp', - 'common/cudaraster/impl/*.inl', - 'common/cudaraster/impl/*.cu', - 'lib/*.h', - 'torch/*.h', - 'torch/*.inl', - 'torch/*.cpp', - 'tensorflow/*.cu', - 'jax/*.h', - 'jax/*.inl', - 'jax/*.cpp', - ] + (['lib/*.lib'] if os.name == 'nt' else []) + "nvdiffrast": [ + "common/*.h", + "common/*.inl", + "common/*.cu", + "common/*.cpp", + "common/cudaraster/*.hpp", + "common/cudaraster/impl/*.cpp", + "common/cudaraster/impl/*.hpp", + "common/cudaraster/impl/*.inl", + "common/cudaraster/impl/*.cu", + "lib/*.h", + "torch/*.h", + "torch/*.inl", + "torch/*.cpp", + "tensorflow/*.cu", + "jax/*.h", + "jax/*.inl", + "jax/*.cpp", + ] + + (["lib/*.lib"] if os.name == "nt" else []) }, include_package_data=True, - install_requires=['numpy'], # note: can't require torch here as it will install torch even for a TensorFlow container + install_requires=[ + "numpy" + ], # note: can't require torch here as it will install torch even for a TensorFlow container classifiers=[ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires=">=3.6", ) diff --git a/bayes3d/rendering/photorealistic_renderers/_kubric_exec_parallel.py b/bayes3d/rendering/photorealistic_renderers/_kubric_exec_parallel.py index dd160cd5..1f6cc3b3 100644 --- a/bayes3d/rendering/photorealistic_renderers/_kubric_exec_parallel.py +++ b/bayes3d/rendering/photorealistic_renderers/_kubric_exec_parallel.py @@ -13,13 +13,14 @@ # limitations under the License. import logging + import kubric as kb import numpy as np from kubric.renderer.blender import Blender as KubricRenderer -from kubric.core.color import get_color -# import kubric.core.color as color -#unpacking the data from the npz file +# import kubric.core.color as color + +# unpacking the data from the npz file data_file = "/tmp/blenderproc_kubric.npz" data = np.load(data_file, allow_pickle=True) mesh_paths = data["mesh_paths"] @@ -42,7 +43,7 @@ logging.basicConfig(level="INFO") -#convert intrinsics to focal_length, sensor_width +# convert intrinsics to focal_length, sensor_width focal_length = float(fx) sensor_width = float(width) @@ -51,19 +52,24 @@ scene = kb.Scene(resolution=(width.item(), height.item())) scene.background = kb.Color(*background_color) renderer = KubricRenderer(scene) - # --- create perspective camera - scene += kb.PerspectiveCamera(name="camera", - position =camera_pose[0],quaternion=camera_pose[1], focal_length=focal_length, sensor_width=sensor_width) - scene += kb.PointLight(name='sun', position=camera_pose[0], intensity=intensity) + # --- create perspective camera + scene += kb.PerspectiveCamera( + name="camera", + position=camera_pose[0], + quaternion=camera_pose[1], + focal_length=focal_length, + sensor_width=sensor_width, + ) + scene += kb.PointLight(name="sun", position=camera_pose[0], intensity=intensity) for obj_number in range(len(poses[scene_number])): mesh_scales = [e * scaling_factor for e in mesh_scales] rng = np.random.default_rng() obj_mat = kb.FlatMaterial(color=kb.Color(*mesh_colors[obj_number])) obj = kb.FileBasedObject( - asset_id=f"1", + asset_id="1", render_filename=mesh_paths[obj_number], - material=obj_mat, + material=obj_mat, simulation_filename=None, scale=mesh_scales[obj_number], position=poses[scene_number][obj_number][0], @@ -73,7 +79,9 @@ scene += obj frame = renderer.render_still() - np.savez(f"/tmp/{scene_number}.npz", rgba=frame["rgba"], segmentation=frame["segmentation"], depth=frame["depth"]) - - - + np.savez( + f"/tmp/{scene_number}.npz", + rgba=frame["rgba"], + segmentation=frame["segmentation"], + depth=frame["depth"], + ) diff --git a/bayes3d/rendering/photorealistic_renderers/kubric_interface.py b/bayes3d/rendering/photorealistic_renderers/kubric_interface.py index 9966c795..274a8b19 100644 --- a/bayes3d/rendering/photorealistic_renderers/kubric_interface.py +++ b/bayes3d/rendering/photorealistic_renderers/kubric_interface.py @@ -1,10 +1,23 @@ -import bayes3d as j -import numpy as np -import jax.numpy as jnp -import subprocess import os +import subprocess -def render_many(mesh_paths, poses, intrinsics, mesh_scales = None, mesh_colors = None, scaling_factor=1.0, lighting=20.0, camera_pose=None, background_color= [0.5,0.5,0.5]): +import jax.numpy as jnp +import numpy as np + +import bayes3d as j + + +def render_many( + mesh_paths, + poses, + intrinsics, + mesh_scales=None, + mesh_colors=None, + scaling_factor=1.0, + lighting=20.0, + camera_pose=None, + background_color=[0.5, 0.5, 0.5], +): """Render a scene with multiple objects in it through kubric. Args: @@ -12,15 +25,15 @@ def render_many(mesh_paths, poses, intrinsics, mesh_scales = None, mesh_colors = poses (jnp.ndarray): Array of poses of shape (num_frames, num_objects, 4, 4). intrinsics (b.camera.Intrinsics): Camera intrinsics. Returns: - list: List of rendered RGBD images. + list: List of rendered RGBD images. """ - # warn if intrinsics are incompatible with Blender camera parameters + # warn if intrinsics are incompatible with Blender camera parameters if intrinsics.fx != intrinsics.fy: print("fx is not equal to fy!") - if intrinsics.cy != intrinsics.height/2: + if intrinsics.cy != intrinsics.height / 2: print("cy is not equal to height/2!") - if intrinsics.cy != intrinsics.height/2: + if intrinsics.cy != intrinsics.height / 2: print("cy is not equal to height/2!") K = j.camera.K_from_intrinsics(intrinsics) @@ -28,49 +41,63 @@ def render_many(mesh_paths, poses, intrinsics, mesh_scales = None, mesh_colors = for scene_index in range(poses.shape[0]): poses_pos_quat = [] for object_index in range(poses.shape[1]): - poses_pos_quat.append(( - np.array(poses[scene_index, object_index,:3,3]), - np.array(j.t3d.rotation_matrix_to_quaternion(poses[scene_index,object_index,:3,:3])) - )) + poses_pos_quat.append( + ( + np.array(poses[scene_index, object_index, :3, 3]), + np.array( + j.t3d.rotation_matrix_to_quaternion( + poses[scene_index, object_index, :3, :3] + ) + ), + ) + ) poses_pos_quat_all.append(poses_pos_quat) if camera_pose is None: camera_pose = jnp.eye(4) # camera_pose = camera_pose @ j.t3d.transform_from_axis_angle(jnp.array([1.0, 0.0,0.0]), jnp.pi) - cam_pose_pos_quat = (np.array(camera_pose[:3,3]), np.array(j.t3d.rotation_matrix_to_quaternion(camera_pose[:3,:3]))) + cam_pose_pos_quat = ( + np.array(camera_pose[:3, 3]), + np.array(j.t3d.rotation_matrix_to_quaternion(camera_pose[:3, :3])), + ) - if mesh_scales == None: - mesh_scales = [[1.0,1.0,1.0] for i in range(len(mesh_paths))] + if mesh_scales is None: + mesh_scales = [[1.0, 1.0, 1.0] for i in range(len(mesh_paths))] else: assert len(mesh_scales) == len(mesh_paths) - np.savez("/tmp/blenderproc_kubric.npz", + np.savez( + "/tmp/blenderproc_kubric.npz", mesh_paths=mesh_paths, - mesh_scales = mesh_scales, - mesh_colors = mesh_colors, + mesh_scales=mesh_scales, + mesh_colors=mesh_colors, scaling_factor=scaling_factor, - poses=np.array(poses_pos_quat_all, dtype=object) , + poses=np.array(poses_pos_quat_all, dtype=object), camera_pose=np.array(cam_pose_pos_quat, dtype=object), K=K, height=intrinsics.height, width=intrinsics.width, - fx = intrinsics.fx, - fy = intrinsics.fy, - cx = intrinsics.cx, - cy = intrinsics.cy, - near = intrinsics.near, - far = intrinsics.far, + fx=intrinsics.fx, + fy=intrinsics.fy, + cx=intrinsics.cx, + cy=intrinsics.cy, + near=intrinsics.near, + far=intrinsics.far, intensity=lighting, - background = background_color + background=background_color, ) path = os.path.dirname(os.path.dirname(__file__)) - print('path:');print(path) + print("path:") + print(path) command_string = f"""sudo docker run --rm --interactive --user $(id -u):$(id -g) --volume {path}:{path} --volume /tmp:/tmp """ - command_strings = "".join([ - f""" --volume {os.path.dirname(p)}:{os.path.dirname(p)} """ for p in mesh_paths - ]) + command_strings = "".join( + [ + f""" --volume {os.path.dirname(p)}:{os.path.dirname(p)} """ + for p in mesh_paths + ] + ) command_string2 = f""" kubricdockerhub/kubruntu /usr/bin/python3 {path}/photorealistic_renderers/_kubric_exec_parallel.py""" print(command_string + command_strings + command_string2) subprocess.run([command_string + command_strings + command_string2], shell=True) @@ -78,9 +105,13 @@ def render_many(mesh_paths, poses, intrinsics, mesh_scales = None, mesh_colors = rgbd_images = [] for i in range(poses.shape[0]): data = np.load(f"/tmp/{i}.npz") - rgb, seg, depth = data["rgba"], data["segmentation"][...,0], data["depth"][:,:,0] + rgb, seg, depth = ( + data["rgba"], + data["segmentation"][..., 0], + data["depth"][:, :, 0], + ) depth[depth > intrinsics.far] = intrinsics.far depth[depth < intrinsics.near] = intrinsics.near - rgbd_images.append(j.RGBD(rgb,depth, camera_pose, intrinsics, seg)) + rgbd_images.append(j.RGBD(rgb, depth, camera_pose, intrinsics, seg)) - return rgbd_images \ No newline at end of file + return rgbd_images diff --git a/bayes3d/rgbd.py b/bayes3d/rgbd.py index 490fddcc..c8eac2e8 100644 --- a/bayes3d/rgbd.py +++ b/bayes3d/rgbd.py @@ -1,15 +1,16 @@ -import bayes3d.camera -import bayes3d as j -import bayes3d as b -import bayes3d.transforms_3d as t3d -import numpy as npe import jax.numpy as jnp import numpy as np +import bayes3d as b +import bayes3d as j +import bayes3d.camera +import bayes3d.transforms_3d as t3d + + class RGBD(object): def __init__(self, rgb, depth, camera_pose, intrinsics, segmentation=None): """RGBD Image - + Args: rgb (np.array): RGB image depth (np.array): Depth image @@ -21,11 +22,11 @@ def __init__(self, rgb, depth, camera_pose, intrinsics, segmentation=None): self.depth = depth self.camera_pose = camera_pose self.intrinsics = intrinsics - self.segmentation = segmentation + self.segmentation = segmentation def construct_from_camera_image(camera_image, near=0.001, far=5.0): """Construct RGBD image from CameraImage - + Args: camera_image (CameraImage): CameraImage object Returns: @@ -35,26 +36,30 @@ def construct_from_camera_image(camera_image, near=0.001, far=5.0): rgb = np.array(camera_image.rgbPixels) camera_pose = t3d.pybullet_pose_to_transform(camera_image.camera_pose) K = camera_image.camera_matrix - fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2] - h,w = depth.shape + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + h, w = depth.shape near = 0.001 - return RGBD(rgb, depth, camera_pose, j.Intrinsics(h,w,fx,fy,cx,cy,near,far)) + return RGBD( + rgb, depth, camera_pose, j.Intrinsics(h, w, fx, fy, cx, cy, near, far) + ) def construct_from_aidan_dict(d, near=0.001, far=5.0): """Construct RGBD image from Aidan's dictionary - + Args: d (dict): Dictionary containing rgb, depth, extrinsics, intrinsics Returns: RGBD: RGBD image """ - depth = np.array(d["depth"] / 1000.0) + depth = np.array(d["depth"] / 1000.0) camera_pose = t3d.pybullet_pose_to_transform(d["extrinsics"]) rgb = np.array(d["rgb"]) K = d["intrinsics"][0] - fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2] - h,w = depth.shape - observation = RGBD(rgb, depth, camera_pose, j.Intrinsics(h,w,fx,fy,cx,cy,near,far)) + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + h, w = depth.shape + observation = RGBD( + rgb, depth, camera_pose, j.Intrinsics(h, w, fx, fy, cx, cy, near, far) + ) return observation def construct_from_step_metadata(step_metadata, intrinsics=None): @@ -75,13 +80,15 @@ def construct_from_step_metadata(step_metadata, intrinsics=None): fy = cy / np.tan(fov_y / 2.0) clipping_near, clipping_far = step_metadata.camera_clipping_planes intrinsics = j.Intrinsics( - height,width, fx,fy,cx,cy,clipping_near,clipping_far + height, width, fx, fy, cx, cy, clipping_near, clipping_far ) rgb = np.array(list(step_metadata.image_list)[-1]) depth = np.array(list(step_metadata.depth_map_list)[-1]) seg = np.array(list(step_metadata.object_mask_list)[-1]) - colors, seg_final_flat = np.unique(seg.reshape(-1,3), axis=0, return_inverse=True) + colors, seg_final_flat = np.unique( + seg.reshape(-1, 3), axis=0, return_inverse=True + ) seg_final = seg_final_flat.reshape(seg.shape[:2]) observation = RGBD(rgb, depth, jnp.eye(4), intrinsics, seg_final) return observation @@ -90,4 +97,4 @@ def scale_rgbd(self, scaling_factor): intrinsics = b.camera.scale_camera_parameters(self.intrinsics, scaling_factor) rgb = b.utils.scale(self.rgb, intrinsics.height, intrinsics.width) depth = b.utils.scale(self.depth, intrinsics.height, intrinsics.width) - return RGBD(rgb, depth, self.camera_pose, intrinsics, self.segmentation) \ No newline at end of file + return RGBD(rgb, depth, self.camera_pose, intrinsics, self.segmentation) diff --git a/bayes3d/scene_graph.py b/bayes3d/scene_graph.py index ddb9554f..8e096457 100644 --- a/bayes3d/scene_graph.py +++ b/bayes3d/scene_graph.py @@ -1,11 +1,26 @@ -import jax.numpy as jnp +from collections import namedtuple + import jax +import jax.numpy as jnp + import bayes3d.transforms_3d as t3d -from collections import namedtuple - -class SceneGraph(namedtuple('SceneGraph', ['root_poses', 'box_dimensions', 'parents', 'contact_params', 'face_parent', 'face_child'])): + + +class SceneGraph( + namedtuple( + "SceneGraph", + [ + "root_poses", + "box_dimensions", + "parents", + "contact_params", + "face_parent", + "face_child", + ], + ) +): """Scene graph data structure. - + Args: root_poses: Array of root poses. Shape (N,4,4). box_dimensions: Array of bounding box dimensions. Shape (N,3). @@ -24,11 +39,11 @@ def get_poses(self): self.face_parent, self.face_child, ) - + def visualize(self, filename, node_names=None, colors=None): + import distinctipy import graphviz import matplotlib - import distinctipy scene_graph = self num_nodes = len(scene_graph.root_poses) @@ -42,35 +57,41 @@ def visualize(self, filename, node_names=None, colors=None): g_out.attr("node", style="filled") for i in range(num_nodes): - g_out.node(f"{i}", node_names[i], fillcolor=matplotlib.colors.to_hex(colors[i])) - + g_out.node( + f"{i}", node_names[i], fillcolor=matplotlib.colors.to_hex(colors[i]) + ) edges = [] edge_label = [] - for i,parent in enumerate(scene_graph.parents): + for i, parent in enumerate(scene_graph.parents): if parent == -1: continue edges.append((parent, i)) - contact_string = f"contact:\n" + " ".join([f"{x:.2f}" for x in scene_graph.contact_params[i]]) - contact_string += f"\nfaces\n{scene_graph.face_parent[i].item()} --- {scene_graph.face_child[i].item()}" + contact_string = "contact:\n" + " ".join( + [f"{x:.2f}" for x in scene_graph.contact_params[i]] + ) + contact_string += f"\nfaces\n{scene_graph.face_parent[i].item()} --- {scene_graph.face_child[i].item()}" edge_label.append(contact_string) - for ((i,j),label) in zip(edges, edge_label): - if i==-1: + for (i, j), label in zip(edges, edge_label): + if i == -1: continue - g_out.edge(f"{i}",f"{j}", label=label) + g_out.edge(f"{i}", f"{j}", label=label) max_width_px = 2000 max_height_px = 2000 dpi = 200 - g_out.attr("graph", - # See https://graphviz.gitlab.io/_pages/doc/info/attrs.html#a:size - size="{},{}!".format(max_width_px / dpi, max_height_px / dpi), - dpi=f"{dpi}") + g_out.attr( + "graph", + # See https://graphviz.gitlab.io/_pages/doc/info/attrs.html#a:size + size="{},{}!".format(max_width_px / dpi, max_height_px / dpi), + dpi=f"{dpi}", + ) filename_prefix, filetype = filename.split(".") g_out.render(filename_prefix, format=filetype) + def create_floating_scene_graph(scene_graph): """Create a new scene graph with the same structure, but with all objects floating. @@ -87,7 +108,9 @@ def create_floating_scene_graph(scene_graph): ) -def add_edge_scene_graph(scene_graph, parent, child, face_parent, face_child, contact_params): +def add_edge_scene_graph( + scene_graph, parent, child, face_parent, face_child, contact_params +): print(parent, child, face_parent, face_child) N = scene_graph.get_poses().shape[0] sg_parents = jnp.array(scene_graph.parents) @@ -107,77 +130,98 @@ def add_edge_scene_graph(scene_graph, parent, child, face_parent, face_child, co face_child=sg_face_child, ) + def get_contact_planes(dimensions): return jnp.stack( [ # bottom - t3d.transform_from_pos(jnp.array([0.0, dimensions[1]/2.0, 0.0])).dot(t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), -jnp.pi/2)), + t3d.transform_from_pos(jnp.array([0.0, dimensions[1] / 2.0, 0.0])).dot( + t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), -jnp.pi / 2) + ), # top - t3d.transform_from_pos(jnp.array([0.0, -dimensions[1]/2.0, 0.0])).dot(t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi/2)), + t3d.transform_from_pos(jnp.array([0.0, -dimensions[1] / 2.0, 0.0])).dot( + t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi / 2) + ), # back - t3d.transform_from_pos(jnp.array([0.0, 0.0, dimensions[2]/2.0])).dot(t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), 0.0)), + t3d.transform_from_pos(jnp.array([0.0, 0.0, dimensions[2] / 2.0])).dot( + t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), 0.0) + ), # front - t3d.transform_from_pos(jnp.array([0.0, 0.0, -dimensions[2]/2.0])).dot(t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi)), + t3d.transform_from_pos(jnp.array([0.0, 0.0, -dimensions[2] / 2.0])).dot( + t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi) + ), # left - t3d.transform_from_pos(jnp.array([-dimensions[0]/2.0, 0.0, 0.0])).dot(t3d.transform_from_axis_angle(jnp.array([0.0, 1.0, 0.0]), -jnp.pi/2)), + t3d.transform_from_pos(jnp.array([-dimensions[0] / 2.0, 0.0, 0.0])).dot( + t3d.transform_from_axis_angle(jnp.array([0.0, 1.0, 0.0]), -jnp.pi / 2) + ), # right - t3d.transform_from_pos(jnp.array([dimensions[0]/2.0, 0.0, 0.0])).dot(t3d.transform_from_axis_angle(jnp.array([0.0, 1.0, 0.0]), jnp.pi/2)), + t3d.transform_from_pos(jnp.array([dimensions[0] / 2.0, 0.0, 0.0])).dot( + t3d.transform_from_axis_angle(jnp.array([0.0, 1.0, 0.0]), jnp.pi / 2) + ), ] ) + def bounding_box_corners(dimensions): """ Returns the corners of an axis aligned bounding box. Args: dimensions: (3,) dimensions of the bounding box Returns: - corners: (8,3) corners of the bounding box + corners: (8,3) corners of the bounding box """ - corners = jnp.array([ - [-dimensions[0]/2, -dimensions[1]/2, -dimensions[2]/2], - [dimensions[0]/2, -dimensions[1]/2, -dimensions[2]/2], - [-dimensions[0]/2, dimensions[1]/2, -dimensions[2]/2], - [dimensions[0]/2, dimensions[1]/2, -dimensions[2]/2], - [-dimensions[0]/2, -dimensions[1]/2, dimensions[2]/2], - [dimensions[0]/2, -dimensions[1]/2, dimensions[2]/2], - [-dimensions[0]/2, dimensions[1]/2, dimensions[2]/2], - [dimensions[0]/2, dimensions[1]/2, dimensions[2]/2] - ]) + corners = jnp.array( + [ + [-dimensions[0] / 2, -dimensions[1] / 2, -dimensions[2] / 2], + [dimensions[0] / 2, -dimensions[1] / 2, -dimensions[2] / 2], + [-dimensions[0] / 2, dimensions[1] / 2, -dimensions[2] / 2], + [dimensions[0] / 2, dimensions[1] / 2, -dimensions[2] / 2], + [-dimensions[0] / 2, -dimensions[1] / 2, dimensions[2] / 2], + [dimensions[0] / 2, -dimensions[1] / 2, dimensions[2] / 2], + [-dimensions[0] / 2, dimensions[1] / 2, dimensions[2] / 2], + [dimensions[0] / 2, dimensions[1] / 2, dimensions[2] / 2], + ] + ) return corners + def get_contact_plane_dimenions(dimensions): - return jnp.array([ - [dimensions[0],dimensions[2]], - [dimensions[0],dimensions[2]], - [dimensions[0],dimensions[1]], - [dimensions[0],dimensions[1]], - [dimensions[2],dimensions[1]], - [dimensions[2],dimensions[1]], - ]) + return jnp.array( + [ + [dimensions[0], dimensions[2]], + [dimensions[0], dimensions[2]], + [dimensions[0], dimensions[1]], + [dimensions[0], dimensions[1]], + [dimensions[2], dimensions[1]], + [dimensions[2], dimensions[1]], + ] + ) + def contact_params_to_pose(contact_params): - x,y,angle = contact_params - return t3d.transform_from_pos(jnp.array([x,y, 0.0])).dot( + x, y, angle = contact_params + return t3d.transform_from_pos(jnp.array([x, y, 0.0])).dot( t3d.transform_from_axis_angle(jnp.array([1.0, 1.0, 0.0]), jnp.pi).dot( t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) ) ) + def relative_pose_from_edge( contact_params, - face_child, dims_child, + face_child, + dims_child, ): - x,y,angle = contact_params - contact_transform = ( - t3d.transform_from_pos(jnp.array([x,y, 0.0])).dot( - t3d.transform_from_axis_angle(jnp.array([1.0, 1.0, 0.0]), jnp.pi).dot( - t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) - ) + x, y, angle = contact_params + contact_transform = t3d.transform_from_pos(jnp.array([x, y, 0.0])).dot( + t3d.transform_from_axis_angle(jnp.array([1.0, 1.0, 0.0]), jnp.pi).dot( + t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) ) ) child_plane = get_contact_planes(dims_child)[face_child] return contact_transform.dot(jnp.linalg.inv(child_plane)) + relative_pose_from_edge_jit = jax.jit(relative_pose_from_edge) relative_pose_from_edge_parallel_jit = jax.jit( jax.vmap( @@ -186,33 +230,53 @@ def relative_pose_from_edge( ) ) -def iter(root_poses, box_dimensions, parent, child, contact_params, face_parent, face_child): + +def iter( + root_poses, box_dimensions, parent, child, contact_params, face_parent, face_child +): parent_plane = get_contact_planes(box_dimensions[parent])[face_parent] relative = parent_plane.dot( relative_pose_from_edge(contact_params, face_child, box_dimensions[child]) ) - return ( - root_poses[parent].dot(relative) * (parent != -1) - + - root_poses[child] * (parent == -1) + return root_poses[parent].dot(relative) * (parent != -1) + root_poses[child] * ( + parent == -1 ) -def poses_from_scene_graph(root_poses, box_dimensions, parents, contact_params, face_parent, face_child): + +def poses_from_scene_graph( + root_poses, box_dimensions, parents, contact_params, face_parent, face_child +): def _f(poses, _): - new_poses = jax.vmap(iter, in_axes=(None, None, 0, 0, 0, 0, 0))(poses, box_dimensions, parents, jnp.arange(parents.shape[0]), contact_params, face_parent, face_child) + new_poses = jax.vmap(iter, in_axes=(None, None, 0, 0, 0, 0, 0))( + poses, + box_dimensions, + parents, + jnp.arange(parents.shape[0]), + contact_params, + face_parent, + face_child, + ) return (new_poses, new_poses) + return jax.lax.scan(_f, root_poses, jnp.ones(parents.shape[0]))[0] + + poses_from_scene_graph_jit = jax.jit(poses_from_scene_graph) + def closest_approximate_contact_params(parent_contact_plane, child_contact_plane): contact_pose = t3d.inverse_pose(parent_contact_plane) @ child_contact_plane - (x, y, _) = contact_pose[:3,3] + (x, y, _) = contact_pose[:3, 3] pose_ = ( - t3d.inverse_pose(t3d.transform_from_axis_angle(jnp.array([1.0, 1.0, 0.0]), jnp.pi)) @ - contact_pose + t3d.inverse_pose( + t3d.transform_from_axis_angle(jnp.array([1.0, 1.0, 0.0]), jnp.pi) + ) + @ contact_pose ) - quaternion = t3d.rotation_matrix_to_quaternion(pose_[:3,:3]) + quaternion = t3d.rotation_matrix_to_quaternion(pose_[:3, :3]) angle = 2 * jnp.arctan2(quaternion[3], quaternion[0]) - inferred_contact_params = jnp.array([x,y,angle]) - slack = t3d.inverse_pose(contact_params_to_pose(inferred_contact_params)) @ contact_pose - return inferred_contact_params, slack \ No newline at end of file + inferred_contact_params = jnp.array([x, y, angle]) + slack = ( + t3d.inverse_pose(contact_params_to_pose(inferred_contact_params)) @ contact_pose + ) + return inferred_contact_params, slack diff --git a/bayes3d/transforms_3d.py b/bayes3d/transforms_3d.py index dccac4a4..ad898584 100644 --- a/bayes3d/transforms_3d.py +++ b/bayes3d/transforms_3d.py @@ -1,16 +1,17 @@ +import cv2 import jax import jax.numpy as jnp import numpy as np -from typing import Tuple -import cv2 + def identity_pose(): """Creates an identity pose matrix.""" return jnp.eye(4) + def inverse_pose(pose): """Inverts a pose matrix. - + Args: pose (jnp.ndarray): The pose matrix. Shape (4, 4) Returns: @@ -18,6 +19,7 @@ def inverse_pose(pose): """ return jnp.linalg.inv(pose) + def transform_from_pos(translation): """Creates a pose matrix from a translation vector. @@ -27,29 +29,38 @@ def transform_from_pos(translation): jnp.ndarray: The pose matrix. Shape (4, 4) """ return jnp.vstack( - [jnp.hstack([jnp.eye(3), translation.reshape(3,1)]), jnp.array([0.0, 0.0, 0.0, 1.0])] + [ + jnp.hstack([jnp.eye(3), translation.reshape(3, 1)]), + jnp.array([0.0, 0.0, 0.0, 1.0]), + ] ) + def transform_from_rot(rotation): """Creates a pose matrix from a rotation matrix. - + Args: rotation (jnp.ndarray): The rotation matrix. Shape (3, 3) Returns: jnp.ndarray: The pose matrix. Shape (4, 4) """ return jnp.vstack( - [jnp.hstack([rotation, jnp.zeros((3,1))]), jnp.array([0.0, 0.0, 0.0, 1.0])] + [jnp.hstack([rotation, jnp.zeros((3, 1))]), jnp.array([0.0, 0.0, 0.0, 1.0])] ) + def transform_from_rot_and_pos(rotation, translation): return jnp.vstack( - [jnp.hstack([rotation, translation.reshape(3,1)]), jnp.array([0.0, 0.0, 0.0, 1.0])] + [ + jnp.hstack([rotation, translation.reshape(3, 1)]), + jnp.array([0.0, 0.0, 0.0, 1.0]), + ] ) + def rotation_from_axis_angle(axis, angle): """Creates a rotation matrix from an axis and angle. - + Args: axis (jnp.ndarray): The axis vector. Shape (3,) angle (float): The angle in radians. @@ -63,14 +74,19 @@ def rotation_from_axis_angle(axis, angle): R = jnp.diag(jnp.array([cosa, cosa, cosa])) R = R + jnp.outer(direction, direction) * (1.0 - cosa) direction = direction * sina - R = R + jnp.array([[0.0, -direction[2], direction[1]], - [direction[2], 0.0, -direction[0]], - [-direction[1], direction[0], 0.0]]) + R = R + jnp.array( + [ + [0.0, -direction[2], direction[1]], + [direction[2], 0.0, -direction[0]], + [-direction[1], direction[0], 0.0], + ] + ) return R + def transform_from_axis_angle(axis, angle): """Creates a pose matrix from an axis and angle. - + Args: axis (jnp.ndarray): The axis vector. Shape (3,) angle (float): The angle in radians. @@ -79,9 +95,10 @@ def transform_from_axis_angle(axis, angle): """ return transform_from_rot(rotation_from_axis_angle(axis, angle)) + def rotation_from_rodrigues(rodrigues_vector): """Creates a rotation matrix from a rodrigues vector. - + Args: rodrigues_vector (jnp.ndarray): The rodrigues vector. Shape (3,) Returns: @@ -89,40 +106,45 @@ def rotation_from_rodrigues(rodrigues_vector): """ r_flat = rodrigues_vector.reshape(-1) theta = jnp.linalg.norm(r_flat) - r = r_flat/theta - A = jnp.array([[0, -r[2], r[1]],[r[2], 0, -r[0]],[-r[1], r[0], 0]]) - R = jnp.cos(theta) * jnp.eye(3) + (1 - jnp.cos(theta)) * r.reshape(-1,1) * r.transpose() + jnp.sin(theta) * A + r = r_flat / theta + A = jnp.array([[0, -r[2], r[1]], [r[2], 0, -r[0]], [-r[1], r[0], 0]]) + R = ( + jnp.cos(theta) * jnp.eye(3) + + (1 - jnp.cos(theta)) * r.reshape(-1, 1) * r.transpose() + + jnp.sin(theta) * A + ) return jnp.where(theta < 0.0001, jnp.eye(3), R) + def transform_to_posevec(transform): - rvec = jnp.array(cv2.Rodrigues(np.array(transform[:3,:3]))[0]).reshape(-1) - tvec = transform[:3,3].reshape(-1) + rvec = jnp.array(cv2.Rodrigues(np.array(transform[:3, :3]))[0]).reshape(-1) + tvec = transform[:3, 3].reshape(-1) posevec = jnp.concatenate([tvec, rvec]) return posevec + def transform_from_posevec(posevec): return transform_from_rot_and_pos(rotation_from_rodrigues(posevec[3:]), posevec[:3]) + def transform_from_rvec_tvec(rvec, tvec): - return transform_from_rot_and_pos( - rotation_from_rodrigues(rvec), - tvec.reshape(-1) - ) + return transform_from_rot_and_pos(rotation_from_rodrigues(rvec), tvec.reshape(-1)) def add_homogenous_ones(cloud): """Adds a column of ones to a point cloud. - + Args: cloud (jnp.ndarray): The point cloud. Shape (N, 3) Returns: jnp.ndarray: The point cloud with a column of ones. Shape (N, 4) """ - return jnp.concatenate([cloud, jnp.ones((*cloud.shape[:-1],1))],axis=-1) + return jnp.concatenate([cloud, jnp.ones((*cloud.shape[:-1], 1))], axis=-1) + def apply_transform(coords, transform): """Applies a transform to a point cloud. - + Args: coords (jnp.ndarray): The point cloud. Shape (N, 3) transform (jnp.ndarray): The transform matrix. Shape (4, 4) @@ -130,16 +152,19 @@ def apply_transform(coords, transform): jnp.ndarray: The transformed point cloud. Shape (N, 3) """ coords = jnp.einsum( - 'ij,...j->...i', + "ij,...j->...i", transform, jnp.concatenate([coords, jnp.ones(coords.shape[:-1] + (1,))], axis=-1), )[..., :-1] return coords + + apply_transform_jit = jax.jit(apply_transform) + def unproject_depth(depth, intrinsics): """Unprojects a depth image into a point cloud. - + Args: depth (jnp.ndarray): The depth image. Shape (H, W) intrinsics (b.camera.Intrinsics): The camera intrinsics. @@ -153,12 +178,15 @@ def unproject_depth(depth, intrinsics): y = (y - intrinsics.cy) / intrinsics.fy point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None] return point_cloud_image + + unproject_depth_jit = jax.jit(unproject_depth) -unproject_depth_vmap_jit = jax.jit(jax.vmap(unproject_depth, in_axes=(0,None))) +unproject_depth_vmap_jit = jax.jit(jax.vmap(unproject_depth, in_axes=(0, None))) + def transform_from_pos_target_up(translation_of_camera, target_point, up): """Creates a pose matrix from a translation_of_camera, target_point and up vector. - + Args: translation_of_camera (jnp.ndarray): The position of the camera. Shape (3,) target_point (jnp.ndarray): The point at which the camera is looking at. Shape (3,) @@ -166,25 +194,24 @@ def transform_from_pos_target_up(translation_of_camera, target_point, up): Returns: jnp.ndarray: The camera pose matrix. Shape (4, 4) """ - z = target_point- translation_of_camera + z = target_point - translation_of_camera z = z / jnp.linalg.norm(z) x = jnp.cross(z, up) x = x / jnp.linalg.norm(x) - y = jnp.cross(z,x) + y = jnp.cross(z, x) y = y / jnp.linalg.norm(y) - R = jnp.hstack([ - x.reshape(-1,1),y.reshape(-1,1),z.reshape(-1,1) - ]) + R = jnp.hstack([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)]) return transform_from_rot_and_pos(R, translation_of_camera) + def estimate_transform_between_clouds(c1, c2): """Estimates a transform between two point clouds. transform = estimate_transform_between_clouds(c1, c2) - + `apply_transform(c1, transform)` should match `c2` as closely as possible. Args: @@ -193,28 +220,30 @@ def estimate_transform_between_clouds(c1, c2): Returns: jnp.ndarray: The transform matrix. Shape (4, 4) """ - - centroid1 = jnp.mean(c1, axis=0) + + centroid1 = jnp.mean(c1, axis=0) centroid2 = jnp.mean(c2, axis=0) c1_centered = c1 - centroid1 c2_centered = c2 - centroid2 H = jnp.transpose(c1_centered).dot(c2_centered) - U,_,V = jnp.linalg.svd(H) - rot = (jnp.transpose(V).dot(jnp.transpose(U))) + U, _, V = jnp.linalg.svd(H) + rot = jnp.transpose(V).dot(jnp.transpose(U)) - modifier = jnp.array([ - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, -1.0], - ]) + modifier = jnp.array( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, -1.0], + ] + ) V_mod = modifier.dot(V) - rot2 = (jnp.transpose(V_mod).dot(jnp.transpose(U))) + rot2 = jnp.transpose(V_mod).dot(jnp.transpose(U)) rot_final = (jnp.linalg.det(rot) < 0) * rot2 + (jnp.linalg.det(rot) > 0) * rot - T = (centroid2 - rot_final.dot(centroid1)) - transform = transform_from_rot_and_pos(rot_final, T) + T = centroid2 - rot_final.dot(centroid1) + transform = transform_from_rot_and_pos(rot_final, T) return transform @@ -226,6 +255,7 @@ def rotation_matrix_to_quaternion(matrix): Returns: jnp.ndarray: The quaternion. Shape (4,) """ + def case0(m): t = 1 + m[0, 0] - m[1, 1] - m[2, 2] q = jnp.array( @@ -297,6 +327,7 @@ def case3(m): ) return q * 0.5 / jnp.sqrt(t) + def pose_matrix_to_translation_and_quaternion(pose_matrix): """Converts a pose matrix to a translation and quaternion. @@ -308,6 +339,7 @@ def pose_matrix_to_translation_and_quaternion(pose_matrix): """ return pose_matrix[:3, 3], rotation_matrix_to_quaternion(pose_matrix[:3, :3]) + def translation_and_quaternion_to_pose_matrix(translation, quaternion): """Converts a translation and quaternion to a pose matrix. @@ -321,16 +353,17 @@ def translation_and_quaternion_to_pose_matrix(translation, quaternion): quaternion_to_rotation_matrix(quaternion), translation ) + def quaternion_to_rotation_matrix(Q_in): """ Covert a quaternion into a full three-dimensional rotation matrix. - + Input - :param Q: A 4 element array representing the quaternion (q0,q1,q2,q3) - + :param Q: A 4 element array representing the quaternion (q0,q1,q2,q3) + Output - :return: A 3x3 element matrix representing the full 3D rotation matrix. - This rotation matrix converts a point in the local reference + :return: A 3x3 element matrix representing the full 3D rotation matrix. + This rotation matrix converts a point in the local reference frame to a point in the global reference frame. """ # Extract the values from Q @@ -339,45 +372,45 @@ def quaternion_to_rotation_matrix(Q_in): q1 = Q[1] q2 = Q[2] q3 = Q[3] - + # First row of the rotation matrix r00 = 2 * (q0 * q0 + q1 * q1) - 1 r01 = 2 * (q1 * q2 - q0 * q3) r02 = 2 * (q1 * q3 + q0 * q2) - + # Second row of the rotation matrix r10 = 2 * (q1 * q2 + q0 * q3) r11 = 2 * (q0 * q0 + q2 * q2) - 1 r12 = 2 * (q2 * q3 - q0 * q1) - + # Third row of the rotation matrix r20 = 2 * (q1 * q3 - q0 * q2) r21 = 2 * (q2 * q3 + q0 * q1) r22 = 2 * (q0 * q0 + q3 * q3) - 1 - + # 3x3 rotation matrix - rot_matrix = jnp.array([[r00, r01, r02], - [r10, r11, r12], - [r20, r21, r22]]) - + rot_matrix = jnp.array([[r00, r01, r02], [r10, r11, r12], [r20, r21, r22]]) + return rot_matrix + def rotation_matrix_to_xyzw(matrix): wxyz = rotation_matrix_to_quaternion(matrix) return jnp.array([*wxyz[1:], wxyz[0]]) + def xyzw_to_rotation_matrix(xyzw): return quaternion_to_rotation_matrix(jnp.array([xyzw[-1], *xyzw[:-1]])) + def pybullet_pose_to_transform(pybullet_pose): translation = jnp.array(pybullet_pose[0]) R = xyzw_to_rotation_matrix(pybullet_pose[1]) - cam_pose = ( - transform_from_rot_and_pos(R, translation) - ) + cam_pose = transform_from_rot_and_pos(R, translation) return cam_pose + def transform_to_pybullet_pose(pose): - translation = jnp.array(pose[:3,3]) - quat = rotation_matrix_to_xyzw(pose[:3,:3]) + translation = jnp.array(pose[:3, 3]) + quat = rotation_matrix_to_xyzw(pose[:3, :3]) return translation, quat diff --git a/bayes3d/utils/__init__.py b/bayes3d/utils/__init__.py index 7c03ae1f..1ba5a556 100644 --- a/bayes3d/utils/__init__.py +++ b/bayes3d/utils/__init__.py @@ -1,8 +1,8 @@ -from .utils import * -from .icp import * +from .bbox import * from .enumerations import * -from .occlusion import * -from .ycb_loader import * +from .icp import * from .mesh import * -from .bbox import * +from .occlusion import * from .r3d_loader import * +from .utils import * +from .ycb_loader import * diff --git a/bayes3d/utils/bbox.py b/bayes3d/utils/bbox.py index 55ca8f38..7fcd32f6 100644 --- a/bayes3d/utils/bbox.py +++ b/bayes3d/utils/bbox.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp + def separating_axis_test(axis, box1, box2): """ Projects both boxes onto the given axis and checks for overlap. @@ -8,12 +9,15 @@ def separating_axis_test(axis, box1, box2): min1, max1 = project_box(axis, box1) min2, max2 = project_box(axis, box2) - return jax.lax.cond(jnp.logical_or(max1 < min2, max2 < min1), lambda: False, lambda: True) + return jax.lax.cond( + jnp.logical_or(max1 < min2, max2 < min1), lambda: False, lambda: True + ) # if max1 < min2 or max2 < min1: # return False # return True + def project_box(axis, box): """ Projects a box onto an axis and returns the min and max projection values. @@ -22,23 +26,25 @@ def project_box(axis, box): projections = jnp.array([jnp.dot(corner, axis) for corner in corners]) return jnp.min(projections), jnp.max(projections) + def get_transformed_box_corners(box): """ Returns the 8 corners of the box based on its dimensions and pose. """ dim, pose = box corners = [] - for dx in [-dim[0]/2, dim[0]/2]: - for dy in [-dim[1]/2, dim[1]/2]: - for dz in [-dim[2]/2, dim[2]/2]: + for dx in [-dim[0] / 2, dim[0] / 2]: + for dy in [-dim[1] / 2, dim[1] / 2]: + for dz in [-dim[2] / 2, dim[2] / 2]: corner = jnp.array([dx, dy, dz, 1]) transformed_corner = pose @ corner corners.append(transformed_corner[:3]) return corners + def are_bboxes_intersecting(dim1, dim2, pose1, pose2): """ - Checks if two oriented bounding boxes (OBBs), which are AABBs with poses, are intersecting using the Separating + Checks if two oriented bounding boxes (OBBs), which are AABBs with poses, are intersecting using the Separating Axis Theorem (SAT). Args: @@ -62,9 +68,14 @@ def are_bboxes_intersecting(dim1, dim2, pose1, pose2): # Perform SAT on each axis count_ = 0 for axis in axes_to_test: - count_+= jax.lax.cond(separating_axis_test(axis, box1, box2), lambda:0,lambda:-1) + count_ += jax.lax.cond( + separating_axis_test(axis, box1, box2), lambda: 0, lambda: -1 + ) + + return jax.lax.cond(count_ < 0, lambda: False, lambda: True) - return jax.lax.cond(count_ < 0, lambda:False,lambda:True) # For one reference pose (object 1) and many possible poses for the second object -are_bboxes_intersecting_many = jax.vmap(are_bboxes_intersecting, in_axes = (None, None, None, 0)) +are_bboxes_intersecting_many = jax.vmap( + are_bboxes_intersecting, in_axes=(None, None, None, 0) +) diff --git a/bayes3d/utils/enumerations.py b/bayes3d/utils/enumerations.py index 5a209ba2..f819dd3e 100644 --- a/bayes3d/utils/enumerations.py +++ b/bayes3d/utils/enumerations.py @@ -1,19 +1,25 @@ -import jax.numpy as jnp import jax +import jax.numpy as jnp + from bayes3d.transforms_3d import transform_from_axis_angle, transform_from_pos + def angle_axis_helper_edgecase(newZ): zUnit = jnp.array([0.0, 0.0, 1.0]) axis = jnp.array([0.0, 1.0, 0.0]) - geodesicAngle = jax.lax.cond(jnp.allclose(newZ, zUnit, atol=1e-3), lambda:0.0, lambda:jnp.pi) + geodesicAngle = jax.lax.cond( + jnp.allclose(newZ, zUnit, atol=1e-3), lambda: 0.0, lambda: jnp.pi + ) return axis, geodesicAngle -def angle_axis_helper(newZ): +def angle_axis_helper(newZ): zUnit = jnp.array([0.0, 0.0, 1.0]) axis = jnp.cross(zUnit, newZ) theta = jax.lax.asin(jax.lax.clamp(-1.0, jnp.linalg.norm(axis), 1.0)) - geodesicAngle = jax.lax.cond(jnp.dot(zUnit, newZ) > 0, lambda:theta, lambda:jnp.pi - theta) + geodesicAngle = jax.lax.cond( + jnp.dot(zUnit, newZ) > 0, lambda: theta, lambda: jnp.pi - theta + ) return axis, geodesicAngle @@ -23,9 +29,16 @@ def geodesicHopf_rotate_within_axis(newZ, planarAngle): zUnit = jnp.array([0.0, 0.0, 1.0]) # todo: implement cases where newZ is approx. -zUnit or approx. zUnit - axis, geodesicAngle = jax.lax.cond(jnp.allclose(jnp.abs(newZ), zUnit, atol=1e-3), angle_axis_helper_edgecase, angle_axis_helper, newZ) + axis, geodesicAngle = jax.lax.cond( + jnp.allclose(jnp.abs(newZ), zUnit, atol=1e-3), + angle_axis_helper_edgecase, + angle_axis_helper, + newZ, + ) - return (transform_from_axis_angle(axis, geodesicAngle) @ transform_from_axis_angle(zUnit, planarAngle)) + return transform_from_axis_angle(axis, geodesicAngle) @ transform_from_axis_angle( + zUnit, planarAngle + ) def fibonacci_sphere(samples_in_range, phi_range=jnp.pi): @@ -33,23 +46,30 @@ def fibonacci_sphere(samples_in_range, phi_range=jnp.pi): eps = 1e-10 min_y = jnp.cos(phi_range) - samples = jnp.round(samples_in_range * (2 / (1-min_y+eps))) + samples = jnp.round(samples_in_range * (2 / (1 - min_y + eps))) def fib_point(i): - y = 1 - (i / (samples - 1)) * 2 # goes from 1 to -1 + y = 1 - (i / (samples - 1)) * 2 # goes from 1 to -1 radius = jnp.sqrt(1 - y * y) theta = ga * i x = jnp.cos(theta) * radius z = jnp.sin(theta) * radius - return jnp.array([x,z,y]) - + return jnp.array([x, z, y]) + fib_sphere = jax.vmap(fib_point, in_axes=(0)) points = jnp.arange(samples_in_range) return fib_sphere(points) -def make_rotation_grid_enumeration(fibonacci_sphere_points, num_planar_angle_points, min_rot_angle, max_rot_angle, sphere_angle_range): + +def make_rotation_grid_enumeration( + fibonacci_sphere_points, + num_planar_angle_points, + min_rot_angle, + max_rot_angle, + sphere_angle_range, +): """ - Generate uniformly spaced rotation proposals around a constrained region of SO(3) + Generate uniformly spaced rotation proposals around a constrained region of SO(3) Params: fib_sample: number of rotation axes to sample, on the region of the fibonacci sphere specified by `sphere_angle_range` @@ -57,60 +77,96 @@ def make_rotation_grid_enumeration(fibonacci_sphere_points, num_planar_angle_poi min_rot_angle, max_rot_angle: the minimum and maximum rotation angle values; max_rot_angle - min_rot_angle leq 2*pi sphere_angle_range: the maximum phi angle (in spherical coordinates) that bounds the region of the fibonacci sphere to sample rotation axes from; sphere_angle_range leq pi - Returns: + Returns: rotation proposals: (fib_sample*rot_sample, 4, 4) """ - unit_sphere_directions = fibonacci_sphere(fibonacci_sphere_points, sphere_angle_range) - geodesicHopf_select_axis_vmap = jax.vmap(jax.vmap(geodesicHopf_rotate_within_axis, in_axes=(0,None)), in_axes=(None,0)) - rot_stepsize = (max_rot_angle - min_rot_angle)/ num_planar_angle_points - rotation_proposals = geodesicHopf_select_axis_vmap(unit_sphere_directions, jnp.arange(min_rot_angle, max_rot_angle, rot_stepsize)).reshape(-1, 4, 4) + unit_sphere_directions = fibonacci_sphere( + fibonacci_sphere_points, sphere_angle_range + ) + geodesicHopf_select_axis_vmap = jax.vmap( + jax.vmap(geodesicHopf_rotate_within_axis, in_axes=(0, None)), in_axes=(None, 0) + ) + rot_stepsize = (max_rot_angle - min_rot_angle) / num_planar_angle_points + rotation_proposals = geodesicHopf_select_axis_vmap( + unit_sphere_directions, jnp.arange(min_rot_angle, max_rot_angle, rot_stepsize) + ).reshape(-1, 4, 4) return rotation_proposals + def make_translation_grid_enumeration( - min_x,min_y,min_z, - max_x,max_y,max_z, - num_x=2,num_y=2,num_z=2 + min_x, min_y, min_z, max_x, max_y, max_z, num_x=2, num_y=2, num_z=2 ): """ Generate uniformly spaced translation proposals in a 3D box Args: min_x, min_y, min_z: minimum x, y, z values """ - deltas = jnp.stack(jnp.meshgrid( - jnp.linspace(min_x, max_x, num_x), - jnp.linspace(min_y, max_y, num_y), - jnp.linspace(min_z, max_z, num_z) - ), - axis=-1) - deltas = deltas.reshape((-1,3),order='F') + deltas = jnp.stack( + jnp.meshgrid( + jnp.linspace(min_x, max_x, num_x), + jnp.linspace(min_y, max_y, num_y), + jnp.linspace(min_z, max_z, num_z), + ), + axis=-1, + ) + deltas = deltas.reshape((-1, 3), order="F") return jax.vmap(transform_from_pos)(deltas) -def make_translation_grid_enumeration_3d(min_x,min_y,min_z, max_x,max_y,max_z, num_x=2,num_y=2,num_z=2): - deltas = jnp.stack(jnp.meshgrid( - jnp.linspace(min_x,max_x,num_x), - jnp.linspace(min_y,max_y,num_y), - jnp.linspace(min_z,max_z,num_z), - ), - axis=-1) - deltas = deltas.reshape(-1,3) + +def make_translation_grid_enumeration_3d( + min_x, min_y, min_z, max_x, max_y, max_z, num_x=2, num_y=2, num_z=2 +): + deltas = jnp.stack( + jnp.meshgrid( + jnp.linspace(min_x, max_x, num_x), + jnp.linspace(min_y, max_y, num_y), + jnp.linspace(min_z, max_z, num_z), + ), + axis=-1, + ) + deltas = deltas.reshape(-1, 3) return deltas -def make_translation_grid_enumeration_2d(min_x,min_y, max_x, max_y, num_x,num_y): - deltas = jnp.stack(jnp.meshgrid( - jnp.linspace(min_x,max_x,num_x), - jnp.linspace(min_y,max_y,num_y), - ), - axis=-1) - deltas = deltas.reshape(-1,2) + +def make_translation_grid_enumeration_2d(min_x, min_y, max_x, max_y, num_x, num_y): + deltas = jnp.stack( + jnp.meshgrid( + jnp.linspace(min_x, max_x, num_x), + jnp.linspace(min_y, max_y, num_y), + ), + axis=-1, + ) + deltas = deltas.reshape(-1, 2) return deltas -def make_pose_grid_enumeration(min_x,min_y,min_z, min_rotation_angle, - max_x,max_y,max_z, max_rotation_angle, - num_x,num_y,num_z, - fibonacci_sphere_points, num_planar_angle_points, - sphere_angle_range=jnp.pi): - rotations = make_rotation_grid_enumeration(fibonacci_sphere_points, num_planar_angle_points, min_rotation_angle, max_rotation_angle, sphere_angle_range) - translations = make_translation_grid_enumeration(min_x,min_y,min_z, max_x,max_y,max_z, num_x,num_y,num_z) - all_proposals = jnp.einsum("aij,bjk->abik", rotations, translations).reshape(-1, 4, 4) +def make_pose_grid_enumeration( + min_x, + min_y, + min_z, + min_rotation_angle, + max_x, + max_y, + max_z, + max_rotation_angle, + num_x, + num_y, + num_z, + fibonacci_sphere_points, + num_planar_angle_points, + sphere_angle_range=jnp.pi, +): + rotations = make_rotation_grid_enumeration( + fibonacci_sphere_points, + num_planar_angle_points, + min_rotation_angle, + max_rotation_angle, + sphere_angle_range, + ) + translations = make_translation_grid_enumeration( + min_x, min_y, min_z, max_x, max_y, max_z, num_x, num_y, num_z + ) + all_proposals = jnp.einsum("aij,bjk->abik", rotations, translations).reshape( + -1, 4, 4 + ) return all_proposals diff --git a/bayes3d/utils/gaussian_splatting.py b/bayes3d/utils/gaussian_splatting.py index a4fb1cc2..0c339a81 100644 --- a/bayes3d/utils/gaussian_splatting.py +++ b/bayes3d/utils/gaussian_splatting.py @@ -1,13 +1,20 @@ -import diff_gaussian_rasterization as dgr -from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer -from diff_gaussian_rasterization import _C +import functools +import math + import jax import jax.numpy as jnp +import numpy as np import torch -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -import math +from diff_gaussian_rasterization import ( + _C, + GaussianRasterizationSettings, + GaussianRasterizer, +) + import bayes3d as b -import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) @@ -31,15 +38,17 @@ def getProjectionMatrix(znear, zfar, fovX, fovY): P[2, 3] = -(zfar * znear) / (zfar - znear) return P -def intrinsics_to_rasterizer(intrinsics, camera_pose_jax): +def intrinsics_to_rasterizer(intrinsics, camera_pose_jax): fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2.0 fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2.0 tan_fovx = math.tan(fovX) tan_fovy = math.tan(fovY) - proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0,1).cuda() - view_matrix = torch.transpose(torch.tensor(np.array(b.inverse_pose(camera_pose_jax))),0,1).cuda() + proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0, 1).cuda() + view_matrix = torch.transpose( + torch.tensor(np.array(b.inverse_pose(camera_pose_jax))), 0, 1 + ).cuda() raster_settings = GaussianRasterizationSettings( image_height=int(intrinsics.height), @@ -53,14 +62,25 @@ def intrinsics_to_rasterizer(intrinsics, camera_pose_jax): sh_degree=0, campos=torch.zeros(3).cuda(), prefiltered=False, - debug=None + debug=None, ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) return rasterizer -def gaussian_raster_fwd(means3D, colors_precomp, opacity, scales, rotations, camera_pose, intrinsics): - means3D_torch, colors_precomp_torch, opacity_torch, scales_torch, rotations_torch, camera_pose_torch = [ - b.utils.jax_to_torch(x) for x in [means3D, colors_precomp, opacity, scales, rotations, camera_pose] + +def gaussian_raster_fwd( + means3D, colors_precomp, opacity, scales, rotations, camera_pose, intrinsics +): + ( + means3D_torch, + colors_precomp_torch, + opacity_torch, + scales_torch, + rotations_torch, + camera_pose_torch, + ) = [ + b.utils.jax_to_torch(x) + for x in [means3D, colors_precomp, opacity, scales, rotations, camera_pose] ] fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2.0 @@ -68,8 +88,8 @@ def gaussian_raster_fwd(means3D, colors_precomp, opacity, scales, rotations, cam tan_fovx = math.tan(fovX) tan_fovy = math.tan(fovY) - proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0,1).cuda() - view_matrix = torch.transpose(torch.linalg.inv(camera_pose_torch),0,1).cuda() + proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0, 1).cuda() + view_matrix = torch.transpose(torch.linalg.inv(camera_pose_torch), 0, 1).cuda() raster_settings = GaussianRasterizationSettings( image_height=int(intrinsics.height), image_width=int(intrinsics.width), @@ -82,12 +102,12 @@ def gaussian_raster_fwd(means3D, colors_precomp, opacity, scales, rotations, cam sh_degree=1, campos=torch.zeros(3).cuda(), prefiltered=False, - debug=None + debug=None, ) cov3Ds_precomp = torch.Tensor([]) sh = torch.Tensor([]) args = ( - raster_settings.bg, + raster_settings.bg, means3D_torch, colors_precomp_torch, opacity_torch, @@ -105,20 +125,59 @@ def gaussian_raster_fwd(means3D, colors_precomp, opacity, scales, rotations, cam raster_settings.sh_degree, raster_settings.campos, raster_settings.prefiltered, - raster_settings.debug + raster_settings.debug, ) - num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + ( + num_rendered, + color, + radii, + geomBuffer, + binningBuffer, + imgBuffer, + ) = _C.rasterize_gaussians(*args) return b.utils.torch_to_jax(color), ( - intrinsics, num_rendered, camera_pose, colors_precomp, means3D, scales, rotations, opacity, - *[b.utils.torch_to_jax(i) for i in [radii, geomBuffer, binningBuffer, imgBuffer]] + intrinsics, + num_rendered, + camera_pose, + colors_precomp, + means3D, + scales, + rotations, + opacity, + *[ + b.utils.torch_to_jax(i) + for i in [radii, geomBuffer, binningBuffer, imgBuffer] + ], ) + def gaussian_raster_bwd(saved_tensors, grad_output): - (intrinsics, num_rendered, camera_pose, colors_precomp, means3D, scales, rotations, opacity, radii, geomBuffer, binningBuffer, imgBuffer) = saved_tensors + ( + intrinsics, + num_rendered, + camera_pose, + colors_precomp, + means3D, + scales, + rotations, + opacity, + radii, + geomBuffer, + binningBuffer, + imgBuffer, + ) = saved_tensors - means3D_torch, colors_precomp_torch, opacity_torch, scales_torch, rotations_torch, camera_pose_torch = [ - b.utils.jax_to_torch(x) for x in [means3D, colors_precomp, opacity, scales, rotations, camera_pose] + ( + means3D_torch, + colors_precomp_torch, + opacity_torch, + scales_torch, + rotations_torch, + camera_pose_torch, + ) = [ + b.utils.jax_to_torch(x) + for x in [means3D, colors_precomp, opacity, scales, rotations, camera_pose] ] fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2.0 @@ -126,8 +185,8 @@ def gaussian_raster_bwd(saved_tensors, grad_output): tan_fovx = math.tan(fovX) tan_fovy = math.tan(fovY) - proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0,1).cuda() - view_matrix = torch.transpose(torch.linalg.inv(camera_pose_torch),0,1).cuda() + proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0, 1).cuda() + view_matrix = torch.transpose(torch.linalg.inv(camera_pose_torch), 0, 1).cuda() raster_settings = GaussianRasterizationSettings( image_height=int(intrinsics.height), image_width=int(intrinsics.width), @@ -140,10 +199,9 @@ def gaussian_raster_bwd(saved_tensors, grad_output): sh_degree=1, campos=torch.zeros(3).cuda(), prefiltered=False, - debug=None + debug=None, ) - geomBuffer_torch = b.utils.jax_to_torch(geomBuffer) binningBuffer_torch = b.utils.jax_to_torch(binningBuffer) imgBuffer_torch = b.utils.jax_to_torch(imgBuffer) @@ -153,44 +211,71 @@ def gaussian_raster_bwd(saved_tensors, grad_output): cov3Ds_precomp = torch.Tensor([]) sh = torch.Tensor([]) - args = (raster_settings.bg, - means3D_torch, - radii_torch, - colors_precomp_torch, - scales_torch, - rotations_torch, - raster_settings.scale_modifier, - cov3Ds_precomp, - raster_settings.viewmatrix, - raster_settings.projmatrix, - raster_settings.tanfovx, - raster_settings.tanfovy, - grad_out_color_torch, - sh, - raster_settings.sh_degree, - raster_settings.campos, - geomBuffer_torch, - num_rendered, - binningBuffer_torch, - imgBuffer_torch, - raster_settings.debug) - - - - grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward( - *args + args = ( + raster_settings.bg, + means3D_torch, + radii_torch, + colors_precomp_torch, + scales_torch, + rotations_torch, + raster_settings.scale_modifier, + cov3Ds_precomp, + raster_settings.viewmatrix, + raster_settings.projmatrix, + raster_settings.tanfovx, + raster_settings.tanfovy, + grad_out_color_torch, + sh, + raster_settings.sh_degree, + raster_settings.campos, + geomBuffer_torch, + num_rendered, + binningBuffer_torch, + imgBuffer_torch, + raster_settings.debug, ) + ( + grad_means2D, + grad_colors_precomp, + grad_opacities, + grad_means3D, + grad_cov3Ds_precomp, + grad_sh, + grad_scales, + grad_rotations, + ) = _C.rasterize_gaussians_backward(*args) + grad_means3D, grad_colors_precomp, grad_opacities, grad_scales, grad_rotations = [ b.utils.torch_to_jax(i) - for i in [grad_means3D, grad_colors_precomp, grad_opacities, grad_scales, grad_rotations] + for i in [ + grad_means3D, + grad_colors_precomp, + grad_opacities, + grad_scales, + grad_rotations, + ] ] # input order means3D, colors_precomp, opacities, scales, rotations, camera_pose, intrinsics - return grad_means3D, grad_colors_precomp, grad_opacities, grad_scales, grad_rotations, None, None + return ( + grad_means3D, + grad_colors_precomp, + grad_opacities, + grad_scales, + grad_rotations, + None, + None, + ) + -import functools @functools.partial(jax.custom_vjp) -def gaussian_raster(means3D, colors_precomp, opacities, scales, rotations, camera_pose, intrinsics): - return gaussian_raster_fwd(means3D, colors_precomp, opacities, scales, rotations, camera_pose, intrinsics)[0] -gaussian_raster.defvjp(gaussian_raster_fwd, gaussian_raster_bwd) \ No newline at end of file +def gaussian_raster( + means3D, colors_precomp, opacities, scales, rotations, camera_pose, intrinsics +): + return gaussian_raster_fwd( + means3D, colors_precomp, opacities, scales, rotations, camera_pose, intrinsics + )[0] + + +gaussian_raster.defvjp(gaussian_raster_fwd, gaussian_raster_bwd) diff --git a/bayes3d/utils/icp.py b/bayes3d/utils/icp.py index 0a57995f..968375ff 100644 --- a/bayes3d/utils/icp.py +++ b/bayes3d/utils/icp.py @@ -1,14 +1,19 @@ -import jax.numpy as jnp -import jax import functools -import numpy as np + +import jax +import jax.numpy as jnp + import bayes3d as b -import functools + @functools.partial( jnp.vectorize, - signature='(m)->(z)', - excluded=(1,2,3,), + signature="(m)->(z)", + excluded=( + 1, + 2, + 3, + ), ) def find_closest_point_at_pixel( ij, @@ -16,28 +21,50 @@ def find_closest_point_at_pixel( rendered_xyz_padded: jnp.ndarray, filter_size, ): - rendered_filter = jax.lax.dynamic_slice(rendered_xyz_padded, (ij[0], ij[1], 0), (2*filter_size + 1, 2*filter_size + 1, 3)) + rendered_filter = jax.lax.dynamic_slice( + rendered_xyz_padded, + (ij[0], ij[1], 0), + (2 * filter_size + 1, 2 * filter_size + 1, 3), + ) distances = jnp.linalg.norm( - observed_xyz[ij[0], ij[1], :3] - rendered_filter, - axis=-1 + observed_xyz[ij[0], ij[1], :3] - rendered_filter, axis=-1 ) - best_point = rendered_filter[jnp.unravel_index(jnp.argmin(distances), distances.shape)] + best_point = rendered_filter[ + jnp.unravel_index(jnp.argmin(distances), distances.shape) + ] return best_point def get_nearest_neighbor( - observed_xyz: jnp.ndarray, - rendered_xyz: jnp.ndarray, - filter_size: int + observed_xyz: jnp.ndarray, rendered_xyz: jnp.ndarray, filter_size: int ): - rendered_xyz_padded = jax.lax.pad(rendered_xyz, -100.0, ((filter_size,filter_size,0,),(filter_size,filter_size,0,),(0,0,0,))) - jj, ii = jnp.meshgrid(jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0])) - indices = jnp.stack([ii,jj],axis=-1) + rendered_xyz_padded = jax.lax.pad( + rendered_xyz, + -100.0, + ( + ( + filter_size, + filter_size, + 0, + ), + ( + filter_size, + filter_size, + 0, + ), + ( + 0, + 0, + 0, + ), + ), + ) + jj, ii = jnp.meshgrid( + jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0]) + ) + indices = jnp.stack([ii, jj], axis=-1) matches = find_closest_point_at_pixel( - indices, - observed_xyz, - rendered_xyz_padded, - filter_size + indices, observed_xyz, rendered_xyz_padded, filter_size ) return matches @@ -49,39 +76,46 @@ def find_least_squares_transform_between_clouds(c1, c2, mask): c2_centered = c2 - centroid2 H = jnp.transpose(c1_centered * mask).dot(c2_centered * mask) - U,_,V = jnp.linalg.svd(H) - rot = (jnp.transpose(V).dot(jnp.transpose(U))) + U, _, V = jnp.linalg.svd(H) + rot = jnp.transpose(V).dot(jnp.transpose(U)) - modifier = jnp.array([ - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, -1.0], - ]) + modifier = jnp.array( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, -1.0], + ] + ) V_mod = modifier.dot(V) - rot2 = (jnp.transpose(V_mod).dot(jnp.transpose(U))) + rot2 = jnp.transpose(V_mod).dot(jnp.transpose(U)) rot_final = (jnp.linalg.det(rot) < 0) * rot2 + (jnp.linalg.det(rot) > 0) * rot - T = (centroid2 - rot_final.dot(centroid1)) - transform = b.transform_from_rot_and_pos(rot_final, T) + T = centroid2 - rot_final.dot(centroid1) + transform = b.transform_from_rot_and_pos(rot_final, T) return transform -def icp_images(img, img_reference, init_pose, error_threshold, iterations, intrinsics, filter_size): + +def icp_images( + img, img_reference, init_pose, error_threshold, iterations, intrinsics, filter_size +): def _icp_step(i, pose_and_error): pose, _ = pose_and_error transformed_cloud = b.apply_transform(img, pose) matches_in_img_reference = b.utils.get_nearest_neighbor( - transformed_cloud, - img_reference, - filter_size + transformed_cloud, img_reference, filter_size + ) + mask = (img[:, :, 2] < intrinsics.far) * ( + matches_in_img_reference[:, :, 2] < intrinsics.far ) - mask = (img[:,:,2] < intrinsics.far) * (matches_in_img_reference[:,:,2] < intrinsics.far) - avg_error = (jnp.linalg.norm(img - matches_in_img_reference,axis=-1) *mask).sum() / mask.sum() + avg_error = ( + jnp.linalg.norm(img - matches_in_img_reference, axis=-1) * mask + ).sum() / mask.sum() transform = b.utils.find_least_squares_transform_between_clouds( - transformed_cloud.reshape(-1,3), - matches_in_img_reference.reshape(-1,3), - mask.reshape(-1,1) + transformed_cloud.reshape(-1, 3), + matches_in_img_reference.reshape(-1, 3), + mask.reshape(-1, 1), ) return jnp.where(avg_error < error_threshold, pose, transform @ pose), avg_error - return jax.lax.fori_loop(0, iterations, _icp_step, (init_pose,0.0)) + return jax.lax.fori_loop(0, iterations, _icp_step, (init_pose, 0.0)) diff --git a/bayes3d/utils/mesh.py b/bayes3d/utils/mesh.py index 50ebf1ee..a230a66b 100644 --- a/bayes3d/utils/mesh.py +++ b/bayes3d/utils/mesh.py @@ -1,112 +1,144 @@ -import trimesh -import numpy as np -import bayes3d.transforms_3d as t3d -import bayes3d as j +from itertools import product + import jax import jax.numpy as jnp -from itertools import product +import numpy as np +import trimesh + +import bayes3d as j +import bayes3d.transforms_3d as t3d + def center_mesh(mesh, return_pose=False): _, pose = j.utils.aabb(mesh.vertices) - shift = np.array(pose[:3,3]) + shift = np.array(pose[:3, 3]) mesh.vertices = mesh.vertices - shift if return_pose: return mesh, pose return mesh + def scale_mesh(mesh, scaling=1.0): mesh.vertices = mesh.vertices * scaling return mesh + def load_mesh(mesh_filename, scaling=1.0): mesh = trimesh.load(mesh_filename) mesh.vertices = mesh.vertices * scaling return mesh + def export_mesh(mesh, filename): - normals = mesh.face_normals - normals = mesh.vertex_normals - with open(filename,"w") as f: - f.write(trimesh.exchange.obj.export_obj(mesh, include_normals=True, include_texture=True)) + with open(filename, "w") as f: + f.write( + trimesh.exchange.obj.export_obj( + mesh, include_normals=True, include_texture=True + ) + ) + def make_cuboid_mesh(dimensions): - mesh = trimesh.creation.box( - dimensions, - np.eye(4) - ) + mesh = trimesh.creation.box(dimensions, np.eye(4)) return mesh + def make_voxel_mesh_from_point_cloud(point_cloud, resolution): poses = jax.vmap(j.t3d.transform_from_pos)(point_cloud) all_voxels = [ - trimesh.creation.box(np.array([resolution,resolution,resolution]), p) for p in poses + trimesh.creation.box(np.array([resolution, resolution, resolution]), p) + for p in poses ] final_mesh = trimesh.util.concatenate(all_voxels) return final_mesh -def make_marching_cubes_mesh_from_point_cloud( - point_cloud, - pitch -): + +def make_marching_cubes_mesh_from_point_cloud(point_cloud, pitch): mesh = trimesh.voxel.ops.points_to_marching_cubes(point_cloud, pitch=pitch) return mesh + def open3d_mesh_to_trimesh(mesh): return trimesh.Trimesh( - vertices=np.asarray(mesh.vertices), - faces=np.asarray(mesh.triangles) + vertices=np.asarray(mesh.vertices), faces=np.asarray(mesh.triangles) ) -def make_alpha_mesh_from_point_cloud( - point_cloud, - alpha -): + +def make_alpha_mesh_from_point_cloud(point_cloud, alpha): import open3d as o3d + pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(np.array(point_cloud)) - learned_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(pcd, alpha) - learned_mesh_trimesh = trimesh.Trimesh(vertices=np.asarray(learned_mesh.vertices), faces=np.asarray(learned_mesh.triangles)) + learned_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape( + pcd, alpha + ) + learned_mesh_trimesh = trimesh.Trimesh( + vertices=np.asarray(learned_mesh.vertices), + faces=np.asarray(learned_mesh.triangles), + ) return learned_mesh_trimesh + def make_table_mesh( - table_width, - table_length, - table_height, - table_thickness, - table_leg_width + table_width, table_length, table_height, table_thickness, table_leg_width ): - table_face = trimesh.creation.box( np.array([table_width, table_length, table_thickness]), - np.array(t3d.transform_from_pos(jnp.array([0.0, 0.0, table_height/2.0 - table_thickness/2.]))) + np.array( + t3d.transform_from_pos( + jnp.array([0.0, 0.0, table_height / 2.0 - table_thickness / 2.0]) + ) + ), ) - table_leg_height = table_height-table_thickness - leg_dims = np.array([table_leg_width, table_leg_width, table_leg_height]) - leg_center = np.array([table_width, table_length])/2. - table_leg_width/2.0*np.ones(2) - leg_xys = [np.multiply(leg_center, np.array(signs)) - for signs in product([-1, +1], repeat=len(leg_center))] + table_leg_height = table_height - table_thickness + leg_dims = np.array([table_leg_width, table_leg_width, table_leg_height]) + leg_center = np.array( + [table_width, table_length] + ) / 2.0 - table_leg_width / 2.0 * np.ones(2) + leg_xys = [ + np.multiply(leg_center, np.array(signs)) + for signs in product([-1, +1], repeat=len(leg_center)) + ] table_legs = [ trimesh.creation.box( leg_dims, - np.array(t3d.transform_from_pos(np.array([x, y, table_leg_height/2. - table_height/2.0]))) + np.array( + t3d.transform_from_pos( + np.array([x, y, table_leg_height / 2.0 - table_height / 2.0]) + ) + ), ) - for (x,y) in leg_xys + for (x, y) in leg_xys ] table = trimesh.util.concatenate([table_face] + table_legs) return table + def point_cloud_image_to_trimesh(point_cloud_image): height, width, _ = point_cloud_image.shape - ij_to_index = lambda i,j: i * width + j - ij_to_faces = lambda ij: jnp.array( - [ - [ij_to_index(ij[0], ij[1]), ij_to_index(ij[0]+1, ij[1]), ij_to_index(ij[0], ij[1]+1)], - [ij_to_index(ij[0]+1, ij[1]), ij_to_index(ij[0]+1, ij[1]+1), ij_to_index(ij[0], ij[1]+1)] - ] - ) - jj, ii = jnp.meshgrid(jnp.arange(width-1), jnp.arange(height-1)) - indices = jnp.stack([ii,jj],axis=-1) - faces = jax.vmap(ij_to_faces)(indices.reshape(-1,2)).reshape(-1,3) + + def ij_to_index(i, j): + return i * width + j + + def ij_to_faces(ij): + return jnp.array( + [ + [ + ij_to_index(ij[0], ij[1]), + ij_to_index(ij[0] + 1, ij[1]), + ij_to_index(ij[0], ij[1] + 1), + ], + [ + ij_to_index(ij[0] + 1, ij[1]), + ij_to_index(ij[0] + 1, ij[1] + 1), + ij_to_index(ij[0], ij[1] + 1), + ], + ] + ) + + jj, ii = jnp.meshgrid(jnp.arange(width - 1), jnp.arange(height - 1)) + indices = jnp.stack([ii, jj], axis=-1) + faces = jax.vmap(ij_to_faces)(indices.reshape(-1, 2)).reshape(-1, 3) print(faces.shape) - vertices = point_cloud_first.reshape(-1,3) - mesh = trimesh.Trimesh(vertices, faces) \ No newline at end of file + # vertices = point_cloud_first.reshape(-1,3) + # mesh = trimesh.Trimesh(vertices, faces) diff --git a/bayes3d/utils/occlusion.py b/bayes3d/utils/occlusion.py index 2a313ea4..515eea58 100644 --- a/bayes3d/utils/occlusion.py +++ b/bayes3d/utils/occlusion.py @@ -1,21 +1,34 @@ -import bayes3d as b -import jax.numpy as jnp -import bayes3d.transforms_3d as t3d import jax +import jax.numpy as jnp + +import bayes3d as b + def voxel_occupied_occluded_free(camera_pose, depth_image, grid, intrinsics, tolerance): grid_in_cam_frame = b.apply_transform(grid, b.t3d.inverse_pose(camera_pose)) pixels = b.project_cloud_to_pixels(grid_in_cam_frame, intrinsics).astype(jnp.int32) - valid_pixels = (0 <= pixels[:,0]) * (0 <= pixels[:,1]) * (pixels[:,0] < intrinsics.width) * (pixels[:,1] < intrinsics.height) - real_depth_vals = depth_image[pixels[:,1],pixels[:,0]] * valid_pixels + (1 - valid_pixels) * (intrinsics.far + 1.0) - - projected_depth_vals = grid_in_cam_frame[:,2] + valid_pixels = ( + (0 <= pixels[:, 0]) + * (0 <= pixels[:, 1]) + * (pixels[:, 0] < intrinsics.width) + * (pixels[:, 1] < intrinsics.height) + ) + real_depth_vals = depth_image[pixels[:, 1], pixels[:, 0]] * valid_pixels + ( + 1 - valid_pixels + ) * (intrinsics.far + 1.0) + + projected_depth_vals = grid_in_cam_frame[:, 2] occupied = jnp.abs(real_depth_vals - projected_depth_vals) < tolerance occluded = real_depth_vals < projected_depth_vals occluded = occluded * (1.0 - occupied) free = (1.0 - occluded) * (1.0 - occupied) return 1.0 * occupied + 0.5 * occluded + voxel_occupied_occluded_free_jit = jax.jit(voxel_occupied_occluded_free) -voxel_occupied_occluded_free_parallel_camera = jax.jit(jax.vmap(voxel_occupied_occluded_free, in_axes=(0, None, None, None, None))) -voxel_occupied_occluded_free_parallel_camera_depth = jax.jit(jax.vmap(voxel_occupied_occluded_free, in_axes=(0, 0, None, None, None))) +voxel_occupied_occluded_free_parallel_camera = jax.jit( + jax.vmap(voxel_occupied_occluded_free, in_axes=(0, None, None, None, None)) +) +voxel_occupied_occluded_free_parallel_camera_depth = jax.jit( + jax.vmap(voxel_occupied_occluded_free, in_axes=(0, 0, None, None, None)) +) diff --git a/bayes3d/utils/pybullet_sim.py b/bayes3d/utils/pybullet_sim.py index 42e8625c..cbfde663 100644 --- a/bayes3d/utils/pybullet_sim.py +++ b/bayes3d/utils/pybullet_sim.py @@ -1,31 +1,39 @@ -import pybullet as p -import pybullet_data -import numpy as np +import imageio import jax.numpy as jnp +import numpy as np import open3d as o3d +import pybullet as p +import pybullet_data import trimesh as tm +from PIL import Image +from scipy.spatial.transform import Rotation as R + import bayes3d as b import bayes3d.transforms_3d as t3d -import imageio -from scipy.spatial.transform import Rotation as R -from PIL import Image + def o3d_to_pybullet_position(o3d_position): return np.array(o3d_position) + def o3d_to_pybullet_pose(o3d_pose): pybullet_pose = o3d_pose - pybullet_pose[3, :3] = -pybullet_pose[3, :3] # Convert rotation from Open3D to PyBullet + pybullet_pose[3, :3] = -pybullet_pose[ + 3, :3 + ] # Convert rotation from Open3D to PyBullet return pybullet_pose + def pybullet_to_o3d_position(pybullet_position): return pybullet_position + def pybullet_to_o3d_pose(pybullet_pose): o3d_pose = pybullet_pose o3d_pose[3, :3] = -o3d_pose[3, :3] # Convert rotation from PyBullet to Open3D return o3d_pose + def o3d_to_trimesh(mesh): vertices = np.asarray(mesh.vertices) faces = np.asarray(mesh.triangles) @@ -33,12 +41,18 @@ def o3d_to_trimesh(mesh): mesh.compute_vertex_normals() tri_normals = np.asarray(mesh.triangle_normals) vert_normals = np.asarray(mesh.vertex_normals) - mesh = tm.Trimesh(vertices=vertices, faces=faces, vertex_normals=vert_normals, face_normals=tri_normals) + mesh = tm.Trimesh( + vertices=vertices, + faces=faces, + vertex_normals=vert_normals, + face_normals=tri_normals, + ) return mesh -def o3d_render(scene): + +def o3d_render(scene): intrinsics = o3d.camera.PinholeCameraIntrinsic() - renderer = b.o3d_viz.O3DVis(intrinsics=intrinsics) + renderer = b.o3d_viz.O3DVis(intrinsics=intrinsics) for body in scene.bodies.values(): mesh = body.mesh pose = body.pose @@ -47,6 +61,7 @@ def o3d_render(scene): image = renderer.render(camera_pose) return image + def pybullet_render(scene): """ Renders a scene using PyBullet. @@ -57,7 +72,7 @@ def pybullet_render(scene): Returns: PIL.Image.Image: The rendered image. """ - pyb_sim = PybulletSimulator(camera=scene.camera, floor = scene.floor) + pyb_sim = PybulletSimulator(camera=scene.camera, floor=scene.floor) for body in scene.bodies.values(): pyb_sim.add_body_to_simulation(body) image_rgb, depth, _ = pyb_sim.capture_image(scene.camera) @@ -70,16 +85,43 @@ def pybullet_render(scene): depth = Image.fromarray(depth) return image, depth -def create_box(pose, scale = [1,1,1], restitution=1, friction=0, velocity=0, angular_velocity = [0,0,0], id=None): + +def create_box( + pose, + scale=[1, 1, 1], + restitution=1, + friction=0, + velocity=0, + angular_velocity=[0, 0, 0], + id=None, +): """ Creates a box-shaped Body object. """ position = pose[:3, 3] orientation = pose[:3, :3] - return create_box(position, scale, restitution, friction, velocity, angular_velocity, orientation, id) - - -def create_box(position, scale=[1,1,1], restitution=1, friction=0, velocity=0, angular_velocity = [0,0,0], orientation=None, id=None): + return create_box( + position, + scale, + restitution, + friction, + velocity, + angular_velocity, + orientation, + id, + ) + + +def create_box( + position, + scale=[1, 1, 1], + restitution=1, + friction=0, + velocity=0, + angular_velocity=[0, 0, 0], + orientation=None, + id=None, +): """ Creates a box-shaped Body object. """ @@ -89,11 +131,29 @@ def create_box(position, scale=[1,1,1], restitution=1, friction=0, velocity=0, a pose[:3, :3] = orientation if orientation is not None else np.eye(3) pose[:3, 3] = position obj_id = "box" if id is None else id - body = Body(obj_id, pose, mesh, file_dir=path_to_box, restitution=restitution, friction=friction, velocity=velocity, angular_velocity=angular_velocity, scale=scale) + body = Body( + obj_id, + pose, + mesh, + file_dir=path_to_box, + restitution=restitution, + friction=friction, + velocity=velocity, + angular_velocity=angular_velocity, + scale=scale, + ) return body -def create_sphere(position, scale = [1,1,1], velocity=0, angular_velocity = [0,0,0], restitution=1, friction=0, id=None): +def create_sphere( + position, + scale=[1, 1, 1], + velocity=0, + angular_velocity=[0, 0, 0], + restitution=1, + friction=0, + id=None, +): """ Creates a sphere-shaped Body object. @@ -113,11 +173,30 @@ def create_sphere(position, scale = [1,1,1], velocity=0, angular_velocity = [0,0 pose = np.eye(4) pose[:3, 3] = position obj_id = "sphere" if id is None else id - body = Body(obj_id, pose, mesh, file_dir=path_to_sphere, restitution=restitution, friction=friction, velocity=velocity, angular_velocity=angular_velocity, scale = scale) + body = Body( + obj_id, + pose, + mesh, + file_dir=path_to_sphere, + restitution=restitution, + friction=friction, + velocity=velocity, + angular_velocity=angular_velocity, + scale=scale, + ) return body -def make_body_from_obj_pose(obj_path, pose, velocity=0, angular_velocity = [0,0,0], restitution=1, friction=0, id=None, scale = None ): +def make_body_from_obj_pose( + obj_path, + pose, + velocity=0, + angular_velocity=[0, 0, 0], + restitution=1, + friction=0, + id=None, + scale=None, +): """ Creates a Body object from an OBJ file with a given pose. @@ -134,11 +213,31 @@ def make_body_from_obj_pose(obj_path, pose, velocity=0, angular_velocity = [0,0, """ mesh = tm.load(obj_path) obj_id = "obj_mesh" if id is None else id - body = Body(obj_id, pose, mesh, velocity=velocity, angular_velocity = angular_velocity, friction=friction, restitution=restitution, file_dir=obj_path, scale = scale) + body = Body( + obj_id, + pose, + mesh, + velocity=velocity, + angular_velocity=angular_velocity, + friction=friction, + restitution=restitution, + file_dir=obj_path, + scale=scale, + ) return body -def make_body_from_obj(obj_path, position, friction=0, restitution=1, velocity=0, angular_velocity = [0,0,0], orientation=None, id=None, scale=None): +def make_body_from_obj( + obj_path, + position, + friction=0, + restitution=1, + velocity=0, + angular_velocity=[0, 0, 0], + orientation=None, + id=None, + scale=None, +): """ Creates a Body object from an OBJ file with a given position. @@ -157,10 +256,36 @@ def make_body_from_obj(obj_path, position, friction=0, restitution=1, velocity=0 pose = np.eye(4) pose[:3, :3] = orientation if orientation is not None else np.eye(3) pose[:3, 3] = position - return make_body_from_obj_pose(obj_path, pose, id=id, friction=friction, restitution=restitution, velocity=velocity, scale=scale, angular_velocity=angular_velocity) + return make_body_from_obj_pose( + obj_path, + pose, + id=id, + friction=friction, + restitution=restitution, + velocity=velocity, + scale=scale, + angular_velocity=angular_velocity, + ) + class Body: - def __init__(self, object_id, pose, mesh, file_dir = None, restitution=0.8, friction=0, damping=0, transparency=1, velocity=[0,0,0], angular_velocity = [0,0,0],mass=1, texture=None, color=[1, 0, 0], scale=None): + def __init__( + self, + object_id, + pose, + mesh, + file_dir=None, + restitution=0.8, + friction=0, + damping=0, + transparency=1, + velocity=[0, 0, 0], + angular_velocity=[0, 0, 0], + mass=1, + texture=None, + color=[1, 0, 0], + scale=None, + ): self.id = object_id self.pose = pose self.restitution = restitution @@ -200,7 +325,7 @@ def get_transparency(self): def get_velocity(self): return self.velocity - + def get_angular_velocity(self): return self.angular_velocity @@ -218,10 +343,10 @@ def get_position(self): def get_orientation(self): return self.pose[:3, :3] - + def get_scale(self): return self.scale - + # Setter methods def set_id(self, object_id): self.id = object_id @@ -265,18 +390,26 @@ def set_orientation(self, orientation): def set_scale(self, scale): self.scale = scale - # Miscellaneous methods def get_fields(self): return f"Body ID: {self.id}, Pose: {self.pose}, Restitution: {self.restitution}, Friction: {self.friction}, Damping: {self.damping}, Transparency: {self.transparency}, Velocity: {self.velocity}, Texture: {self.texture}, Color: {self.color}" - + def __str__(self): return f"Body ID: {self.id}, Position: {self.get_position()}" - class Scene: - def __init__(self, id = None, bodies=None, camera=None, timestep = 1/60, light=None, gravity = [0,0,0], downsampling = 1, floor = True): + def __init__( + self, + id=None, + bodies=None, + camera=None, + timestep=1 / 60, + light=None, + gravity=[0, 0, 0], + downsampling=1, + floor=True, + ): self.scene_id = id if id is not None else "scene" self.bodies = bodies if bodies is not None else {} self.gravity = gravity @@ -286,11 +419,10 @@ def __init__(self, id = None, bodies=None, camera=None, timestep = 1/60, light=N self.pyb_sim = None self.downsampling = downsampling - def add_body(self, body: Body): self.bodies[body.id] = body return self.bodies - + def add_bodies(self, bodies: list): for body in bodies: self.add_body(body) @@ -302,16 +434,16 @@ def remove_body(self, body_id): else: del self.bodies[body_id] return self.bodies - + def remove_bodies(self, body_ids): for body_id in body_ids: self.remove_body(body_id) - print('removed body: ', body_id) + print("removed body: ", body_id) return self.bodies - + def get_bodies(self): return self.bodies - + def set_floor(self, floor): self.floor = floor return self.floor @@ -321,35 +453,41 @@ def set_camera_position_target(self, position, target): self.camera.position = position self.camera.target = target return self.camera - + def set_camera_pose(self, pose): self.camera.position_target = False self.camera.pose = pose return self.camera def set_light(self, light: Body): - self.light = light + self.light = light return self.light - + def set_gravity(self, gravity): self.gravity = gravity return self.gravity - + def set_downsampling(self, downsampling): self.downsampling = downsampling return self.downsampling - + def set_timestep(self, timestep): self.timestep = timestep return self.timestep - + def render(self, render_func): image = render_func(self) return image - + def simulate(self, timesteps): - # create physics simulator - pyb = PybulletSimulator(timestep=self.timestep, gravity=self.gravity, camera = self.camera, downsampling = self.downsampling, floor = self.floor ) + # create physics simulator + pyb = PybulletSimulator( + timestep=self.timestep, + gravity=self.gravity, + camera=self.camera, + downsampling=self.downsampling, + floor=self.floor, + ) self.pyb_sim = pyb # add bodies to physics simulator @@ -358,11 +496,11 @@ def simulate(self, timesteps): # simulate for timesteps pyb.simulate(timesteps) - # returns pybullet simulation, which you can obtain a gif, poses from. + # returns pybullet simulation, which you can obtain a gif, poses from. return pyb - def close(self): - if self.pyb_sim == None: + def close(self): + if self.pyb_sim == None: raise ValueError("No pybullet simulation to close") else: p.disconnect(self.pyb_sim.client) @@ -370,12 +508,26 @@ def close(self): def __str__(self): body_str = "\n".join([" " + str(body) for body in self.bodies.values()]) return f"Scene ID: {self.scene_id}\nBodies:\n{body_str}" - -# TODO: reduce exposure + +# TODO: reduce exposure class Camera(object): - def __init__(self, position_target=True, position=None, near=0.1, far=100.0, fov=60, width=960, height=720, - up_vector=None, distance=7, yaw=0, pitch=-30, roll=0, intrinsics=None): + def __init__( + self, + position_target=True, + position=None, + near=0.1, + far=100.0, + fov=60, + width=960, + height=720, + up_vector=None, + distance=7, + yaw=0, + pitch=-30, + roll=0, + intrinsics=None, + ): self.position_target = position_target self.position = [0, -5.5, 3] if position is None else position self.target = [0, 0, 0] @@ -383,10 +535,7 @@ def __init__(self, position_target=True, position=None, near=0.1, far=100.0, fov self.near = near self.far = far self.fov = fov - self.fx,self.fy, self.cx,self.cy = ( - 500.0,500.0, - 320.0,240.0 - ) + self.fx, self.fy, self.cx, self.cy = (500.0, 500.0, 320.0, 240.0) self.width = width self.height = height self.up_vector = [0, 0, 1] if up_vector is None else up_vector @@ -398,30 +547,46 @@ def __init__(self, position_target=True, position=None, near=0.1, far=100.0, fov def __str__(self) -> str: return f"Camera: position_target = {self.position_target}, position={self.position}, target={self.target}, pose={self.pose}, near={self.near}, far={self.far}, fov={self.fov}, width={self.width}, height={self.height}, up_vector={self.up_vector}, distance={self.distance}, yaw={self.yaw}, pitch={self.pitch}, roll={self.roll}, intrinsics={self.intrinsics}" - + + class PybulletSimulator(object): - def __init__(self, timestep=1/60, gravity=[0,0,0], floor_restitution=0.5, camera = None, downsampling=1, floor = True): + def __init__( + self, + timestep=1 / 60, + gravity=[0, 0, 0], + floor_restitution=0.5, + camera=None, + downsampling=1, + floor=True, + ): self.timestep = timestep self.gravity = gravity self.client = p.connect(p.DIRECT) self.step_count = 0 - self.frames = [] + self.frames = [] self.depth = [] self.pyb_id_to_body_id = {} self.body_poses = {} self.camera = camera self.downsampling = downsampling - self.floor = floor + self.floor = floor # Set up the simulation environment p.resetSimulation(physicsClientId=self.client) - p.setGravity(self.gravity[0], self.gravity[1], self.gravity[2], physicsClientId=self.client) - p.setPhysicsEngineParameter(fixedTimeStep=self.timestep, physicsClientId=self.client) + p.setGravity( + self.gravity[0], + self.gravity[1], + self.gravity[2], + physicsClientId=self.client, + ) + p.setPhysicsEngineParameter( + fixedTimeStep=self.timestep, physicsClientId=self.client + ) if self.floor: p.setAdditionalSearchPath(pybullet_data.getDataPath()) self.plane_id = p.loadURDF("plane.urdf", physicsClientId=self.client) p.changeDynamics(self.plane_id, -1, restitution=floor_restitution) - + def add_body_to_simulation(self, body): """ Add a body to the pybullet simulation. @@ -433,24 +598,25 @@ def add_body_to_simulation(self, body): obj_file_dir = body.file_dir mesh_scale = body.scale - # Create visual and collision shapes - visualShapeId = p.createVisualShape(shapeType=p.GEOM_MESH, - # vertices=vertices, - # indices=faces, - fileName = obj_file_dir, - meshScale = mesh_scale, - physicsClientId=self.client, - rgbaColor=np.append(body.color, body.transparency), - ) - - collisionShapeId = p.createCollisionShape(shapeType=p.GEOM_MESH, - # vertices=vertices, - # indices=faces, - fileName = obj_file_dir, - meshScale = mesh_scale, - physicsClientId=self.client, - ) + visualShapeId = p.createVisualShape( + shapeType=p.GEOM_MESH, + # vertices=vertices, + # indices=faces, + fileName=obj_file_dir, + meshScale=mesh_scale, + physicsClientId=self.client, + rgbaColor=np.append(body.color, body.transparency), + ) + + collisionShapeId = p.createCollisionShape( + shapeType=p.GEOM_MESH, + # vertices=vertices, + # indices=faces, + fileName=obj_file_dir, + meshScale=mesh_scale, + physicsClientId=self.client, + ) # Get the orientation matrix rot_matrix = body.get_orientation() @@ -460,27 +626,40 @@ def add_body_to_simulation(self, body): quaternion = r.as_quat() # Create a multibody with the created shapes - pyb_id = p.createMultiBody(baseMass=body.mass, - baseCollisionShapeIndex=collisionShapeId, - baseVisualShapeIndex=visualShapeId, - basePosition=body.get_position(), - baseOrientation=quaternion, - physicsClientId=self.client, - ) - + pyb_id = p.createMultiBody( + baseMass=body.mass, + baseCollisionShapeIndex=collisionShapeId, + baseVisualShapeIndex=visualShapeId, + basePosition=body.get_position(), + baseOrientation=quaternion, + physicsClientId=self.client, + ) # Set physical properties - p.changeDynamics(pyb_id, -1, restitution=body.restitution, lateralFriction=body.friction, - linearDamping=body.damping, physicsClientId=self.client) + p.changeDynamics( + pyb_id, + -1, + restitution=body.restitution, + lateralFriction=body.friction, + linearDamping=body.damping, + physicsClientId=self.client, + ) # Set initial velocity if specified - if body.velocity != [0,0,0] or body.angular_velocity != [0,0,0]: - p.resetBaseVelocity(pyb_id, linearVelocity=body.velocity, angularVelocity = body.angular_velocity, physicsClientId=self.client) -# + if body.velocity != [0, 0, 0] or body.angular_velocity != [0, 0, 0]: + p.resetBaseVelocity( + pyb_id, + linearVelocity=body.velocity, + angularVelocity=body.angular_velocity, + physicsClientId=self.client, + ) + # # If texture is specified, load it if body.texture is not None: textureId = p.loadTexture(body.texture, physicsClientId=self.client) - p.changeVisualShape(pyb_id, -1, textureUniqueId=textureId, physicsClientId=self.client) + p.changeVisualShape( + pyb_id, -1, textureUniqueId=textureId, physicsClientId=self.client + ) # Add to mapping from pybullet id to body id self.pyb_id_to_body_id[pyb_id] = body.id @@ -493,47 +672,78 @@ def check_collision(self): print(f"Body {body} is colliding.") def step_simulation(self): - self.step_count+=1 + self.step_count += 1 p.stepSimulation(physicsClientId=self.client) - + def update_body_poses(self): for pyb_id in self.pyb_id_to_body_id.keys(): - position, orientation = p.getBasePositionAndOrientation(pyb_id, physicsClientId=self.client) + position, orientation = p.getBasePositionAndOrientation( + pyb_id, physicsClientId=self.client + ) orientation = p.getRotationMatrixFromQuaternion(orientation) pose = np.eye(4) pose[:3, :3] = orientation pose[:3, 3] = position self.body_poses[self.pyb_id_to_body_id[pyb_id]].append(pose) - - def simulate(self, steps): + + def simulate(self, steps): # returns frames, poses of objects over time for i in range(steps): # if i % self.downsampling == 0: - # rgb, depth, segm = self.capture_image(self.camera) - # self.frames.append(rgb) - # self.update_body_poses() + # rgb, depth, segm = self.capture_image(self.camera) + # self.frames.append(rgb) + # self.update_body_poses() self.step_simulation() self.close() - + def capture_image(self, camera): if camera.position_target: - projMatrix = p.computeProjectionMatrixFOV(fov=camera.fov, aspect=float(camera.width) / camera.height, nearVal=camera.near, farVal=camera.far) - viewMatrix = p.computeViewMatrix(cameraEyePosition=camera.position, cameraTargetPosition=camera.target, cameraUpVector=camera.up_vector) + projMatrix = p.computeProjectionMatrixFOV( + fov=camera.fov, + aspect=float(camera.width) / camera.height, + nearVal=camera.near, + farVal=camera.far, + ) + viewMatrix = p.computeViewMatrix( + cameraEyePosition=camera.position, + cameraTargetPosition=camera.target, + cameraUpVector=camera.up_vector, + ) else: projMatrix = ( - 2*camera.fx/camera.width,0,0,0, - 0,2*camera.fy/camera.height,0,0, - 2*(camera.cx/camera.width)-1,2*(camera.cy/camera.height)-1,-(camera.far+camera.near)/(camera.far-camera.near), - -1,0,0,-2*camera.far*camera.near/(camera.far-camera.near),0 + 2 * camera.fx / camera.width, + 0, + 0, + 0, + 0, + 2 * camera.fy / camera.height, + 0, + 0, + 2 * (camera.cx / camera.width) - 1, + 2 * (camera.cy / camera.height) - 1, + -(camera.far + camera.near) / (camera.far - camera.near), + -1, + 0, + 0, + -2 * camera.far * camera.near / (camera.far - camera.near), + 0, + ) + mat = np.array( + t3d.transform_from_rot( + t3d.rotation_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi) + ) + ) + viewMatrix = tuple( + np.linalg.inv(np.array(camera.pose).dot(mat)).T.reshape(-1) ) - mat = np.array(t3d.transform_from_rot(t3d.rotation_from_axis_angle(jnp.array([1.0, 0.0, 0.0]),jnp.pi))) - viewMatrix = tuple(np.linalg.inv(np.array(camera.pose).dot(mat)).T.reshape(-1)) - _,_, rgb, depth, segmentation = p.getCameraImage(camera.width, camera.height, + _, _, rgb, depth, segmentation = p.getCameraImage( + camera.width, + camera.height, viewMatrix, projMatrix, - renderer=p.ER_BULLET_HARDWARE_OPENGL + renderer=p.ER_BULLET_HARDWARE_OPENGL, ) rgb = np.array(rgb, dtype=np.uint8).reshape((camera.height, camera.width, 4)) @@ -544,27 +754,34 @@ def capture_image(self, camera): depth[depth > camera.far] = 0.0 segmentation = np.array(segmentation).reshape((camera.height, camera.width)) return rgb, depth_buffer, segmentation - + def create_gif(self, path, fps=15): - imageio.mimsave(path, self.frames, duration = (1000 * (1/fps))) - + imageio.mimsave(path, self.frames, duration=(1000 * (1 / fps))) + def set_velocity(self, obj_id): return - + # adjusts the timestep of the simulation def set_timestep(self, dt): self.timestep = dt - p.setPhysicsEngineParameter(fixedTimeStep=self.timestep, physicsClientId=self.client) + p.setPhysicsEngineParameter( + fixedTimeStep=self.timestep, physicsClientId=self.client + ) # adjusts the gravity of the simulation def set_gravity(self, g): self.gravity = g - p.setGravity(self.gravity[0],self.gravity[1], self.gravity[2], physicsClientId=self.client) + p.setGravity( + self.gravity[0], + self.gravity[1], + self.gravity[2], + physicsClientId=self.client, + ) def close(self): p.resetSimulation(physicsClientId=self.client) p.disconnect(self.client) - + # returns a mapping of body_id to poses over time def get_object_poses(self): return self.body_poses diff --git a/bayes3d/utils/r3d_loader.py b/bayes3d/utils/r3d_loader.py index cb5647c8..5873eb4f 100644 --- a/bayes3d/utils/r3d_loader.py +++ b/bayes3d/utils/r3d_loader.py @@ -4,28 +4,21 @@ import glob import json import os -import yaml -from dataclasses import dataclass from pathlib import Path -from typing import List, Tuple, Union import cv2 +import jax.numpy as jnp import liblzfse # https://pypi.org/project/pyliblzfse/ import numpy as np -import png # pip install pypng -import torch -import tyro from natsort import natsorted from PIL import Image from scipy.spatial.transform import Rotation -from tqdm import tqdm, trange -import subprocess import bayes3d as b -import jax.numpy as jnp + def load_depth(filepath): - with open(filepath, 'rb') as depth_fh: + with open(filepath, "rb") as depth_fh: raw_bytes = depth_fh.read() decompressed_bytes = liblzfse.decompress(raw_bytes) depth_img = np.frombuffer(decompressed_bytes, dtype=np.float32) @@ -37,7 +30,7 @@ def load_depth(filepath): def load_conf(filepath): - with open(filepath, 'rb') as conf_fh: + with open(filepath, "rb") as conf_fh: raw_bytes = conf_fh.read() decompressed_bytes = liblzfse.decompress(raw_bytes) conf_img = np.frombuffer(decompressed_bytes, dtype=np.uint8) @@ -63,9 +56,11 @@ def write_depth(outpath, depth): depth = Image.fromarray(depth) depth.save(outpath) + def write_conf(outpath, conf): np.save(outpath, conf) + def write_pose(outpath, pose): np.save(outpath, pose.astype(np.float32)) @@ -124,11 +119,15 @@ def get_intrinsics(metadata_dict: dict): return intrinsics, intrinsics_depth + def load_r3d(r3d_path): r3d_path = Path(r3d_path) import subprocess + subprocess.run([f"cp {r3d_path} /tmp/{r3d_path.name}.zip"], shell=True) - subprocess.run([f"unzip -qq -o /tmp/{r3d_path.name}.zip -d /tmp/{r3d_path.name}"], shell=True) + subprocess.run( + [f"unzip -qq -o /tmp/{r3d_path.name}.zip -d /tmp/{r3d_path.name}"], shell=True + ) datapath = f"/tmp/{r3d_path.name}" f = open(os.path.join(datapath, "metadata"), "r") @@ -145,14 +144,13 @@ def load_r3d(r3d_path): depths[np.isnan(depths)] = 0.0 poses = get_poses(metadata) - P = np.array( - [ - [1, 0, 0, 0], - [0, -1, 0, 0], - [0, 0, -1, 0], - [0, 0, 0, 1] - ] + P = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + poses = P @ poses @ P.T + + return ( + jnp.array(colors), + jnp.array(depths), + jnp.array(poses), + intrinsics, + intrinsics_depth, ) - poses = (P @ poses @ P.T) - - return jnp.array(colors), jnp.array(depths), jnp.array(poses), intrinsics, intrinsics_depth \ No newline at end of file diff --git a/bayes3d/utils/utils.py b/bayes3d/utils/utils.py index d0928e6a..01cccfd8 100644 --- a/bayes3d/utils/utils.py +++ b/bayes3d/utils/utils.py @@ -1,43 +1,51 @@ -import jax.numpy as jnp -import numpy as np +import inspect +import os +import subprocess as sp +import time +from pathlib import Path from typing import Tuple -import jax + import cv2 -import bayes3d.transforms_3d as t3d -import bayes3d as b -import os +import jax +import jax.numpy as jnp +import numpy as np import pyransac3d import sklearn.cluster -from jax.scipy.special import logsumexp -import time -import subprocess as sp -import os -import inspect import torch -from pathlib import Path +from jax.scipy.special import logsumexp + +import bayes3d as b +import bayes3d.transforms_3d as t3d + def video_to_images(video_path, image_directory): import cv2 + vidcap = cv2.VideoCapture(str(video_path)) - success,image = vidcap.read() + success, image = vidcap.read() count = 0 while success: - cv2.imwrite(str(Path(image_directory) / Path(f"frame_{count:05}.jpg")), image) # save frame as JPEG file - success,image = vidcap.read() - print('Read a new frame: ', success) + cv2.imwrite( + str(Path(image_directory) / Path(f"frame_{count:05}.jpg")), image + ) # save frame as JPEG file + success, image = vidcap.read() + print("Read a new frame: ", success) count += 1 def make_onehot(n, i, hot=1, cold=0): return tuple(cold if j != i else hot for j in range(n)) + def multivmap(f, args=None): if args is None: args = (True,) * len(inspect.signature(f).parameters) multivmapped = f - for (i, ismapped) in reversed(list(enumerate(args))): + for i, ismapped in reversed(list(enumerate(args))): if ismapped: - multivmapped = jax.vmap(multivmapped, in_axes=make_onehot(len(args), i, hot=0, cold=None)) + multivmapped = jax.vmap( + multivmapped, in_axes=make_onehot(len(args), i, hot=0, cold=None) + ) return multivmapped @@ -46,10 +54,14 @@ def time_code_block(func, args): output = func(*args) print(output[0]) end = time.time() - print ("Time elapsed:", end - start) + print("Time elapsed:", end - start) + def get_assets_dir(): - return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),"assets") + return os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "assets" + ) + def make_cube_point_cloud(side_width, num_points): """ @@ -61,31 +73,33 @@ def make_cube_point_cloud(side_width, num_points): object_model_cloud: (N,3) point cloud of the cube """ side_half_width = side_width / 2.0 - single_side = np.stack(np.meshgrid( - np.linspace(-side_half_width, side_half_width, num_points), - np.linspace(-side_half_width, side_half_width, num_points), - np.linspace(0.0, 0.0, num_points) - ), - axis=-1 - ).reshape(-1,3) + single_side = np.stack( + np.meshgrid( + np.linspace(-side_half_width, side_half_width, num_points), + np.linspace(-side_half_width, side_half_width, num_points), + np.linspace(0.0, 0.0, num_points), + ), + axis=-1, + ).reshape(-1, 3) all_faces = [] - for a in [0,1,2]: - for side in [-1.0, 1.0]: + for a in [0, 1, 2]: + for side in [-1.0, 1.0]: perm = np.arange(3) perm[a] = 2 perm[2] = a - face = single_side[:,perm] - face[:,a] = side * side_half_width + face = single_side[:, perm] + face[:, a] = side * side_half_width all_faces.append(face) object_model_cloud = np.vstack(all_faces) return jnp.array(object_model_cloud) + def discretize(data, resolution): """ - Discretizes a point cloud. + Discretizes a point cloud. """ - return jnp.round(data /resolution) * resolution + return jnp.round(data / resolution) * resolution def voxelize(data, resolution): @@ -101,6 +115,7 @@ def voxelize(data, resolution): data = jnp.unique(data, axis=0) return data + def aabb(object_points): """ Returns the axis aligned bounding box of a point cloud. @@ -110,32 +125,36 @@ def aabb(object_points): dims: (3,) dimensions of the bounding box pose: (4,4) pose of the bounding box """ - maxs = jnp.max(object_points,axis=0) - mins = jnp.min(object_points,axis=0) - dims = (maxs - mins) + maxs = jnp.max(object_points, axis=0) + mins = jnp.min(object_points, axis=0) + dims = maxs - mins center = (maxs + mins) / 2 return dims, t3d.transform_from_pos(center) + def bounding_box_corners(dim): """ Returns the corners of an axis aligned bounding box. Args: dim: (3,) dimensions of the bounding box Returns: - corners: (8,3) corners of the bounding box + corners: (8,3) corners of the bounding box """ - corners = np.array([ - [-dim[0]/2, -dim[1]/2, -dim[2]/2], - [dim[0]/2, -dim[1]/2, -dim[2]/2], - [-dim[0]/2, dim[1]/2, -dim[2]/2], - [dim[0]/2, dim[1]/2, -dim[2]/2], - [-dim[0]/2, -dim[1]/2, dim[2]/2], - [dim[0]/2, -dim[1]/2, dim[2]/2], - [-dim[0]/2, dim[1]/2, dim[2]/2], - [dim[0]/2, dim[1]/2, dim[2]/2] - ]) + corners = np.array( + [ + [-dim[0] / 2, -dim[1] / 2, -dim[2] / 2], + [dim[0] / 2, -dim[1] / 2, -dim[2] / 2], + [-dim[0] / 2, dim[1] / 2, -dim[2] / 2], + [dim[0] / 2, dim[1] / 2, -dim[2] / 2], + [-dim[0] / 2, -dim[1] / 2, dim[2] / 2], + [dim[0] / 2, -dim[1] / 2, dim[2] / 2], + [-dim[0] / 2, dim[1] / 2, dim[2] / 2], + [dim[0] / 2, dim[1] / 2, dim[2] / 2], + ] + ) return corners + def plane_eq_to_plane_pose(plane_eq): """ Returns the pose of a plane from its equation. @@ -153,32 +172,46 @@ def plane_eq_to_plane_pose(plane_eq): plane_pose = b.t3d.transform_from_rot_and_pos(R, point_on_plane) return plane_pose -def find_plane(point_cloud, threshold, minPoints=100, maxIteration=1000): + +def find_plane(point_cloud, threshold, minPoints=100, maxIteration=1000): """ Returns the pose of a plane from a point cloud. """ plane = pyransac3d.Plane() - plane_eq, _ = plane.fit(point_cloud, threshold, minPoints=minPoints, maxIteration=maxIteration) + plane_eq, _ = plane.fit( + point_cloud, threshold, minPoints=minPoints, maxIteration=maxIteration + ) plane_pose = plane_eq_to_plane_pose(plane_eq) return plane_pose + def get_bounding_box_z_axis_aligned(point_cloud): """ Returns the axis aligned bounding box of a point cloud. """ dims, pose = aabb(point_cloud) point_cloud_centered = t3d.apply_transform(point_cloud, t3d.inverse_pose(pose)) - - (cx,cy), (width,height), rotation_deg = cv2.minAreaRect(np.array(point_cloud_centered[:,:2])) + + (cx, cy), (width, height), rotation_deg = cv2.minAreaRect( + np.array(point_cloud_centered[:, :2]) + ) pose_shift = t3d.transform_from_rot_and_pos( - t3d.rotation_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), jnp.deg2rad(rotation_deg)), - jnp.array([cx,cy, 0.0]) + t3d.rotation_from_axis_angle( + jnp.array([0.0, 0.0, 1.0]), jnp.deg2rad(rotation_deg) + ), + jnp.array([cx, cy, 0.0]), ) new_pose = pose @ pose_shift - dims, _ = aabb( t3d.apply_transform(point_cloud, t3d.inverse_pose(new_pose))) + dims, _ = aabb(t3d.apply_transform(point_cloud, t3d.inverse_pose(new_pose))) return dims, new_pose -def find_plane_and_dims(point_cloud, ransac_threshold=0.001, inlier_threshold=0.002, segmentation_threshold=0.008): + +def find_plane_and_dims( + point_cloud, + ransac_threshold=0.001, + inlier_threshold=0.002, + segmentation_threshold=0.008, +): """ Returns the pose of a plane from a point cloud. Args: @@ -189,55 +222,75 @@ def find_plane_and_dims(point_cloud, ransac_threshold=0.001, inlier_threshold=0. """ plane_pose = find_plane(np.array(point_cloud), ransac_threshold) points_in_plane_frame = t3d.apply_transform(point_cloud, jnp.linalg.inv(plane_pose)) - inliers = (jnp.abs(points_in_plane_frame[:,2]) < inlier_threshold) + inliers = jnp.abs(points_in_plane_frame[:, 2]) < inlier_threshold inlier_plane_points = points_in_plane_frame[inliers] - inlier_table_points_seg = segment_point_cloud(inlier_plane_points, segmentation_threshold) + inlier_table_points_seg = segment_point_cloud( + inlier_plane_points, segmentation_threshold + ) + + most_frequent_seg_id = get_largest_cluster_id_from_segmentation( + inlier_table_points_seg + ) - most_frequent_seg_id = get_largest_cluster_id_from_segmentation(inlier_table_points_seg) - - table_points_in_plane_frame = inlier_plane_points[inlier_table_points_seg == most_frequent_seg_id] + table_points_in_plane_frame = inlier_plane_points[ + inlier_table_points_seg == most_frequent_seg_id + ] - (cx,cy), (width,height), rotation_deg = cv2.minAreaRect(np.array(table_points_in_plane_frame[:,:2])) + (cx, cy), (width, height), rotation_deg = cv2.minAreaRect( + np.array(table_points_in_plane_frame[:, :2]) + ) pose_shift = t3d.transform_from_rot_and_pos( - t3d.rotation_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), jnp.deg2rad(rotation_deg)), - jnp.array([cx,cy, 0.0]) + t3d.rotation_from_axis_angle( + jnp.array([0.0, 0.0, 1.0]), jnp.deg2rad(rotation_deg) + ), + jnp.array([cx, cy, 0.0]), ) table_pose = plane_pose.dot(pose_shift) table_dims = jnp.array([width, height, 1e-10]) return table_pose, table_dims + def segment_point_cloud(point_cloud, threshold=0.01, min_points_in_cluster=0): c = sklearn.cluster.DBSCAN(eps=threshold).fit(point_cloud) labels = c.labels_ - unique, counts = np.unique(labels, return_counts=True) + unique, counts = np.unique(labels, return_counts=True) for val in unique[counts < min_points_in_cluster]: labels[labels == val] = -1 return labels -def segment_point_cloud_image(point_cloud_image, threshold=0.01, min_points_in_cluster=0): - point_cloud = point_cloud_image.reshape(-1,3) - non_zero = point_cloud[:,2] > 0.0 + +def segment_point_cloud_image( + point_cloud_image, threshold=0.01, min_points_in_cluster=0 +): + point_cloud = point_cloud_image.reshape(-1, 3) + non_zero = point_cloud[:, 2] > 0.0 segmentation_img = np.ones(point_cloud.shape[0]) * -1.0 if non_zero.sum() == 0: return segmentation_img non_zero_indices = np.where(non_zero)[0] - segmentation = segment_point_cloud(point_cloud[non_zero_indices,:], threshold=threshold, min_points_in_cluster=min_points_in_cluster) - unique, counts = np.unique(segmentation, return_counts=True) - for (i,val) in enumerate(unique[unique != -1]): + segmentation = segment_point_cloud( + point_cloud[non_zero_indices, :], + threshold=threshold, + min_points_in_cluster=min_points_in_cluster, + ) + unique, counts = np.unique(segmentation, return_counts=True) + for i, val in enumerate(unique[unique != -1]): segmentation_img[non_zero_indices[segmentation == val]] = i segmentation_img = segmentation_img.reshape(point_cloud_image.shape[:2]) return segmentation_img + def get_largest_cluster_id_from_segmentation(segmentation_array_or_img): """ Returns the id of the largest cluster in a segmentation. """ - unique, counts = jnp.unique(segmentation_array_or_img, return_counts=True) - non_neg_one = (unique != -1) + unique, counts = jnp.unique(segmentation_array_or_img, return_counts=True) + non_neg_one = unique != -1 unique = unique[non_neg_one] counts = counts[non_neg_one] return unique[counts.argmax()] + def normalize_log_scores(log_p): """ Normalizes log scores. @@ -248,6 +301,7 @@ def normalize_log_scores(log_p): """ return jnp.exp(log_p - logsumexp(log_p)) + def resize(depth, h, w): """ Resizes a depth image. @@ -258,7 +312,10 @@ def resize(depth, h, w): Returns: depth: (h,w) resized depth image """ - return cv2.resize(np.asarray(depth, dtype=depth.dtype), (w,h),interpolation=0).astype(depth.dtype) + return cv2.resize( + np.asarray(depth, dtype=depth.dtype), (w, h), interpolation=0 + ).astype(depth.dtype) + def scale(depth, h, w): """ @@ -270,38 +327,63 @@ def scale(depth, h, w): """ return resize(depth, h, w) -def infer_table_plane(point_cloud_image, camera_pose, intrinsics, ransac_threshold=0.001, inlier_threshold=0.002, segmentation_threshold=0.008): + +def infer_table_plane( + point_cloud_image, + camera_pose, + intrinsics, + ransac_threshold=0.001, + inlier_threshold=0.002, + segmentation_threshold=0.008, +): point_cloud_flat = point_cloud_image.reshape(-1, 3) - point_cloud_flat_not_far = point_cloud_flat[point_cloud_flat[:,2] < intrinsics.far, :] + point_cloud_flat_not_far = point_cloud_flat[ + point_cloud_flat[:, 2] < intrinsics.far, : + ] table_pose, table_dims = find_plane_and_dims( - t3d.apply_transform(point_cloud_flat_not_far, camera_pose), - ransac_threshold=ransac_threshold, inlier_threshold=inlier_threshold, segmentation_threshold=segmentation_threshold + t3d.apply_transform(point_cloud_flat_not_far, camera_pose), + ransac_threshold=ransac_threshold, + inlier_threshold=inlier_threshold, + segmentation_threshold=segmentation_threshold, ) table_pose_in_cam_frame = t3d.inverse_pose(camera_pose) @ table_pose - if table_pose_in_cam_frame[2,2] > 0: - table_pose = table_pose @ t3d.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi) + if table_pose_in_cam_frame[2, 2] > 0: + table_pose = table_pose @ t3d.transform_from_axis_angle( + jnp.array([1.0, 0.0, 0.0]), jnp.pi + ) return table_pose, table_dims + def get_gpu_memory(): command = "nvidia-smi --query-gpu=memory.free --format=csv" - memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:] + memory_free_info = ( + sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:] + ) memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] return memory_free_values + def make_schedules_contact_params(grid_widths, rotation_angle_widths, grid_params): sched = [] - for (grid_width, angle_width, grid_param) in zip(grid_widths, rotation_angle_widths, grid_params): + for grid_width, angle_width, grid_param in zip( + grid_widths, rotation_angle_widths, grid_params + ): cf = b.scene_graph.enumerate_contact_and_face_parameters( - -grid_width, -grid_width, -angle_width, - +grid_width, +grid_width, angle_width, + -grid_width, + -grid_width, + -angle_width, + +grid_width, + +grid_width, + angle_width, *grid_param, # *grid_param is num_x, num_y, num_angle - jnp.arange(6) + jnp.arange(6), ) sched.append(cf) return sched + def extract_2d_patches(data: jnp.ndarray, filter_shape: Tuple[int, int]) -> jnp.ndarray: """For each pixel, extract 2D patches centered at that pixel. Args: @@ -339,8 +421,10 @@ def extract_2d_patches(data: jnp.ndarray, filter_shape: Tuple[int, int]) -> jnp. ) return extracted_patches + def jax_to_torch(jax_array): return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(jax_array)) + def torch_to_jax(torch_array): - return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(torch_array)) \ No newline at end of file + return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(torch_array)) diff --git a/bayes3d/utils/ycb_loader.py b/bayes3d/utils/ycb_loader.py index b50fbbf7..46480432 100644 --- a/bayes3d/utils/ycb_loader.py +++ b/bayes3d/utils/ycb_loader.py @@ -1,124 +1,128 @@ -import sys +import json import os -sys.path.append(os.getcwd()) - -from dataclasses import dataclass from glob import glob -from PIL import Image -import json -import bayes3d as j import jax.numpy as jnp import numpy as np -import pickle - -MODEL_NAMES = ["002_master_chef_can" -,"003_cracker_box" -,"004_sugar_box" -,"005_tomato_soup_can" -,"006_mustard_bottle" -,"007_tuna_fish_can" -,"008_pudding_box" -,"009_gelatin_box" -,"010_potted_meat_can" -,"011_banana" -,"019_pitcher_base" -,"021_bleach_cleanser" -,"024_bowl" -,"025_mug" -,"035_power_drill" -,"036_wood_block" -,"037_scissors" -,"040_large_marker" -,"051_large_clamp" -,"052_extra_large_clamp" -,"061_foam_brick"] +from PIL import Image + +import bayes3d as j + +MODEL_NAMES = [ + "002_master_chef_can", + "003_cracker_box", + "004_sugar_box", + "005_tomato_soup_can", + "006_mustard_bottle", + "007_tuna_fish_can", + "008_pudding_box", + "009_gelatin_box", + "010_potted_meat_can", + "011_banana", + "019_pitcher_base", + "021_bleach_cleanser", + "024_bowl", + "025_mug", + "035_power_drill", + "036_wood_block", + "037_scissors", + "040_large_marker", + "051_large_clamp", + "052_extra_large_clamp", + "061_foam_brick", +] + def remove_zero_pad(img_id): for i, ch in enumerate(img_id): - if ch != '0': + if ch != "0": return img_id[i:] def get_test_img(scene_id, img_id, ycb_dir): if len(scene_id) < 6: - scene_id = scene_id.rjust(6, '0') + scene_id = scene_id.rjust(6, "0") if len(img_id) < 6: - img_id = img_id.rjust(6, '0') + img_id = img_id.rjust(6, "0") data_dir = os.path.join(ycb_dir, "test") - scene_data_dir = os.path.join(data_dir, scene_id) # depth, mask, mask_visib, rgb; scene_camera.json, scene_gt_info.json, scene_gt.json + scene_data_dir = os.path.join( + data_dir, scene_id + ) # depth, mask, mask_visib, rgb; scene_camera.json, scene_gt_info.json, scene_gt.json - scene_rgb_images_dir = os.path.join(scene_data_dir, 'rgb') - scene_depth_images_dir = os.path.join(scene_data_dir, 'depth') - mask_visib_dir = os.path.join(scene_data_dir, 'mask_visib') + scene_rgb_images_dir = os.path.join(scene_data_dir, "rgb") + scene_depth_images_dir = os.path.join(scene_data_dir, "depth") + mask_visib_dir = os.path.join(scene_data_dir, "mask_visib") with open(os.path.join(scene_data_dir, "scene_camera.json")) as scene_cam_data_json: scene_cam_data = json.load(scene_cam_data_json) with open(os.path.join(scene_data_dir, "scene_gt.json")) as scene_imgs_gt_data_json: scene_imgs_gt_data = json.load(scene_imgs_gt_data_json) - + # get rgb image rgb = jnp.array(Image.open(os.path.join(scene_rgb_images_dir, f"{img_id}.png"))) # get depth image depth = jnp.array(Image.open(os.path.join(scene_depth_images_dir, f"{img_id}.png"))) - + # get camera intrinsics and pose for image image_cam_data = scene_cam_data[remove_zero_pad(img_id)] - cam_K = jnp.array(image_cam_data['cam_K']).reshape(3,3) - cam_R_w2c = jnp.array(image_cam_data['cam_R_w2c']).reshape(3,3) - cam_t_w2c = jnp.array(image_cam_data['cam_t_w2c']).reshape(3,1) - cam_pose_w2c = jnp.vstack([jnp.hstack([cam_R_w2c, cam_t_w2c]), jnp.array([0,0,0,1])]) + cam_K = jnp.array(image_cam_data["cam_K"]).reshape(3, 3) + cam_R_w2c = jnp.array(image_cam_data["cam_R_w2c"]).reshape(3, 3) + cam_t_w2c = jnp.array(image_cam_data["cam_t_w2c"]).reshape(3, 1) + cam_pose_w2c = jnp.vstack( + [jnp.hstack([cam_R_w2c, cam_t_w2c]), jnp.array([0, 0, 0, 1])] + ) cam_pose = jnp.linalg.inv(cam_pose_w2c) - cam_depth_scale = image_cam_data['depth_scale'] + cam_depth_scale = image_cam_data["depth_scale"] # get {visible mask, ID, pose} for each object in the scene anno = dict() # get GT object model ID+poses objects_gt_data = scene_imgs_gt_data[remove_zero_pad(img_id)] - mask_visib_image_paths = sorted(glob(os.path.join(mask_visib_dir, f"{img_id}_*.png"))) + mask_visib_image_paths = sorted( + glob(os.path.join(mask_visib_dir, f"{img_id}_*.png")) + ) gt_ids = [] - anno = [] gt_poses = [] masks = [] - for object_gt_data, mask_visib_image_path in zip(objects_gt_data, mask_visib_image_paths): + for object_gt_data, mask_visib_image_path in zip( + objects_gt_data, mask_visib_image_paths + ): mask_visible = jnp.array(Image.open(mask_visib_image_path)) - model_R = jnp.array(object_gt_data['cam_R_m2c']).reshape(3,3) - model_t = jnp.array(object_gt_data['cam_t_m2c']).reshape(3,1) - model_pose = jnp.vstack([jnp.hstack([model_R, model_t]), jnp.array([0,0,0,1])]) - model_pose = model_pose.at[:3,3].set(model_pose[:3,3]*1.0/1000.0) + model_R = jnp.array(object_gt_data["cam_R_m2c"]).reshape(3, 3) + model_t = jnp.array(object_gt_data["cam_t_m2c"]).reshape(3, 1) + model_pose = jnp.vstack( + [jnp.hstack([model_R, model_t]), jnp.array([0, 0, 0, 1])] + ) + model_pose = model_pose.at[:3, 3].set(model_pose[:3, 3] * 1.0 / 1000.0) gt_poses.append(model_pose) - - obj_id = object_gt_data['obj_id'] - 1 + + obj_id = object_gt_data["obj_id"] - 1 gt_ids.append(obj_id) masks.append(jnp.array(mask_visible > 0)) - cam_pose = cam_pose.at[:3,3].set(cam_pose[:3,3]*1.0/1000.0) + cam_pose = cam_pose.at[:3, 3].set(cam_pose[:3, 3] * 1.0 / 1000.0) cam_K = np.array(cam_K) intrinsics = j.Intrinsics( rgb.shape[0], rgb.shape[1], - cam_K[0,0], - cam_K[1,1], - cam_K[0,2], - cam_K[1,2], + cam_K[0, 0], + cam_K[1, 1], + cam_K[0, 2], + cam_K[1, 2], 0.01, - 2.0 + 2.0, ) return ( - j.RGBD( - rgb, - depth * cam_depth_scale / 1000.0, - cam_pose, - intrinsics - ), - jnp.array(gt_ids), jnp.array(gt_poses), masks + j.RGBD(rgb, depth * cam_depth_scale / 1000.0, cam_pose, intrinsics), + jnp.array(gt_ids), + jnp.array(gt_poses), + masks, ) - diff --git a/bayes3d/viz/meshcatviz.py b/bayes3d/viz/meshcatviz.py index 12211635..da0b6d78 100644 --- a/bayes3d/viz/meshcatviz.py +++ b/bayes3d/viz/meshcatviz.py @@ -1,10 +1,10 @@ +import jax.numpy as jnp import meshcat -import numpy as np import meshcat.geometry as g +import numpy as np from matplotlib.colors import rgb2hex -import bayes3d.transforms_3d as t3d -import jax.numpy as jnp +import bayes3d.transforms_3d as t3d RED = np.array([1.0, 0.0, 0.0]) GREEN = np.array([0.0, 1.0, 0.0]) @@ -12,25 +12,31 @@ VISUALIZER = None + def setup_visualizer(): global VISUALIZER VISUALIZER = meshcat.Visualizer() set_background_color([1, 1, 1]) + def get_visualizer(): global VISUALIZER return VISUALIZER + def set_background_color(color): VISUALIZER["/Background"].set_property("top_color", color) VISUALIZER["/Background"].set_property("bottom_color", color) + def clear_visualizer(): global VISUALIZER VISUALIZER.delete() + def set_pose(channel, pose): - VISUALIZER[channel].set_transform(np.array(pose,dtype=np.float64)) + VISUALIZER[channel].set_transform(np.array(pose, dtype=np.float64)) + def show_cloud(channel, point_cloud, color=None, size=0.01): global VISUALIZER @@ -40,7 +46,7 @@ def show_cloud(channel, point_cloud, color=None, size=0.01): if color is None: color = np.zeros_like(point_cloud) elif len(color.shape) == 1: - color = np.tile(color.reshape(-1,1), (1,point_cloud.shape[1])) + color = np.tile(color.reshape(-1, 1), (1, point_cloud.shape[1])) color = np.array(color) obj = g.PointCloud(point_cloud, color, size=size) VISUALIZER[channel].set_object(obj) @@ -50,27 +56,26 @@ def show_trimesh(channel, mesh, color=None, wireframe=False, opacity=1.0): global VISUALIZER if color is None: color = [1, 0, 0] - material = g.MeshLambertMaterial(color=int(rgb2hex(color)[1:],16), wireframe=wireframe, opacity=opacity) + material = g.MeshLambertMaterial( + color=int(rgb2hex(color)[1:], 16), wireframe=wireframe, opacity=opacity + ) obj = g.TriangularMeshGeometry(mesh.vertices, mesh.faces) VISUALIZER[channel].set_object(obj, material) def show_pose(channel, pose, size=0.1): global VISUALIZER - pose_x = t3d.transform_from_pos(jnp.array([size/2.0, 0.0, 0.0])) - objx = g.Box(np.array([size, size/10.0, size/10.0])) - matx = g.MeshLambertMaterial(color=0xf41515, - reflectivity=0.8) - - pose_y = t3d.transform_from_pos(jnp.array([0.0, size/2.0, 0.0])) - objy = g.Box(np.array([size/10.0, size, size/10.0])) - maty = g.MeshLambertMaterial(color=0x40ec00, - reflectivity=0.8) - - pose_z = t3d.transform_from_pos(jnp.array([0.0, 0.0, size/2.0])) - objz = g.Box(np.array([size/10.0, size/10.0, size])) - matz = g.MeshLambertMaterial(color=0x0b5cfc, - reflectivity=0.8) + pose_x = t3d.transform_from_pos(jnp.array([size / 2.0, 0.0, 0.0])) + objx = g.Box(np.array([size, size / 10.0, size / 10.0])) + matx = g.MeshLambertMaterial(color=0xF41515, reflectivity=0.8) + + pose_y = t3d.transform_from_pos(jnp.array([0.0, size / 2.0, 0.0])) + objy = g.Box(np.array([size / 10.0, size, size / 10.0])) + maty = g.MeshLambertMaterial(color=0x40EC00, reflectivity=0.8) + + pose_z = t3d.transform_from_pos(jnp.array([0.0, 0.0, size / 2.0])) + objz = g.Box(np.array([size / 10.0, size / 10.0, size])) + matz = g.MeshLambertMaterial(color=0x0B5CFC, reflectivity=0.8) VISUALIZER[channel]["x"].set_object(objx, matx) VISUALIZER[channel]["x"].set_transform(np.array(pose @ pose_x, dtype=np.float64)) diff --git a/bayes3d/viz/open3dviz.py b/bayes3d/viz/open3dviz.py index e5b1ed52..273c8f8c 100644 --- a/bayes3d/viz/open3dviz.py +++ b/bayes3d/viz/open3dviz.py @@ -1,79 +1,94 @@ -import open3d as o3d +import jax.numpy as jnp import numpy as np -import bayes3d as j +import open3d as o3d + import bayes3d as b -import jax.numpy as jnp +import bayes3d as j + def trimesh_to_o3d_triangle_mesh(trimesh_mesh): mesh = o3d.geometry.TriangleMesh() - mesh.vertices = o3d.utility.Vector3dVector(trimesh_mesh.vertices) + mesh.vertices = o3d.utility.Vector3dVector(trimesh_mesh.vertices) mesh.triangles = o3d.utility.Vector3iVector(trimesh_mesh.faces) - mesh.triangle_normals = o3d.utility.Vector3dVector(np.array(trimesh_mesh.face_normals)) + mesh.triangle_normals = o3d.utility.Vector3dVector( + np.array(trimesh_mesh.face_normals) + ) return mesh + class Open3DVisualizer(object): def __init__(self, intrinsics): - self.render = o3d.visualization.rendering.OffscreenRenderer(intrinsics.width, intrinsics.height) + self.render = o3d.visualization.rendering.OffscreenRenderer( + intrinsics.width, intrinsics.height + ) # self.set_background(np.array([0.0, 0.0, 0.0, 0.0])) self.render.scene.set_background(np.array([1.0, 1.0, 1.0, 1.0])) - self.render.scene.set_lighting(self.render.scene.LightingProfile.NO_SHADOWS, (0, 0, 0)) + self.render.scene.set_lighting( + self.render.scene.LightingProfile.NO_SHADOWS, (0, 0, 0) + ) self.counter = 0 def set_background(self, background): self.render.scene.set_background(background) - def make_bounding_box(self, dims, pose, color=None, update=True): line_set = o3d.geometry.LineSet() if color is None: - color = j.RED - - points = np.zeros((9,3)) - points[0, :] = np.array([dims[0]/2, -dims[1]/2, dims[2]/2] ) - points[1, :] = np.array([-dims[0]/2, -dims[1]/2, dims[2]/2]) - points[2, :] = np.array([-dims[0]/2, dims[1]/2, dims[2]/2]) - points[3, :] = np.array([dims[0]/2, dims[1]/2, dims[2]/2]) - points[4, :] = np.array([dims[0]/2, -dims[1]/2, -dims[2]/2]) - points[5, :] = np.array([-dims[0]/2, -dims[1]/2, -dims[2]/2]) - points[6, :] = np.array([-dims[0]/2, dims[1]/2, -dims[2]/2]) - points[7, :] = np.array([dims[0]/2, dims[1]/2, -dims[2]/2]) + color = j.RED + + points = np.zeros((9, 3)) + points[0, :] = np.array([dims[0] / 2, -dims[1] / 2, dims[2] / 2]) + points[1, :] = np.array([-dims[0] / 2, -dims[1] / 2, dims[2] / 2]) + points[2, :] = np.array([-dims[0] / 2, dims[1] / 2, dims[2] / 2]) + points[3, :] = np.array([dims[0] / 2, dims[1] / 2, dims[2] / 2]) + points[4, :] = np.array([dims[0] / 2, -dims[1] / 2, -dims[2] / 2]) + points[5, :] = np.array([-dims[0] / 2, -dims[1] / 2, -dims[2] / 2]) + points[6, :] = np.array([-dims[0] / 2, dims[1] / 2, -dims[2] / 2]) + points[7, :] = np.array([dims[0] / 2, dims[1] / 2, -dims[2] / 2]) points[8, :] = np.array([0.0, 0.0, 0.0]) new_points = j.t3d.apply_transform(points, pose) - lines = np.array([ - [1,2], - [2,3], - [3,4], - [4,1], - [5,6], - [6,7], - [7,8], - [8,5], - [1,5], - [2,6], - [3,7], - [4,8] - ]) - 1 - - line_set.points = o3d.utility.Vector3dVector(new_points) + lines = ( + np.array( + [ + [1, 2], + [2, 3], + [3, 4], + [4, 1], + [5, 6], + [6, 7], + [7, 8], + [8, 5], + [1, 5], + [2, 6], + [3, 7], + [4, 8], + ] + ) + - 1 + ) + + line_set.points = o3d.utility.Vector3dVector(new_points) line_set.lines = o3d.utility.Vector2iVector(lines) line_set.paint_uniform_color(color) - mtl = o3d.visualization.rendering.MaterialRecord() # or MaterialRecord(), for later versions of Open3D + mtl = ( + o3d.visualization.rendering.MaterialRecord() + ) # or MaterialRecord(), for later versions of Open3D mtl.shader = "defaultUnlit" self.render.scene.add_geometry(f"{self.counter}", line_set, mtl) - self.counter+=1 + self.counter += 1 return line_set def make_cloud(self, cloud, color=None, update=True): if color is None: color = j.BLUE - + if color.shape[0] != cloud.shape[0]: - colors = np.tile(color, (cloud.shape[0],1)) + colors = np.tile(color, (cloud.shape[0], 1)) else: colors = color @@ -81,11 +96,13 @@ def make_cloud(self, cloud, color=None, update=True): pcd.points = o3d.utility.Vector3dVector(cloud) pcd.colors = o3d.utility.Vector3dVector(colors) - mtl = o3d.visualization.rendering.MaterialRecord() # or MaterialRecord(), for later versions of Open3D + mtl = ( + o3d.visualization.rendering.MaterialRecord() + ) # or MaterialRecord(), for later versions of Open3D mtl.shader = "defaultUnlit" self.render.scene.add_geometry(f"{self.counter}", pcd, mtl) - self.counter+=1 + self.counter += 1 return pcd def make_mesh_from_file(self, filename, pose, scaling_factor=1.0): @@ -93,8 +110,8 @@ def make_mesh_from_file(self, filename, pose, scaling_factor=1.0): mesh.meshes[0].mesh.scale(scaling_factor, np.array([0.0, 0.0, 0.0])) mesh.meshes[0].mesh.transform(pose) self.render.scene.add_model(f"{self.counter}", mesh) - self.counter+=1 - + self.counter += 1 + def clear(self): self.render.scene.clear_geometry() @@ -102,23 +119,26 @@ def make_trimesh(self, trimesh_mesh, pose, color): mesh = trimesh_to_o3d_triangle_mesh(trimesh_mesh) mesh.transform(pose) mtl = o3d.visualization.rendering.MaterialRecord() - mtl.shader = 'defaultLitTransparency' + mtl.shader = "defaultLitTransparency" mtl.base_color = color self.render.scene.add_geometry(f"{self.counter}", mesh, mtl) - self.counter+=1 + self.counter += 1 def capture_image(self, intrinsics, camera_pose): self.render.scene.camera.set_projection( b.camera.K_from_intrinsics(intrinsics), - intrinsics.near, intrinsics.far, + intrinsics.near, + intrinsics.far, intrinsics.width, - intrinsics.height + intrinsics.height, ) # Look at the origin from the front (along the -Z direction, into the screen), with Y as Up. - center = np.array(camera_pose[:3,3]) + np.array(camera_pose[:3,2]) # look_at target - eye = np.array(camera_pose[:3,3]) # camera position - up = -np.array(camera_pose[:3,1]) + center = np.array(camera_pose[:3, 3]) + np.array( + camera_pose[:3, 2] + ) # look_at target + eye = np.array(camera_pose[:3, 3]) # camera position + up = -np.array(camera_pose[:3, 1]) self.render.scene.camera.look_at(center, eye, up) img = np.array(self.render.render_to_image()) rgb = j.add_rgba_dimension(img) @@ -135,34 +155,37 @@ def make_camera(self, intrinsics, pose, size): height = intrinsics.height color = j.BLUE - dist=size + dist = size vertices = np.zeros((5, 3)) vertices[0, :] = [0, 0, 0] - vertices[1, :] = [(0-cx)*dist/fx, (0-cy)*dist/fy, dist] - vertices[2, :] = [(width-cx)*dist/fx, (0-cy)*dist/fy, dist] - vertices[3, :] = [(width-cx)*dist/fx, (height-cy)*dist/fy, dist] - vertices[4, :] = [(0-cx)*dist/fx, (height-cy)*dist/fy, dist] + vertices[1, :] = [(0 - cx) * dist / fx, (0 - cy) * dist / fy, dist] + vertices[2, :] = [(width - cx) * dist / fx, (0 - cy) * dist / fy, dist] + vertices[3, :] = [(width - cx) * dist / fx, (height - cy) * dist / fy, dist] + vertices[4, :] = [(0 - cx) * dist / fx, (height - cy) * dist / fy, dist] new_points = j.t3d.apply_transform(vertices, pose) - lines = np.array([ - [0,1], - [0,2], - [0,3], - [0,4], - [1,2], - [2,3], - [3,4], - [4,1], - ]) + lines = np.array( + [ + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [1, 2], + [2, 3], + [3, 4], + [4, 1], + ] + ) line_set = o3d.geometry.LineSet() - line_set.points = o3d.utility.Vector3dVector(new_points) + line_set.points = o3d.utility.Vector3dVector(new_points) line_set.lines = o3d.utility.Vector2iVector(lines) line_set.paint_uniform_color(color) - mtl = o3d.visualization.rendering.MaterialRecord() # or MaterialRecord(), for later versions of Open3D + mtl = ( + o3d.visualization.rendering.MaterialRecord() + ) # or MaterialRecord(), for later versions of Open3D mtl.base_color = [0.0, 0.0, 1.0, 1.0] # RGBA mtl.shader = "defaultUnlit" self.render.scene.add_geometry(f"{self.counter}", line_set, mtl) - self.counter+=1 - + self.counter += 1 diff --git a/bayes3d/viz/viz.py b/bayes3d/viz/viz.py index 078173d9..c611825f 100644 --- a/bayes3d/viz/viz.py +++ b/bayes3d/viz/viz.py @@ -1,28 +1,30 @@ -from PIL import Image, ImageDraw, ImageFont -import numpy as np +import copy import os -from PIL import Image -import numpy as np -import bayes3d.utils -import matplotlib.pyplot as plt -import matplotlib -import graphviz + import distinctipy +import graphviz import jax.numpy as jnp -import copy +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +import bayes3d.utils RED = np.array([1.0, 0.0, 0.0]) GREEN = np.array([0.0, 1.0, 0.0]) BLUE = np.array([0.0, 0.0, 1.0]) BLACK = np.array([0.0, 0.0, 0.0]) + def load_image_from_file(filename): """Load an image from a file.""" return Image.open(filename) + def make_gif_from_pil_images(images, filename): """Save a list of PIL images as a GIF. - + Args: images (list): List of PIL images. filename (str): Filename to save GIF to. @@ -36,17 +38,20 @@ def make_gif_from_pil_images(images, filename): loop=0, ) + def preprocess_for_viz(img): depth_np = np.array(img) depth_np[depth_np >= depth_np.max()] = np.inf return depth_np -cmap = copy.copy(plt.get_cmap('turbo')) + +cmap = copy.copy(plt.get_cmap("turbo")) cmap.set_bad(color=(1.0, 1.0, 1.0, 1.0)) + def get_depth_image(image, min_val=None, max_val=None, remove_max=True): """Convert a depth image to a PIL image. - + Args: image (np.ndarray): Depth image. Shape (H, W). min (float): Minimum depth value for colormap. @@ -56,7 +61,7 @@ def get_depth_image(image, min_val=None, max_val=None, remove_max=True): PIL.Image: Depth image visualized as a PIL image. """ if len(image.shape) > 2: - depth = np.array(image[:,:,-1]) + depth = np.array(image[:, :, -1]) else: depth = np.array(image) @@ -64,10 +69,9 @@ def get_depth_image(image, min_val=None, max_val=None, remove_max=True): max_val = depth.max() if not remove_max: max_val += 1 - if min_val is None: + if min_val is None: min_val = depth.min() - - max_threhold = max_val + 1.0 * (not remove_max) + mask = (depth < max_val) * (depth > min_val) depth[np.logical_not(mask)] = np.nan depth = (depth - min_val) / (max_val - min_val + 1e-10) @@ -77,9 +81,10 @@ def get_depth_image(image, min_val=None, max_val=None, remove_max=True): ).convert("RGB") return img + def get_rgb_image(image, max=255.0): """Convert an RGB image to a PIL image. - + Args: image (np.ndarray): RGB image. Shape (H, W, 3). max (float): Maximum value for colormap. @@ -93,29 +98,29 @@ def get_rgb_image(image, max=255.0): image_type = "RGBA" img = Image.fromarray( - np.rint( - image / max * 255.0 - ).astype(np.int8), + np.rint(image / max * 255.0).astype(np.int8), mode=image_type, ).convert("RGB") return img -saveargs = dict(bbox_inches='tight', pad_inches=0) + +saveargs = dict(bbox_inches="tight", pad_inches=0) def add_depth_image(ax, depth): - d = ax.imshow(preprocess_for_viz(depth),cmap=cmap) - ax.axis('off') + d = ax.imshow(preprocess_for_viz(depth), cmap=cmap) + ax.axis("off") return d + def add_rgb_image(ax, rgb): ax.imshow(rgb) - ax.axis('off') + ax.axis("off") def viz_depth_image(depth): """Convert a depth image to a PIL image. - + Args: image (np.ndarray): Depth image. Shape (H, W). min (float): Minimum depth value for colormap. @@ -125,13 +130,14 @@ def viz_depth_image(depth): PIL.Image: Depth image visualized as a PIL image. """ fig = plt.figure() - ax = fig.add_subplot(1,1,1) + ax = fig.add_subplot(1, 1, 1) add_depth_image(ax, depth) return fig + def viz_rgb_image(image): """Convert an RGB image to a PIL image. - + Args: image (np.ndarray): RGB image. Shape (H, W, 3). max (float): Maximum value for colormap. @@ -139,28 +145,33 @@ def viz_rgb_image(image): PIL.Image: RGB image visualized as a PIL image. """ fig = plt.figure() - ax = fig.add_subplot(1,1,1) + ax = fig.add_subplot(1, 1, 1) add_rgb_image(ax, image) return fig + def pil_image_from_matplotlib(fig): - img = Image.frombytes('RGBA', fig.canvas.get_width_height(),bytes(fig.canvas.buffer_rgba())) + img = Image.frombytes( + "RGBA", fig.canvas.get_width_height(), bytes(fig.canvas.buffer_rgba()) + ) return img + def add_rgba_dimension(image): """Add an alpha channel to a particle image if it doesn't already have one. - + Args: image (np.ndarray): Particle image. Shape (H, W, 3) or (H, W, 4). """ if image.shape[-1] == 3: - p = jnp.concatenate([image, 255.0 * jnp.ones((*image.shape[:2],1))],axis=-1) + p = jnp.concatenate([image, 255.0 * jnp.ones((*image.shape[:2], 1))], axis=-1) return p return image + def overlay_image(img_1, img_2, alpha=0.5): """Overlay two images. - + Args: img_1 (PIL.Image): First image. img_2 (PIL.Image): Second image. @@ -170,6 +181,7 @@ def overlay_image(img_1, img_2, alpha=0.5): """ return Image.blend(img_1, img_2, alpha=alpha) + def resize_image(img, h, w): """Resize an image. @@ -182,19 +194,21 @@ def resize_image(img, h, w): """ return img.resize((w, h)) + def scale_image(img, factor): """Scale an image. - + Args: img (PIL.Image): Image to scale. factor (float): Scale factor. Returns: PIL.Image: Scaled image. """ - w,h = img.size + w, h = img.size return img.resize((int(w * factor), int(h * factor))) -def vstack_images(images, border = 10): + +def vstack_images(images, border=10): """Stack images vertically. Args: @@ -204,21 +218,22 @@ def vstack_images(images, border = 10): PIL.Image: Stacked image. """ max_w = 0 - sum_h = (len(images)-1)*border + sum_h = (len(images) - 1) * border for img in images: - w,h = img.size + w, h = img.size max_w = max(max_w, w) sum_h += h - full_image = Image.new('RGB', (max_w, sum_h), (255, 255, 255)) + full_image = Image.new("RGB", (max_w, sum_h), (255, 255, 255)) running_h = 0 for img in images: - w,h = img.size - full_image.paste(img, (int(max_w/2 - w/2), running_h)) + w, h = img.size + full_image.paste(img, (int(max_w / 2 - w / 2), running_h)) running_h += h + border return full_image -def hstack_images(images, border = 10): + +def hstack_images(images, border=10): """Stack images horizontally. Args: @@ -228,20 +243,21 @@ def hstack_images(images, border = 10): PIL.Image: Stacked image. """ max_h = 0 - sum_w = (len(images)-1)*border + sum_w = (len(images) - 1) * border for img in images: - w,h = img.size + w, h = img.size max_h = max(max_h, h) sum_w += w - full_image = Image.new('RGB', (sum_w, max_h),(255, 255, 255)) + full_image = Image.new("RGB", (sum_w, max_h), (255, 255, 255)) running_w = 0 for img in images: - w,h = img.size - full_image.paste(img, (running_w, int(max_h/2 - h/2))) + w, h = img.size + full_image.paste(img, (running_w, int(max_h / 2 - h / 2))) running_w += w + border return full_image + def hvstack_images(images, h, w, border=10): """Stack images in a grid. @@ -251,21 +267,31 @@ def hvstack_images(images, h, w, border=10): w (int): Number of columns. border (int): Border between images. Returns: - PIL.Image: Stacked image. + PIL.Image: Stacked image. """ assert len(images) == h * w images_to_vstack = [] for row_idx in range(h): - hstacked_row = hstack_images(images[row_idx*w:(row_idx+1)*w]) + hstacked_row = hstack_images(images[row_idx * w : (row_idx + 1) * w]) images_to_vstack.append(hstacked_row) - + return vstack_images(images_to_vstack) -def multi_panel(images, labels=None, title=None, bottom_text=None, title_fontsize=40, label_fontsize=30, bottom_fontsize=20, middle_width=10): + +def multi_panel( + images, + labels=None, + title=None, + bottom_text=None, + title_fontsize=40, + label_fontsize=30, + bottom_fontsize=20, + middle_width=10, +): """Combine multiple images into a single image. - + Args: images (list): List of PIL images. labels (list): List of labels for each image. @@ -285,13 +311,30 @@ def multi_panel(images, labels=None, title=None, bottom_text=None, title_fontsiz sum_of_widths = np.sum([img.width for img in images]) dst = Image.new( - "RGBA", (sum_of_widths + (num_images - 1) * middle_width, h), (255, 255, 255, 255) + "RGBA", + (sum_of_widths + (num_images - 1) * middle_width, h), + (255, 255, 255, 255), ) drawer = ImageDraw.Draw(dst) - font_bottom = ImageFont.truetype(os.path.join(bayes3d.utils.get_assets_dir(), "fonts", "IBMPlexSerif-Regular.ttf"), bottom_fontsize) - font_label = ImageFont.truetype(os.path.join(bayes3d.utils.get_assets_dir(), "fonts", "IBMPlexSerif-Regular.ttf"), label_fontsize) - font_title = ImageFont.truetype(os.path.join(bayes3d.utils.get_assets_dir(), "fonts", "IBMPlexSerif-Regular.ttf"), title_fontsize) + font_bottom = ImageFont.truetype( + os.path.join( + bayes3d.utils.get_assets_dir(), "fonts", "IBMPlexSerif-Regular.ttf" + ), + bottom_fontsize, + ) + font_label = ImageFont.truetype( + os.path.join( + bayes3d.utils.get_assets_dir(), "fonts", "IBMPlexSerif-Regular.ttf" + ), + label_fontsize, + ) + font_title = ImageFont.truetype( + os.path.join( + bayes3d.utils.get_assets_dir(), "fonts", "IBMPlexSerif-Regular.ttf" + ), + title_fontsize, + ) bottom_border = 0 title_border = 0 @@ -309,58 +352,85 @@ def multi_panel(images, labels=None, title=None, bottom_text=None, title_fontsiz _, _, text_w, text_h = drawer.textbbox((0, 0), msg, font=font_label) label_border = max(text_h, label_border) - bottom_border += 0 + bottom_border += 0 title_border += 20 - label_border += 20 + label_border += 20 dst = Image.new( - "RGBA", (sum_of_widths+ (num_images - 1) * middle_width, h + title_border + label_border + bottom_border), (255, 255, 255, 255) + "RGBA", + ( + sum_of_widths + (num_images - 1) * middle_width, + h + title_border + label_border + bottom_border, + ), + (255, 255, 255, 255), ) drawer = ImageDraw.Draw(dst) width_counter = 0 - for (j, img) in enumerate(images): - dst.paste( - img, - (width_counter + j * middle_width, title_border + label_border) - ) + for j, img in enumerate(images): + dst.paste(img, (width_counter + j * middle_width, title_border + label_border)) width_counter += img.width if title is not None: msg = title _, _, text_w, text_h = drawer.textbbox((0, 0), msg, font=font_title) - drawer.text(((sum_of_widths + (num_images - 1) * middle_width)/2.0 - text_w/2 , title_border/2 - text_h/2), msg, font=font_title, fill="black") - + drawer.text( + ( + (sum_of_widths + (num_images - 1) * middle_width) / 2.0 - text_w / 2, + title_border / 2 - text_h / 2, + ), + msg, + font=font_title, + fill="black", + ) width_counter = 0 if labels is not None: - for (i, msg) in enumerate(labels): + for i, msg in enumerate(labels): w = images[i].width _, _, text_w, text_h = drawer.textbbox((0, 0), msg, font=font_label) - drawer.text((width_counter + i * middle_width + w/2 - text_w/2, title_border + label_border/2 - text_h/2), msg, font=font_label, fill="black") + drawer.text( + ( + width_counter + i * middle_width + w / 2 - text_w / 2, + title_border + label_border / 2 - text_h / 2, + ), + msg, + font=font_label, + fill="black", + ) width_counter += w if bottom_text is not None: msg = bottom_text _, _, text_w, text_h = drawer.textbbox((0, 0), msg, font=font_bottom) - drawer.text((5, title_border + label_border + h + 5), msg, font=font_bottom, fill="black") + drawer.text( + (5, title_border + label_border + h + 5), + msg, + font=font_bottom, + fill="black", + ) return dst + def distinct_colors(num_colors, pastel_factor=0.5): """Get a list of distinct colors. - + Args: num_colors (int): Number of colors to generate. pastel_factor (float): Pastel factor. Returns: list: List of colors. """ - return [np.array(i) for i in distinctipy.get_colors(num_colors, pastel_factor=pastel_factor)] + return [ + np.array(i) + for i in distinctipy.get_colors(num_colors, pastel_factor=pastel_factor) + ] + def viz_graph(num_nodes, edges, filename, node_names=None): """Visualize a graph. - + Args: num_nodes (int): Number of nodes in graph. edges (list): List of edges in graph. @@ -372,24 +442,26 @@ def viz_graph(num_nodes, edges, filename, node_names=None): g_out = graphviz.Digraph() g_out.attr("node", style="filled") - + colors = matplotlib.cm.tab20(range(num_nodes)) colors = distinctipy.get_colors(num_nodes, pastel_factor=0.7) for i in range(num_nodes): g_out.node(str(i), node_names[i], fillcolor=matplotlib.colors.to_hex(colors[i])) - for (i,j) in edges: - if i==-1: + for i, j in edges: + if i == -1: continue - g_out.edge(str(i),str(j)) + g_out.edge(str(i), str(j)) max_width_px = 2000 max_height_px = 2000 dpi = 200 - g_out.attr("graph", - # See https://graphviz.gitlab.io/_pages/doc/info/attrs.html#a:size - size="{},{}!".format(max_width_px / dpi, max_height_px / dpi), - dpi=str(dpi)) + g_out.attr( + "graph", + # See https://graphviz.gitlab.io/_pages/doc/info/attrs.html#a:size + size="{},{}!".format(max_width_px / dpi, max_height_px / dpi), + dpi=str(dpi), + ) filename_prefix, filetype = filename.split(".") g_out.render(filename_prefix, format=filetype) diff --git a/demo.py b/demo.py index 25878b8f..fb23acd3 100644 --- a/demo.py +++ b/demo.py @@ -1,87 +1,94 @@ -import numpy as np -import jax.numpy as jnp -import jax -import bayes3d as b +import os import time -from PIL import Image + +import jax +import jax.numpy as jnp +from IPython import embed from scipy.spatial.transform import Rotation as R -import matplotlib.pyplot as plt -import cv2 -import trimesh -import os + +import bayes3d as b # Can be helpful for debugging: -# jax.config.update('jax_enable_checks', True) +# jax.config.update('jax_enable_checks', True) intrinsics = b.Intrinsics( - height=100, - width=100, - fx=50.0, fy=50.0, - cx=50.0, cy=50.0, - near=0.001, far=6.0 + height=100, width=100, fx=50.0, fy=50.0, cx=50.0, cy=50.0, near=0.001, far=6.0 ) b.setup_renderer(intrinsics) -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/bunny.obj")) +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/bunny.obj") +) num_frames = 60 poses = [b.t3d.transform_from_pos(jnp.array([-3.0, 0.0, 3.5]))] delta_pose = b.t3d.transform_from_rot_and_pos( - R.from_euler('zyx', [-1.0, 0.1, 2.0], degrees=True).as_matrix(), - jnp.array([0.09, 0.05, 0.02]) + R.from_euler("zyx", [-1.0, 0.1, 2.0], degrees=True).as_matrix(), + jnp.array([0.09, 0.05, 0.02]), ) -for t in range(num_frames-1): +for t in range(num_frames - 1): poses.append(poses[-1].dot(delta_pose)) poses = jnp.stack(poses) print("Number of frames: ", poses.shape[0]) -observed_images = b.RENDERER.render_many(poses[:,None,...], jnp.array([0])) +observed_images = b.RENDERER.render_many(poses[:, None, ...], jnp.array([0])) print("observed_images.shape", observed_images.shape) -translation_deltas = b.utils.make_translation_grid_enumeration(-0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 5, 5, 5) -rotation_deltas = jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0))( - jax.random.split(jax.random.PRNGKey(30), 100) +translation_deltas = b.utils.make_translation_grid_enumeration( + -0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 5, 5, 5 +) +rotation_deltas = jax.vmap( + lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0) +)(jax.random.split(jax.random.PRNGKey(30), 100)) + +likelihood = jax.vmap( + b.threedp3_likelihood_old, in_axes=(None, 0, None, None, None, None, None) ) -likelihood = jax.vmap(b.threedp3_likelihood_old, in_axes=(None, 0, None, None, None, None, None)) def update_pose_estimate(pose_estimate, gt_image): proposals = jnp.einsum("ij,ajk->aik", pose_estimate, translation_deltas) - rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(proposals[:,None, ...], jnp.array([0])) + rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))( + proposals[:, None, ...], jnp.array([0]) + ) weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3) pose_estimate = proposals[jnp.argmax(weights_new)] proposals = jnp.einsum("ij,ajk->aik", pose_estimate, rotation_deltas) - rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(proposals[:, None, ...], jnp.array([0])) + rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))( + proposals[:, None, ...], jnp.array([0]) + ) weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3) pose_estimate = proposals[jnp.argmax(weights_new)] return pose_estimate, pose_estimate -inference_program = jax.jit(lambda p,x: jax.lax.scan(update_pose_estimate, p,x)[1]) + +inference_program = jax.jit(lambda p, x: jax.lax.scan(update_pose_estimate, p, x)[1]) inferred_poses = inference_program(poses[0], observed_images) start = time.time() pose_estimates_over_time = inference_program(poses[0], observed_images) end = time.time() -print ("Time elapsed:", end - start) -print ("FPS:", poses.shape[0] / (end - start)) +print("Time elapsed:", end - start) +print("FPS:", poses.shape[0] / (end - start)) -rerendered_images = b.RENDERER.render_many(pose_estimates_over_time[:, None, ...], jnp.array([0])) +rerendered_images = b.RENDERER.render_many( + pose_estimates_over_time[:, None, ...], jnp.array([0]) +) viz_images = [ b.viz.multi_panel( [ - b.viz.scale_image(b.viz.get_depth_image(d[:,:,2]), 3), - b.viz.scale_image(b.viz.get_depth_image(r[:,:,2]), 3) - ], + b.viz.scale_image(b.viz.get_depth_image(d[:, :, 2]), 3), + b.viz.scale_image(b.viz.get_depth_image(r[:, :, 2]), 3), + ], labels=["Observed", "Rerendered"], - label_fontsize=20 + label_fontsize=20, ) for (r, d) in zip(rerendered_images, observed_images) ] b.make_gif_from_pil_images(viz_images, "assets/demo.gif") - -from IPython import embed; embed() \ No newline at end of file +embed() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..ae978ce9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +[tool.ruff] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv" +] +# extend-include = ["*.ipynb"] +line-length = 88 +indent-width = 4 + +[tool.ruff.lint] +extend-select = ["I"] +select = ["E4", "E7", "E9", "F"] + +# F403 disables errors from `*` imports, which we currently use heavily. +ignore = ["F403"] +fixable = ["ALL"] +unfixable = [] +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/scripts/_mkl/notebooks/kubric/kubric_helper.py b/scripts/_mkl/notebooks/kubric/kubric_helper.py index 35caae64..64d289d3 100644 --- a/scripts/_mkl/notebooks/kubric/kubric_helper.py +++ b/scripts/_mkl/notebooks/kubric/kubric_helper.py @@ -1,27 +1,23 @@ -import logging - -import bpy import kubric as kb -from kubric.simulator import PyBullet -from kubric.renderer import Blender import numpy as np def get_linear_camera_motion_start_end( movement_speed: float, - inner_radius: float = 8., - outer_radius: float = 12., + inner_radius: float = 8.0, + outer_radius: float = 12.0, z_offset: float = 0.1, ): - """Sample a linear path which starts and ends within a half-sphere shell.""" - while True: - camera_start = np.array(kb.sample_point_in_half_sphere_shell(inner_radius, - outer_radius, - z_offset)) - direction = rng.rand(3) - 0.5 - movement = direction / np.linalg.norm(direction) * movement_speed - camera_end = camera_start + movement - if (inner_radius <= np.linalg.norm(camera_end) <= outer_radius and - camera_end[2] > z_offset): - return camera_start, camera_end - + """Sample a linear path which starts and ends within a half-sphere shell.""" + while True: + camera_start = np.array( + kb.sample_point_in_half_sphere_shell(inner_radius, outer_radius, z_offset) + ) + direction = rng.rand(3) - 0.5 + movement = direction / np.linalg.norm(direction) * movement_speed + camera_end = camera_start + movement + if ( + inner_radius <= np.linalg.norm(camera_end) <= outer_radius + and camera_end[2] > z_offset + ): + return camera_start, camera_end diff --git a/scripts/_mkl/notebooks/nbexporter.py b/scripts/_mkl/notebooks/nbexporter.py index 482654fe..f7383445 100644 --- a/scripts/_mkl/notebooks/nbexporter.py +++ b/scripts/_mkl/notebooks/nbexporter.py @@ -3,41 +3,42 @@ Exports notebooks to `.py` files using `nbdev.nb_export`. """ -import argparse import glob import os -from nbdev.export import nb_export from pathlib import Path +from nbdev.export import nb_export + NBS = "." LIB = "../../../bayes3d/_mkl/" + class bcolors: - BLUE = '\033[94m' - CYAN = '\033[96m' - GREEN = '\033[92m' + BLUE = "\033[94m" + CYAN = "\033[96m" + GREEN = "\033[92m" PURPLE = "\033[95m" - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" def main(): - lib_path = Path(__file__).parents[0]/LIB + lib_path = Path(__file__).parents[0] / LIB rel_lib_path = os.path.relpath(lib_path) - rel_nbs_path = os.path.relpath(Path(__file__).parents[0]/NBS) + rel_nbs_path = os.path.relpath(Path(__file__).parents[0] / NBS) file_pattern = f"{rel_nbs_path}/**/[a-zA-Z0-9]*.ipynb" print(f"{bcolors.BLUE}Trying to export the following files") for fname in glob.glob(file_pattern, recursive=True): - print(f"\t{bcolors.PURPLE}{fname}{bcolors.ENDC}") nb_export(fname, lib_path=rel_lib_path) print(f"{bcolors.BLUE}to{bcolors.ENDC}") print(f"\t{bcolors.PURPLE}{bcolors.BOLD}{rel_lib_path}{bcolors.ENDC}") + if __name__ == "__main__": main() diff --git a/scripts/experiments/collaborations/arijit_physics.py b/scripts/experiments/collaborations/arijit_physics.py index d5381cda..b11efe41 100644 --- a/scripts/experiments/collaborations/arijit_physics.py +++ b/scripts/experiments/collaborations/arijit_physics.py @@ -1,28 +1,32 @@ import genjax -import bayes3d as b -from genjax.generative_functions.distributions import ExactDensity +import jax import jax.numpy as jnp + +import bayes3d as b import bayes3d.genjax -import jax b.setup_visualizer() + @genjax.gen def body_fun(prev): (t, pose, velocity) = prev - velocity = b.gaussian_vmf_pose(velocity, 0.01, 10000.0) @ f"velocity" - pose = b.gaussian_vmf_pose(pose @ velocity, 0.01, 10000.0) @ f"pose" + velocity = b.gaussian_vmf_pose(velocity, 0.01, 10000.0) @ "velocity" + pose = b.gaussian_vmf_pose(pose @ velocity, 0.01, 10000.0) @ "pose" # Render return (t + 1, pose, velocity) # Creating a `SwitchCombinator` via the preferred `new` class method. + @genjax.gen def model(T): - pose = b.uniform_pose(jnp.ones(3)*-1.0, jnp.ones(3)*1.0) @ "init_pose" + pose = b.uniform_pose(jnp.ones(3) * -1.0, jnp.ones(3) * 1.0) @ "init_pose" velocity = b.gaussian_vmf_pose(jnp.eye(4), 0.01, 10000.0) @ "init_velocity" - evolve = genjax.UnfoldCombinator.new(body_fun, 100)(50,(0, pose, velocity)) @ "dynamics" + evolve = ( + genjax.UnfoldCombinator.new(body_fun, 100)(50, (0, pose, velocity)) @ "dynamics" + ) return 1.0 @@ -34,5 +38,5 @@ def model(T): # TODO: # 1. Add rendering and images likelihood - # Do simple SMC tracking of one object -# 2. Make this multiobject \ No newline at end of file +# Do simple SMC tracking of one object +# 2. Make this multiobject diff --git a/scripts/experiments/colmap/colmap_loader.py b/scripts/experiments/colmap/colmap_loader.py index 4a498ee1..ddef28cd 100644 --- a/scripts/experiments/colmap/colmap_loader.py +++ b/scripts/experiments/colmap/colmap_loader.py @@ -3,24 +3,27 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # -import numpy as np import collections import struct +import numpy as np + CameraModel = collections.namedtuple( - "CameraModel", ["model_id", "model_name", "num_params"]) -Camera = collections.namedtuple( - "Camera", ["id", "model", "width", "height", "params"]) + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) BaseImage = collections.namedtuple( - "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) Point3D = collections.namedtuple( - "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) CAMERA_MODELS = { CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), CameraModel(model_id=1, model_name="PINHOLE", num_params=4), @@ -32,43 +35,63 @@ CameraModel(model_id=7, model_name="FOV", num_params=5), CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), - CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), } -CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) - for camera_model in CAMERA_MODELS]) -CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) - for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) def qvec2rotmat(qvec): - return np.array([ - [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, - 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], - 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], - [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], - 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, - 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], - [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], - 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], - 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + def rotmat2qvec(R): Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat - K = np.array([ - [Rxx - Ryy - Rzz, 0, 0, 0], - [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], - [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], - [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) eigvals, eigvecs = np.linalg.eigh(K) qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] if qvec[0] < 0: qvec *= -1 return qvec + class Image(BaseImage): def qvec2rotmat(self): return qvec2rotmat(self.qvec) + def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): """Read and unpack the next bytes from a binary file. :param fid: @@ -80,6 +103,7 @@ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): data = fid.read(num_bytes) return struct.unpack(endian_character + format_char_sequence, data) + def read_points3D_text(path): """ see: src/base/reconstruction.cc @@ -110,6 +134,7 @@ def read_points3D_text(path): errors = np.append(errors, error[None, ...], axis=0) return xyzs, rgbs, errors + def read_points3D_binary(path_to_model_file): """ see: src/base/reconstruction.cc @@ -117,7 +142,6 @@ def read_points3D_binary(path_to_model_file): void Reconstruction::WritePoints3DBinary(const std::string& path) """ - with open(path_to_model_file, "rb") as fid: num_points = read_next_bytes(fid, 8, "Q")[0] @@ -127,20 +151,25 @@ def read_points3D_binary(path_to_model_file): for p_id in range(num_points): binary_point_line_properties = read_next_bytes( - fid, num_bytes=43, format_char_sequence="QdddBBBd") + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) xyz = np.array(binary_point_line_properties[1:4]) rgb = np.array(binary_point_line_properties[4:7]) error = np.array(binary_point_line_properties[7]) - track_length = read_next_bytes( - fid, num_bytes=8, format_char_sequence="Q")[0] + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] track_elems = read_next_bytes( - fid, num_bytes=8*track_length, - format_char_sequence="ii"*track_length) + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) xyzs[p_id] = xyz rgbs[p_id] = rgb errors[p_id] = error return xyzs, rgbs, errors + def read_intrinsics_text(path): """ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py @@ -156,15 +185,18 @@ def read_intrinsics_text(path): elems = line.split() camera_id = int(elems[0]) model = elems[1] - assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + assert ( + model == "PINHOLE" + ), "While the loader support other types, the rest of the code assumes PINHOLE" width = int(elems[2]) height = int(elems[3]) params = np.array(tuple(map(float, elems[4:]))) - cameras[camera_id] = Camera(id=camera_id, model=model, - width=width, height=height, - params=params) + cameras[camera_id] = Camera( + id=camera_id, model=model, width=width, height=height, params=params + ) return cameras + def read_extrinsics_binary(path_to_model_file): """ see: src/base/reconstruction.cc @@ -176,27 +208,38 @@ def read_extrinsics_binary(path_to_model_file): num_reg_images = read_next_bytes(fid, 8, "Q")[0] for _ in range(num_reg_images): binary_image_properties = read_next_bytes( - fid, num_bytes=64, format_char_sequence="idddddddi") + fid, num_bytes=64, format_char_sequence="idddddddi" + ) image_id = binary_image_properties[0] qvec = np.array(binary_image_properties[1:5]) tvec = np.array(binary_image_properties[5:8]) camera_id = binary_image_properties[8] image_name = "" current_char = read_next_bytes(fid, 1, "c")[0] - while current_char != b"\x00": # look for the ASCII 0 entry + while current_char != b"\x00": # look for the ASCII 0 entry image_name += current_char.decode("utf-8") current_char = read_next_bytes(fid, 1, "c")[0] - num_points2D = read_next_bytes(fid, num_bytes=8, - format_char_sequence="Q")[0] - x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, - format_char_sequence="ddq"*num_points2D) - xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), - tuple(map(float, x_y_id_s[1::3]))]) + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] + ) point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) images[image_id] = Image( - id=image_id, qvec=qvec, tvec=tvec, - camera_id=camera_id, name=image_name, - xys=xys, point3D_ids=point3D_ids) + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) return images @@ -211,20 +254,24 @@ def read_intrinsics_binary(path_to_model_file): num_cameras = read_next_bytes(fid, 8, "Q")[0] for _ in range(num_cameras): camera_properties = read_next_bytes( - fid, num_bytes=24, format_char_sequence="iiQQ") + fid, num_bytes=24, format_char_sequence="iiQQ" + ) camera_id = camera_properties[0] model_id = camera_properties[1] model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name width = camera_properties[2] height = camera_properties[3] num_params = CAMERA_MODEL_IDS[model_id].num_params - params = read_next_bytes(fid, num_bytes=8*num_params, - format_char_sequence="d"*num_params) - cameras[camera_id] = Camera(id=camera_id, - model=model_name, - width=width, - height=height, - params=np.array(params)) + params = read_next_bytes( + fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) assert len(cameras) == num_cameras return cameras @@ -248,13 +295,19 @@ def read_extrinsics_text(path): camera_id = int(elems[8]) image_name = elems[9] elems = fid.readline().split() - xys = np.column_stack([tuple(map(float, elems[0::3])), - tuple(map(float, elems[1::3]))]) + xys = np.column_stack( + [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] + ) point3D_ids = np.array(tuple(map(int, elems[2::3]))) images[image_id] = Image( - id=image_id, qvec=qvec, tvec=tvec, - camera_id=camera_id, name=image_name, - xys=xys, point3D_ids=point3D_ids) + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) return images @@ -266,8 +319,9 @@ def read_colmap_bin_array(path): :return: nd array with the floating point values in the value """ with open(path, "rb") as fid: - width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, - usecols=(0, 1, 2), dtype=int) + width, height, channels = np.genfromtxt( + fid, delimiter="&", max_rows=1, usecols=(0, 1, 2), dtype=int + ) fid.seek(0) num_delimiter = 0 byte = fid.read(1) @@ -279,4 +333,4 @@ def read_colmap_bin_array(path): byte = fid.read(1) array = np.fromfile(fid, np.float32) array = array.reshape((width, height, channels), order="F") - return np.transpose(array, (1, 0, 2)).squeeze() \ No newline at end of file + return np.transpose(array, (1, 0, 2)).squeeze() diff --git a/scripts/experiments/colmap/dataset_loader.py b/scripts/experiments/colmap/dataset_loader.py index 123f0cef..5acc112d 100644 --- a/scripts/experiments/colmap/dataset_loader.py +++ b/scripts/experiments/colmap/dataset_loader.py @@ -3,32 +3,40 @@ # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # -# This software is free for non-commercial, research and evaluation use +# This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # +import json +import math import os import sys -import math -from PIL import Image +from pathlib import Path from typing import NamedTuple -from colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ - read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text + import numpy as np -import json -from pathlib import Path +from colmap_loader import ( + qvec2rotmat, + read_extrinsics_binary, + read_extrinsics_text, + read_intrinsics_binary, + read_intrinsics_text, + read_points3D_binary, + read_points3D_text, +) +from PIL import Image from plyfile import PlyData, PlyElement -from typing import NamedTuple class BasicPointCloud(NamedTuple): - points : np.array - colors : np.array - normals : np.array + points: np.array + colors: np.array + normals: np.array + -def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): +def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): Rt = np.zeros((4, 4)) Rt[:3, :3] = R.transpose() Rt[:3, 3] = t @@ -41,11 +49,14 @@ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): Rt = np.linalg.inv(C2W) return np.float32(Rt) + def fov2focal(fov, pixels): return pixels / (2 * math.tan(fov / 2)) + def focal2fov(focal, pixels): - return 2*math.atan(pixels/(2*focal)) + return 2 * math.atan(pixels / (2 * focal)) + class CameraInfo(NamedTuple): uid: int @@ -59,6 +70,7 @@ class CameraInfo(NamedTuple): width: int height: int + class SceneInfo(NamedTuple): point_cloud: BasicPointCloud train_cameras: list @@ -66,12 +78,13 @@ class SceneInfo(NamedTuple): nerf_normalization: dict ply_path: str + def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): cam_infos = [] for idx, key in enumerate(cam_extrinsics): - sys.stdout.write('\r') + sys.stdout.write("\r") # the exact output you're looking for: - sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) + sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics))) sys.stdout.flush() extr = cam_extrinsics[key] @@ -83,11 +96,11 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): R = np.transpose(qvec2rotmat(extr.qvec)) T = np.array(extr.tvec) - if intr.model=="SIMPLE_PINHOLE": + if intr.model == "SIMPLE_PINHOLE": focal_length_x = intr.params[0] FovY = focal2fov(focal_length_x, height) FovX = focal2fov(focal_length_x, width) - elif intr.model=="PINHOLE": + elif intr.model == "PINHOLE": focal_length_x = intr.params[0] focal_length_y = intr.params[1] FovY = focal2fov(focal_length_y, height) @@ -99,26 +112,46 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): image_name = os.path.basename(image_path).split(".")[0] image = Image.open(image_path) - cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, - image_path=image_path, image_name=image_name, width=width, height=height) + cam_info = CameraInfo( + uid=uid, + R=R, + T=T, + FovY=FovY, + FovX=FovX, + image=image, + image_path=image_path, + image_name=image_name, + width=width, + height=height, + ) cam_infos.append(cam_info) - sys.stdout.write('\n') + sys.stdout.write("\n") return cam_infos + def fetchPly(path): plydata = PlyData.read(path) - vertices = plydata['vertex'] - positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T - colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 - normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + vertices = plydata["vertex"] + positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T + colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0 + normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T return BasicPointCloud(points=positions, colors=colors, normals=normals) + def storePly(path, xyz, rgb): # Define the dtype for the structured array - dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), - ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), - ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] - + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + normals = np.zeros_like(xyz) elements = np.empty(xyz.shape[0], dtype=dtype) @@ -126,7 +159,7 @@ def storePly(path, xyz, rgb): elements[:] = list(map(tuple, attributes)) # Create the PlyData object and write to file - vertex_element = PlyElement.describe(elements, 'vertex') + vertex_element = PlyElement.describe(elements, "vertex") ply_data = PlyData([vertex_element]) ply_data.write(path) @@ -154,21 +187,26 @@ def get_center_and_diag(cam_centers): return {"translate": translate, "radius": radius} + def readColmapSceneInfo(path, images, eval, llffhold=8): try: cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) - except: + except Exception: cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) - reading_dir = "images" if images == None else images - cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) - cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) + reading_dir = "images" if images is None else images + cam_infos_unsorted = readColmapCameras( + cam_extrinsics=cam_extrinsics, + cam_intrinsics=cam_intrinsics, + images_folder=os.path.join(path, reading_dir), + ) + cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) if eval: train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] @@ -183,24 +221,29 @@ def readColmapSceneInfo(path, images, eval, llffhold=8): bin_path = os.path.join(path, "sparse/0/points3D.bin") txt_path = os.path.join(path, "sparse/0/points3D.txt") if not os.path.exists(ply_path): - print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") + print( + "Converting point3d.bin to .ply, will happen only the first time you open the scene." + ) try: xyz, rgb, _ = read_points3D_binary(bin_path) - except: + except Exception: xyz, rgb, _ = read_points3D_text(txt_path) storePly(ply_path, xyz, rgb) try: pcd = fetchPly(ply_path) - except: + except Exception: pcd = None - scene_info = SceneInfo(point_cloud=pcd, - train_cameras=train_cam_infos, - test_cameras=test_cam_infos, - nerf_normalization=nerf_normalization, - ply_path=ply_path) + scene_info = SceneInfo( + point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path, + ) return scene_info + def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): cam_infos = [] @@ -219,7 +262,9 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= # get the world-to-camera transform and set R, T w2c = np.linalg.inv(c2w) - R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code + R = np.transpose( + w2c[:3, :3] + ) # R is stored transposed due to 'glm' in CUDA code T = w2c[:3, 3] image_path = os.path.join(path, cam_name) @@ -228,21 +273,36 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= im_data = np.array(image.convert("RGBA")) - bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) + bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0]) norm_data = im_data / 255.0 - arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) - image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") + arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * ( + 1 - norm_data[:, :, 3:4] + ) + image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) - FovY = fovy + FovY = fovy FovX = fovx - cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, - image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) - + cam_infos.append( + CameraInfo( + uid=idx, + R=R, + T=T, + FovY=FovY, + FovX=FovX, + image=image, + image_path=image_path, + image_name=image_name, + width=image.size[0], + height=image.size[1], + ) + ) + return cam_infos + sceneLoadTypeCallbacks = { "Colmap": readColmapSceneInfo, -} \ No newline at end of file +} diff --git a/scripts/experiments/colmap/run.py b/scripts/experiments/colmap/run.py index bc9d9616..1379ceb3 100644 --- a/scripts/experiments/colmap/run.py +++ b/scripts/experiments/colmap/run.py @@ -9,44 +9,67 @@ # For inquiries contact george.drettakis@inria.fr # -import os import logging -from argparse import ArgumentParser +import os import shutil +from argparse import ArgumentParser # This Python script is based on the shell converter script provided in the MipNerF 360 repository. parser = ArgumentParser("Colmap converter") -parser.add_argument("--no_gpu", action='store_true') -parser.add_argument("--skip_matching", action='store_true') +parser.add_argument("--no_gpu", action="store_true") +parser.add_argument("--skip_matching", action="store_true") parser.add_argument("--source_path", "-s", required=True, type=str) parser.add_argument("--camera", default="PINHOLE", type=str) parser.add_argument("--colmap_executable", default="", type=str) parser.add_argument("--resize", action="store_true") parser.add_argument("--magick_executable", default="", type=str) args = parser.parse_args() -colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" -magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" +colmap_command = ( + '"{}"'.format(args.colmap_executable) + if len(args.colmap_executable) > 0 + else "colmap" +) +magick_command = ( + '"{}"'.format(args.magick_executable) + if len(args.magick_executable) > 0 + else "magick" +) use_gpu = 1 if not args.no_gpu else 0 if not args.skip_matching: os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) ## Feature extraction - feat_extracton_cmd = colmap_command + " feature_extractor "\ - "--database_path " + args.source_path + "/distorted/database.db \ - --image_path " + args.source_path + "/input \ + feat_extracton_cmd = ( + colmap_command + " feature_extractor " + "--database_path " + + args.source_path + + "/distorted/database.db \ + --image_path " + + args.source_path + + "/input \ --ImageReader.single_camera 1 \ - --ImageReader.camera_model " + args.camera + " \ - --SiftExtraction.use_gpu " + str(use_gpu) + --ImageReader.camera_model " + + args.camera + + " \ + --SiftExtraction.use_gpu " + + str(use_gpu) + ) exit_code = os.system(feat_extracton_cmd) if exit_code != 0: logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") exit(exit_code) ## Feature matching - feat_matching_cmd = colmap_command + " exhaustive_matcher \ - --database_path " + args.source_path + "/distorted/database.db \ - --SiftMatching.use_gpu " + str(use_gpu) + feat_matching_cmd = ( + colmap_command + + " exhaustive_matcher \ + --database_path " + + args.source_path + + "/distorted/database.db \ + --SiftMatching.use_gpu " + + str(use_gpu) + ) exit_code = os.system(feat_matching_cmd) if exit_code != 0: logging.error(f"Feature matching failed with code {exit_code}. Exiting.") @@ -55,11 +78,20 @@ ### Bundle adjustment # The default Mapper tolerance is unnecessarily large, # decreasing it speeds up bundle adjustment steps. - mapper_cmd = (colmap_command + " mapper \ - --database_path " + args.source_path + "/distorted/database.db \ - --image_path " + args.source_path + "/input \ - --output_path " + args.source_path + "/distorted/sparse \ - --Mapper.ba_global_function_tolerance=0.000001") + mapper_cmd = ( + colmap_command + + " mapper \ + --database_path " + + args.source_path + + "/distorted/database.db \ + --image_path " + + args.source_path + + "/input \ + --output_path " + + args.source_path + + "/distorted/sparse \ + --Mapper.ba_global_function_tolerance=0.000001" + ) exit_code = os.system(mapper_cmd) if exit_code != 0: logging.error(f"Mapper failed with code {exit_code}. Exiting.") @@ -67,11 +99,20 @@ ### Image undistortion ## We need to undistort our images into ideal pinhole intrinsics. -img_undist_cmd = (colmap_command + " image_undistorter \ - --image_path " + args.source_path + "/input \ - --input_path " + args.source_path + "/distorted/sparse/0 \ - --output_path " + args.source_path + "\ - --output_type COLMAP") +img_undist_cmd = ( + colmap_command + + " image_undistorter \ + --image_path " + + args.source_path + + "/input \ + --input_path " + + args.source_path + + "/distorted/sparse/0 \ + --output_path " + + args.source_path + + "\ + --output_type COLMAP" +) exit_code = os.system(img_undist_cmd) if exit_code != 0: logging.error(f"Mapper failed with code {exit_code}. Exiting.") @@ -81,13 +122,13 @@ os.makedirs(args.source_path + "/sparse/0", exist_ok=True) # Copy each file from the source directory to the destination directory for file in files: - if file == '0': + if file == "0": continue source_file = os.path.join(args.source_path, "sparse", file) destination_file = os.path.join(args.source_path, "sparse", "0", file) shutil.move(source_file, destination_file) -if(args.resize): +if args.resize: print("Copying and resizing...") # Resize images. @@ -102,23 +143,29 @@ destination_file = os.path.join(args.source_path, "images_2", file) shutil.copy2(source_file, destination_file) - exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) + exit_code = os.system( + magick_command + " mogrify -resize 50% " + destination_file + ) if exit_code != 0: logging.error(f"50% resize failed with code {exit_code}. Exiting.") exit(exit_code) destination_file = os.path.join(args.source_path, "images_4", file) shutil.copy2(source_file, destination_file) - exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) + exit_code = os.system( + magick_command + " mogrify -resize 25% " + destination_file + ) if exit_code != 0: logging.error(f"25% resize failed with code {exit_code}. Exiting.") exit(exit_code) destination_file = os.path.join(args.source_path, "images_8", file) shutil.copy2(source_file, destination_file) - exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) + exit_code = os.system( + magick_command + " mogrify -resize 12.5% " + destination_file + ) if exit_code != 0: logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") exit(exit_code) -print("Done.") \ No newline at end of file +print("Done.") diff --git a/scripts/experiments/deeplearning/kubric_dataset_gen/kubric_dataset_gen.py b/scripts/experiments/deeplearning/kubric_dataset_gen/kubric_dataset_gen.py index 61a6d8f0..ccbcd4c8 100644 --- a/scripts/experiments/deeplearning/kubric_dataset_gen/kubric_dataset_gen.py +++ b/scripts/experiments/deeplearning/kubric_dataset_gen/kubric_dataset_gen.py @@ -1,11 +1,12 @@ -import jax.numpy as jnp -import bayes3d as j -import trimesh import os + +import jax +import jax.numpy as jnp import numpy as np import trimesh -import jax +from IPython import embed +import bayes3d as j # --- creating the model dir from the working directory model_dir = os.path.join(j.utils.get_assets_dir(), "ycb_video_models/models") @@ -23,43 +24,61 @@ bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), "bop/ycbv") -rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir) +rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img("52", "1", bop_ycb_dir) intrinsics = j.Intrinsics( height=rgbd.intrinsics.height, width=rgbd.intrinsics.width, - fx=rgbd.intrinsics.fx, fy=rgbd.intrinsics.fx, - cx=rgbd.intrinsics.width/2.0, cy=rgbd.intrinsics.height/2.0, - near=0.001, far=3.0 + fx=rgbd.intrinsics.fx, + fy=rgbd.intrinsics.fx, + cx=rgbd.intrinsics.width / 2.0, + cy=rgbd.intrinsics.height / 2.0, + near=0.001, + far=3.0, ) - - NUM_IMAGES_PER_ITER = 10 NUM_ITER = 100 for iter in range(NUM_ITER): print("Iteration: ", iter) key = jax.random.PRNGKey(iter) - object_poses = jax.vmap(lambda key: j.distributions.gaussian_vmf(key, 0.00001, 0.001))( - jax.random.split(key, NUM_IMAGES_PER_ITER) - ) - object_poses = jnp.einsum("ij,ajk",j.t3d.inverse_pose(camera_pose),object_poses) + object_poses = jax.vmap( + lambda key: j.distributions.gaussian_vmf(key, 0.00001, 0.001) + )(jax.random.split(key, NUM_IMAGES_PER_ITER)) + object_poses = jnp.einsum("ij,ajk", j.t3d.inverse_pose(camera_pose), object_poses) mesh_paths = [] - mesh_path = os.path.join(model_dir,name,"textured.obj") + mesh_path = os.path.join(model_dir, name, "textured.obj") for _ in range(NUM_IMAGES_PER_ITER): mesh_paths.append(mesh_path) _, offset_pose = j.mesh.center_mesh(trimesh.load(mesh_path), return_pose=True) - - all_data = j.kubric_interface.render_multiobject_parallel(mesh_paths, object_poses[None,:,...], intrinsics, scaling_factor=1.0, lighting=3.0) # multi img singleobj + all_data = j.kubric_interface.render_multiobject_parallel( + mesh_paths, + object_poses[None, :, ...], + intrinsics, + scaling_factor=1.0, + lighting=3.0, + ) # multi img singleobj gt_poses = object_poses @ offset_pose DATASET_FILENAME = f"dataset_{iter}.npz" # npz file - DATASET_FILE = os.path.join(j.utils.get_assets_dir(), f"datasets/{DATASET_FILENAME}") - np.savez(DATASET_FILE, rgbds=all_data, poses=gt_poses, id=IDX, name=model_names[IDX], intrinsics=intrinsics, mesh_path=mesh_path) + DATASET_FILE = os.path.join( + j.utils.get_assets_dir(), f"datasets/{DATASET_FILENAME}" + ) + np.savez( + DATASET_FILE, + rgbds=all_data, + poses=gt_poses, + id=IDX, + name=model_names[IDX], + intrinsics=intrinsics, + mesh_path=mesh_path, + ) - rgb_images = j.hstack_images([j.get_rgb_image(r.rgb) for r in all_data]).save(f"dataset_{iter}.png") + rgb_images = j.hstack_images([j.get_rgb_image(r.rgb) for r in all_data]).save( + f"dataset_{iter}.png" + ) -from IPython import embed; embed() \ No newline at end of file +embed() diff --git a/scripts/experiments/deeplearning/sam/sam.py b/scripts/experiments/deeplearning/sam/sam.py index f763c732..824a4a12 100644 --- a/scripts/experiments/deeplearning/sam/sam.py +++ b/scripts/experiments/deeplearning/sam/sam.py @@ -1,37 +1,45 @@ -import bayes3d as j -import jax.numpy as jnp import os -import torch -import numpy as np import pickle -import warnings -from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, build_sam import sys +import warnings + +import jax.numpy as jnp +import numpy as np +from segment_anything import SamAutomaticMaskGenerator, build_sam + +import bayes3d as j + sys.path.extend(["/home/nishadgothoskar/ptamp/pybullet_planning"]) sys.path.extend(["/home/nishadgothoskar/ptamp"]) warnings.filterwarnings("ignore") bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), "bop/ycbv") -rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir) - -test_pkl_file = os.path.join(j.utils.get_assets_dir(),"sample_imgs/strawberry_error.pkl") -test_pkl_file = os.path.join(j.utils.get_assets_dir(),"sample_imgs/knife_spoon_box_real.pkl") -test_pkl_file = os.path.join(j.utils.get_assets_dir(),"sample_imgs/red_lego_multi.pkl") -test_pkl_file = os.path.join(j.utils.get_assets_dir(),"sample_imgs/demo2_nolight.pkl") - -file = open(test_pkl_file,'rb') +rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img("52", "1", bop_ycb_dir) + +test_pkl_file = os.path.join( + j.utils.get_assets_dir(), "sample_imgs/strawberry_error.pkl" +) +test_pkl_file = os.path.join( + j.utils.get_assets_dir(), "sample_imgs/knife_spoon_box_real.pkl" +) +test_pkl_file = os.path.join(j.utils.get_assets_dir(), "sample_imgs/red_lego_multi.pkl") +test_pkl_file = os.path.join(j.utils.get_assets_dir(), "sample_imgs/demo2_nolight.pkl") + +file = open(test_pkl_file, "rb") camera_images = pickle.load(file)["camera_images"] images = [j.RGBD.construct_from_camera_image(c) for c in camera_images] rgbd = images[0] j.get_rgb_image(rgbd.rgb).save("rgb.png") -sam = build_sam(checkpoint="/home/nishadgothoskar/jax3dp3/assets/sam/sam_vit_h_4b8939.pth") +sam = build_sam( + checkpoint="/home/nishadgothoskar/jax3dp3/assets/sam/sam_vit_h_4b8939.pth" +) sam.to(device="cuda") mask_generator = SamAutomaticMaskGenerator(sam) -boxes= mask_generator.generate(np.array(rgbd.rgb)) +boxes = mask_generator.generate(np.array(rgbd.rgb)) full_segmentation = jnp.ones(rgbd.rgb.shape[:2]) * -1.0 num_objects_so_far = 0 @@ -40,41 +48,39 @@ matched = False for jj in range(num_objects_so_far): - seg_mask_existing_object = (full_segmentation == jj) - + seg_mask_existing_object = full_segmentation == jj + intersection = seg_mask * seg_mask_existing_object if intersection[seg_mask].mean() > 0.9: matched = True - + if not matched: full_segmentation = full_segmentation.at[seg_mask].set(num_objects_so_far) num_objects_so_far += 1 - segmentation_image = j.get_depth_image(full_segmentation + 1,max=full_segmentation.max() + 2) + segmentation_image = j.get_depth_image( + full_segmentation + 1, max=full_segmentation.max() + 2 + ) seg_viz = j.get_depth_image(seg_mask) - j.hstack_images([segmentation_image, seg_viz]).save(f"{i}.png") - -full_segmentation = full_segmentation.at[seg_mask].set(i+1) - + j.hstack_images([segmentation_image, seg_viz]).save(f"{i}.png") +full_segmentation = full_segmentation.at[seg_mask].set(i + 1) # sam = build_sam() # sam.to(device="cuda") # mask_generator = SamAutomaticMaskGenerator(sam) -j.get_rgb_image(rgbd.rgb).save("rgb.png") -mask_generator.generate(np.array(rgbd.rgb)) - -mask_generator.generate(np.array(img)) - -sam = sam_model_registry["default"](checkpoint=args.checkpoint) -_ = sam.to(device=args.device) -output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" -amg_kwargs = get_amg_kwargs(args) -generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) +# j.get_rgb_image(rgbd.rgb).save("rgb.png") +# mask_generator.generate(np.array(rgbd.rgb)) +# mask_generator.generate(np.array(img)) -mask_generator = SamAutomaticMaskGenerator["default"](build_sam()) +# sam = sam_model_registry["default"](checkpoint=args.checkpoint) +# _ = sam.to(device=args.device) +# output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" +# amg_kwargs = get_amg_kwargs(args) +# generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) +# mask_generator = SamAutomaticMaskGenerator["default"](build_sam()) diff --git a/scripts/experiments/gaussian_splatting/optimization.py b/scripts/experiments/gaussian_splatting/optimization.py index eddda0a0..9f07966f 100644 --- a/scripts/experiments/gaussian_splatting/optimization.py +++ b/scripts/experiments/gaussian_splatting/optimization.py @@ -1,18 +1,14 @@ -import diff_gaussian_rasterization as dgr -from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer -import torch import os + +import jax.numpy as jnp import numpy as np -import matplotlib.pyplot as plt -import math -from tqdm import tqdm +import torch + import bayes3d as b -import jax.numpy as jnp -from random import randint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models") -mesh_path = os.path.join(model_dir,"obj_" + "{}".format(3).rjust(6, '0') + ".ply") +model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") +mesh_path = os.path.join(model_dir, "obj_" + "{}".format(3).rjust(6, "0") + ".ply") mesh = b.utils.load_mesh(mesh_path) -vertices = torch.tensor(np.array(jnp.array(mesh.vertices) / 1000.0),device=device) \ No newline at end of file +vertices = torch.tensor(np.array(jnp.array(mesh.vertices) / 1000.0), device=device) diff --git a/scripts/experiments/gaussian_splatting/splatting_simple.ipynb b/scripts/experiments/gaussian_splatting/splatting_simple.ipynb index 4c3a657a..7a98e08c 100644 --- a/scripts/experiments/gaussian_splatting/splatting_simple.ipynb +++ b/scripts/experiments/gaussian_splatting/splatting_simple.ipynb @@ -209,7 +209,7 @@ "try:\n", " del render_from_pos_quat_jit\n", " del value_and_grad_loss\n", - "except:\n", + "except Exception:\n", " pass\n", "\n", "def render_from_pos_quat(pos,quat):\n", diff --git a/scripts/experiments/icra/camera_pose_tracking/util.py b/scripts/experiments/icra/camera_pose_tracking/util.py index 8c53bad8..5d578012 100644 --- a/scripts/experiments/icra/camera_pose_tracking/util.py +++ b/scripts/experiments/icra/camera_pose_tracking/util.py @@ -9,65 +9,85 @@ import numpy as np import torch -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # Projection and transformation matrix helpers. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def projection(x=0.1, n=1.0, f=50.0): - return np.array([[n/x, 0, 0, 0], - [ 0, n/x, 0, 0], - [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], - [ 0, 0, -1, 0]]).astype(np.float32) + return np.array( + [ + [n / x, 0, 0, 0], + [0, n / x, 0, 0], + [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)], + [0, 0, -1, 0], + ] + ).astype(np.float32) + def translate(x, y, z): - return np.array([[1, 0, 0, x], - [0, 1, 0, y], - [0, 0, 1, z], - [0, 0, 0, 1]]).astype(np.float32) + return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]).astype( + np.float32 + ) + def rotate_x(a): s, c = np.sin(a), np.cos(a) - return np.array([[1, 0, 0, 0], - [0, c, s, 0], - [0, -s, c, 0], - [0, 0, 0, 1]]).astype(np.float32) + return np.array([[1, 0, 0, 0], [0, c, s, 0], [0, -s, c, 0], [0, 0, 0, 1]]).astype( + np.float32 + ) + def rotate_y(a): s, c = np.sin(a), np.cos(a) - return np.array([[ c, 0, s, 0], - [ 0, 1, 0, 0], - [-s, 0, c, 0], - [ 0, 0, 0, 1]]).astype(np.float32) + return np.array([[c, 0, s, 0], [0, 1, 0, 0], [-s, 0, c, 0], [0, 0, 0, 1]]).astype( + np.float32 + ) + def random_rotation_translation(t): m = np.random.normal(size=[3, 3]) m[1] = np.cross(m[0], m[2]) m[2] = np.cross(m[0], m[1]) m = m / np.linalg.norm(m, axis=1, keepdims=True) - m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m = np.pad(m, [[0, 1], [0, 1]], mode="constant") m[3, 3] = 1.0 m[:3, 3] = np.random.uniform(-t, t, size=[3]) return m -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Bilinear downsample by 2x. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def bilinear_downsample(x): - w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 - w = w.expand(x.shape[-1], 1, 4, 4) - x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + w = ( + torch.tensor( + [[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], + dtype=torch.float32, + device=x.device, + ) + / 64.0 + ) + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d( + x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1] + ) return x.permute(0, 2, 3, 1) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Image display function using OpenGL. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- _glfw_window = None -def display_image(image, zoom=None, size=None, title=None): # HWC + + +def display_image(image, zoom=None, size=None, title=None): # HWC # Import OpenGL and glfw. - import OpenGL.GL as gl import glfw + import OpenGL.GL as gl # Zoom image if requested. image = np.asarray(image) @@ -80,7 +100,7 @@ def display_image(image, zoom=None, size=None, title=None): # HWC # Initialize window. if title is None: - title = 'Debug window' + title = "Debug window" global _glfw_window if _glfw_window is None: glfw.init() @@ -100,21 +120,25 @@ def display_image(image, zoom=None, size=None, title=None): # HWC gl.glWindowPos2f(0, 0) gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] - gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl_dtype = {"uint8": gl.GL_UNSIGNED_BYTE, "float32": gl.GL_FLOAT}[image.dtype.name] gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) glfw.swap_buffers(_glfw_window) if glfw.window_should_close(_glfw_window): return False return True -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- # Image save helper. -#---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + def save_image(fn, x): import imageio + x = np.rint(x * 255.0) x = np.clip(x, 0, 255).astype(np.uint8) imageio.imsave(fn, x) -#---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- diff --git a/scripts/experiments/mcs/cognitive-battery/model.py b/scripts/experiments/mcs/cognitive-battery/model.py index 468907ad..cc47755c 100644 --- a/scripts/experiments/mcs/cognitive-battery/model.py +++ b/scripts/experiments/mcs/cognitive-battery/model.py @@ -1,5 +1,6 @@ import json import os +from collections import deque import cog_utils as utils import jax @@ -12,9 +13,6 @@ from bayes3d.transforms_3d import transform_from_pos, unproject_depth from bayes3d.viz import get_depth_image, make_gif_from_pil_images, multi_panel -from collections import deque - - SCENE = "swap" @@ -123,6 +121,7 @@ ## Defining inference helper functions + # Enumerating proposals def make_unfiform_grid(n, d): # d: number of enumerated proposals on each dimension (x, y, z). @@ -146,6 +145,7 @@ def prior(new_pose, prev_poses): prior_parallel = jax.jit(jax.vmap(prior, in_axes=(0, None))) + # Liklelihood model def scorer(rendered_image, gt, r=0.1, op=0.005, ov=0.5): # Liklihood parameters @@ -246,7 +246,7 @@ def scorer(rendered_image, gt, r=0.1, op=0.005, ov=0.5): multi_panel( [rgb_viz, gt_depth_1, rendered_image, *rendered_apple], [ - f"\nRGB Image", + "\nRGB Image", f" Frame: {t}\nActual Depth", "\nReconstructed Depth", *(["\nApple Only"] * len(rendered_apple)), diff --git a/scripts/experiments/mcs/cognitive-battery/scene_graph.py b/scripts/experiments/mcs/cognitive-battery/scene_graph.py index 5dcd1c94..4f7afeb8 100644 --- a/scripts/experiments/mcs/cognitive-battery/scene_graph.py +++ b/scripts/experiments/mcs/cognitive-battery/scene_graph.py @@ -2,6 +2,7 @@ import jax.numpy as jnp + from bayes3d.transforms_3d import transform_from_pos @@ -26,11 +27,12 @@ def is_contained_in(self, other_bbox, threshold=0.03): return jnp.all(other_bbox.maxs > (self.maxs - threshold)) and jnp.all( other_bbox.mins < (self.mins + threshold) ) - + def move(self, vector): self.mins += vector self.maxs += vector + class SceneObject: def __init__(self, mesh_name, bbox, transform): self.name = mesh_name diff --git a/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py b/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py index 1fbb9cd2..e34216d4 100644 --- a/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py +++ b/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py @@ -1,9 +1,10 @@ import jax import jax.numpy as jnp -from jax.debug import print as jprint + + def physics_prior(proposed_pose, physics_estimated_pose): - proposed_pos = proposed_pose[:3,3] - physics_estimated_pos = physics_estimated_pose[:3,3] + proposed_pos = proposed_pose[:3, 3] + physics_estimated_pos = physics_estimated_pose[:3, 3] return jax.scipy.stats.multivariate_normal.logpdf( proposed_pos, physics_estimated_pos, jnp.diag(jnp.array([0.02, 0.02, 0.02])) ) @@ -12,6 +13,7 @@ def physics_prior(proposed_pose, physics_estimated_pose): physics_prior_parallel_jit = jax.jit(jax.vmap(physics_prior, in_axes=(0, None))) physics_prior_parallel = jax.vmap(physics_prior, in_axes=(0, None)) + def physics_prior_v1(prev_pose, prev_prev_pose, bbox_dims, camera_pose, world2cam): """ Score the physics of the simulation outside of a PPL, this will @@ -21,7 +23,7 @@ def physics_prior_v1(prev_pose, prev_prev_pose, bbox_dims, camera_pose, world2ca ASSUMPTIONS: A1 - Single object - A2 - Change in position, no change in orientation + A2 - Change in position, no change in orientation A3 - No Friction A4 - No Restitution/Damping (no bouncing) A5 - No Collision between objects (Except single object and Floor) @@ -47,40 +49,47 @@ def physics_prior_v1(prev_pose, prev_prev_pose, bbox_dims, camera_pose, world2ca # Assuming all poses are in camera frame # extract x-y-z positions - prev_pos = prev_pose[:3,3] - prev_prev_pos = prev_prev_pose[:3,3] + prev_pos = prev_pose[:3, 3] + prev_prev_pos = prev_prev_pose[:3, 3] # find X-Y-Z velocity change - # I1 & I2 -> find simple difference in world frame + check if object + # I1 & I2 -> find simple difference in world frame + check if object # is on the floor and force it to have no downward vector # conversions to world frame - prev_prev_pos_world = camera_pose[:3,:] @ jnp.concatenate([prev_prev_pos, 1], axis = None) - prev_pos_world = camera_pose[:3,:] @ jnp.concatenate([prev_pos, 1], axis = None) + prev_prev_pos_world = camera_pose[:3, :] @ jnp.concatenate( + [prev_prev_pos, 1], axis=None + ) + prev_pos_world = camera_pose[:3, :] @ jnp.concatenate([prev_pos, 1], axis=None) vel_pos_world = prev_pos_world - prev_prev_pos_world # find object's bottom in world frame - object_bottom = prev_pos_world[2] - 0.5*bbox_dims[2] + object_bottom = prev_pos_world[2] - 0.5 * bbox_dims[2] - vel_pos_world = jax.lax.cond(jnp.less_equal(object_bottom, 0.01 * bbox_dims[2]), + vel_pos_world = jax.lax.cond( + jnp.less_equal(object_bottom, 0.01 * bbox_dims[2]), lambda x: x.at[2].set(0), lambda x: x, - vel_pos_world) + vel_pos_world, + ) pred_pos_world = prev_pos_world + vel_pos_world - pred_pos = world2cam[:3,:] @ jnp.concatenate([pred_pos_world, 1], axis = None) - + pred_pos = world2cam[:3, :] @ jnp.concatenate([pred_pos_world, 1], axis=None) + # I1 -> Integrate X-Y-Z forward to current time step # jprint("pred pos: {}", camera_pose[:3,:] @ jnp.concatenate([pred_pos, 1], axis = None)) - physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same - physics_estimated_pose = physics_estimated_pose.at[:3,3].set(pred_pos) + physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same + physics_estimated_pose = physics_estimated_pose.at[:3, 3].set(pred_pos) return physics_estimated_pose - + + physics_prior_v1_jit = jax.jit(physics_prior_v1) -def physics_prior_v2(prev_poses, bbox_dims, camera_pose, world2cam, T, t_interval = 1.0/60.0): +def physics_prior_v2( + prev_poses, bbox_dims, camera_pose, world2cam, T, t_interval=1.0 / 60.0 +): """ Score the physics of the simulation outside of a PPL, this will score physics estimates independent of what we see (3DP3 likelihood) @@ -89,7 +98,7 @@ def physics_prior_v2(prev_poses, bbox_dims, camera_pose, world2cam, T, t_interva ASSUMPTIONS: A1 - Single object - A2 - Change in position, no change in orientation + A2 - Change in position, no change in orientation A3 - No Friction A4 - No Restitution/Damping (no bouncing) A5 - No Collision between objects (Except single object and Floor) @@ -114,34 +123,39 @@ def physics_prior_v2(prev_poses, bbox_dims, camera_pose, world2cam, T, t_interva # Assuming all poses are in camera frame # extract x-y-z positions - prev_pos = prev_poses[T,...] - prev_prev_pos = prev_poses[T-1,...] + prev_pos = prev_poses[T, ...] + prev_prev_pos = prev_poses[T - 1, ...] # find X-Y-Z velocity change - # I1 & I2 -> find simple difference in world frame + check if object + # I1 & I2 -> find simple difference in world frame + check if object # is on the floor and force it to have no downward vector # conversions to world frame - prev_prev_pos_world = camera_pose[:3,:] @ jnp.concatenate([prev_prev_pos, 1], axis = None) - prev_pos_world = camera_pose[:3,:] @ jnp.concatenate([prev_pos, 1], axis = None) + prev_prev_pos_world = camera_pose[:3, :] @ jnp.concatenate( + [prev_prev_pos, 1], axis=None + ) + prev_pos_world = camera_pose[:3, :] @ jnp.concatenate([prev_pos, 1], axis=None) vel_pos_world = prev_pos_world - prev_prev_pos_world # find object's bottom in world frame - object_bottom = prev_pos_world[2] - 0.5*bbox_dims[2] + object_bottom = prev_pos_world[2] - 0.5 * bbox_dims[2] - vel_pos_world = jax.lax.cond(jnp.less_equal(object_bottom, 0.01 * bbox_dims[2]), - lambda x: x, # x.at[2].set(0), + vel_pos_world = jax.lax.cond( + jnp.less_equal(object_bottom, 0.01 * bbox_dims[2]), + lambda x: x, # x.at[2].set(0), lambda x: x, - vel_pos_world) + vel_pos_world, + ) pred_pos_world = prev_pos_world + vel_pos_world - pred_pos = world2cam[:3,:] @ jnp.concatenate([pred_pos_world, 1], axis = None) - + pred_pos = world2cam[:3, :] @ jnp.concatenate([pred_pos_world, 1], axis=None) + # I1 -> Integrate X-Y-Z forward to current time step # jprint("pred pos: {}", camera_pose[:3,:] @ jnp.concatenate([pred_pos, 1], axis = None)) - physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same - physics_estimated_pose = physics_estimated_pose.at[:3,3].set(pred_pos) + physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same + physics_estimated_pose = physics_estimated_pose.at[:3, 3].set(pred_pos) return physics_estimated_pose - -physics_prior_v2_jit = jax.jit(physics_prior_v2) \ No newline at end of file + + +physics_prior_v2_jit = jax.jit(physics_prior_v2) diff --git a/scripts/experiments/tabletop/data_gen.py b/scripts/experiments/tabletop/data_gen.py index 9b7ff130..1dbaedf1 100644 --- a/scripts/experiments/tabletop/data_gen.py +++ b/scripts/experiments/tabletop/data_gen.py @@ -1,35 +1,32 @@ -import bayes3d as b +import os + import genjax -import jax.numpy as jnp import jax -import os -import matplotlib.pyplot as plt -import jax.tree_util as jtu -from tqdm import tqdm +import jax.numpy as jnp + +import bayes3d as b import bayes3d.genjax + console = genjax.pretty(show_locals=False) -from genjax._src.core.transforms.incremental import NoChange -from genjax._src.core.transforms.incremental import UnknownChange -from genjax._src.core.transforms.incremental import Diff -import inspect import joblib intrinsics = b.Intrinsics( - height=100, - width=100, - fx=500.0, fy=500.0, - cx=50.0, cy=50.0, - near=0.01, far=20.0 + height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0 ) b.setup_renderer(intrinsics) -model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models") +model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") meshes = [] -for idx in range(1,22): - mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply") - b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0) +for idx in range(1, 22): + mesh_path = os.path.join( + model_dir, "obj_" + "{}".format(idx).rjust(6, "0") + ".ply" + ) + b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0) -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0) +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), + scaling_factor=1.0 / 1000000000.0, +) table_pose = b.t3d.inverse_pose( b.t3d.transform_from_pos_target_up( @@ -48,29 +45,43 @@ while True: if scene_id >= 100: break - key, (_,trace) = importance_jit(key, genjax.choice_map({ - "parent_0": -1, - "parent_1": 0, - "parent_2": 0, - "parent_3": 0, - "id_0": jnp.int32(21), - "camera_pose": jnp.eye(4), - "root_pose_0": table_pose, - "face_parent_1": 2, - "face_parent_2": 2, - "face_parent_3": 2, - "face_child_1": 3, - "face_child_2": 3, - "face_child_3": 3, - }), ( - jnp.arange(4), - jnp.arange(22), - jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]), - jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]), - b.RENDERER.model_box_dims, OUTLIER_VOLUME) + key, (_, trace) = importance_jit( + key, + genjax.choice_map( + { + "parent_0": -1, + "parent_1": 0, + "parent_2": 0, + "parent_3": 0, + "id_0": jnp.int32(21), + "camera_pose": jnp.eye(4), + "root_pose_0": table_pose, + "face_parent_1": 2, + "face_parent_2": 2, + "face_parent_3": 2, + "face_child_1": 3, + "face_child_2": 3, + "face_child_3": 3, + } + ), + ( + jnp.arange(4), + jnp.arange(22), + jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]), + jnp.array( + [ + jnp.array([-0.2, -0.2, -2 * jnp.pi]), + jnp.array([0.2, 0.2, 2 * jnp.pi]), + ] + ), + b.RENDERER.model_box_dims, + OUTLIER_VOLUME, + ), ) if (b.genjax.get_indices(trace) == 21).sum() > 1: continue - - joblib.dump((trace.get_choices(), trace.get_args()), f"data/trace_{scene_id}.joblib") - scene_id += 1 \ No newline at end of file + + joblib.dump( + (trace.get_choices(), trace.get_args()), f"data/trace_{scene_id}.joblib" + ) + scene_id += 1 diff --git a/scripts/experiments/tabletop/inference.py b/scripts/experiments/tabletop/inference.py index 2a3137dc..72e0d35c 100644 --- a/scripts/experiments/tabletop/inference.py +++ b/scripts/experiments/tabletop/inference.py @@ -1,37 +1,33 @@ -import bayes3d as b +import os + import genjax -import jax.numpy as jnp import jax -import os -import matplotlib.pyplot as plt -import jax.tree_util as jtu +import jax.numpy as jnp from tqdm import tqdm + +import bayes3d as b import bayes3d.genjax + console = genjax.pretty(show_locals=False) -from genjax._src.core.transforms.incremental import NoChange -from genjax._src.core.transforms.incremental import UnknownChange -from genjax._src.core.transforms.incremental import Diff -import inspect import joblib - - intrinsics = b.Intrinsics( - height=100, - width=100, - fx=500.0, fy=500.0, - cx=50.0, cy=50.0, - near=0.01, far=20.0 + height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0 ) b.setup_renderer(intrinsics) -model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models") +model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") meshes = [] -for idx in range(1,22): - mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply") - b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0) +for idx in range(1, 22): + mesh_path = os.path.join( + model_dir, "obj_" + "{}".format(idx).rjust(6, "0") + ".ply" + ) + b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0) -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0) +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), + scaling_factor=1.0 / 1000000000.0, +) OUTLIER_VOLUME = 100.0 VARIANCE_GRID = jnp.array([0.000001, 0.00001, 0.0001]) @@ -40,32 +36,42 @@ # OUTLIER_GRID = jnp.array([ 0.0001]) grid_params = [ - (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi/3, (11,11,11)), (0.05, 0.0, (11,11,1)), - (0.05, jnp.pi/5, (11,11,11)), (0.02, 2*jnp.pi, (5,5,51)), (0.02, jnp.pi/5, (11,11,11)), (0.005, jnp.pi/10, (11,11,11)) + (0.2, jnp.pi, (11, 11, 11)), + (0.1, jnp.pi / 3, (11, 11, 11)), + (0.05, 0.0, (11, 11, 1)), + (0.05, jnp.pi / 5, (11, 11, 11)), + (0.02, 2 * jnp.pi, (5, 5, 51)), + (0.02, jnp.pi / 5, (11, 11, 11)), + (0.005, jnp.pi / 10, (11, 11, 11)), ] contact_param_gridding_schedule = [ - b.utils.make_translation_grid_enumeration_3d( - -x, -x, -ang, - x, x, ang, - *nums - ) - for (x,ang,nums) in grid_params + b.utils.make_translation_grid_enumeration_3d(-x, -x, -ang, x, x, ang, *nums) + for (x, ang, nums) in grid_params ] key = jax.random.PRNGKey(500) importance_jit = jax.jit(b.genjax.model.importance) -contact_enumerators = [b.genjax.make_enumerator([f"contact_params_{i}", "variance", "outlier_prob"]) for i in range(5)] +contact_enumerators = [ + b.genjax.make_enumerator([f"contact_params_{i}", "variance", "outlier_prob"]) + for i in range(5) +] add_object_jit = jax.jit(b.genjax.add_object) -def c2f_contact_update(trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID): + +def c2f_contact_update( + trace_, key, number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID +): contact_param_grid = contact_param_deltas + trace_[f"contact_params_{number}"] - scores = contact_enumerators[number].enumerate_choices_get_scores(trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID) - i,j,k = jnp.unravel_index(scores.argmax(), scores.shape) + scores = contact_enumerators[number].enumerate_choices_get_scores( + trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID + ) + i, j, k = jnp.unravel_index(scores.argmax(), scores.shape) return contact_enumerators[number].update_choices( - trace_, key, - contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k] + trace_, key, contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k] ) + + c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=("number",)) V_VARIANT = 0 @@ -87,42 +93,51 @@ def c2f_contact_update(trace_, key, number, contact_param_deltas, VARIANCE_GRID V_GRID = VARIANCE_GRID O_GRID = OUTLIER_GRID else: - V_GRID, O_GRID = jnp.array([VARIANCE_GRID[V_VARIANT]]), jnp.array([OUTLIER_GRID[O_VARIANT]]) + V_GRID, O_GRID = ( + jnp.array([VARIANCE_GRID[V_VARIANT]]), + jnp.array([OUTLIER_GRID[O_VARIANT]]), + ) print(V_GRID, O_GRID) gt_trace = importance_jit(key, *joblib.load(f"data/trace_{scene_id}.joblib"))[1][1] choices = gt_trace.get_choices() - key, (_,trace) = importance_jit(key, choices, (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:])) + key, (_, trace) = importance_jit( + key, choices, (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:]) + ) all_all_paths = [] for _ in range(3): all_paths = [] - for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)): + for obj_id in tqdm(range(len(b.RENDERER.meshes) - 1)): path = [] - trace_ = add_object_jit(trace, key, obj_id, 0, 2,3) + trace_ = add_object_jit(trace, key, obj_id, 0, 2, 3) number = b.genjax.get_contact_params(trace_).shape[0] - 1 path.append(trace_) for c2f_iter in range(len(contact_param_gridding_schedule)): - trace_ = c2f_contact_update_jit(trace_, key, number, - contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID) + trace_ = c2f_contact_update_jit( + trace_, + key, + number, + contact_param_gridding_schedule[c2f_iter], + V_GRID, + O_GRID, + ) path.append(trace_) # for c2f_iter in range(len(contact_param_gridding_schedule)): # trace_ = c2f_contact_update_jit(trace_, key, number, # contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID) - all_paths.append( - path - ) + all_paths.append(path) all_all_paths.append(all_paths) - + scores = jnp.array([t[-1].get_score() for t in all_paths]) print(scores) normalized_scores = b.utils.normalize_log_scores(scores) trace = all_paths[jnp.argmax(scores)][-1] - + print(b.genjax.get_indices(gt_trace)) print(b.genjax.get_indices(trace)) joblib.dump((trace.get_choices(), trace.get_args()), filename) del trace - del gt_trace \ No newline at end of file + del gt_trace diff --git a/scripts/run_colmap.py b/scripts/run_colmap.py index de712cb5..da1917f0 100644 --- a/scripts/run_colmap.py +++ b/scripts/run_colmap.py @@ -9,44 +9,67 @@ # For inquiries contact george.drettakis@inria.fr # -import os import logging -from argparse import ArgumentParser +import os import shutil +from argparse import ArgumentParser # This Python script is based on the shell converter script provided in the MipNerF 360 repository. parser = ArgumentParser("Colmap converter") -parser.add_argument("--no_gpu", action='store_true') -parser.add_argument("--skip_matching", action='store_true') +parser.add_argument("--no_gpu", action="store_true") +parser.add_argument("--skip_matching", action="store_true") parser.add_argument("--source_path", "-s", required=True, type=str) parser.add_argument("--camera", default="OPENCV", type=str) parser.add_argument("--colmap_executable", default="", type=str) parser.add_argument("--resize", action="store_true") parser.add_argument("--magick_executable", default="", type=str) args = parser.parse_args() -colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" -magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" +colmap_command = ( + '"{}"'.format(args.colmap_executable) + if len(args.colmap_executable) > 0 + else "colmap" +) +magick_command = ( + '"{}"'.format(args.magick_executable) + if len(args.magick_executable) > 0 + else "magick" +) use_gpu = 1 if not args.no_gpu else 0 if not args.skip_matching: os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) ## Feature extraction - feat_extracton_cmd = colmap_command + " feature_extractor "\ - "--database_path " + args.source_path + "/distorted/database.db \ - --image_path " + args.source_path + "/input \ + feat_extracton_cmd = ( + colmap_command + " feature_extractor " + "--database_path " + + args.source_path + + "/distorted/database.db \ + --image_path " + + args.source_path + + "/input \ --ImageReader.single_camera 1 \ - --ImageReader.camera_model " + args.camera + " \ - --SiftExtraction.use_gpu " + str(use_gpu) + --ImageReader.camera_model " + + args.camera + + " \ + --SiftExtraction.use_gpu " + + str(use_gpu) + ) exit_code = os.system(feat_extracton_cmd) if exit_code != 0: logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") exit(exit_code) ## Feature matching - feat_matching_cmd = colmap_command + " exhaustive_matcher \ - --database_path " + args.source_path + "/distorted/database.db \ - --SiftMatching.use_gpu " + str(use_gpu) + feat_matching_cmd = ( + colmap_command + + " exhaustive_matcher \ + --database_path " + + args.source_path + + "/distorted/database.db \ + --SiftMatching.use_gpu " + + str(use_gpu) + ) exit_code = os.system(feat_matching_cmd) if exit_code != 0: logging.error(f"Feature matching failed with code {exit_code}. Exiting.") @@ -55,11 +78,20 @@ ### Bundle adjustment # The default Mapper tolerance is unnecessarily large, # decreasing it speeds up bundle adjustment steps. - mapper_cmd = (colmap_command + " mapper \ - --database_path " + args.source_path + "/distorted/database.db \ - --image_path " + args.source_path + "/input \ - --output_path " + args.source_path + "/distorted/sparse \ - --Mapper.ba_global_function_tolerance=0.000001") + mapper_cmd = ( + colmap_command + + " mapper \ + --database_path " + + args.source_path + + "/distorted/database.db \ + --image_path " + + args.source_path + + "/input \ + --output_path " + + args.source_path + + "/distorted/sparse \ + --Mapper.ba_global_function_tolerance=0.000001" + ) exit_code = os.system(mapper_cmd) if exit_code != 0: logging.error(f"Mapper failed with code {exit_code}. Exiting.") @@ -67,11 +99,20 @@ ### Image undistortion ## We need to undistort our images into ideal pinhole intrinsics. -img_undist_cmd = (colmap_command + " image_undistorter \ - --image_path " + args.source_path + "/input \ - --input_path " + args.source_path + "/distorted/sparse/0 \ - --output_path " + args.source_path + "\ - --output_type COLMAP") +img_undist_cmd = ( + colmap_command + + " image_undistorter \ + --image_path " + + args.source_path + + "/input \ + --input_path " + + args.source_path + + "/distorted/sparse/0 \ + --output_path " + + args.source_path + + "\ + --output_type COLMAP" +) exit_code = os.system(img_undist_cmd) if exit_code != 0: logging.error(f"Mapper failed with code {exit_code}. Exiting.") @@ -81,13 +122,13 @@ os.makedirs(args.source_path + "/sparse/0", exist_ok=True) # Copy each file from the source directory to the destination directory for file in files: - if file == '0': + if file == "0": continue source_file = os.path.join(args.source_path, "sparse", file) destination_file = os.path.join(args.source_path, "sparse", "0", file) shutil.move(source_file, destination_file) -if(args.resize): +if args.resize: print("Copying and resizing...") # Resize images. @@ -102,23 +143,29 @@ destination_file = os.path.join(args.source_path, "images_2", file) shutil.copy2(source_file, destination_file) - exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) + exit_code = os.system( + magick_command + " mogrify -resize 50% " + destination_file + ) if exit_code != 0: logging.error(f"50% resize failed with code {exit_code}. Exiting.") exit(exit_code) destination_file = os.path.join(args.source_path, "images_4", file) shutil.copy2(source_file, destination_file) - exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) + exit_code = os.system( + magick_command + " mogrify -resize 25% " + destination_file + ) if exit_code != 0: logging.error(f"25% resize failed with code {exit_code}. Exiting.") exit(exit_code) destination_file = os.path.join(args.source_path, "images_8", file) shutil.copy2(source_file, destination_file) - exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) + exit_code = os.system( + magick_command + " mogrify -resize 12.5% " + destination_file + ) if exit_code != 0: logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") exit(exit_code) -print("Done.") \ No newline at end of file +print("Done.") diff --git a/scripts/ssh.py b/scripts/ssh.py index 8c79b097..a1ed1667 100644 --- a/scripts/ssh.py +++ b/scripts/ssh.py @@ -1,6 +1,7 @@ import paramiko from scp import SCPClient + class SSHSender: def __init__(self, hostname, username, ssh_key_path, result_directory): self.ssh = paramiko.SSHClient() @@ -10,5 +11,16 @@ def __init__(self, hostname, username, ssh_key_path, result_directory): self.result_directory = result_directory -sender = SSHSender('34.123.143.56', 'nishadgothoskar', '/Users/nishadgothoskar/.ssh/id_ed25519.pub', ".") -sender = SSHSender('34.123.143.56', 'nishadgothoskar', '/Users/nishadgothoskar/.ssh/id_ed25519.pub', ".") + +sender = SSHSender( + "34.123.143.56", + "nishadgothoskar", + "/Users/nishadgothoskar/.ssh/id_ed25519.pub", + ".", +) +sender = SSHSender( + "34.123.143.56", + "nishadgothoskar", + "/Users/nishadgothoskar/.ssh/id_ed25519.pub", + ".", +) diff --git a/test/test_bbox_intersect.py b/test/test_bbox_intersect.py index 0d1a6045..aa49377f 100644 --- a/test/test_bbox_intersect.py +++ b/test/test_bbox_intersect.py @@ -1,42 +1,60 @@ import os + import jax import jax.numpy as jnp + import bayes3d as b are_bboxes_intersecting_jit = jax.jit(b.utils.are_bboxes_intersecting) # set up renderer intrinsics = b.Intrinsics( -height=100, -width=100, -fx=250, fy=250, -cx=100/2.0, cy=100/2.0, -near=0.1, far=20 + height=100, width=100, fx=250, fy=250, cx=100 / 2.0, cy=100 / 2.0, near=0.1, far=20 ) b.setup_renderer(intrinsics) -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj") - , scaling_factor=0.1, mesh_name = "cube_1") -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj") - , scaling_factor=0.1, mesh_name = "cube_2") +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), + scaling_factor=0.1, + mesh_name="cube_1", +) +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), + scaling_factor=0.1, + mesh_name="cube_2", +) # make poses intersect/collide/penetrate -pose_1 = jnp.eye(4).at[:3,3].set([-0.1,0,1.5]) -pose_1 = pose_1 @ b.transform_from_axis_angle(jnp.array([1,0,0]), jnp.pi/4) -pose_2 = jnp.eye(4).at[:3,3].set([-0.05,0,1.5]) -pose_2 = pose_2 @ b.transform_from_axis_angle(jnp.array([1,1,1]), jnp.pi/4) +pose_1 = jnp.eye(4).at[:3, 3].set([-0.1, 0, 1.5]) +pose_1 = pose_1 @ b.transform_from_axis_angle(jnp.array([1, 0, 0]), jnp.pi / 4) +pose_2 = jnp.eye(4).at[:3, 3].set([-0.05, 0, 1.5]) +pose_2 = pose_2 @ b.transform_from_axis_angle(jnp.array([1, 1, 1]), jnp.pi / 4) # make sure the output confirms the intersection -b.scale_image(b.get_depth_image(b.RENDERER.render(jnp.stack([pose_1,pose_2]), jnp.array([0,1]))[:,:,2]),4).save("intersecting.png") -is_intersecting = are_bboxes_intersecting_jit(b.RENDERER.model_box_dims[0], b.RENDERER.model_box_dims[1], pose_1, pose_2) -assert is_intersecting == True +b.scale_image( + b.get_depth_image( + b.RENDERER.render(jnp.stack([pose_1, pose_2]), jnp.array([0, 1]))[:, :, 2] + ), + 4, +).save("intersecting.png") +is_intersecting = are_bboxes_intersecting_jit( + b.RENDERER.model_box_dims[0], b.RENDERER.model_box_dims[1], pose_1, pose_2 +) +assert is_intersecting is True # make poses NOT intersect/collided/penetrate -pose_2 = jnp.eye(4).at[:3,3].set([0.04,0,1.5]) -pose_2 = pose_2 @ b.transform_from_axis_angle(jnp.array([1,1,1]), jnp.pi/4) +pose_2 = jnp.eye(4).at[:3, 3].set([0.04, 0, 1.5]) +pose_2 = pose_2 @ b.transform_from_axis_angle(jnp.array([1, 1, 1]), jnp.pi / 4) # make sure the output confirms NO intersection -b.scale_image(b.get_depth_image(b.RENDERER.render(jnp.stack([pose_1,pose_2]), jnp.array([0,1]))[:,:,2]),4).save("no_intersecting.png") -is_intersecting = are_bboxes_intersecting_jit(b.RENDERER.model_box_dims[0], b.RENDERER.model_box_dims[1], pose_1, pose_2) -assert is_intersecting == False \ No newline at end of file +b.scale_image( + b.get_depth_image( + b.RENDERER.render(jnp.stack([pose_1, pose_2]), jnp.array([0, 1]))[:, :, 2] + ), + 4, +).save("no_intersecting.png") +is_intersecting = are_bboxes_intersecting_jit( + b.RENDERER.model_box_dims[0], b.RENDERER.model_box_dims[1], pose_1, pose_2 +) +assert is_intersecting is False diff --git a/test/test_colmap.py b/test/test_colmap.py index bab08494..87f9d050 100644 --- a/test/test_colmap.py +++ b/test/test_colmap.py @@ -1,30 +1,31 @@ -import bayes3d as b -import bayes3d.colmap +import argparse import glob +import subprocess from pathlib import Path -import argparse + +import bayes3d as b +import bayes3d.colmap parser = argparse.ArgumentParser() -parser.add_argument("movie_path", - help="Path to movie file", - type=str) +parser.add_argument("movie_path", help="Path to movie file", type=str) args = parser.parse_args() b.setup_visualizer() movie_file_path = Path(args.movie_path) -dataset_path = Path(b.utils.get_assets_dir()) / Path(movie_file_path.name + "_colmap_dataset") +dataset_path = Path(b.utils.get_assets_dir()) / Path( + movie_file_path.name + "_colmap_dataset" +) input_path = dataset_path / Path("input") input_path.mkdir(parents=True, exist_ok=True) b.utils.video_to_images(movie_file_path, input_path) -import subprocess assets_dir = Path(b.utils.get_assets_dir()) script_path = assets_dir.parent / Path("scripts/run_colmap.py") -import subprocess -subprocess.run([f"python {str(script_path)} -s {str(dataset_path)}"],shell=True) + +subprocess.run([f"python {str(script_path)} -s {str(dataset_path)}"], shell=True) image_paths = sorted(glob.glob(str(input_path / Path("*.jpg")))) @@ -32,9 +33,7 @@ images = [b.viz.load_image_from_file(f) for f in image_paths] # b.make_gif_from_pil_images(images, "input.gif") (positions, colors, normals), train_cam_infos = b.colmap.readColmapSceneInfo( - dataset_path, - "images", - False + dataset_path, "images", False ) train_cam_infos[0].FovY @@ -46,5 +45,5 @@ ] b.show_cloud("cloud", positions * scaling_factor) -for (i,p) in enumerate(poses): +for i, p in enumerate(poses): b.show_pose(f"{i}", p) diff --git a/test/test_cosypose.py b/test/test_cosypose.py index 7a0a1be1..4e9b84e4 100644 --- a/test/test_cosypose.py +++ b/test/test_cosypose.py @@ -1,12 +1,15 @@ -import os -import bayes3d as b -import jax -import jax.numpy as jnp +import os + import numpy as np -import subprocess + +import bayes3d as b from bayes3d.neural.cosypose_baseline import cosypose_utils bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv") -rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('55', '1592', bop_ycb_dir) +rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img( + "55", "1592", bop_ycb_dir +) -pred = cosypose_utils.cosypose_interface(np.array(rgbd.rgb), b.K_from_intrinsics(rgbd.intrinsics)) \ No newline at end of file +pred = cosypose_utils.cosypose_interface( + np.array(rgbd.rgb), b.K_from_intrinsics(rgbd.intrinsics) +) diff --git a/test/test_differentiable_rendering.py b/test/test_differentiable_rendering.py index 9d81be4d..b7887c53 100644 --- a/test/test_differentiable_rendering.py +++ b/test/test_differentiable_rendering.py @@ -1,48 +1,44 @@ -from collections import namedtuple +import os + import jax import jax.numpy as jnp import numpy as np -import os, argparse -import time -import torch + import bayes3d as b from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer - - intrinsics = b.Intrinsics( - height=200, - width=200, - fx=200.0, fy=200.0, - cx=100.0, cy=100.0, - near=0.01, far=5.5 + height=200, width=200, fx=200.0, fy=200.0, cx=100.0, cy=100.0, near=0.01, far=5.5 ) jax_renderer = JaxRenderer(intrinsics) -#--------------------- +# --------------------- # Load object -#--------------------- -model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models") +# --------------------- +model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") idx = 14 -mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply") +mesh_path = os.path.join(model_dir, "obj_" + "{}".format(idx).rjust(6, "0") + ".ply") m = b.utils.load_mesh(mesh_path) -m = b.utils.scale_mesh(m, 1.0/100.0) +m = b.utils.scale_mesh(m, 1.0 / 100.0) vtx_pos = jnp.array(m.vertices.astype(np.float32)) pos_idx = jnp.array(m.faces.astype(np.int32)) print("Mesh has %d triangles and %d vertices." % (pos_idx.shape[0], vtx_pos.shape[0])) -resolution = jnp.array([200,200]) -pos = vtx_pos[None,...] -pos = jnp.concatenate([pos, jnp.ones((*pos.shape[:-1],1))], axis=-1) +resolution = jnp.array([200, 200]) +pos = vtx_pos[None, ...] +pos = jnp.concatenate([pos, jnp.ones((*pos.shape[:-1], 1))], axis=-1) def func(i): - rast_out, rast_out_db = jax_renderer.rasterize(pos + i, pos_idx, jnp.array([200,200])) + rast_out, rast_out_db = jax_renderer.rasterize( + pos + i, pos_idx, jnp.array([200, 200]) + ) return rast_out.mean() + func_jit = jax.jit(func) print(func_jit(0.0)) grad_func = jax.value_and_grad(func) -val,grad = grad_func(0.0) +val, grad = grad_func(0.0) print(grad) # # Test Torch @@ -60,14 +56,19 @@ def func(i): def func(i): - rast_out, rast_out_db = jax_renderer.rasterize(pos + i, pos_idx, jnp.array([200,200])) - colors,_ = jax_renderer.interpolate(pos + i, rast_out, pos_idx, rast_out_db, jnp.array([0,1,2,3])) + rast_out, rast_out_db = jax_renderer.rasterize( + pos + i, pos_idx, jnp.array([200, 200]) + ) + colors, _ = jax_renderer.interpolate( + pos + i, rast_out, pos_idx, rast_out_db, jnp.array([0, 1, 2, 3]) + ) return colors.mean() + func_jit = jax.jit(func) print(func_jit(0.0)) grad_func = jax.value_and_grad(func) -val,grad = grad_func(0.0) +val, grad = grad_func(0.0) print(grad) # # Test Torch @@ -83,5 +84,3 @@ def func(i): # print(loss) # loss.backward() # print(input_vec.grad) - - diff --git a/test/test_genjax_model.py b/test/test_genjax_model.py index b47adacc..4e80938b 100644 --- a/test/test_genjax_model.py +++ b/test/test_genjax_model.py @@ -1,23 +1,25 @@ -import bayes3d as b -import bayes3d.genjax -import jax import os -import jax.numpy as jnp + import genjax +import jax +import jax.numpy as jnp + +import bayes3d as b +import bayes3d.genjax key = jax.random.PRNGKey(1) intrinsics = b.Intrinsics( - height=100, - width=100, - fx=300.0, fy=300.0, - cx=50.0, cy=50.0, - near=0.01, far=20.0 + height=100, width=100, fx=300.0, fy=300.0, cx=50.0, cy=50.0, near=0.01, far=20.0 ) b.setup_renderer(intrinsics) -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj")) -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj")) +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj") +) +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj") +) importance_jit = jax.jit(b.model.importance) @@ -31,30 +33,46 @@ enumerators = b.make_enumerator(["contact_params_1"]) + def test_genjax_trace_contains_right_info(): key = jax.random.PRNGKey(1) low, high = jnp.array([-0.2, -0.2, -jnp.pi]), jnp.array([0.2, 0.2, jnp.pi]) - weight, trace = importance_jit(key, genjax.choice_map({ - "parent_0": -1, - "parent_1": 0, - "id_0": jnp.int32(1), - "id_1": jnp.int32(0), - "root_pose_0": table_pose, - "camera_pose": jnp.eye(4), - "face_parent_1": 3, - "face_child_1": 2, - "variance": 0.0001, - "outlier_prob": 0.0001, - "contact_params_1": jax.random.uniform(key, shape=(3,),minval=low, maxval=high) - }), ( - jnp.arange(2), - jnp.arange(22), - jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]), - jnp.array([jnp.array([-0.5, -0.5, -2*jnp.pi]), jnp.array([0.5, 0.5, 2*jnp.pi])]), - b.RENDERER.model_box_dims, 1.0, intrinsics.fx) + weight, trace = importance_jit( + key, + genjax.choice_map( + { + "parent_0": -1, + "parent_1": 0, + "id_0": jnp.int32(1), + "id_1": jnp.int32(0), + "root_pose_0": table_pose, + "camera_pose": jnp.eye(4), + "face_parent_1": 3, + "face_child_1": 2, + "variance": 0.0001, + "outlier_prob": 0.0001, + "contact_params_1": jax.random.uniform( + key, shape=(3,), minval=low, maxval=high + ), + } + ), + ( + jnp.arange(2), + jnp.arange(22), + jnp.array([-jnp.ones(3) * 100.0, jnp.ones(3) * 100.0]), + jnp.array( + [ + jnp.array([-0.5, -0.5, -2 * jnp.pi]), + jnp.array([0.5, 0.5, 2 * jnp.pi]), + ] + ), + b.RENDERER.model_box_dims, + 1.0, + intrinsics.fx, + ), ) - scores = enumerators.enumerate_choices_get_scores(trace, key, jnp.zeros((100,3))) + scores = enumerators.enumerate_choices_get_scores(trace, key, jnp.zeros((100, 3))) assert trace["parent_0"] == -1 assert (trace["camera_pose"] == jnp.eye(4)).all() diff --git a/test/test_icp.py b/test/test_icp.py index ebf7d270..1b01f29f 100644 --- a/test/test_icp.py +++ b/test/test_icp.py @@ -1,74 +1,72 @@ -import numpy as np -import jax.numpy as jnp +import os + import jax -import time -from PIL import Image -import matplotlib.pyplot as plt -import cv2 import jax.numpy as jnp + import bayes3d as b -from scipy.spatial.transform import Rotation as R -import os b.setup_visualizer() N = 100 -cloud = jax.random.uniform(jax.random.PRNGKey(10), shape=(N, 3))*0.1 +cloud = jax.random.uniform(jax.random.PRNGKey(10), shape=(N, 3)) * 0.1 b.show_cloud("c", cloud) -pose = b.distributions.gaussian_vmf_zero_mean(jax.random.PRNGKey(5), 0.1,10.0) +pose = b.distributions.gaussian_vmf_zero_mean(jax.random.PRNGKey(5), 0.1, 10.0) cloud_transformed = b.apply_transform(cloud, pose) b.show_cloud("d", cloud_transformed, color=b.RED) -transform = b.utils.find_least_squares_transform_between_clouds(cloud, cloud_transformed) +transform = b.utils.find_least_squares_transform_between_clouds( + cloud, cloud_transformed +) print(jnp.abs(cloud - cloud_transformed).sum()) print(jnp.abs(cloud_transformed - b.apply_transform(cloud, transform)).sum()) intrinsics = b.Intrinsics( - height=50, - width=50, - fx=50.0, fy=50.0, - cx=25.0, cy=25.0, - near=0.01, far=1.0 + height=50, width=50, fx=50.0, fy=50.0, cx=25.0, cy=25.0, near=0.01, far=1.0 ) b.setup_renderer(intrinsics) -model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models") +model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") meshes = [] -for idx in range(1,22): - mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply") - b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0) - -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0) - - +for idx in range(1, 22): + mesh_path = os.path.join( + model_dir, "obj_" + "{}".format(idx).rjust(6, "0") + ".ply" + ) + b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0 / 1000.0) + +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), + scaling_factor=1.0 / 1000000000.0, +) pose = b.t3d.transform_from_pos(jnp.array([-1.0, -1.0, 4.0])) -pose2 = pose @ b.distributions.gaussian_vmf_zero_mean(jax.random.PRNGKey(5), 0.05,1000.0) - +pose2 = pose @ b.distributions.gaussian_vmf_zero_mean( + jax.random.PRNGKey(5), 0.05, 1000.0 +) b.show_pose("1", pose) b.show_pose("2", pose2) -img1 = b.RENDERER.render(pose.reshape(-1,4,4), jnp.array([0]))[...,:3] -img2 = b.RENDERER.render(pose2.reshape(-1,4,4), jnp.array([0]))[...,:3] +img1 = b.RENDERER.render(pose.reshape(-1, 4, 4), jnp.array([0]))[..., :3] +img2 = b.RENDERER.render(pose2.reshape(-1, 4, 4), jnp.array([0]))[..., :3] b.clear() -b.show_cloud("c", img1.reshape(-1,3)) -b.show_cloud("d", img2.reshape(-1,3), color=b.RED) +b.show_cloud("c", img1.reshape(-1, 3)) +b.show_cloud("d", img2.reshape(-1, 3), color=b.RED) -mask = (img1[:,:,2] < intrinsics.far) * (img2[:,:,2] < intrinsics.far) +mask = (img1[:, :, 2] < intrinsics.far) * (img2[:, :, 2] < intrinsics.far) -transform = b.utils.find_least_squares_transform_between_clouds(img1[mask,:], img2[mask, :]) +transform = b.utils.find_least_squares_transform_between_clouds( + img1[mask, :], img2[mask, :] +) print(jnp.abs(img2[mask, :] - img1[mask, :]).sum()) print(jnp.abs(img2[mask, :] - b.apply_transform(img1[mask, :], transform)).sum()) print(jnp.abs(cloud_transformed - b.apply_transform(cloud, transform)).sum()) - diff --git a/test/test_kubric.py b/test/test_kubric.py index 6549d58e..63108aff 100644 --- a/test/test_kubric.py +++ b/test/test_kubric.py @@ -1,45 +1,49 @@ -import jax.numpy as jnp -import bayes3d as b -import trimesh import os -import numpy as np + +import jax.numpy as jnp import trimesh +from IPython import embed from tqdm import tqdm + +import bayes3d as b from bayes3d.rendering.photorealistic_renderers.kubric_interface import render_many # --- creating the ycb dir from the working directory bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv") -rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('52', '1', bop_ycb_dir) - +rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img("52", "1", bop_ycb_dir) mesh_paths = [] offset_poses = [] model_dir = os.path.join(b.utils.get_assets_dir(), "ycb_video_models/models") for i in tqdm(gt_ids): - mesh_path = os.path.join(model_dir, b.utils.ycb_loader.MODEL_NAMES[i],"textured.obj") + mesh_path = os.path.join( + model_dir, b.utils.ycb_loader.MODEL_NAMES[i], "textured.obj" + ) _, pose = b.utils.mesh.center_mesh(trimesh.load(mesh_path), return_pose=True) offset_poses.append(pose) - mesh_paths.append( - mesh_path - ) + mesh_paths.append(mesh_path) intrinsics = b.Intrinsics( - rgbd.intrinsics.height, rgbd.intrinsics.width, - 200.0, 200.0, - rgbd.intrinsics.width/2, rgbd.intrinsics.height/2, - rgbd.intrinsics.near, rgbd.intrinsics.far + rgbd.intrinsics.height, + rgbd.intrinsics.width, + 200.0, + 200.0, + rgbd.intrinsics.width / 2, + rgbd.intrinsics.height / 2, + rgbd.intrinsics.near, + rgbd.intrinsics.far, ) print(intrinsics) poses = [] for i in range(len(gt_ids)): - poses.append( - gt_poses[i] @ b.t3d.inverse_pose(offset_poses[i]) - ) + poses.append(gt_poses[i] @ b.t3d.inverse_pose(offset_poses[i])) poses = jnp.array(poses) -rgbds = render_many(mesh_paths, poses[None,...], intrinsics, scaling_factor=1.0, lighting=5.0) +rgbds = render_many( + mesh_paths, poses[None, ...], intrinsics, scaling_factor=1.0, lighting=5.0 +) b.setup_renderer(intrinsics) @@ -50,8 +54,11 @@ kubri_rgb = b.get_rgb_image(rgbds[0].rgb) kubric_depth = b.get_depth_image(rgbds[0].depth) -rerendered_depth = b.get_depth_image(img[:,:,2]) +rerendered_depth = b.get_depth_image(img[:, :, 2]) overlay = b.overlay_image(kubric_depth, rerendered_depth, alpha=0.5) -b.multi_panel([kubri_rgb, kubric_depth, rerendered_depth, overlay],labels=["kubric_rgb", "kubric_depth", "rerendered_depth", "overlay"]).save("test_kubric.png") +b.multi_panel( + [kubri_rgb, kubric_depth, rerendered_depth, overlay], + labels=["kubric_rgb", "kubric_depth", "rerendered_depth", "overlay"], +).save("test_kubric.png") -from IPython import embed; embed() \ No newline at end of file +embed() diff --git a/test/test_likelihood.py b/test/test_likelihood.py index d1dd88c1..5b3ada2f 100644 --- a/test/test_likelihood.py +++ b/test/test_likelihood.py @@ -1,17 +1,8 @@ -import numpy as np import jax.numpy as jnp -import jax + import bayes3d as b -import trimesh -import os -import time -H=100 -W=200 -observed_xyz, rendered_xyz = jnp.ones((H,W,3)), jnp.ones((H,W,3)) +H = 100 +W = 200 +observed_xyz, rendered_xyz = jnp.ones((H, W, 3)), jnp.ones((H, W, 3)) b.threedp3_likelihood(observed_xyz, rendered_xyz, 0.007, 0.1, 0.1, 1.0, 3) - - - - - diff --git a/test/test_open3d.py b/test/test_open3d.py index 03a10d69..5972ef84 100644 --- a/test/test_open3d.py +++ b/test/test_open3d.py @@ -1,59 +1,60 @@ -import jax.numpy as jnp -import bayes3d as b -import trimesh import os -import numpy as np + +import jax.numpy as jnp import trimesh +from IPython import embed from tqdm import tqdm -from bayes3d.viz.open3dviz import Open3DVisualizer +import bayes3d as b +from bayes3d.viz.open3dviz import Open3DVisualizer # --- creating the ycb dir from the working directory bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv") -rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('52', '1', bop_ycb_dir) - +rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img("52", "1", bop_ycb_dir) mesh_paths = [] offset_poses = [] model_dir = os.path.join(b.utils.get_assets_dir(), "ycb_video_models/models") for i in tqdm(gt_ids): - mesh_path = os.path.join(model_dir, b.utils.ycb_loader.MODEL_NAMES[i],"textured.obj") + mesh_path = os.path.join( + model_dir, b.utils.ycb_loader.MODEL_NAMES[i], "textured.obj" + ) _, pose = b.utils.mesh.center_mesh(trimesh.load(mesh_path), return_pose=True) offset_poses.append(pose) - mesh_paths.append( - mesh_path - ) + mesh_paths.append(mesh_path) intrinsics = b.Intrinsics( - rgbd.intrinsics.height, rgbd.intrinsics.width, - rgbd.intrinsics.fx, rgbd.intrinsics.fx, - rgbd.intrinsics.width/2, rgbd.intrinsics.height/2, - rgbd.intrinsics.near, rgbd.intrinsics.far + rgbd.intrinsics.height, + rgbd.intrinsics.width, + rgbd.intrinsics.fx, + rgbd.intrinsics.fx, + rgbd.intrinsics.width / 2, + rgbd.intrinsics.height / 2, + rgbd.intrinsics.near, + rgbd.intrinsics.far, ) poses = [] for i in range(len(gt_ids)): - poses.append( - gt_poses[i] @ b.t3d.inverse_pose(offset_poses[i]) - ) + poses.append(gt_poses[i] @ b.t3d.inverse_pose(offset_poses[i])) poses = jnp.array(poses) visualizer = Open3DVisualizer(intrinsics) visualizer.clear() -for (pose, path) in zip(poses, mesh_paths): +for pose, path in zip(poses, mesh_paths): visualizer.make_mesh_from_file(path, pose) rgbd_textured_reconstruction = visualizer.capture_image(intrinsics, jnp.eye(4)) visualizer.clear() colors = b.viz.distinct_colors(len(gt_ids)) -for (i,(pose, path)) in enumerate(zip(poses, mesh_paths)): +for i, (pose, path) in enumerate(zip(poses, mesh_paths)): mesh = b.utils.load_mesh(path) visualizer.make_trimesh(mesh, pose, (*tuple(colors[i]), 1.0)) -rgbd_color_mesh_reconstruction= visualizer.capture_image(intrinsics, jnp.eye(4)) +rgbd_color_mesh_reconstruction = visualizer.capture_image(intrinsics, jnp.eye(4)) panel = b.viz.multi_panel( [ @@ -63,13 +64,9 @@ b.overlay_image( b.get_rgb_image(rgbd.rgb), b.get_rgb_image(rgbd_color_mesh_reconstruction.rgb), - ) + ), ] ) panel.save("test.png") - - - - -from IPython import embed; embed() \ No newline at end of file +embed() diff --git a/test/test_renderer.py b/test/test_renderer.py index 07206f30..6ce88bb6 100644 --- a/test/test_renderer.py +++ b/test/test_renderer.py @@ -1,19 +1,12 @@ -import numpy as np -import jax.numpy as jnp -import jax -import bayes3d as b -import trimesh import os -import time +import jax +import jax.numpy as jnp +from IPython import embed + +import bayes3d as b -intrinsics = b.Intrinsics( - 300, - 300, - 200.0,200.0, - 150.0,150.0, - 0.001, 50.0 -) +intrinsics = b.Intrinsics(300, 300, 200.0, 200.0, 150.0, 150.0, 0.001, 50.0) b.setup_renderer(intrinsics, num_layers=1) renderer = b.RENDERER @@ -21,86 +14,106 @@ outlier_prob = 0.01 max_depth = 15.0 -renderer.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj")) -renderer.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/sphere.obj")) - +renderer.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj") +) +renderer.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/sphere.obj") +) num_parallel_frames = 20 -gt_poses_1 = jnp.tile(jnp.array([ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, -1.0], - [0.0, 0.0, 1.0, 8.0], - [0.0, 0.0, 0.0, 1.0], - ] -)[None,...],(num_parallel_frames,1,1)) -gt_poses_1 = gt_poses_1.at[:,0,3].set(jnp.linspace(-2.0, 2.0, gt_poses_1.shape[0])) -gt_poses_1 = gt_poses_1.at[:,2,3].set(jnp.linspace(10.0, 5.0, gt_poses_1.shape[0])) - -gt_poses_2 = jnp.tile(jnp.array([ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, -1.0], - [0.0, 0.0, 1.0, 8.0], - [0.0, 0.0, 0.0, 1.0], - ] -)[None,...],(num_parallel_frames,1,1)) -gt_poses_2 = gt_poses_2.at[:,0,3].set(jnp.linspace(4.0, -3.0, gt_poses_2.shape[0])) -gt_poses_2 = gt_poses_2.at[:,2,3].set(jnp.linspace(12.0, 5.0, gt_poses_2.shape[0])) - -gt_poses_all = jnp.stack([gt_poses_1, gt_poses_2],axis=1) - -indices = jnp.array( [0, 1]) +gt_poses_1 = jnp.tile( + jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, -1.0], + [0.0, 0.0, 1.0, 8.0], + [0.0, 0.0, 0.0, 1.0], + ] + )[None, ...], + (num_parallel_frames, 1, 1), +) +gt_poses_1 = gt_poses_1.at[:, 0, 3].set(jnp.linspace(-2.0, 2.0, gt_poses_1.shape[0])) +gt_poses_1 = gt_poses_1.at[:, 2, 3].set(jnp.linspace(10.0, 5.0, gt_poses_1.shape[0])) + +gt_poses_2 = jnp.tile( + jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, -1.0], + [0.0, 0.0, 1.0, 8.0], + [0.0, 0.0, 0.0, 1.0], + ] + )[None, ...], + (num_parallel_frames, 1, 1), +) +gt_poses_2 = gt_poses_2.at[:, 0, 3].set(jnp.linspace(4.0, -3.0, gt_poses_2.shape[0])) +gt_poses_2 = gt_poses_2.at[:, 2, 3].set(jnp.linspace(12.0, 5.0, gt_poses_2.shape[0])) + +gt_poses_all = jnp.stack([gt_poses_1, gt_poses_2], axis=1) + +indices = jnp.array([0, 1]) multiobject_scene_img = renderer.render(gt_poses_all[-1, ...], jnp.array([0, 1])) -multiobject_viz = b.get_depth_image(multiobject_scene_img[:,:,2]) +multiobject_viz = b.get_depth_image(multiobject_scene_img[:, :, 2]) multiobject_scene_parallel_img = renderer.render_many(gt_poses_all, jnp.array([0, 1])) -multiobject_parallel_viz = b.get_depth_image(multiobject_scene_parallel_img[-1,:,:,2]) +multiobject_parallel_viz = b.get_depth_image( + multiobject_scene_parallel_img[-1, :, :, 2] +) -segmentation_viz = b.get_depth_image(multiobject_scene_parallel_img[-1,:,:,3]) +segmentation_viz = b.get_depth_image(multiobject_scene_parallel_img[-1, :, :, 3]) -images = [b.get_depth_image(multiobject_scene_parallel_img[i,:,:,2]) for i in range(num_parallel_frames)] +images = [ + b.get_depth_image(multiobject_scene_parallel_img[i, :, :, 2]) + for i in range(num_parallel_frames) +] b.multi_panel( [multiobject_viz, multiobject_parallel_viz, segmentation_viz] + images ).save("test_renderer.png") def test_segmentation_produces_sensical_outputs(): - assert jnp.allclose(multiobject_scene_parallel_img[-1,:,:,3].max(), 2.0) - assert jnp.allclose(multiobject_scene_parallel_img[-1,:,:,3].min(), 0.0) - assert jnp.allclose(multiobject_scene_img[:,:,3].max(), 2.0) - assert jnp.allclose(multiobject_scene_img[:,:,3].min(), 0.0) + assert jnp.allclose(multiobject_scene_parallel_img[-1, :, :, 3].max(), 2.0) + assert jnp.allclose(multiobject_scene_parallel_img[-1, :, :, 3].min(), 0.0) + assert jnp.allclose(multiobject_scene_img[:, :, 3].max(), 2.0) + assert jnp.allclose(multiobject_scene_img[:, :, 3].min(), 0.0) + def test_something_is_being_rendered(): - assert not jnp.all(multiobject_scene_parallel_img[0,:,:,2] == intrinsics.far) - assert not jnp.all(multiobject_scene_parallel_img[-1,:,:,2] == intrinsics.far) - assert not jnp.all(multiobject_scene_img[:,:,2] == intrinsics.far) + assert not jnp.all(multiobject_scene_parallel_img[0, :, :, 2] == intrinsics.far) + assert not jnp.all(multiobject_scene_parallel_img[-1, :, :, 2] == intrinsics.far) + assert not jnp.all(multiobject_scene_img[:, :, 2] == intrinsics.far) + def render(key): pose = b.distributions.vmf(key, 1.0) img = b.RENDERER.render(jnp.array([pose]), jnp.array([0])) return img + render_parallel = jax.jit(jax.vmap(render)) x = render_parallel(jax.random.split(jax.random.PRNGKey(10), 10)) -y = renderer.render_many_custom_intrinsics(gt_poses_all, jnp.array([0, 1]),intrinsics) +y = renderer.render_many_custom_intrinsics(gt_poses_all, jnp.array([0, 1]), intrinsics) render_many_custom_intrinsics_jit = jax.jit(renderer.render_many_custom_intrinsics) -z = render_many_custom_intrinsics_jit(gt_poses_all, jnp.array([0, 1]),intrinsics) +z = render_many_custom_intrinsics_jit(gt_poses_all, jnp.array([0, 1]), intrinsics) -render_parallel = jax.jit(jax.vmap(b.RENDERER.render, in_axes=(0,None))) +render_parallel = jax.jit(jax.vmap(b.RENDERER.render, in_axes=(0, None))) render_parallel(gt_poses_all, jnp.array([0, 1])).shape -render_many_custom_intrinsics_parallel = jax.vmap(renderer.render, in_axes=(0,None)) +render_many_custom_intrinsics_parallel = jax.vmap(renderer.render, in_axes=(0, None)) render_many_custom_intrinsics_parallel(gt_poses_all, jnp.array([0, 1])).shape -render_many_custom_intrinsics_parallel = jax.vmap(renderer.render_custom_intrinsics, in_axes=(0,None, None)) +render_many_custom_intrinsics_parallel = jax.vmap( + renderer.render_custom_intrinsics, in_axes=(0, None, None) +) del render_many_custom_intrinsics_parallel -render_many_custom_intrinsics_parallel(gt_ -poses_all, jnp.array([0, 1]), intrinsics).shape +# render_many_custom_intrinsics_parallel(gt_poses_all, jnp.array([0, 1]), intrinsics).shape -from IPython import embed; embed() \ No newline at end of file +embed() diff --git a/test/test_renderer_internals.py b/test/test_renderer_internals.py index b6bce357..b330523a 100644 --- a/test/test_renderer_internals.py +++ b/test/test_renderer_internals.py @@ -1,21 +1,13 @@ -import numpy as np -import jax.numpy as jnp -import jax -import bayes3d as b -import trimesh import os -import time + +import jax +import jax.numpy as jnp import torch -import bayes3d._rendering.nvdiffrast.common as dr +import bayes3d as b +import bayes3d._rendering.nvdiffrast.common as dr -intrinsics = b.Intrinsics( - 300, - 300, - 200.0,200.0, - 150.0,150.0, - 0.001, 50.0 -) +intrinsics = b.Intrinsics(300, 300, 200.0, 200.0, 150.0, 150.0, 0.001, 50.0) b.setup_renderer(intrinsics) renderer = b.RENDERER @@ -23,28 +15,41 @@ outlier_prob = 0.01 max_depth = 15.0 -renderer.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj")) -renderer.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/sphere.obj")) +renderer.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj") +) +renderer.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/sphere.obj") +) -poses = jnp.array([ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, -1.0], - [0.0, 0.0, 1.0, 8.0], - [0.0, 0.0, 0.0, 1.0], +poses = jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, -1.0], + [0.0, 0.0, 1.0, 8.0], + [0.0, 0.0, 0.0, 1.0], ] )[None, None, ...] indices = jnp.array([0]) -img = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(dr._get_plugin(gl=True).rasterize_fwd_gl( - b.RENDERER.renderer_env.cpp_wrapper, - torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(jnp.tile(poses, (1,2,1,1)))), - b.RENDERER.proj_list, - [0] -))) -b.get_depth_image(img[0,:,:,2]).save("1.png") -assert not jnp.all(img[0,:,:,2] == 0.0) - -multiobject_scene_img = renderer._render_many(jnp.tile(poses, (2,1,1,1)), jnp.array([1]))[0] -b.get_depth_image(multiobject_scene_img[:,:,2]).save("0.png") -assert not jnp.all( multiobject_scene_img[:,:,2] == 0.0) +img = jax.dlpack.from_dlpack( + torch.utils.dlpack.to_dlpack( + dr._get_plugin(gl=True).rasterize_fwd_gl( + b.RENDERER.renderer_env.cpp_wrapper, + torch.utils.dlpack.from_dlpack( + jax.dlpack.to_dlpack(jnp.tile(poses, (1, 2, 1, 1))) + ), + b.RENDERER.proj_list, + [0], + ) + ) +) +b.get_depth_image(img[0, :, :, 2]).save("1.png") +assert not jnp.all(img[0, :, :, 2] == 0.0) + +multiobject_scene_img = renderer._render_many( + jnp.tile(poses, (2, 1, 1, 1)), jnp.array([1]) +)[0] +b.get_depth_image(multiobject_scene_img[:, :, 2]).save("0.png") +assert not jnp.all(multiobject_scene_img[:, :, 2] == 0.0) diff --git a/test/test_renderer_memory.py b/test/test_renderer_memory.py index 940bda13..42b498e8 100644 --- a/test/test_renderer_memory.py +++ b/test/test_renderer_memory.py @@ -1,19 +1,12 @@ -import numpy as np +import gc +import os + import jax.numpy as jnp -import jax + import bayes3d as b -import os -import gc -import time # setup renderer -intrinsics = b.Intrinsics( - 50, - 50, - 200.0,200.0, - 25.0,25.0, - 0.001, 10.0 -) +intrinsics = b.Intrinsics(50, 50, 200.0, 200.0, 25.0, 25.0, 0.001, 10.0) # Note: removing the b.RENDERER object does the same operation in C++ as clear_meshmem() b.setup_renderer(intrinsics, num_layers=1) renderer = b.RENDERER @@ -21,24 +14,27 @@ pre_test_clearmesh = b.utils.get_gpu_memory()[0] for i in range(5): - b.setup_renderer(intrinsics, num_layers=1) renderer = b.RENDERER pre_add_mesh = b.utils.get_gpu_memory()[0] for x in range(1): - renderer.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj"), mesh_name = f'cube_{i+1}') + renderer.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), + mesh_name=f"cube_{i+1}", + ) post_add_mesh = b.utils.get_gpu_memory()[0] - pose = jnp.array([ - [1.0, 0.0, 0.0, 0.5], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 10.0], - [0.0, 0.0, 0.0, 1.0], + pose = jnp.array( + [ + [1.0, 0.0, 0.0, 0.5], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], ] ) - depth = renderer.render(pose[None,...], jnp.array([0]))[...,2] + depth = renderer.render(pose[None, ...], jnp.array([0]))[..., 2] post_render = b.utils.get_gpu_memory()[0] @@ -46,16 +42,21 @@ post_clear_meshmem = b.utils.get_gpu_memory()[0] - # ensure the mesh memory is fully cleared assert pre_add_mesh - post_add_mesh == post_clear_meshmem - post_render gc.collect() - print(f"{i}: ",b.utils.get_gpu_memory()[0]) + print(f"{i}: ", b.utils.get_gpu_memory()[0]) post_test_clearmesh = b.utils.get_gpu_memory()[0] # Expected result should be around 2MiB for the given camera intrinsics -print("GPU memory lost with clear_meshmem() --> ", pre_test_clearmesh - post_test_clearmesh, " MiB") -print("The memeory lost is from the JAX memeory in GPU and not accumulations in the GPU") \ No newline at end of file +print( + "GPU memory lost with clear_meshmem() --> ", + pre_test_clearmesh - post_test_clearmesh, + " MiB", +) +print( + "The memeory lost is from the JAX memeory in GPU and not accumulations in the GPU" +) diff --git a/test/test_scene_graph.py b/test/test_scene_graph.py index 93a767ba..130c3efc 100644 --- a/test/test_scene_graph.py +++ b/test/test_scene_graph.py @@ -1,17 +1,18 @@ -import bayes3d as b -import jax.numpy as jnp import jax -import os +import jax.numpy as jnp +import bayes3d as b N = 4 scene_graph = b.scene_graph.SceneGraph( - root_poses= jnp.tile(jnp.eye(4)[None,...],(N,1,1)), - box_dimensions = jnp.ones((N,3)), - parents = jnp.array([-1, 0, 0, 2]), - contact_params = jax.random.uniform(jax.random.PRNGKey(10),(N,3), minval=-1.0, maxval=1.0), - face_parent = jnp.array([0 ,1, 1, 2]), - face_child = jnp.array([2 ,3, 4, 5]) + root_poses=jnp.tile(jnp.eye(4)[None, ...], (N, 1, 1)), + box_dimensions=jnp.ones((N, 3)), + parents=jnp.array([-1, 0, 0, 2]), + contact_params=jax.random.uniform( + jax.random.PRNGKey(10), (N, 3), minval=-1.0, maxval=1.0 + ), + face_parent=jnp.array([0, 1, 1, 2]), + face_child=jnp.array([2, 3, 4, 5]), ) scene_graph.visualize("graph.png", node_names=["table", "apple", "can", "banana"]) @@ -28,12 +29,17 @@ dims_parent = scene_graph.box_dimensions[parent_object_index] dims_child = scene_graph.box_dimensions[child_object_index] -parent_contact_plane = parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent] -child_contact_plane = child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child] +parent_contact_plane = ( + parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent] +) +child_contact_plane = ( + child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child] +) -contact_params, slack = b.scene_graph.closest_approximate_contact_params(parent_contact_plane, child_contact_plane) -assert jnp.isclose(slack[:3,3], 0.0, atol=1e-7).all() -assert jnp.isclose(slack[:3,:3], jnp.eye(3), atol=1e-7).all() +contact_params, slack = b.scene_graph.closest_approximate_contact_params( + parent_contact_plane, child_contact_plane +) +assert jnp.isclose(slack[:3, 3], 0.0, atol=1e-7).all() +assert jnp.isclose(slack[:3, :3], jnp.eye(3), atol=1e-7).all() assert jnp.isclose(contact_params, scene_graph.contact_params[child_object_index]).all() - diff --git a/test/test_splatting.py b/test/test_splatting.py index a2b9151d..3a6f15c4 100644 --- a/test/test_splatting.py +++ b/test/test_splatting.py @@ -1,26 +1,24 @@ -import diff_gaussian_rasterization as dgr -from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer -import torch -import os -import numpy as np -import matplotlib.pyplot as plt import math -from tqdm import tqdm -import bayes3d as b + import jax.numpy as jnp +import torch +from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, +) +from IPython import embed + +import bayes3d as b device = torch.device("cuda" if torch.cuda.is_available() else "cpu") intrinsics = b.Intrinsics( - height=100, - width=100, - fx=50.0, fy=50.0, - cx=50.0, cy=50.0, - near=0.1, far=6.0 + height=100, width=100, fx=50.0, fy=50.0, cx=50.0, cy=50.0, near=0.1, far=6.0 ) # proj_matrix = torch.tensor(b.camera._open_gl_projection_matrix(intrinsics.height, intrinsics.width, intrinsics.fx, intrinsics.fy, intrinsics.cx, intrinsics.cy, intrinsics.near, intrinsics.far).astype(np.float32), device=device) # print(proj_matrix) + def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) @@ -43,20 +41,23 @@ def getProjectionMatrix(znear, zfar, fovX, fovY): P[2, 3] = -(zfar * znear) / (zfar - znear) return P -fovX = jnp.deg2rad(45) + +fovX = jnp.deg2rad(45) fovY = jnp.deg2rad(45) -proj_matrix = torch.tensor(getProjectionMatrix(intrinsics.near, intrinsics.far, fovX, fovY), device=device) +proj_matrix = torch.tensor( + getProjectionMatrix(intrinsics.near, intrinsics.far, fovX, fovY), device=device +) N = 1 # means3D = torch.rand((N, 3)).cuda() - 0.5 + torch.tensor([0.0, 0.0, 4.8],device= device) -means3D = torch.tensor([[-0.01, 0.01, 1.0]],device= device) -means2D = torch.ones((N, 3),device= device) +means3D = torch.tensor([[-0.01, 0.01, 1.0]], device=device) +means2D = torch.ones((N, 3), device=device) colors = torch.rand((N, 3)).cuda() opacity = torch.rand((N, 1)).cuda() + 0.5 scales = torch.rand((N, 3)).cuda() * 0.1 rotations = torch.rand((N, 4)).cuda() -# tan_fovx = intrinsics.width / intrinsics.fx / 2.0 +# tan_fovx = intrinsics.width / intrinsics.fx / 2.0 # tan_fovy = intrinsics.hei"ght / intrinsics.fy / 2.0 # print(tan_fovx, tan_fovy) @@ -68,46 +69,43 @@ def getProjectionMatrix(znear, zfar, fovX, fovY): image_width=int(intrinsics.width), tanfovx=tan_fovx, tanfovy=tan_fovy, - bg=torch.tensor([1.,1.,1.]).cuda(), + bg=torch.tensor([1.0, 1.0, 1.0]).cuda(), scale_modifier=1.0, viewmatrix=torch.eye(4).cuda(), projmatrix=proj_matrix, sh_degree=1, campos=torch.zeros(3).cuda(), prefiltered=False, - debug=None + debug=None, ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) ground_truth_image, radii = rasterizer( - means3D = means3D, - means2D = means2D, - shs = None, - colors_precomp = colors, - opacities = opacity, - scales = scales, - rotations = rotations, - cov3D_precomp = None + means3D=means3D, + means2D=means2D, + shs=None, + colors_precomp=colors, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=None, ) -from IPython import embed; embed() +embed() -p_hom = torch.transpose(proj_matrix,0,1) @ torch.tensor([means3D[0,0], means3D[0,1],means3D[0,2], 1.0], device= device) +p_hom = torch.transpose(proj_matrix, 0, 1) @ torch.tensor( + [means3D[0, 0], means3D[0, 1], means3D[0, 2], 1.0], device=device +) print(p_hom) p_proj = p_hom / p_hom[3] print(p_proj) - - - - - # proj_matrix = torch.tensor(getProjectionMatrix(intrinsics.near, intrinsics.far, fovX, fovY), device=device) -# p_orig -0.010000 0.010000 1.000000 +# p_orig -0.010000 0.010000 1.000000 # p_hom -0.024142 0.024142 2.016949 -0.101695 # p_proj 0.237398 -0.237398 -19.833355 -# point_image 61.369896 37.630104 \ No newline at end of file +# point_image 61.369896 37.630104 diff --git a/test/test_transforms_3d.py b/test/test_transforms_3d.py index 41f36997..600fb710 100644 --- a/test/test_transforms_3d.py +++ b/test/test_transforms_3d.py @@ -1,13 +1,12 @@ -import bayes3d as b import jax import jax.numpy as jnp - +import bayes3d as b def test_estimate_transform_between_clouds(): key = jax.random.PRNGKey(500) - c1 = jax.random.uniform(jax.random.PRNGKey(0), (10,3)) * 5.0 + c1 = jax.random.uniform(jax.random.PRNGKey(0), (10, 3)) * 5.0 random_pose = b.distributions.gaussian_vmf_zero_mean(key, 0.1, 1.0) c2 = b.t3d.apply_transform(c1, random_pose) diff --git a/test/test_viz.py b/test/test_viz.py index 326ae15e..9b3e74fd 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -1,23 +1,22 @@ -import jax.numpy as jnp -import bayes3d as b import os -import jax -import functools + +import jax.numpy as jnp import matplotlib.pyplot as plt -import pathlib + +import bayes3d as b bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv") -rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('52', '1', bop_ycb_dir) +rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img("52", "1", bop_ycb_dir) fig = b.viz_depth_image(rgbd.depth) fig.savefig("depth.png", **b.saveargs) fig = b.viz_rgb_image(rgbd.rgb) fig.savefig("rgb.png", **b.saveargs) fig = plt.figure() -ax = fig.add_subplot(1,2,1) +ax = fig.add_subplot(1, 2, 1) b.add_rgb_image(ax, rgbd.rgb) ax.set_title("RGB") -ax = fig.add_subplot(1,2,2) +ax = fig.add_subplot(1, 2, 2) b.add_depth_image(ax, rgbd.depth) ax.set_title("DEPTH") fig.savefig("fig.png", **b.saveargs) @@ -28,39 +27,37 @@ ################################################################################## # set up renderer -intrinsics = b.Intrinsics( - 50, - 50, - 200.0,200.0, - 25.0,25.0, - 0.001, 20.0 -) +intrinsics = b.Intrinsics(50, 50, 200.0, 200.0, 25.0, 25.0, 0.001, 20.0) b.setup_renderer(intrinsics) renderer = b.RENDERER -renderer.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj")) +renderer.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj") +) # Test 1: check if b.get_depth_image returns a valid image if there is no object in the scene -no_object_in_scene_pose = jnp.array([ - [1.0, 0.0, 0.0, -100.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 10.0], - [0.0, 0.0, 0.0, 1.0], +no_object_in_scene_pose = jnp.array( + [ + [1.0, 0.0, 0.0, -100.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], ] ) -depth = renderer.render(no_object_in_scene_pose[None,...], jnp.array([0]))[...,2] -depth_image = b.scale_image(b.get_depth_image(depth),8) -depth_image.save('viz_test_no_object_in_scene.png') +depth = renderer.render(no_object_in_scene_pose[None, ...], jnp.array([0]))[..., 2] +depth_image = b.scale_image(b.get_depth_image(depth), 8) +depth_image.save("viz_test_no_object_in_scene.png") # Test 2: check if b.get_depth_image returns a valid image if object has only one unique depth value -object_unique_depth_pose = jnp.array([ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 10.0], - [0.0, 0.0, 0.0, 1.0], +object_unique_depth_pose = jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], ] ) -depth = renderer.render(object_unique_depth_pose[None,...], jnp.array([0]))[...,2] -assert jnp.unique(depth).size == 2 # far and object's depth -depth_image = b.scale_image(b.get_depth_image(depth),8) -depth_image.save('viz_test_object_unique_depth.png') \ No newline at end of file +depth = renderer.render(object_unique_depth_pose[None, ...], jnp.array([0]))[..., 2] +assert jnp.unique(depth).size == 2 # far and object's depth +depth_image = b.scale_image(b.get_depth_image(depth), 8) +depth_image.save("viz_test_object_unique_depth.png") diff --git a/test/test_ycb_loading.py b/test/test_ycb_loading.py index 92b0d9f1..f1019ae0 100644 --- a/test/test_ycb_loading.py +++ b/test/test_ycb_loading.py @@ -1,93 +1,127 @@ +import os + +import jax import jax.numpy as jnp -import bayes3d as b import numpy as np + +import bayes3d as b import bayes3d.utils.ycb_loader -import trimesh -import jax -import os -from tqdm import tqdm def test_ycb_loading(): bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv") - rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('52', '1', bop_ycb_dir) + rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img( + "52", "1", bop_ycb_dir + ) b.setup_renderer(rgbd.intrinsics, num_layers=1) - model_dir =os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") - for idx in range(1,22): - b.RENDERER.add_mesh_from_file(os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply"),scaling_factor=1.0/1000.0) + model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") + for idx in range(1, 22): + b.RENDERER.add_mesh_from_file( + os.path.join(model_dir, "obj_" + "{}".format(idx).rjust(6, "0") + ".ply"), + scaling_factor=1.0 / 1000.0, + ) - reconstruction_depth = b.RENDERER.render(gt_poses, gt_ids)[:,:,2] + reconstruction_depth = b.RENDERER.render(gt_poses, gt_ids)[:, :, 2] match_fraction = (jnp.abs(rgbd.depth - reconstruction_depth) < 0.05).mean() assert match_fraction > 0.2 + bop_ycb_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv") -rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img('55', '22', bop_ycb_dir) -poses = jnp.concatenate([jnp.eye(4)[None,...], rgbd.camera_pose @ gt_poses],axis=0) -ids = jnp.concatenate([jnp.array([21]), gt_ids],axis=0) +rgbd, gt_ids, gt_poses, masks = b.utils.ycb_loader.get_test_img("55", "22", bop_ycb_dir) +poses = jnp.concatenate([jnp.eye(4)[None, ...], rgbd.camera_pose @ gt_poses], axis=0) +ids = jnp.concatenate([jnp.array([21]), gt_ids], axis=0) b.setup_renderer(rgbd.intrinsics, num_layers=1) -model_dir =os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") -for idx in range(1,22): - b.RENDERER.add_mesh_from_file(os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply"),scaling_factor=1.0/1000.0) - -b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0) - - +model_dir = os.path.join(b.utils.get_assets_dir(), "bop/ycbv/models") +for idx in range(1, 22): + b.RENDERER.add_mesh_from_file( + os.path.join(model_dir, "obj_" + "{}".format(idx).rjust(6, "0") + ".ply"), + scaling_factor=1.0 / 1000.0, + ) +b.RENDERER.add_mesh_from_file( + os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), + scaling_factor=1.0 / 1000000000.0, +) scene_graph = b.scene_graph.SceneGraph( root_poses=poses, box_dimensions=b.RENDERER.model_box_dims[ids], parents=jnp.full(poses.shape[0], -1), - contact_params=jnp.zeros((poses.shape[0],3)), + contact_params=jnp.zeros((poses.shape[0], 3)), face_parent=jnp.zeros(poses.shape[0], dtype=jnp.int32), face_child=jnp.zeros(poses.shape[0], dtype=jnp.int32), ) assert jnp.isclose(scene_graph.get_poses(), poses).all() -def get_slack(scene_graph, parent_object_index, child_object_index, face_parent, face_child): + +def get_slack( + scene_graph, parent_object_index, child_object_index, face_parent, face_child +): parent_pose = scene_graph.get_poses()[parent_object_index] child_pose = scene_graph.get_poses()[child_object_index] dims_parent = scene_graph.box_dimensions[parent_object_index] dims_child = scene_graph.box_dimensions[child_object_index] - parent_contact_plane = parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent] - child_contact_plane = child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child] + parent_contact_plane = ( + parent_pose @ b.scene_graph.get_contact_planes(dims_parent)[face_parent] + ) + child_contact_plane = ( + child_pose @ b.scene_graph.get_contact_planes(dims_child)[face_child] + ) + + contact_params, slack = b.scene_graph.closest_approximate_contact_params( + parent_contact_plane, child_contact_plane + ) + return ( + jnp.array([parent_object_index, child_object_index, face_parent, face_child]), + contact_params, + slack, + ) - contact_params, slack = b.scene_graph.closest_approximate_contact_params(parent_contact_plane, child_contact_plane) - return jnp.array([parent_object_index, child_object_index, face_parent, face_child]), contact_params, slack add_edge_scene_graph = jax.jit(b.scene_graph.add_edge_scene_graph) - N = poses.shape[0] b.setup_visualizer() -get_slack_vmap = jax.jit(b.utils.multivmap(get_slack, (False, False, False, True, True))) +get_slack_vmap = jax.jit( + b.utils.multivmap(get_slack, (False, False, False, True, True)) +) -edges = [(0,1),(0,2),(0,3),(0,4),(0,6),(2,5)] -for i,j in edges: - settings, contact_params, slacks = get_slack_vmap(scene_graph, i,j, jnp.arange(6), jnp.arange(6)) - settings = settings.reshape(-1,settings.shape[-1]) - contact_params = contact_params.reshape(-1,contact_params.shape[-1]) - error = jnp.abs(slacks - jnp.eye(4)).sum([-1,-2]).reshape(-1) +edges = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 6), (2, 5)] +for i, j in edges: + settings, contact_params, slacks = get_slack_vmap( + scene_graph, i, j, jnp.arange(6), jnp.arange(6) + ) + settings = settings.reshape(-1, settings.shape[-1]) + contact_params = contact_params.reshape(-1, contact_params.shape[-1]) + error = jnp.abs(slacks - jnp.eye(4)).sum([-1, -2]).reshape(-1) indices = jnp.argsort(error.reshape(-1)) - parent_object_index, child_object_index, face_parent, face_child = settings[indices[0]] - scene_graph = add_edge_scene_graph(scene_graph,parent_object_index, child_object_index, face_parent, face_child, contact_params[indices[0]]) + parent_object_index, child_object_index, face_parent, face_child = settings[ + indices[0] + ] + scene_graph = add_edge_scene_graph( + scene_graph, + parent_object_index, + child_object_index, + face_parent, + face_child, + contact_params[indices[0]], + ) node_names = np.array([*b.utils.ycb_loader.MODEL_NAMES, "table"]) -scene_graph.visualize("graph.png", node_names=list(map(str,enumerate(node_names[ids])))) +scene_graph.visualize( + "graph.png", node_names=list(map(str, enumerate(node_names[ids]))) +) b.clear() -for i,p in enumerate(scene_graph.get_poses()): +for i, p in enumerate(scene_graph.get_poses()): b.show_trimesh(f"pose_{i}", b.RENDERER.meshes[ids[i]]) b.set_pose(f"pose_{i}", p) - - - From df18958664fe3b9aa7ae027e30dcb6a5960e737f Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Tue, 23 Jan 2024 12:40:03 -0500 Subject: [PATCH 2/6] fix final ruff errors --- bayes3d/__init__.py | 10 +--------- bayes3d/_mkl/gaussian_renderer.py | 13 ++++++++++--- bayes3d/_mkl/simple_likelihood.py | 17 +++++++++-------- bayes3d/_mkl/trimesh_to_gaussians.py | 18 +++++++++--------- bayes3d/colmap/colmap_loader.py | 2 +- bayes3d/genjax/model.py | 9 +++++++-- .../neural/cosypose_baseline/cosypose_utils.py | 9 ++++----- bayes3d/renderer.py | 8 ++++---- bayes3d/scene_graph.py | 1 - bayes3d/utils/occlusion.py | 2 +- bayes3d/utils/pybullet_sim.py | 4 ++-- bayes3d/utils/r3d_loader.py | 3 +-- bayes3d/utils/ycb_loader.py | 2 +- pyproject.toml | 1 + scripts/_mkl/notebooks/kubric/kubric_helper.py | 2 ++ .../collaborations/arijit_physics.py | 2 +- scripts/experiments/colmap/colmap_loader.py | 2 +- .../mcs/otp_gen/otp_gen/physics_priors.py | 8 ++++---- scripts/experiments/tabletop/data_gen.py | 2 +- scripts/experiments/tabletop/inference.py | 2 +- test/test_genjax_model.py | 2 +- 21 files changed, 62 insertions(+), 57 deletions(-) diff --git a/bayes3d/__init__.py b/bayes3d/__init__.py index 2d18c3df..4d72a187 100644 --- a/bayes3d/__init__.py +++ b/bayes3d/__init__.py @@ -1,6 +1,7 @@ """ .. include:: ./documentation.md """ + from .camera import * from .likelihood import * from .renderer import * @@ -8,13 +9,4 @@ from .transforms_3d import * from .viz import * -try: - import genjax - - from .genjax import * -except ImportError as e: - print("GenJAX not installed. Importing bayes3d without genjax dependencies.") - print(e) - - RENDERER = None diff --git a/bayes3d/_mkl/gaussian_renderer.py b/bayes3d/_mkl/gaussian_renderer.py index 5e56441d..ab9c0db5 100644 --- a/bayes3d/_mkl/gaussian_renderer.py +++ b/bayes3d/_mkl/gaussian_renderer.py @@ -34,12 +34,19 @@ normal_pdf = jax.scipy.stats.norm.pdf normal_logpdf = jax.scipy.stats.norm.logpdf inv = jnp.linalg.inv +from bayes3d._mkl.types import ( + Array, + CholeskyMatrix, + CovarianceMatrix, + Direction, + Float, + Matrix, + PrecisionMatrix, + Vector, +) key = jax.random.PRNGKey(0) -# %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 5 -from bayes3d._mkl.types import * - # %% ../../scripts/_mkl/notebooks/06a - Gaussian Renderer.ipynb 6 def ellipsoid_embedding(cov: CovarianceMatrix) -> Matrix: diff --git a/bayes3d/_mkl/simple_likelihood.py b/bayes3d/_mkl/simple_likelihood.py index 67a01940..79699bf6 100644 --- a/bayes3d/_mkl/simple_likelihood.py +++ b/bayes3d/_mkl/simple_likelihood.py @@ -1,6 +1,15 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb. # %% auto 0 +import genjax +import genjax._src.generative_functions.distributions.tensorflow_probability as gentfp +import jax +import jax.numpy as jnp +import tensorflow_probability.substrates.jax as tfp +from genjax._src.generative_functions.distributions.distribution import ExactDensity + +from bayes3d._mkl.utils import * + __all__ = [ "key", "tfd", @@ -20,17 +29,10 @@ ] # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 3 -import genjax -import jax -import jax.numpy as jnp - -from bayes3d._mkl.utils import * key = jax.random.PRNGKey(0) # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 4 -import genjax._src.generative_functions.distributions.tensorflow_probability as gentfp -import tensorflow_probability.substrates.jax as tfp tfd = tfp.distributions @@ -177,7 +179,6 @@ def sensor_model(Y, sig, out): # %% ../../scripts/_mkl/notebooks/10 - Simple Likelihood.ipynb 15 -from genjax._src.generative_functions.distributions.distribution import ExactDensity def wrap_into_dist(score_func): diff --git a/bayes3d/_mkl/trimesh_to_gaussians.py b/bayes3d/_mkl/trimesh_to_gaussians.py index 722e2284..6b616791 100644 --- a/bayes3d/_mkl/trimesh_to_gaussians.py +++ b/bayes3d/_mkl/trimesh_to_gaussians.py @@ -61,6 +61,15 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb. # %% auto 0 +import jax +import jax.numpy as jnp +import jaxlib +import numpy as np +import trimesh +from jax import jit, vmap + +from bayes3d._mkl.utils import * + __all__ = [ "Array", "Shape", @@ -156,15 +165,6 @@ """ # %% ../../scripts/_mkl/notebooks/05 - Trimesh to Gaussians.ipynb 3 -import jax -import jax.numpy as jnp -import jaxlib -import numpy as np -import trimesh -from jax import jit, vmap - -from bayes3d._mkl.utils import * - Array = np.ndarray | jax.Array Shape = int | tuple[int, ...] FaceIndex = int diff --git a/bayes3d/colmap/colmap_loader.py b/bayes3d/colmap/colmap_loader.py index c9fcb6f1..9516ccb1 100644 --- a/bayes3d/colmap/colmap_loader.py +++ b/bayes3d/colmap/colmap_loader.py @@ -170,7 +170,7 @@ def read_points3D_binary(path_to_model_file): track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 0 ] - track_elems = read_next_bytes( + _track_elems = read_next_bytes( fid, num_bytes=8 * track_length, format_char_sequence="ii" * track_length, diff --git a/bayes3d/genjax/model.py b/bayes3d/genjax/model.py index 7257f1b8..6f667197 100644 --- a/bayes3d/genjax/model.py +++ b/bayes3d/genjax/model.py @@ -8,7 +8,12 @@ import bayes3d as b -from .genjax_distributions import * +from .genjax_distributions import ( + contact_params_uniform, + image_likelihood, + uniform_discrete, + uniform_pose, +) @genjax.static @@ -64,7 +69,7 @@ def model(array, possible_object_indices, pose_bounds, contact_bounds, all_box_d variance = genjax.uniform(0.00000000001, 10000.0) @ "variance" outlier_prob = genjax.uniform(-0.01, 10000.0) @ "outlier_prob" - image = image_likelihood(rendered, variance, outlier_prob) @ "image" + _image = image_likelihood(rendered, variance, outlier_prob) @ "image" return ( rendered, indices, diff --git a/bayes3d/neural/cosypose_baseline/cosypose_utils.py b/bayes3d/neural/cosypose_baseline/cosypose_utils.py index d9d1372a..b6c3918a 100644 --- a/bayes3d/neural/cosypose_baseline/cosypose_utils.py +++ b/bayes3d/neural/cosypose_baseline/cosypose_utils.py @@ -2,18 +2,17 @@ import signal import subprocess import sys +import time import numpy as np +import torch +import yaml cosypose_path = ( f"{os.path.dirname(os.path.abspath(__file__))}/cosypose_baseline/cosypose" ) sys.path.append(cosypose_path) # TODO cleaner import / add to path -import time - -import torch -import yaml torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -211,7 +210,7 @@ def cosypose_interface(rgb_imgs, camera_k): pred_poses = np.asarray(pred.poses.cpu()) pred_ids = [ - int(l[-3:]) - 1 for l in pred.infos.label + int(label[-3:]) - 1 for label in pred.infos.label ] # ex) 'obj_000014' for GT_IDX 13 pred_scores = [pred.infos.iloc[i].score for i in range(len(pred.infos))] diff --git a/bayes3d/renderer.py b/bayes3d/renderer.py index 49837703..b3e6fce1 100644 --- a/bayes3d/renderer.py +++ b/bayes3d/renderer.py @@ -382,10 +382,10 @@ def _load_vertices_lowering(ctx, vertices, triangles): # Extract the numpy type of the inputs vertices_aval, triangles_aval = ctx.avals_in - if np.dtype(vertices_aval.dtype) != np.float32: - raise NotImplementedError(f"Unsupported vertices dtype {np_dtype}") - if np.dtype(triangles_aval.dtype) != np.int32: - raise NotImplementedError(f"Unsupported triangles dtype {np_dtype}") + if (dt := np.dtype(vertices_aval.dtype)) != np.float32: + raise NotImplementedError(f"Unsupported vertices dtype {dt}") + if (dt := np.dtype(triangles_aval.dtype)) != np.int32: + raise NotImplementedError(f"Unsupported triangles dtype {dt}") opaque = dr._get_plugin(gl=True).build_load_vertices_descriptor( r.renderer_env.cpp_wrapper, vertices_aval.shape[0], triangles_aval.shape[0] diff --git a/bayes3d/scene_graph.py b/bayes3d/scene_graph.py index 8e096457..6b704e44 100644 --- a/bayes3d/scene_graph.py +++ b/bayes3d/scene_graph.py @@ -112,7 +112,6 @@ def add_edge_scene_graph( scene_graph, parent, child, face_parent, face_child, contact_params ): print(parent, child, face_parent, face_child) - N = scene_graph.get_poses().shape[0] sg_parents = jnp.array(scene_graph.parents) sg_parents = sg_parents.at[child].set(parent) sg_contact_params = jnp.array(scene_graph.contact_params) diff --git a/bayes3d/utils/occlusion.py b/bayes3d/utils/occlusion.py index 515eea58..aed4680a 100644 --- a/bayes3d/utils/occlusion.py +++ b/bayes3d/utils/occlusion.py @@ -21,7 +21,7 @@ def voxel_occupied_occluded_free(camera_pose, depth_image, grid, intrinsics, tol occupied = jnp.abs(real_depth_vals - projected_depth_vals) < tolerance occluded = real_depth_vals < projected_depth_vals occluded = occluded * (1.0 - occupied) - free = (1.0 - occluded) * (1.0 - occupied) + _free = (1.0 - occluded) * (1.0 - occupied) return 1.0 * occupied + 0.5 * occluded diff --git a/bayes3d/utils/pybullet_sim.py b/bayes3d/utils/pybullet_sim.py index cbfde663..01b579cd 100644 --- a/bayes3d/utils/pybullet_sim.py +++ b/bayes3d/utils/pybullet_sim.py @@ -86,7 +86,7 @@ def pybullet_render(scene): return image, depth -def create_box( +def create_box_from_pose( pose, scale=[1, 1, 1], restitution=1, @@ -500,7 +500,7 @@ def simulate(self, timesteps): return pyb def close(self): - if self.pyb_sim == None: + if self.pyb_sim is None: raise ValueError("No pybullet simulation to close") else: p.disconnect(self.pyb_sim.client) diff --git a/bayes3d/utils/r3d_loader.py b/bayes3d/utils/r3d_loader.py index 5873eb4f..94d6cb18 100644 --- a/bayes3d/utils/r3d_loader.py +++ b/bayes3d/utils/r3d_loader.py @@ -137,10 +137,9 @@ def load_r3d(r3d_path): color_paths = natsorted(glob.glob(os.path.join(datapath, "rgbd", "*.jpg"))) depth_paths = natsorted(glob.glob(os.path.join(datapath, "rgbd", "*.depth"))) - conf_paths = natsorted(glob.glob(os.path.join(datapath, "rgbd", "*.conf"))) colors = np.array([load_color(color_paths[i]) for i in range(len(color_paths))]) - depths = np.array([load_depth(depth_paths[i]) for i in range(len(color_paths))]) + depths = np.array([load_depth(depth_paths[i]) for i in range(len(depth_paths))]) depths[np.isnan(depths)] = 0.0 poses = get_poses(metadata) diff --git a/bayes3d/utils/ycb_loader.py b/bayes3d/utils/ycb_loader.py index 46480432..91fa2e63 100644 --- a/bayes3d/utils/ycb_loader.py +++ b/bayes3d/utils/ycb_loader.py @@ -80,7 +80,7 @@ def get_test_img(scene_id, img_id, ycb_dir): cam_depth_scale = image_cam_data["depth_scale"] # get {visible mask, ID, pose} for each object in the scene - anno = dict() + # anno = dict() # get GT object model ID+poses objects_gt_data = scene_imgs_gt_data[remove_zero_pad(img_id)] diff --git a/pyproject.toml b/pyproject.toml index ae978ce9..e08b51a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ line-length = 88 indent-width = 4 [tool.ruff.lint] +exclude = ["bayes3d/_mkl/*.py"] extend-select = ["I"] select = ["E4", "E7", "E9", "F"] diff --git a/scripts/_mkl/notebooks/kubric/kubric_helper.py b/scripts/_mkl/notebooks/kubric/kubric_helper.py index 64d289d3..93a62592 100644 --- a/scripts/_mkl/notebooks/kubric/kubric_helper.py +++ b/scripts/_mkl/notebooks/kubric/kubric_helper.py @@ -1,6 +1,8 @@ import kubric as kb import numpy as np +rng = np.random.default_rng(2021) + def get_linear_camera_motion_start_end( movement_speed: float, diff --git a/scripts/experiments/collaborations/arijit_physics.py b/scripts/experiments/collaborations/arijit_physics.py index b11efe41..9ce0f5db 100644 --- a/scripts/experiments/collaborations/arijit_physics.py +++ b/scripts/experiments/collaborations/arijit_physics.py @@ -24,7 +24,7 @@ def body_fun(prev): def model(T): pose = b.uniform_pose(jnp.ones(3) * -1.0, jnp.ones(3) * 1.0) @ "init_pose" velocity = b.gaussian_vmf_pose(jnp.eye(4), 0.01, 10000.0) @ "init_velocity" - evolve = ( + _evolve = ( genjax.UnfoldCombinator.new(body_fun, 100)(50, (0, pose, velocity)) @ "dynamics" ) return 1.0 diff --git a/scripts/experiments/colmap/colmap_loader.py b/scripts/experiments/colmap/colmap_loader.py index ddef28cd..72541c8c 100644 --- a/scripts/experiments/colmap/colmap_loader.py +++ b/scripts/experiments/colmap/colmap_loader.py @@ -159,7 +159,7 @@ def read_points3D_binary(path_to_model_file): track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 0 ] - track_elems = read_next_bytes( + _track_elems = read_next_bytes( fid, num_bytes=8 * track_length, format_char_sequence="ii" * track_length, diff --git a/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py b/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py index e34216d4..024065ea 100644 --- a/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py +++ b/scripts/experiments/mcs/otp_gen/otp_gen/physics_priors.py @@ -78,8 +78,7 @@ def physics_prior_v1(prev_pose, prev_prev_pose, bbox_dims, camera_pose, world2ca # I1 -> Integrate X-Y-Z forward to current time step # jprint("pred pos: {}", camera_pose[:3,:] @ jnp.concatenate([pred_pos, 1], axis = None)) - physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same - physics_estimated_pose = physics_estimated_pose.at[:3, 3].set(pred_pos) + physics_estimated_pose = prev_pose.at[:3, 3].set(pred_pos) return physics_estimated_pose @@ -152,8 +151,9 @@ def physics_prior_v2( # I1 -> Integrate X-Y-Z forward to current time step # jprint("pred pos: {}", camera_pose[:3,:] @ jnp.concatenate([pred_pos, 1], axis = None)) - physics_estimated_pose = jnp.copy(prev_pose) # orientation is the same - physics_estimated_pose = physics_estimated_pose.at[:3, 3].set(pred_pos) + # NOTE @sritchie and @colin modified this line in response to a Ruff error; + # we'll flag this in code review but this could be buggy. + physics_estimated_pose = prev_poses[T].at[:3, 3].set(pred_pos) return physics_estimated_pose diff --git a/scripts/experiments/tabletop/data_gen.py b/scripts/experiments/tabletop/data_gen.py index 1dbaedf1..dbfe6012 100644 --- a/scripts/experiments/tabletop/data_gen.py +++ b/scripts/experiments/tabletop/data_gen.py @@ -3,12 +3,12 @@ import genjax import jax import jax.numpy as jnp +import joblib import bayes3d as b import bayes3d.genjax console = genjax.pretty(show_locals=False) -import joblib intrinsics = b.Intrinsics( height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0 diff --git a/scripts/experiments/tabletop/inference.py b/scripts/experiments/tabletop/inference.py index 72e0d35c..98f32e15 100644 --- a/scripts/experiments/tabletop/inference.py +++ b/scripts/experiments/tabletop/inference.py @@ -3,13 +3,13 @@ import genjax import jax import jax.numpy as jnp +import joblib from tqdm import tqdm import bayes3d as b import bayes3d.genjax console = genjax.pretty(show_locals=False) -import joblib intrinsics = b.Intrinsics( height=100, width=100, fx=500.0, fy=500.0, cx=50.0, cy=50.0, near=0.01, far=20.0 diff --git a/test/test_genjax_model.py b/test/test_genjax_model.py index 4e80938b..7bb49e48 100644 --- a/test/test_genjax_model.py +++ b/test/test_genjax_model.py @@ -72,7 +72,7 @@ def test_genjax_trace_contains_right_info(): ), ) - scores = enumerators.enumerate_choices_get_scores(trace, key, jnp.zeros((100, 3))) + _scores = enumerators.enumerate_choices_get_scores(trace, key, jnp.zeros((100, 3))) assert trace["parent_0"] == -1 assert (trace["camera_pose"] == jnp.eye(4)).all() From cd84f301e833e56641de57580a9df76d608b3146 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 25 Jan 2024 12:40:47 -0500 Subject: [PATCH 3/6] run formatter --- assets/sample_objs/bunny.obj | 2 +- assets/sample_objs/cube.obj | 2 +- assets/sample_objs/diamond.obj | 2 +- assets/sample_objs/icosahedron.obj | 2 +- assets/sample_objs/occulder.obj | 2 +- assets/sample_objs/pyramid.obj | 2 +- assets/sample_objs/sphere.obj | 2 +- bayes3d/colmap/__init__.py | 2 +- bayes3d/documentation.md | 2 +- bayes3d/genjax/__init__.py | 2 +- bayes3d/neural/cosypose_baseline/INSTALL.md | 2 - .../cosypose_baseline/cosypose_setup.sh | 2 +- bayes3d/neural/requirements_dino.txt | 2 +- bayes3d/prototyping.py | 2 +- bayes3d/rendering/nvdiffrast/__init__.py | 2 +- .../nvdiffrast/common/glutil_extlist.h | 2 +- .../nvdiffrast/common/rasterize_gl.cpp | 17 +++--- .../nvdiffrast/common/rasterize_gl.h | 2 +- bayes3d/rendering/nvdiffrast_jax/__init__.py | 2 +- .../nvdiffrast/common/interpolate.cu | 6 +-- .../nvdiffrast/common/rasterize.cu | 2 +- .../nvdiffrast/jax/jax_binding_ops.h | 2 +- .../nvdiffrast/jax/jax_bindings.cpp | 6 +-- .../nvdiffrast/jax/jax_interpolate.cpp | 52 +++++++++---------- .../nvdiffrast/jax/jax_interpolate.h | 2 +- .../nvdiffrast/jax/jax_rasterize_gl.cpp | 32 ++++++------ .../nvdiffrast/jax/jax_rasterize_gl.h | 2 +- bayes3d/viz/__init__.py | 2 +- docker/Dockerfile | 6 +-- docker/run.sh | 2 +- docs/installation.md | 2 +- mkdocs.yml | 2 +- requirements.txt | 2 +- scripts/_mkl/notebooks/.gitignore | 3 +- scripts/_mkl/notebooks/viz/draw_utils.js | 26 +++++----- scripts/_mkl/notebooks/viz/main.js | 38 +++++++------- scripts/experiments/icra/fork_knife/m1.obj | 2 +- scripts/experiments/icra/fork_knife/m2.obj | 2 +- .../mcs/cognitive-battery/.gitignore | 2 +- .../mcs/cognitive-battery/data/info.md | 2 +- setup.py | 10 ++-- 41 files changed, 124 insertions(+), 134 deletions(-) diff --git a/assets/sample_objs/bunny.obj b/assets/sample_objs/bunny.obj index 55e46e9f..af375e42 100644 --- a/assets/sample_objs/bunny.obj +++ b/assets/sample_objs/bunny.obj @@ -7469,4 +7469,4 @@ f 2411 1493 2503 f 1493 1487 2503 f 1487 1318 2503 f 1318 1320 2503 -f 1320 2443 2503 \ No newline at end of file +f 1320 2443 2503 diff --git a/assets/sample_objs/cube.obj b/assets/sample_objs/cube.obj index 72d47ec7..6c94abda 100644 --- a/assets/sample_objs/cube.obj +++ b/assets/sample_objs/cube.obj @@ -58,4 +58,4 @@ f 15//15 23//23 17//17 f 3//3 14//14 16//16 f 3//3 16//16 6//6 f 4//4 18//18 22//22 -f 4//4 22//22 11//11 \ No newline at end of file +f 4//4 22//22 11//11 diff --git a/assets/sample_objs/diamond.obj b/assets/sample_objs/diamond.obj index 999ce0d3..c4aa6fd3 100644 --- a/assets/sample_objs/diamond.obj +++ b/assets/sample_objs/diamond.obj @@ -13,4 +13,4 @@ f 6 5 4 f 6 4 3 f 6 3 2 f 6 2 1 -f 6 1 5 \ No newline at end of file +f 6 1 5 diff --git a/assets/sample_objs/icosahedron.obj b/assets/sample_objs/icosahedron.obj index c1f0acfb..7210ad9b 100644 --- a/assets/sample_objs/icosahedron.obj +++ b/assets/sample_objs/icosahedron.obj @@ -29,4 +29,4 @@ f 12 11 6 f 10 9 5 f 8 3 9 f 10 8 9 -f 2 8 10 \ No newline at end of file +f 2 8 10 diff --git a/assets/sample_objs/occulder.obj b/assets/sample_objs/occulder.obj index 0d2f2f11..76fee9e9 100644 --- a/assets/sample_objs/occulder.obj +++ b/assets/sample_objs/occulder.obj @@ -26,4 +26,4 @@ f 4//4 8//8 3//3 f 7//7 5//5 3//3 f 3//3 8//8 7//7 f 7//7 6//6 5//5 -f 8//8 6//6 7//7 \ No newline at end of file +f 8//8 6//6 7//7 diff --git a/assets/sample_objs/pyramid.obj b/assets/sample_objs/pyramid.obj index 6a96026e..1a7aed88 100644 --- a/assets/sample_objs/pyramid.obj +++ b/assets/sample_objs/pyramid.obj @@ -9,4 +9,4 @@ f 3 4 2 f 5 2 1 f 4 5 1 f 3 5 4 -f 5 3 2 \ No newline at end of file +f 5 3 2 diff --git a/assets/sample_objs/sphere.obj b/assets/sample_objs/sphere.obj index 0a320260..5d76b581 100644 --- a/assets/sample_objs/sphere.obj +++ b/assets/sample_objs/sphere.obj @@ -806,4 +806,4 @@ f 17/17/17 159/159/159 158/158/158 f 13/13/13 157/157/157 159/159/159 f 14/14/14 160/160/160 161/161/161 f 15/15/15 161/161/161 162/162/162 -f 13/13/13 162/162/162 160/160/160 \ No newline at end of file +f 13/13/13 162/162/162 160/160/160 diff --git a/bayes3d/colmap/__init__.py b/bayes3d/colmap/__init__.py index 73b4c81a..7c8b748f 100644 --- a/bayes3d/colmap/__init__.py +++ b/bayes3d/colmap/__init__.py @@ -1 +1 @@ -from .dataset_loader import * \ No newline at end of file +from .dataset_loader import * diff --git a/bayes3d/documentation.md b/bayes3d/documentation.md index 025ac85e..142cddbe 100644 --- a/bayes3d/documentation.md +++ b/bayes3d/documentation.md @@ -1 +1 @@ -`bayes3d` is a package for Bayesian 3D Inverse Graphics \ No newline at end of file +`bayes3d` is a package for Bayesian 3D Inverse Graphics diff --git a/bayes3d/genjax/__init__.py b/bayes3d/genjax/__init__.py index f3258d78..6a2e2b9c 100644 --- a/bayes3d/genjax/__init__.py +++ b/bayes3d/genjax/__init__.py @@ -1,2 +1,2 @@ from .genjax_distributions import * -from .model import * \ No newline at end of file +from .model import * diff --git a/bayes3d/neural/cosypose_baseline/INSTALL.md b/bayes3d/neural/cosypose_baseline/INSTALL.md index b9dd126d..cc6d2094 100644 --- a/bayes3d/neural/cosypose_baseline/INSTALL.md +++ b/bayes3d/neural/cosypose_baseline/INSTALL.md @@ -5,5 +5,3 @@ cd jax3dp3/cosypose_baseline bash cosypose_setup.sh ``` To test setup, run `test/test_cosypose.py`. - - diff --git a/bayes3d/neural/cosypose_baseline/cosypose_setup.sh b/bayes3d/neural/cosypose_baseline/cosypose_setup.sh index b763a823..ad5d65d2 100644 --- a/bayes3d/neural/cosypose_baseline/cosypose_setup.sh +++ b/bayes3d/neural/cosypose_baseline/cosypose_setup.sh @@ -23,7 +23,7 @@ conda activate cosypose git lfs pull python setup.py install mkdir local_data - + echo "Downloading data..." # it is required to download 'train_real', 'train_synt', but not 'train_all' python -m cosypose.scripts.download --bop_dataset=ycbv diff --git a/bayes3d/neural/requirements_dino.txt b/bayes3d/neural/requirements_dino.txt index cf89efdf..805aa251 100644 --- a/bayes3d/neural/requirements_dino.txt +++ b/bayes3d/neural/requirements_dino.txt @@ -2,4 +2,4 @@ omegaconf fvcore iopath xformers==0.0.18 -submitit \ No newline at end of file +submitit diff --git a/bayes3d/prototyping.py b/bayes3d/prototyping.py index 6b7b11e1..206f38a9 100644 --- a/bayes3d/prototyping.py +++ b/bayes3d/prototyping.py @@ -1 +1 @@ -from ._mkl.utils import * \ No newline at end of file +from ._mkl.utils import * diff --git a/bayes3d/rendering/nvdiffrast/__init__.py b/bayes3d/rendering/nvdiffrast/__init__.py index d65b6965..53d2ea76 100644 --- a/bayes3d/rendering/nvdiffrast/__init__.py +++ b/bayes3d/rendering/nvdiffrast/__init__.py @@ -6,4 +6,4 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -__version__ = '0.3.0' +__version__ = "0.3.0" diff --git a/bayes3d/rendering/nvdiffrast/common/glutil_extlist.h b/bayes3d/rendering/nvdiffrast/common/glutil_extlist.h index 1be2862f..457dbe47 100644 --- a/bayes3d/rendering/nvdiffrast/common/glutil_extlist.h +++ b/bayes3d/rendering/nvdiffrast/common/glutil_extlist.h @@ -56,4 +56,4 @@ GLUTIL_EXT(void, glGenVertexArrays, GLsizei n, GLuint* arrays); GLUTIL_EXT(void, glMultiDrawElementsIndirect, GLenum mode, GLenum type, const void *indirect, GLsizei primcount, GLsizei stride); #endif -//------------------------------------------------------------------------ \ No newline at end of file +//------------------------------------------------------------------------ diff --git a/bayes3d/rendering/nvdiffrast/common/rasterize_gl.cpp b/bayes3d/rendering/nvdiffrast/common/rasterize_gl.cpp index 5912b5f5..41f4c2f0 100644 --- a/bayes3d/rendering/nvdiffrast/common/rasterize_gl.cpp +++ b/bayes3d/rendering/nvdiffrast/common/rasterize_gl.cpp @@ -419,14 +419,14 @@ void setup(RasterizeGLStateWrapper& stateWrapper, int height, int width, int num void jax_setup(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { - const SetUpCustomCallDescriptor &d = + const SetUpCustomCallDescriptor &d = *UnpackDescriptor(opaque, opaque_len); RasterizeGLStateWrapper& stateWrapper = *d.gl_state_wrapper; _setup(stream, stateWrapper, d.height, d.width, d.num_layers); } -void _load_vertices_fwd(cudaStream_t stream, +void _load_vertices_fwd(cudaStream_t stream, RasterizeGLStateWrapper& stateWrapper, const float * pos, uint num_vertices, const int * tri, uint num_triangles) { // const at::cuda::OptionalCUDAGuard device_guard(device_of(pos)); @@ -515,7 +515,7 @@ void jax_load_vertices(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { - const LoadVerticesCustomCallDescriptor &d = + const LoadVerticesCustomCallDescriptor &d = *UnpackDescriptor(opaque, opaque_len); RasterizeGLStateWrapper& stateWrapper = *d.gl_state_wrapper; // std::cerr << "load_vertices: " << d.num_vertices << "," << d.num_triangles << "\n"; @@ -592,9 +592,9 @@ void _rasterize_fwd_gl(cudaStream_t stream, RasterizeGLStateWrapper& stateWrappe poses_on_this_iter*16*sizeof(float), cudaMemcpyDeviceToDevice, stream)); NVDR_CHECK_CUDA_ERROR(cudaGraphicsUnmapResources(1, &s.cudaPoseTexture, stream)); glUniform1f(1, object_idx+1.0); - + NVDR_CHECK_GL_ERROR(glMultiDrawElementsIndirect(GL_TRIANGLES, GL_UNSIGNED_INT, &drawCmdBuffer[0], poses_on_this_iter, sizeof(GLDrawCmd))); - } + } @@ -632,7 +632,7 @@ void jax_rasterize_fwd_gl(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { - const RasterizeCustomCallDescriptor &d = + const RasterizeCustomCallDescriptor &d = *UnpackDescriptor(opaque, opaque_len); RasterizeGLStateWrapper& stateWrapper = *d.gl_state_wrapper; @@ -718,8 +718,3 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { } //------------------------------------------------------------------------ - - - - - diff --git a/bayes3d/rendering/nvdiffrast/common/rasterize_gl.h b/bayes3d/rendering/nvdiffrast/common/rasterize_gl.h index a992d908..b4c84ca4 100644 --- a/bayes3d/rendering/nvdiffrast/common/rasterize_gl.h +++ b/bayes3d/rendering/nvdiffrast/common/rasterize_gl.h @@ -62,7 +62,7 @@ class RasterizeGLStateWrapper; struct SetUpCustomCallDescriptor { RasterizeGLStateWrapper* gl_state_wrapper; - + int height; int width; int num_layers; diff --git a/bayes3d/rendering/nvdiffrast_jax/__init__.py b/bayes3d/rendering/nvdiffrast_jax/__init__.py index d65b6965..53d2ea76 100644 --- a/bayes3d/rendering/nvdiffrast_jax/__init__.py +++ b/bayes3d/rendering/nvdiffrast_jax/__init__.py @@ -6,4 +6,4 @@ # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. -__version__ = '0.3.0' +__version__ = "0.3.0" diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/interpolate.cu b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/interpolate.cu index 84f5fb76..ed3cf2fe 100755 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/interpolate.cu +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/interpolate.cu @@ -94,9 +94,9 @@ static __forceinline__ __device__ void InterpolateFwdKernelTemplate(const Interp float dvdx = db.z; float dvdy = db.w; - // Calculate the pixel differentials of chosen attributes. + // Calculate the pixel differentials of chosen attributes. for (int i=0; i < p.numDiffAttr; i++) - { + { // Input attribute index. int j = p.diff_attrs_all ? i : p.diffAttrs[i]; if (j < 0) @@ -132,7 +132,7 @@ template static __forceinline__ __device__ void InterpolateGradKernelTemplate(const InterpolateKernelParams p) { // Temporary space for coalesced atomics. - CA_DECLARE_TEMP(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH * IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT); + CA_DECLARE_TEMP(IP_GRAD_MAX_KERNEL_BLOCK_WIDTH * IP_GRAD_MAX_KERNEL_BLOCK_HEIGHT); // Calculate pixel position. int px = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/rasterize.cu b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/rasterize.cu index 9c9578e5..cf6ca209 100755 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/rasterize.cu +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/common/rasterize.cu @@ -240,7 +240,7 @@ static __forceinline__ __device__ void RasterizeGradKernelTemplate(const Rasteri float a1p0 = fx * p2.y - fy * p2.x; float a1p2 = fy * p0.x - fx * p0.y; - float wdudX = 2.f * b0 * datdX - da0dX; + float wdudX = 2.f * b0 * datdX - da0dX; float wdudY = 2.f * b0 * datdY - da0dY; float wdvdX = 2.f * b1 * datdX - da1dX; float wdvdY = 2.f * b1 * datdY - da1dY; diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_binding_ops.h b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_binding_ops.h index 42d730f0..daee555a 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_binding_ops.h +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_binding_ops.h @@ -34,4 +34,4 @@ const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) { throw std::runtime_error("Invalid opaque object size"); } return bit_cast(opaque); -} \ No newline at end of file +} diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_bindings.cpp b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_bindings.cpp index 01ee78ee..18b88051 100755 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_bindings.cpp +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_bindings.cpp @@ -61,13 +61,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("build_diff_interpolate_descriptor", [](std::vector attr_shape, std::vector rast_shape, - std::vector tri_shape, + std::vector tri_shape, int num_diff_attrs ) { DiffInterpolateCustomCallDescriptor d; - d.num_images = attr_shape[0], + d.num_images = attr_shape[0], d.num_vertices = attr_shape[1], - d.num_attributes = attr_shape[2], + d.num_attributes = attr_shape[2], d.rast_height = rast_shape[1], d.rast_width = rast_shape[2], d.rast_depth = rast_shape[0], diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.cpp b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.cpp index fd9e1a3e..3941ebfc 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.cpp +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.cpp @@ -42,12 +42,12 @@ static void set_diff_attrs(InterpolateKernelParams& p, bool diff_attrs_all, std: //------------------------------------------------------------------------ // Forward op. -void _interpolate_fwd_da(cudaStream_t stream, - const float* attr, const float* rast, const int* tri, - const float* rast_db, bool diff_attrs_all, +void _interpolate_fwd_da(cudaStream_t stream, + const float* attr, const float* rast, const int* tri, + const float* rast_db, bool diff_attrs_all, std::vector& diff_attrs_vec, - std::vector attr_shape, std::vector rast_shape, - std::vector tri_shape, + std::vector attr_shape, std::vector rast_shape, + std::vector tri_shape, float* out, float* out_da) { @@ -99,7 +99,7 @@ void jax_interpolate_fwd(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { - const DiffInterpolateCustomCallDescriptor &d = + const DiffInterpolateCustomCallDescriptor &d = *UnpackDescriptor(opaque, opaque_len); const float *attr = reinterpret_cast (buffers[0]); @@ -109,8 +109,8 @@ void jax_interpolate_fwd(cudaStream_t stream, const int *diff_attrs = reinterpret_cast (buffers[4]); float *out = reinterpret_cast (buffers[5]); - float *out_da = reinterpret_cast (buffers[6]); - + float *out_da = reinterpret_cast (buffers[6]); + auto opts = torch::dtype(torch::kFloat32).device(torch::kCUDA); std::vector attr_shape; @@ -139,15 +139,15 @@ void jax_interpolate_fwd(cudaStream_t stream, cudaStreamSynchronize(stream); _interpolate_fwd_da(stream, - attr, - rast_out, - tri, - rast_db, + attr, + rast_out, + tri, + rast_db, diff_attrs_all, diff_attrs_vec, - attr_shape, - rast_shape, - tri_shape, + attr_shape, + rast_shape, + tri_shape, out, out_da ); @@ -157,10 +157,10 @@ void jax_interpolate_fwd(cudaStream_t stream, //------------------------------------------------------------------------ // Gradient op. -void _interpolate_grad_da(cudaStream_t stream, +void _interpolate_grad_da(cudaStream_t stream, const float* attr, const float* rast, const int* tri, const float* dy, const float* rast_db, const float* dda, bool diff_attrs_all, std::vector& diff_attrs_vec, - std::vector attr_shape, std::vector rast_shape, std::vector tri_shape, + std::vector attr_shape, std::vector rast_shape, std::vector tri_shape, float* g_attr, float* g_rast, float* g_rast_db) { InterpolateKernelParams p = {}; // Initialize all fields to zero. @@ -224,7 +224,7 @@ void jax_interpolate_bwd(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { - const DiffInterpolateCustomCallDescriptor &d = + const DiffInterpolateCustomCallDescriptor &d = *UnpackDescriptor(opaque, opaque_len); const float *attr = reinterpret_cast (buffers[0]); @@ -236,7 +236,7 @@ void jax_interpolate_bwd(cudaStream_t stream, const int *diff_attrs = reinterpret_cast (buffers[6]); float *g_attr = reinterpret_cast (buffers[7]); - float *g_rast = reinterpret_cast (buffers[8]); + float *g_rast = reinterpret_cast (buffers[8]); float* g_rast_db = reinterpret_cast (buffers[9]); cudaMemset(g_attr, 0, d.num_images*d.num_vertices*d.num_attributes*sizeof(float)); @@ -269,18 +269,18 @@ void jax_interpolate_bwd(cudaStream_t stream, cudaStreamSynchronize(stream); _interpolate_grad_da(stream, attr, - rast_out, - tri, + rast_out, + tri, dy, rast_db, // dda, // - diff_attrs_all, // + diff_attrs_all, // diff_attrs_vec, // - attr_shape, - rast_shape, - tri_shape, + attr_shape, + rast_shape, + tri_shape, g_attr, - g_rast, + g_rast, g_rast_db // ); cudaStreamSynchronize(stream); diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.h b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.h index 23a8601b..f199fef6 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.h +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_interpolate.h @@ -23,4 +23,4 @@ struct DiffInterpolateCustomCallDescriptor { int num_diff_attributes; // diff_attr }; -#endif // !(defined(NVDR_TORCH) && defined(__CUDACC__)) \ No newline at end of file +#endif // !(defined(NVDR_TORCH) && defined(__CUDACC__)) diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.cpp b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.cpp index 976452a4..6b6cd0ad 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.cpp +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.cpp @@ -51,13 +51,13 @@ void RasterizeGLStateWrapper::releaseContext(void) // Forward op (OpenGL). // void _rasterize_fwd_gl(cudaStream_t stream, RasterizeGLStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple resolution, torch::Tensor ranges, int peeling_idx) -void _rasterize_fwd_gl(cudaStream_t stream, RasterizeGLStateWrapper& stateWrapper, - const float* pos, const int* tri, +void _rasterize_fwd_gl(cudaStream_t stream, RasterizeGLStateWrapper& stateWrapper, + const float* pos, const int* tri, std::vector dims, - std::vector resolution, + std::vector resolution, float* out, float* out_db) -{ +{ // const at::cuda::OptionalCUDAGuard device_guard(at::device_of(pos)); RasterizeGLState& s = *stateWrapper.pState; @@ -116,7 +116,7 @@ void jax_rasterize_fwd_gl(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { - const DiffRasterizeCustomCallDescriptor &d = + const DiffRasterizeCustomCallDescriptor &d = *UnpackDescriptor(opaque, opaque_len); RasterizeGLStateWrapper& stateWrapper = *d.gl_state_wrapper; @@ -126,7 +126,7 @@ void jax_rasterize_fwd_gl(cudaStream_t stream, float *out = reinterpret_cast (buffers[3]); float *out_db = reinterpret_cast (buffers[4]); - + auto opts = torch::dtype(torch::kFloat32).device(torch::kCUDA); std::vector resolution; @@ -145,7 +145,7 @@ void jax_rasterize_fwd_gl(cudaStream_t stream, pos, tri, pos_dim, - resolution, + resolution, out, out_db ); @@ -163,10 +163,10 @@ void RasterizeGradKernelDb(const RasterizeGradParams p); //------------------------------------------------------------------------ void _rasterize_grad_db(cudaStream_t stream, - const float* pos, const int* tri, const float* rast_out, - const float* dy, const float* ddb, - std::vector pos_shape, - std::vector tri_shape, + const float* pos, const int* tri, const float* rast_out, + const float* dy, const float* ddb, + std::vector pos_shape, + std::vector tri_shape, std::vector rast_out_shape, float* grad) { @@ -220,7 +220,7 @@ void jax_rasterize_bwd(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { - const DiffRasterizeBwdCustomCallDescriptor &d = + const DiffRasterizeBwdCustomCallDescriptor &d = *UnpackDescriptor(opaque, opaque_len); const float *pos = reinterpret_cast (buffers[0]); @@ -231,7 +231,7 @@ void jax_rasterize_bwd(cudaStream_t stream, float *grad = reinterpret_cast (buffers[5]); // output cudaMemset(grad, 0, d.num_images*d.num_vertices*4*sizeof(float)); - + auto opts = torch::dtype(torch::kFloat32).device(torch::kCUDA); std::vector pos_shape; @@ -252,11 +252,11 @@ void jax_rasterize_bwd(cudaStream_t stream, _rasterize_grad_db(stream, pos, tri, - rast_out, - dy, + rast_out, + dy, ddb, pos_shape, - tri_shape, + tri_shape, rast_out_shape, grad ); diff --git a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.h b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.h index 2f2abbbe..7193c003 100644 --- a/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.h +++ b/bayes3d/rendering/nvdiffrast_jax/nvdiffrast/jax/jax_rasterize_gl.h @@ -76,4 +76,4 @@ void rasterizeReleaseBuffers(NVDR_CTX_ARGS, RasterizeGLState& s); //------------------------------------------------------------------------ -#endif \ No newline at end of file +#endif diff --git a/bayes3d/viz/__init__.py b/bayes3d/viz/__init__.py index 73ef1dfa..10113e08 100644 --- a/bayes3d/viz/__init__.py +++ b/bayes3d/viz/__init__.py @@ -1,2 +1,2 @@ from .meshcatviz import * -from .viz import * \ No newline at end of file +from .viz import * diff --git a/docker/Dockerfile b/docker/Dockerfile index b6590753..edbc26d3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,5 +1,5 @@ # Base container: CUDA 12.21, cuDNN 8.9.4, Python 3.10, PyTorch 2.1.0 -ARG BASE_IMG=nvcr.io/nvidia/pytorch:23.08-py3 +ARG BASE_IMG=nvcr.io/nvidia/pytorch:23.08-py3 FROM ${BASE_IMG} WORKDIR /workspace @@ -8,7 +8,7 @@ WORKDIR /workspace COPY ./docker/requirements_docker.txt /workspace/requirements.txt COPY ./genjax /workspace/genjax RUN pip install -r /workspace/requirements.txt -RUN pip install -r /workspace/genjax/requirements.txt +RUN pip install -r /workspace/genjax/requirements.txt RUN pip install /workspace/genjax # Install JAX (0.4.16) and OpenGL dependencies @@ -17,7 +17,7 @@ RUN apt-get update RUN pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html RUN apt-get install -y mesa-common-dev libegl1-mesa-dev libglfw3-dev libgl1-mesa-dev libglu1-mesa-dev -# Cleanup and prepare env variables for graphics +# Cleanup and prepare env variables for graphics RUN rm -rf /workspace/requirements.txt RUN rm -rf /workspace/genjax ENV NVIDIA_VISIBLE_DEVICES all diff --git a/docker/run.sh b/docker/run.sh index d6319042..7d63def6 100644 --- a/docker/run.sh +++ b/docker/run.sh @@ -6,4 +6,4 @@ SCRIPT=$(realpath "$0") DOCKERPATH=$(dirname "$SCRIPT") BAYES3DPATH=$(dirname "$DOCKERPATH") echo "Mounting $BAYES3DPATH into /workspace/bayes3d" -docker run --runtime=nvidia -it -p 8888:8888 --gpus all --rm --ipc=host -v $(dirname "$BAYES3DPATH"):/workspace/ bayes3d:latest # mount the directory that contains Bayes3D into container \ No newline at end of file +docker run --runtime=nvidia -it -p 8888:8888 --gpus all --rm --ipc=host -v $(dirname "$BAYES3DPATH"):/workspace/ bayes3d:latest # mount the directory that contains Bayes3D into container diff --git a/docs/installation.md b/docs/installation.md index 83bc6970..7a8227cc 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -23,4 +23,4 @@ pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-re ### Install Torch ``` pip install torch torchvision torchaudio --upgrade --index-url https://download.pytorch.org/whl/cu118 -``` \ No newline at end of file +``` diff --git a/mkdocs.yml b/mkdocs.yml index e3f92a5a..2c8c8569 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -52,4 +52,4 @@ nav: markdown_extensions: - attr_list - - md_in_html \ No newline at end of file + - md_in_html diff --git a/requirements.txt b/requirements.txt index b9fd4ee5..dbc89256 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,4 +36,4 @@ mkdocs-material natsort pyliblzfse pypng -tyro \ No newline at end of file +tyro diff --git a/scripts/_mkl/notebooks/.gitignore b/scripts/_mkl/notebooks/.gitignore index 7d2c0b55..acb35640 100644 --- a/scripts/_mkl/notebooks/.gitignore +++ b/scripts/_mkl/notebooks/.gitignore @@ -23,7 +23,7 @@ local/ local # **/data/ -# tensorboard runs folder and +# tensorboard runs folder and # pytorch lightning stuff **/runs/ **/lightning_logs/ @@ -183,4 +183,3 @@ checklink/cookies.txt # Quarto .quarto - diff --git a/scripts/_mkl/notebooks/viz/draw_utils.js b/scripts/_mkl/notebooks/viz/draw_utils.js index 52f14669..2d54b42e 100644 --- a/scripts/_mkl/notebooks/viz/draw_utils.js +++ b/scripts/_mkl/notebooks/viz/draw_utils.js @@ -1,7 +1,7 @@ const arrays = require('./lib/arrays.js'); const THREE = require('three'); -import GUI from 'lil-gui'; -import Stats from 'stats-js'; +import GUI from 'lil-gui'; +import Stats from 'stats-js'; import * as BufferGeometryUtils from "three/examples/jsm/utils/BufferGeometryUtils"; @@ -56,8 +56,8 @@ export function combine_meshes(meshes) { export function create_gaussian_meshes(transforms4x4, colorsRGBA) { const t = 0 - const N = transforms4x4.shape[1]; - + const N = transforms4x4.shape[1]; + const meshes = [] for (let i = 0; i < N; i++) { const color = new Float32Array(arrays.strided_slice(colorsRGBA, t, i, arrays.ALL).values) @@ -70,7 +70,7 @@ export function create_gaussian_meshes(transforms4x4, colorsRGBA) { const transform4x4 = new Float32Array(arrays.strided_slice(transforms4x4, t, i, arrays.ALL, arrays.ALL).values) const matrix = new THREE.Matrix4(); matrix.fromArray(transform4x4) - + const geometry = new THREE.SphereGeometry(1.0); geometry.applyMatrix4(matrix); @@ -97,7 +97,7 @@ export function update_gaussian_meshes(meshes, colorsRGBA, transforms4x4) { meshes[i].material.needsUpdate = true; // Update transform - meshes[i].matrixAutoUpdate = false; + meshes[i].matrixAutoUpdate = false; const transform = new Float32Array(arrays.strided_slice(transforms4x4, t, i, arrays.ALL, arrays.ALL).values) const matrix = new THREE.Matrix4(); matrix.fromArray(transform) @@ -135,7 +135,7 @@ export function create_instanced_sphere_mesh(centers, colorsRGBA, scales) { const rotation = new THREE.Quaternion(0,0,0,1); const scale = new THREE.Vector3(1., 1., 1.); const colors = new Float32Array(instanceCount * 3); // RGB for each instance - + for (let i = 0; i < instanceCount; i++) { const center = new Float32Array(arrays.strided_slice(centers, i, arrays.ALL).values) @@ -149,8 +149,8 @@ export function create_instanced_sphere_mesh(centers, colorsRGBA, scales) { } // instancedGeometry.setAttribute('color', colorAttribute); - - + + return instancedMesh; } // END OF create_instanced_sphere_mesh @@ -177,15 +177,15 @@ export function update_instanced_sphere_mesh(instanced_mesh, centers, colorsRGBA instanced_mesh.geometry.setAttribute('color', colorAttribute); instanced_mesh.instanceMatrix.needsUpdate = true; colorAttribute.needsUpdate = true; - - + + return instanced_mesh; } // END OF update_instanced_sphere_mesh export function create_sphere_meshes(positionsNx3, colorsRGBA, scales) { - const N = positionsNx3.shape[0]; + const N = positionsNx3.shape[0]; const geometry = new THREE.SphereGeometry(1, 5, 5); // Radius set to 1, which will be scaled let meshes = []; for (let i = 0; i < N; i++) { @@ -194,7 +194,7 @@ export function create_sphere_meshes(positionsNx3, colorsRGBA, scales) { const material = new THREE.MeshBasicMaterial({ color: new THREE.Color(rgba[0], rgba[1], rgba[2]), // RGB transparent: true, - opacity: rgba[3] + opacity: rgba[3] }); const mesh = new THREE.Mesh(geometry, material); diff --git a/scripts/_mkl/notebooks/viz/main.js b/scripts/_mkl/notebooks/viz/main.js index ab02f09d..546635b4 100644 --- a/scripts/_mkl/notebooks/viz/main.js +++ b/scripts/_mkl/notebooks/viz/main.js @@ -2,8 +2,8 @@ const messaging = require('./lib/messaging.js'); const arrays = require('./lib/arrays.js'); const viz_pb = require('./lib/viz_pb.js'); const THREE = require('three'); -import GUI from 'lil-gui'; -import Stats from 'stats-js'; +import GUI from 'lil-gui'; +import Stats from 'stats-js'; import * as BufferGeometryUtils from "three/examples/jsm/utils/BufferGeometryUtils"; import {MapControls} from 'three/examples/jsm/controls/MapControls'; import { MeshDepthMaterial } from 'three'; @@ -164,7 +164,7 @@ function initialize_interface(shared_data) { fixed_scene_objs: fixed_scene_objs }; return result; -} +} // END OF initialize_interface // <<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -237,7 +237,7 @@ const shared_data = { mixers: [], actions: [] }; - + const ui = initialize_interface(shared_data); @@ -248,9 +248,9 @@ mc.add_handler(handle_message); /* * * * * * * * * * * * * * * * * * * - * - * - * + * + * + * * * * * * * * * * * * * * * * * * * */ function linear_index(multiIndex, shape) { @@ -280,7 +280,7 @@ function handle_dynamic_gaussians(transformsTxNx4x4, colors, ui, shared_data) { shared_data.num_frames = T shared_data.animation_frame_controller._max = shared_data.num_frames - 1 - + const meshes = [] for (let i = 0; i < N; i++) { @@ -294,7 +294,7 @@ function handle_dynamic_gaussians(transformsTxNx4x4, colors, ui, shared_data) { const transform4x4 = new Float32Array(arrays.strided_slice(transforms4x4, i, arrays.ALL, arrays.ALL).values) const matrix = new THREE.Matrix4(); matrix.fromArray(transform4x4) - + const geometry = new THREE.SphereGeometry(1.0); geometry.applyMatrix4(matrix); @@ -321,7 +321,7 @@ function handle_pytree_message(pytree, ui, shared_data) { switch(type) { - + case "setup": console.log("case 'setup'") @@ -331,8 +331,8 @@ function handle_pytree_message(pytree, ui, shared_data) { // ui.scene.background = new THREE.Color( 0xd3d3d3 ); ui.scene.background = new THREE.Color( 0xffffff ); const dir_light = new THREE.DirectionalLight( 0xffffff, 2 ); - dir_light.position.set( 0, .1, 0); - dir_light.castShadow = true; + dir_light.position.set( 0, .1, 0); + dir_light.castShadow = true; dir_light.shadow.camera.visible = true; const ambientLight = new THREE.AmbientLight(0x404040, 20); // soft white light ui.scene.add(dir_light); @@ -358,7 +358,7 @@ function handle_pytree_message(pytree, ui, shared_data) { meshes.forEach(mesh => ui.scene.add(mesh)); shared_data.num_frames = data.centers.shape[0] shared_data.animation_frame_controller._max = shared_data.num_frames - 1 - + break; @@ -395,7 +395,7 @@ function handle_pytree_message(pytree, ui, shared_data) { break; - + case "animated gaussians": console.log("animated gaussians", data.transforms.shape) var T = data.transforms.shape[0] @@ -408,8 +408,8 @@ function handle_pytree_message(pytree, ui, shared_data) { var colors = arrays.strided_slice(data.colors, t, arrays.ALL, arrays.ALL); var meshes = draw_utils.create_gaussian_meshes(transforms, colors); meshes.forEach(mesh => ui.scene.add(mesh)); - - + + shared_data.animation_update_handler = t => { const transforms = arrays.strided_slice(data.transforms, t, arrays.ALL, arrays.ALL, arrays.ALL) const colors = arrays.strided_slice(data.colors, t, arrays.ALL, arrays.ALL) @@ -418,11 +418,11 @@ function handle_pytree_message(pytree, ui, shared_data) { break; - + case "gaussians": console.log("case 'gaussians'", data.transforms.shape, data.colors.shape) - + // var transforms = arrays.strided_slice(data.transforms, 0, arrays.ALL, arrays.ALL, arrays.ALL); // var colors = arrays.strided_slice(data.colors, 0, arrays.ALL, arrays.ALL); @@ -440,5 +440,3 @@ function handle_pytree_message(pytree, ui, shared_data) { } // END OF handle_payload // <<<<<<<<<<<<<<<<<<<<< - - diff --git a/scripts/experiments/icra/fork_knife/m1.obj b/scripts/experiments/icra/fork_knife/m1.obj index 141882e1..08295868 100644 --- a/scripts/experiments/icra/fork_knife/m1.obj +++ b/scripts/experiments/icra/fork_knife/m1.obj @@ -54,4 +54,4 @@ f 12//12 16//16 11//11 f 15//15 13//13 11//11 f 11//11 16//16 15//15 f 15//15 14//14 13//13 -f 16//16 14//14 15//15 \ No newline at end of file +f 16//16 14//14 15//15 diff --git a/scripts/experiments/icra/fork_knife/m2.obj b/scripts/experiments/icra/fork_knife/m2.obj index 53d091fa..25083378 100644 --- a/scripts/experiments/icra/fork_knife/m2.obj +++ b/scripts/experiments/icra/fork_knife/m2.obj @@ -82,4 +82,4 @@ f 20//20 24//24 19//19 f 23//23 21//21 19//19 f 19//19 24//24 23//23 f 23//23 22//22 21//21 -f 24//24 22//22 23//23 \ No newline at end of file +f 24//24 22//22 23//23 diff --git a/scripts/experiments/mcs/cognitive-battery/.gitignore b/scripts/experiments/mcs/cognitive-battery/.gitignore index 73eceb38..8e26b40c 100644 --- a/scripts/experiments/mcs/cognitive-battery/.gitignore +++ b/scripts/experiments/mcs/cognitive-battery/.gitignore @@ -1 +1 @@ -_*.ipynb \ No newline at end of file +_*.ipynb diff --git a/scripts/experiments/mcs/cognitive-battery/data/info.md b/scripts/experiments/mcs/cognitive-battery/data/info.md index d9121581..df415842 100644 --- a/scripts/experiments/mcs/cognitive-battery/data/info.md +++ b/scripts/experiments/mcs/cognitive-battery/data/info.md @@ -3,4 +3,4 @@ * `experiment_video.mp4`: the generated video, with a framerate `30 fps`. * `experiment_stats.yaml`: contains information about the experiment and camera intrinsics (e.g. width, height, fov). * `frames/`: contains each RGB frame JPEG images. - * `depths/`: contains each depth frame as loadable numpy arrays. \ No newline at end of file + * `depths/`: contains each depth frame as loadable numpy arrays. diff --git a/setup.py b/setup.py index 54a1bdef..e34bb99c 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,10 @@ import setuptools -NAME = 'bayes3d' -VERSION = '0.0.1' -if __name__ == '__main__': +NAME = "bayes3d" +VERSION = "0.0.1" +if __name__ == "__main__": setuptools.setup( name=NAME, version=VERSION, - packages=setuptools.find_namespace_packages(include=['bayes3d.*']), - ) \ No newline at end of file + packages=setuptools.find_namespace_packages(include=["bayes3d.*"]), + ) From bd7093aff83825365aaae741dd21cdbfe14828b4 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 25 Jan 2024 12:41:08 -0500 Subject: [PATCH 4/6] begin to tidy dependencies --- .gitignore | 1 - README.md | 2 +- requirements-dev.txt | 8 ++++++++ requirements.txt | 26 -------------------------- 4 files changed, 9 insertions(+), 28 deletions(-) create mode 100644 requirements-dev.txt diff --git a/.gitignore b/.gitignore index 76a960aa..08aeb8db 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,6 @@ assets/* *.ply *.zip *.npz -*.txt *.pdf *.pkl .DS_Store diff --git a/README.md b/README.md index f79c92e1..0c5b59e1 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ pip install git+https://github.com/probcomp/genjax.git@v0.1.0 Install JAX and Torch: ``` pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -pip install torch torchvision torchaudio --upgrade --index-url https://download.pytorch.org/whl/cu118 +pip install torch torchvision --upgrade --index-url https://download.pytorch.org/whl/cu118 ``` Download model and data assets: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..527ba686 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,8 @@ +ipython +jupyter +jupterlab +joblib +flax +mkdocs +mkdocs-material +pyliblzfse diff --git a/requirements.txt b/requirements.txt index dbc89256..7035caa6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,39 +1,13 @@ -scikit-learn -wheel -ipython -jupyter numpy matplotlib -pillow -opencv-python open3d graphviz distinctipy trimesh -ninja pyransac3d meshcat -h5py -gdown -pytest -zmq plyfile -jupyterlab imageio timm -joblib -pdoc3 -addict tensorflow-probability -flax -omegaconf -fvcore -iopath -submitit -wget -mkdocs -mkdocs-material natsort -pyliblzfse -pypng -tyro From b977ab760a1c626babf34fea1970cb9025f488fa Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 25 Jan 2024 11:31:46 -0500 Subject: [PATCH 5/6] add initial pyproject.toml --- pyproject.toml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index e08b51a3..91ef5666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,31 @@ +[build-system] +requires = ["setuptools>=64", "setuptools_scm>=8"] +build-backend = "setuptools.build_meta" + +[project] +name = "bayes3d" +authors = [ + {name = "Nishad Gothoskar", email = "nishadg@mit.edu"}, +] +description = "Probabilistic inference in 3D." +readme = "README.md" +requires-python = ">=3.9" +keywords = [ + "artificial-intelligence", + "probabilistic-programming", + "bayesian-inference", + "differentiable-programming" + ] +license = {text = "Apache 2.0"} +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] +dynamic = ["version"] + [tool.ruff] exclude = [ ".bzr", From 5dc425a7aba846e967c5747bd1762fe76cfc58e9 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 25 Jan 2024 11:46:12 -0500 Subject: [PATCH 6/6] add to readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 0c5b59e1..17c4b9cb 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,8 @@ sudo apt-get update sudo apt-get install ninja-build ``` +I did somethi! + To check your CUDA version: ``` nvcc --version