Skip to content

Commit

Permalink
RIG OSS 1/?: Open-source the utilities we actually used.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645130374
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jun 20, 2024
1 parent 32ea239 commit e60a96d
Show file tree
Hide file tree
Showing 13 changed files with 2,706 additions and 0 deletions.
163 changes: 163 additions & 0 deletions discussion/robust_inverse_graphics/util/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2024 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Utilities.

# [internal] load pytype.bzl (pytype_strict_library, pytype_strict_test)
# [internal] load strict.bzl
# Placeholder: py_library

package(
# default_applicable_licenses
default_visibility = ["//discussion/robust_inverse_graphics:__subpackages__"],
)

licenses(["notice"])

# pytype_strict
py_library(
name = "array_util",
srcs = ["array_util.py"],
deps = [
# jax dep,
],
)

# pytype_strict
py_library(
name = "camera_util",
srcs = ["camera_util.py"],
deps = [
# numpy dep,
# pyquaternion dep,
],
)

# pytype_strict
py_library(
name = "gin_utils",
srcs = ["gin_utils.py"],
deps = [
# absl/flags dep,
# gin dep,
# yaml dep,
],
)

# pytype_strict
py_library(
name = "math_util",
srcs = ["math_util.py"],
deps = [
# jax dep,
],
)

# pytype_strict
py_test(
name = "math_util_test",
srcs = ["math_util_test.py"],
deps = [
":math_util",
":test_util",
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
],
)

# pytype_strict
py_library(
name = "plot_util",
srcs = ["plot_util.py"],
deps = [
# jax dep,
# matplotlib dep,
# numpy dep,
"//fun_mc:using_jax",
],
)

# Not strict or pytype due to the test_util.jax dep.
py_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.py"],
deps = [
# absl/testing:absltest dep,
# jax dep,
"//tensorflow_probability/python/internal:test_util.jax",
],
)

# py_strict
py_test(
name = "test_util_test",
srcs = ["test_util_test.py"],
deps = [
":test_util",
# flax dep,
# google/protobuf:use_fast_cpp_protos dep,
# numpy dep,
],
)

# pytype_strict
py_library(
name = "tree2",
srcs = ["tree2.py"],
deps = [
# etils/epath dep,
# immutabledict dep,
# numpy dep,
],
)

# py_strict
py_test(
name = "tree2_test",
srcs = ["tree2_test.py"],
deps = [
":test_util",
":tree2",
# absl/testing:parameterized dep,
# flax:core dep,
# google/protobuf:use_fast_cpp_protos dep,
# immutabledict dep,
# jax dep,
# numpy dep,
"//tensorflow_probability:jax",
],
)

# pytype_strict
py_library(
name = "tree_util",
srcs = ["tree_util.py"],
deps = [
# jax dep,
],
)

# py_strict
py_test(
name = "tree_util_test",
srcs = ["tree_util_test.py"],
deps = [
":test_util",
":tree_util",
# flax:core dep,
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
],
)
36 changes: 36 additions & 0 deletions discussion/robust_inverse_graphics/util/array_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Array utilities."""

from typing import TypeVar

import jax

__all__ = [
'shard_tree',
'unshard_tree',
]

T = TypeVar('T')


def shard_tree(tree: T) -> T:
shard_part = lambda x: x.reshape((len(jax.devices()), -1) + x.shape[1:])
return jax.tree.map(shard_part, tree)


def unshard_tree(tree: T) -> T:
unshard_part = lambda x: x.reshape((-1,) + x.shape[2:])
return jax.tree.map(unshard_part, tree)
154 changes: 154 additions & 0 deletions discussion/robust_inverse_graphics/util/camera_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2024 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Camera utilities."""

from typing import Optional

import numpy as np
import pyquaternion


def look_at_quat(
position: np.ndarray,
target: np.ndarray,
up: np.ndarray = np.array([0., 1., 0.]),
front: np.ndarray = np.array([0., 0., -1.]),
quaternion_atol: float = 1e-8,
quaternion_rtol: float = 1e-5
) -> tuple[float, float, float, float]:
"""Constructs a quaternion looking at `target` from `position`.
Args:
position: Camera position. Shape: [3]
target: Camera target. Shape: [3]
up: World up unit vector. Shape: [3]
front: World front unit vector. Shape: [3]
quaternion_atol: atol for pyquaternion matrix orthogonality checks.
quaternion_rtol: rtol for pyquaternion matrix orthogonality checks.
Returns:
Quaternion as a 4-tuple.
"""

right = np.cross(up, front)

normalize = lambda x: x / (np.linalg.norm(x, axis=-1) + 1e-20)

look_at_front = normalize(target - position)
look_at_right = normalize(np.cross(up, look_at_front))
if np.linalg.norm(look_at_right, axis=-1) == 0.:
look_at_right = right

look_at_up = normalize(np.cross(look_at_front, look_at_right))

rotation_matrix1 = np.stack([look_at_right, look_at_up, look_at_front])
rotation_matrix2 = np.stack([right, up, front])

return tuple(
pyquaternion.Quaternion(matrix=(rotation_matrix1.T @ rotation_matrix2),
atol=quaternion_atol,
rtol=quaternion_rtol))


def random_sphere(rng: Optional[np.random.RandomState] = None) -> np.ndarray:
"""Generates points uniformly on a sphere."""
if rng is None:
rng = np.random
z = rng.randn(3)
z /= (np.linalg.norm(z) + 1e-20)
return z


def random_half_sphere(
half_elem: int = 1,
rng: Optional[np.random.RandomState] = None) -> np.ndarray:
"""Generates points uniformly on a half-sphere."""
z = random_sphere(rng)
z[half_elem] = np.abs(z[half_elem])
return z


def grid_sphere(num_slices: int) -> np.ndarray:
"""Generates points on a regular grid on a sphere.
This places the poles at (0, +-1, 0).
Args:
num_slices: Number of slices. Should be even.
Returns:
The generated points. This will generate
`2 + (num_slices // 2 - 1) * num_slices` points.
"""
elevation = np.linspace(np.pi / 2, -np.pi / 2, num_slices // 2 + 1)
azimuth = np.linspace(0.0, 2 * np.pi, num_slices + 1)[:num_slices]

points = []
for (
band,
el,
) in enumerate(elevation):
if band == 0 or band == len(elevation) - 1:
band_azimuth = [0.0]
else:
band_azimuth = azimuth

for az in band_azimuth:
r = np.cos(el)
x = r * np.sin(az)
z = r * np.cos(az)
y = np.sin(el)
points.append(np.array([x, y, z]))
return np.array(points)


def get_mipnerf_camera_intrinsics(width: int,
height: int,
focal_length: float,
sensor_width: float = 1.,
sensor_height: float = 1.) -> np.ndarray:
"""Constructs the mipnerf-compatible intrinsics matrix."""
# See https://en.wikipedia.org/wiki/Camera_resectioning#Intrinsic_parameters

fx = focal_length / sensor_width * width
fy = focal_length / sensor_height * height

return np.array([
[fx, 0., width / 2.],
[0., fy, height / 2.],
[0., 0., 1.],
], np.float32)


def get_camera_position(radius: float, inclination: float,
azimuth: float) -> np.ndarray:
"""Converts radius, inclination, azimuth to xyz.
Uses this convention
https://en.wikipedia.org/wiki/Spherical_coordinate_system#/media/File:3D_Spherical.svg
Args:
radius: float, how far is the camera from the center.
inclination: float, in radians, theta in the above.
azimuth: float, in radians, phi (LaTeX `varphi`) in the above.
Returns:
camera_position: Shape [3], xyz position.
"""
return np.array([
radius * np.cos(azimuth) * np.sin(inclination),
radius * np.sin(azimuth) * np.sin(inclination),
radius * np.cos(inclination)
])
Loading

0 comments on commit e60a96d

Please sign in to comment.