From e60a96d0bb240010f61fe95a8f8cdeb80c9c9532 Mon Sep 17 00:00:00 2001 From: siege Date: Thu, 20 Jun 2024 13:03:32 -0700 Subject: [PATCH] RIG OSS 1/?: Open-source the utilities we actually used. PiperOrigin-RevId: 645130374 --- discussion/robust_inverse_graphics/util/BUILD | 163 +++ .../util/array_util.py | 36 + .../util/camera_util.py | 154 +++ .../robust_inverse_graphics/util/gin_utils.py | 44 + .../robust_inverse_graphics/util/math_util.py | 83 ++ .../util/math_util_test.py | 57 + .../robust_inverse_graphics/util/plot_util.py | 360 ++++++ .../robust_inverse_graphics/util/test_util.py | 114 ++ .../util/test_util_test.py | 51 + .../robust_inverse_graphics/util/tree2.py | 1104 +++++++++++++++++ .../util/tree2_test.py | 334 +++++ .../robust_inverse_graphics/util/tree_util.py | 144 +++ .../util/tree_util_test.py | 62 + 13 files changed, 2706 insertions(+) create mode 100644 discussion/robust_inverse_graphics/util/BUILD create mode 100644 discussion/robust_inverse_graphics/util/array_util.py create mode 100644 discussion/robust_inverse_graphics/util/camera_util.py create mode 100644 discussion/robust_inverse_graphics/util/gin_utils.py create mode 100644 discussion/robust_inverse_graphics/util/math_util.py create mode 100644 discussion/robust_inverse_graphics/util/math_util_test.py create mode 100644 discussion/robust_inverse_graphics/util/plot_util.py create mode 100644 discussion/robust_inverse_graphics/util/test_util.py create mode 100644 discussion/robust_inverse_graphics/util/test_util_test.py create mode 100644 discussion/robust_inverse_graphics/util/tree2.py create mode 100644 discussion/robust_inverse_graphics/util/tree2_test.py create mode 100644 discussion/robust_inverse_graphics/util/tree_util.py create mode 100644 discussion/robust_inverse_graphics/util/tree_util_test.py diff --git a/discussion/robust_inverse_graphics/util/BUILD b/discussion/robust_inverse_graphics/util/BUILD new file mode 100644 index 0000000000..a6ece023f1 --- /dev/null +++ b/discussion/robust_inverse_graphics/util/BUILD @@ -0,0 +1,163 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Utilities. + +# [internal] load pytype.bzl (pytype_strict_library, pytype_strict_test) +# [internal] load strict.bzl +# Placeholder: py_library + +package( + # default_applicable_licenses + default_visibility = ["//discussion/robust_inverse_graphics:__subpackages__"], +) + +licenses(["notice"]) + +# pytype_strict +py_library( + name = "array_util", + srcs = ["array_util.py"], + deps = [ + # jax dep, + ], +) + +# pytype_strict +py_library( + name = "camera_util", + srcs = ["camera_util.py"], + deps = [ + # numpy dep, + # pyquaternion dep, + ], +) + +# pytype_strict +py_library( + name = "gin_utils", + srcs = ["gin_utils.py"], + deps = [ + # absl/flags dep, + # gin dep, + # yaml dep, + ], +) + +# pytype_strict +py_library( + name = "math_util", + srcs = ["math_util.py"], + deps = [ + # jax dep, + ], +) + +# pytype_strict +py_test( + name = "math_util_test", + srcs = ["math_util_test.py"], + deps = [ + ":math_util", + ":test_util", + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + ], +) + +# pytype_strict +py_library( + name = "plot_util", + srcs = ["plot_util.py"], + deps = [ + # jax dep, + # matplotlib dep, + # numpy dep, + "//fun_mc:using_jax", + ], +) + +# Not strict or pytype due to the test_util.jax dep. +py_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.py"], + deps = [ + # absl/testing:absltest dep, + # jax dep, + "//tensorflow_probability/python/internal:test_util.jax", + ], +) + +# py_strict +py_test( + name = "test_util_test", + srcs = ["test_util_test.py"], + deps = [ + ":test_util", + # flax dep, + # google/protobuf:use_fast_cpp_protos dep, + # numpy dep, + ], +) + +# pytype_strict +py_library( + name = "tree2", + srcs = ["tree2.py"], + deps = [ + # etils/epath dep, + # immutabledict dep, + # numpy dep, + ], +) + +# py_strict +py_test( + name = "tree2_test", + srcs = ["tree2_test.py"], + deps = [ + ":test_util", + ":tree2", + # absl/testing:parameterized dep, + # flax:core dep, + # google/protobuf:use_fast_cpp_protos dep, + # immutabledict dep, + # jax dep, + # numpy dep, + "//tensorflow_probability:jax", + ], +) + +# pytype_strict +py_library( + name = "tree_util", + srcs = ["tree_util.py"], + deps = [ + # jax dep, + ], +) + +# py_strict +py_test( + name = "tree_util_test", + srcs = ["tree_util_test.py"], + deps = [ + ":test_util", + ":tree_util", + # flax:core dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + ], +) diff --git a/discussion/robust_inverse_graphics/util/array_util.py b/discussion/robust_inverse_graphics/util/array_util.py new file mode 100644 index 0000000000..84a9410d3d --- /dev/null +++ b/discussion/robust_inverse_graphics/util/array_util.py @@ -0,0 +1,36 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Array utilities.""" + +from typing import TypeVar + +import jax + +__all__ = [ + 'shard_tree', + 'unshard_tree', +] + +T = TypeVar('T') + + +def shard_tree(tree: T) -> T: + shard_part = lambda x: x.reshape((len(jax.devices()), -1) + x.shape[1:]) + return jax.tree.map(shard_part, tree) + + +def unshard_tree(tree: T) -> T: + unshard_part = lambda x: x.reshape((-1,) + x.shape[2:]) + return jax.tree.map(unshard_part, tree) diff --git a/discussion/robust_inverse_graphics/util/camera_util.py b/discussion/robust_inverse_graphics/util/camera_util.py new file mode 100644 index 0000000000..6609e1d5b3 --- /dev/null +++ b/discussion/robust_inverse_graphics/util/camera_util.py @@ -0,0 +1,154 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Camera utilities.""" + +from typing import Optional + +import numpy as np +import pyquaternion + + +def look_at_quat( + position: np.ndarray, + target: np.ndarray, + up: np.ndarray = np.array([0., 1., 0.]), + front: np.ndarray = np.array([0., 0., -1.]), + quaternion_atol: float = 1e-8, + quaternion_rtol: float = 1e-5 +) -> tuple[float, float, float, float]: + """Constructs a quaternion looking at `target` from `position`. + + Args: + position: Camera position. Shape: [3] + target: Camera target. Shape: [3] + up: World up unit vector. Shape: [3] + front: World front unit vector. Shape: [3] + quaternion_atol: atol for pyquaternion matrix orthogonality checks. + quaternion_rtol: rtol for pyquaternion matrix orthogonality checks. + + Returns: + Quaternion as a 4-tuple. + """ + + right = np.cross(up, front) + + normalize = lambda x: x / (np.linalg.norm(x, axis=-1) + 1e-20) + + look_at_front = normalize(target - position) + look_at_right = normalize(np.cross(up, look_at_front)) + if np.linalg.norm(look_at_right, axis=-1) == 0.: + look_at_right = right + + look_at_up = normalize(np.cross(look_at_front, look_at_right)) + + rotation_matrix1 = np.stack([look_at_right, look_at_up, look_at_front]) + rotation_matrix2 = np.stack([right, up, front]) + + return tuple( + pyquaternion.Quaternion(matrix=(rotation_matrix1.T @ rotation_matrix2), + atol=quaternion_atol, + rtol=quaternion_rtol)) + + +def random_sphere(rng: Optional[np.random.RandomState] = None) -> np.ndarray: + """Generates points uniformly on a sphere.""" + if rng is None: + rng = np.random + z = rng.randn(3) + z /= (np.linalg.norm(z) + 1e-20) + return z + + +def random_half_sphere( + half_elem: int = 1, + rng: Optional[np.random.RandomState] = None) -> np.ndarray: + """Generates points uniformly on a half-sphere.""" + z = random_sphere(rng) + z[half_elem] = np.abs(z[half_elem]) + return z + + +def grid_sphere(num_slices: int) -> np.ndarray: + """Generates points on a regular grid on a sphere. + + This places the poles at (0, +-1, 0). + + Args: + num_slices: Number of slices. Should be even. + + Returns: + The generated points. This will generate + `2 + (num_slices // 2 - 1) * num_slices` points. + """ + elevation = np.linspace(np.pi / 2, -np.pi / 2, num_slices // 2 + 1) + azimuth = np.linspace(0.0, 2 * np.pi, num_slices + 1)[:num_slices] + + points = [] + for ( + band, + el, + ) in enumerate(elevation): + if band == 0 or band == len(elevation) - 1: + band_azimuth = [0.0] + else: + band_azimuth = azimuth + + for az in band_azimuth: + r = np.cos(el) + x = r * np.sin(az) + z = r * np.cos(az) + y = np.sin(el) + points.append(np.array([x, y, z])) + return np.array(points) + + +def get_mipnerf_camera_intrinsics(width: int, + height: int, + focal_length: float, + sensor_width: float = 1., + sensor_height: float = 1.) -> np.ndarray: + """Constructs the mipnerf-compatible intrinsics matrix.""" + # See https://en.wikipedia.org/wiki/Camera_resectioning#Intrinsic_parameters + + fx = focal_length / sensor_width * width + fy = focal_length / sensor_height * height + + return np.array([ + [fx, 0., width / 2.], + [0., fy, height / 2.], + [0., 0., 1.], + ], np.float32) + + +def get_camera_position(radius: float, inclination: float, + azimuth: float) -> np.ndarray: + """Converts radius, inclination, azimuth to xyz. + + Uses this convention + https://en.wikipedia.org/wiki/Spherical_coordinate_system#/media/File:3D_Spherical.svg + + Args: + radius: float, how far is the camera from the center. + inclination: float, in radians, theta in the above. + azimuth: float, in radians, phi (LaTeX `varphi`) in the above. + + Returns: + camera_position: Shape [3], xyz position. + """ + return np.array([ + radius * np.cos(azimuth) * np.sin(inclination), + radius * np.sin(azimuth) * np.sin(inclination), + radius * np.cos(inclination) + ]) diff --git a/discussion/robust_inverse_graphics/util/gin_utils.py b/discussion/robust_inverse_graphics/util/gin_utils.py new file mode 100644 index 0000000000..c9dc15f150 --- /dev/null +++ b/discussion/robust_inverse_graphics/util/gin_utils.py @@ -0,0 +1,44 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utilities for interfacing with Gin configurations.""" + +from typing import Any, Dict, Mapping + +from absl import flags +import gin +import yaml + +__all__ = ['bind_hparams', 'YAMLDictParser'] + + +class YAMLDictParser(flags.ArgumentParser): + syntactic_help = """Expects YAML one-line dictionaries without braces, e.g. + 'key1: val1, key2: val2'""" + + def parse(self, argument: str) -> Dict[str, Any]: + return yaml.safe_load('{' + argument + '}') + + def flag_type(self) -> str: + return 'Dict[str, Any]' + + +def bind_hparams(hparams: Mapping[str, Any]): + """Binds all Gin parameters from a dictionary. + + Args: + hparams: A dictionary of bindings. + """ + for k, v in hparams.items(): + gin.bind_parameter(k, v) diff --git a/discussion/robust_inverse_graphics/util/math_util.py b/discussion/robust_inverse_graphics/util/math_util.py new file mode 100644 index 0000000000..367c595b5c --- /dev/null +++ b/discussion/robust_inverse_graphics/util/math_util.py @@ -0,0 +1,83 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Some math utilities.""" + +from collections.abc import Callable +from typing import Any, TypeVar + +import jax +import jax.numpy as jnp + +__all__ = [ + 'transform_gradients', + 'sanitize_gradients', + 'clip_gradients', + 'is_finite', +] + + +T = TypeVar('T') + + +def is_finite(tree: Any) -> jax.Array: + """Verifies that all the elements in `x` are finite.""" + leaves = jax.tree_util.tree_leaves( + jax.tree.map(lambda x: jnp.isfinite(x).all(), tree) + ) + if leaves: + return jnp.stack(leaves).all() + else: + return jnp.array(True) + + +def transform_gradients(x: T, handler_fn: Callable[[T], T]) -> T: + """Applies `handler_fn` to gradients flowing through `x`.""" + wrapper = jax.custom_vjp(lambda x: x) + + def fwd(x): + return x, () + + def bwd(_, g): + return (handler_fn(g),) + + wrapper.defvjp(fwd, bwd) + + return wrapper(x) + + +def sanitize_gradients(x: T) -> T: + """Zeroes all gradients flowing through `x` if any element is non-finite.""" + + def sanitize_fn(x): + finite = is_finite(x) + return jax.tree.map(lambda x: jnp.where(finite, x, jnp.zeros_like(x)), x) + + return transform_gradients(x, sanitize_fn) + + +def clip_gradients( + x: T, + global_norm: jax.typing.ArrayLike = 1.0, + eps: jax.typing.ArrayLike = 1e-20, +) -> T: + """Clips the norm of gradients flowing through `x`.""" + + def clip_fn(x): + leaves = jax.tree.leaves(jax.tree.map(lambda x: jnp.square(x).sum(), x)) + norm = jnp.sqrt(eps + jnp.sum(jnp.stack(leaves))) + new_norm = jnp.where(norm > global_norm, global_norm, norm) + return jax.tree.map(lambda x: x * new_norm / norm, x) + + return transform_gradients(x, clip_fn) diff --git a/discussion/robust_inverse_graphics/util/math_util_test.py b/discussion/robust_inverse_graphics/util/math_util_test.py new file mode 100644 index 0000000000..851debe0d7 --- /dev/null +++ b/discussion/robust_inverse_graphics/util/math_util_test.py @@ -0,0 +1,57 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import jax +import jax.numpy as jnp + +from discussion.robust_inverse_graphics.util import math_util +from discussion.robust_inverse_graphics.util import test_util + + +class MathUtilTest(test_util.TestCase): + + def test_transform_gradients(self): + def f(x): + return math_util.transform_gradients(x, lambda x: x + 1) + + grad = jax.grad(f)(0.0) + self.assertAllEqual(grad, 2.0) + + def test_sanitize_gradients(self): + def f(x): + return jnp.sqrt(math_util.sanitize_gradients(x['x'])) + + grad = jax.grad(f)({'x': 0.0}) + self.assertAllEqual(grad['x'], 0.0) + + def test_clip_gradients(self): + def f(x): + return jnp.square(math_util.clip_gradients(x['x'])).sum() + + grad_small = jax.grad(f)({'x': 0.1 * jnp.ones(3)}) + grad_big = jax.grad(f)({'x': jnp.ones(3)}) + + self.assertAllClose(grad_small['x'], 2 * 0.1 * jnp.ones(3)) + self.assertAllClose(grad_big['x'], jnp.ones(3) / jnp.sqrt(3)) + + def test_is_finite(self): + self.assertTrue(math_util.is_finite([])) + self.assertTrue(math_util.is_finite(0.)) + self.assertTrue(math_util.is_finite([0., 0.])) + self.assertFalse(math_util.is_finite(float('nan'))) + self.assertFalse(math_util.is_finite([0., float('nan')])) + + +if __name__ == '__main__': + test_util.main() diff --git a/discussion/robust_inverse_graphics/util/plot_util.py b/discussion/robust_inverse_graphics/util/plot_util.py new file mode 100644 index 0000000000..cd7fdc2c7c --- /dev/null +++ b/discussion/robust_inverse_graphics/util/plot_util.py @@ -0,0 +1,360 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Plotting utilities.""" +from collections.abc import Sequence +from typing import Any + +import jax +import jax.numpy as jnp +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import tensorflow_probability.spinoffs.fun_mc.using_jax as fun_mc + +__all__ = [ + 'COLORS', + 'polkagram_horiz', + 'polkagram_vert', + 'trace_plot', +] + +# From +# https://mikemol.github.io/technique/colorblind/2018/02/11/color-safe-palette.html, +# ordered by luminocity. +COLORS = ('#009E73', '#0072B2', '#D55E00', '#E69F00', '#56B4E9', '#F0E442') + + +# TODO(siege): Move this to some sort of utils directory? +def _exp_mean(vals: jax.Array, window_size: float) -> jax.Array: + """Exponential moving average.""" + vals = jnp.asarray(vals) + + def kernel(rm_state, i): + v = vals[i] + cand_rm_state, _ = fun_mc.running_mean_step( + rm_state, v, window_size=window_size + ) + rm_state = fun_mc.choose(jnp.isfinite(v), cand_rm_state, rm_state) + return (rm_state, i + 1), rm_state.mean + + _, exp_mean = fun_mc.trace( + (fun_mc.running_mean_init(vals.shape[1:], jnp.float32), 0), + kernel, + vals.shape[0], + ) + return exp_mean + + +def trace_plot(ax: plt.Axes, vals: jax.Array, color: str = COLORS[0]): + """Plots a trace, with overlaid EMA trace.""" + ax.plot(vals, color=color) + ax.plot(_exp_mean(vals, window_size=10), color='k') + + # Drop the extremes. + y_min = jnp.nanpercentile(vals, 1) + y_max = jnp.nanpercentile(vals, 99) + if y_min == y_max: + y_min = y_min - 0.1 + y_max = y_max + 0.1 + ax.set_ylim(y_min, y_max) + + +class HistBoxes(matplotlib.collections.PatchCollection): + """Collection for the histogram parts.""" + + def __init__(self, *args, scatter, **scatter_kwargs): + super().__init__(*args, **scatter_kwargs) + self.scatter = scatter + + +class HistBoxesHandler: + """Legend handler for histograms.""" + + def legend_artist( + self, + legend: matplotlib.legend.Legend, + orig_handle: matplotlib.patches.Patch, + fontsize: int, + handlebox: matplotlib.offsetbox.DrawingArea, + ) -> matplotlib.patches.Patch: + """Creates the legend artist.""" + del fontsize, legend # Unused + x0, y0 = handlebox.xdescent, handlebox.ydescent + width, height = handlebox.width, handlebox.height + + scatter_kwargs = {} + ec = orig_handle.get_edgecolor() + fc = orig_handle.get_facecolor() + if ec is not None and len(ec) > 0: # pylint: disable=g-explicit-length-test + scatter_kwargs.update(ec=ec[0]) + else: + scatter_kwargs.update(ec='none') + if fc is not None and len(fc) > 0: # pylint: disable=g-explicit-length-test + scatter_kwargs.update(fc=fc[0]) + else: + scatter_kwargs.update(fc='none') + + patch = matplotlib.patches.Rectangle( + [x0, y0], + width, + height, + lw=orig_handle.get_linewidth()[0], + transform=handlebox.get_transform(), + **scatter_kwargs, + ) + handlebox.add_artist(patch) + handlebox.add_artist( + matplotlib.collections.PathCollection( + orig_handle.scatter.get_paths(), + offsets=(x0 + width / 2, y0 + height / 2), + fc=orig_handle.scatter.get_fc(), + ec=orig_handle.scatter.get_ec(), + sizes=5 * np.array([min(width, height)]), + ) + ) + return patch + + +matplotlib.legend.Legend.update_default_handler_map( + {HistBoxes: HistBoxesHandler()} +) + + +def polkagram_vert( + ys: Sequence[float] | np.ndarray, + x: float | np.generic = 0.0, + bins: int = 20, + width: float = 1.0, + rng: np.random.RandomState = np.random, + draw_boxes: bool = True, + box_ec: str = 'none', + box_fc: str = 'lightgray', + ax: plt.Axes | None = None, + center: bool = True, + max_point_density: float | None = None, + min_points: int = 10, + **scatter_kwargs: Any, +) -> tuple[ + matplotlib.collections.PatchCollection, + matplotlib.collections.PatchCollection | None, +]: + """Draw a histogram/scatter combo. + + Args: + ys: Y-coordinates of the datapoints. + x: X-coordinate of the datapoints. + bins: Number of bins to use for the histogram. + width: Maximum bin width. + rng: PRNG for the scatter. + draw_boxes: Whether to add the boxes behind the scatter points. + box_ec: Box edge color. + box_fc: Box face color. + ax: Axis to draw on. + center: Whether to center the bins around the X axis. + max_point_density: Maximum number of points per bin size. + min_points: Minimum number of points to keep per bin. Only relevant when + `max_point_density` is not `None`. + **scatter_kwargs: Passed to ax.scatter. + + Returns: + A pair of collections for the scatter points and boxes. + """ + if ax is None: + ax = plt.gca() + range_ = scatter_kwargs.pop('range', None) + + ys = np.array(ys) + ys = ys[np.isfinite(ys)] + if range_ is None: + range_ = (ys.min(), ys.max()) + + heights, edges = np.histogram(ys, bins=bins, range=range_, density=True) + + ids = np.digitize(ys, edges[:-1]) - 1 + heights /= heights.max() + + if center: + start_frac = -0.5 + end_frac = 0.5 + else: + start_frac = 0.0 + end_frac = 1.0 + + patches = [] + xs = np.empty(ys.shape) + for bin_id, bin_height in enumerate(heights): + mask = ids == bin_id + bin_xs = xs[mask] + if len(bin_xs) == 0: # pylint: disable=g-explicit-length-test + continue + elif len(bin_xs) == 1: + bin_xs = np.zeros_like(bin_xs) + else: + bin_xs = ( + width * bin_height * np.linspace(start_frac, end_frac, len(bin_xs)) + ) + rng.shuffle(bin_xs) + xs[mask] = bin_xs + if max_point_density is not None: + max_points = max(int(max_point_density * bin_height), min_points) + if len(bin_xs) > max_points: + idxs = np.where(mask)[0] + rng.shuffle(idxs) + nan_idxs = idxs[max_points:] + ys[nan_idxs] = np.nan + patches.append( + matplotlib.patches.Rectangle( + (x + width * bin_height * start_frac, edges[bin_id]), + width * bin_height, + edges[bin_id + 1] - edges[bin_id], + ) + ) + + if draw_boxes: + label = scatter_kwargs.pop('label', None) + + isfinite = np.isfinite(xs) | np.isfinite(ys) + xs = xs[isfinite] + ys = ys[isfinite] + scatter = ax.scatter(x + xs, ys, **scatter_kwargs) + + if draw_boxes: + boxes = ax.add_collection( + HistBoxes( + patches, + facecolor=box_fc, + edgecolor=box_ec, + label=label, + zorder=scatter.zorder - 1, + scatter=scatter, + ) + ) + else: + boxes = None + return scatter, boxes + + +def polkagram_horiz( + xs: Sequence[float] | np.ndarray, + y: float | np.generic = 0.0, + bins: int = 20, + height: float = 1.0, + rng: np.random.RandomState = np.random, + draw_boxes: bool = True, + box_ec: str = 'none', + box_fc: str = 'lightgray', + ax: plt.Axes | None = None, + center: bool = True, + max_point_density: float | None = None, + min_points: int = 10, + **scatter_kwargs: Any, +) -> tuple[ + matplotlib.collections.PatchCollection, + matplotlib.collections.PatchCollection | None, +]: + """Draw a histogram/scatter combo. + + Args: + xs: X-coordinates of the datapoints. + y: Y-coordinate of the datapoints. + bins: Number of bins to use for the histogram. + height: Maximum bin height. + rng: PRNG for the scatter. + draw_boxes: Whether to add the boxes behind the scatter points. + box_ec: Box edge color. + box_fc: Box face color. + ax: Axis to draw on. + center: Whether to center the bins around the Y axis. + max_point_density: Maximum number of points per bin size. + min_points: Minimum number of points to keep per bin. Only relevant when + `max_point_density` is not `None`. + **scatter_kwargs: Passed to ax.scatter. + + Returns: + A pair of collections for the scatter points and boxes. + """ + if ax is None: + ax = plt.gca() + range_ = scatter_kwargs.pop('range', None) + + xs = np.array(xs) + xs = xs[np.isfinite(xs)] + if range_ is None: + range_ = (xs.min(), xs.max()) + + heights, edges = np.histogram(xs, bins=bins, range=range_, density=True) + + ids = np.digitize(xs, edges[:-1]) - 1 + heights /= heights.max() + + if center: + start_frac = -0.5 + end_frac = 0.5 + else: + start_frac = 0.0 + end_frac = 1.0 + + patches = [] + ys = np.empty(xs.shape) + for bin_id, bin_height in enumerate(heights): + mask = ids == bin_id + bin_ys = ys[mask] + if len(bin_ys) == 0: # pylint: disable=g-explicit-length-test + continue + elif len(bin_ys) == 1: + bin_ys = np.zeros_like(bin_ys) + else: + bin_ys = ( + height * bin_height * np.linspace(start_frac, end_frac, len(bin_ys)) + ) + rng.shuffle(bin_ys) + ys[mask] = bin_ys + if max_point_density is not None: + max_points = max(int(max_point_density * bin_height), min_points) + if len(bin_ys) > max_points: + idxs = np.where(mask)[0] + rng.shuffle(idxs) + nan_idxs = idxs[max_points:] + xs[nan_idxs] = np.nan + patches.append( + matplotlib.patches.Rectangle( + (edges[bin_id], y + height * bin_height * start_frac), + edges[bin_id + 1] - edges[bin_id], + height * bin_height, + ) + ) + + if draw_boxes: + label = scatter_kwargs.pop('label', None) + + isfinite = np.isfinite(xs) | np.isfinite(ys) + xs = xs[isfinite] + ys = ys[isfinite] + scatter = ax.scatter(xs, y + ys, **scatter_kwargs) + + if draw_boxes: + boxes = ax.add_collection( + HistBoxes( + patches, + facecolor=box_fc, + edgecolor=box_ec, + label=label, + zorder=scatter.zorder - 1, + scatter=scatter, + ) + ) + else: + boxes = None + return scatter, boxes diff --git a/discussion/robust_inverse_graphics/util/test_util.py b/discussion/robust_inverse_graphics/util/test_util.py new file mode 100644 index 0000000000..f09cee7dd7 --- /dev/null +++ b/discussion/robust_inverse_graphics/util/test_util.py @@ -0,0 +1,114 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test utilities for Electric Sheep.""" + +import functools +import math +import re + +from absl import flags +import jax +import numpy as np +from tensorflow_probability.substrates.jax.internal import test_util + +from absl.testing import absltest + +FLAGS = flags.FLAGS + +__all__ = [ + 'TestCase', + 'main', +] + + +class _LeafIndicator: + + def __init__(self, v): + self.v = v + + def __repr__(self): + return self.v + + +class TestCase(test_util.TestCase): + """Electric Sheep TestCase.""" + + def test_seed(self, *args, **kwargs): + return test_util.test_seed(*args, **kwargs) + + def assertAllEqual(self, a, b, msg=''): + np.testing.assert_array_equal(a, b, err_msg=msg) + + def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=''): + assert_fn = functools.partial( + np.testing.assert_allclose, rtol=rtol, atol=atol, err_msg=msg + ) + + exceptions = [] + + def assert_part(a, b): + try: + assert_fn(a, b) + return _LeafIndicator('.') + except Exception as e: # pylint: disable=broad-except + exceptions.append(e) + return _LeafIndicator(f'#{len(exceptions)}') + + positions = jax.tree.map(assert_part, a, b) + + if exceptions: + lines = [ + 'Some leaves are not close. Differing leaves:\n', + f'{positions}\n', + ] + + for i, e in enumerate(exceptions): + lines.append(f'Exception #{i + 1}:') + lines.append(str(e)) + + raise AssertionError('\n'.join(lines)) + + def assertNear(self, f1, f2, err, msg=None): + if isinstance(f1, jax.Array): + f1 = float(f1.item()) + if isinstance(f2, jax.Array): + f2 = float(f2.item()) + self.assertTrue( + f1 == f2 or math.fabs(f1 - f2) <= err, + '%f != %f +/- %f%s' + % (f1, f2, err, ' (%s)' % msg if msg is not None else ''), + ) + + +class _TestLoader(absltest.TestLoader): + """A custom TestLoader that allows for Regex filtering test cases.""" + + def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name + names = super().getTestCaseNames(testCaseClass) + if FLAGS.test_regex: # This flag is defined in TFP's test_util. + pattern = re.compile(FLAGS.test_regex) + names = [ + name + for name in names + if pattern.search(f'{testCaseClass.__name__}.{name}') + ] + # Remove the test_seed, as it's not a test despite starting with `test_`. + names = [name for name in names if name != 'test_seed'] + return names + + +def main(): + """Test main function that injects a custom loader.""" + absltest.main(testLoader=_TestLoader()) diff --git a/discussion/robust_inverse_graphics/util/test_util_test.py b/discussion/robust_inverse_graphics/util/test_util_test.py new file mode 100644 index 0000000000..fb86473873 --- /dev/null +++ b/discussion/robust_inverse_graphics/util/test_util_test.py @@ -0,0 +1,51 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for test_util.""" + +from flax import struct +import numpy as np +from discussion.robust_inverse_graphics.util import test_util + + +@struct.dataclass +class TestStruct: + a: int + b: int + + +class TestUtilTest(test_util.TestCase): + + def testAssertAllEqual(self): + self.assertAllEqual(np.arange(3), np.arange(3)) + with self.assertRaisesRegex(AssertionError, 'message'): + self.assertAllEqual(np.arange(3), np.arange(4), msg='message') + + def testAssertAllClose(self): + self.assertAllClose(np.arange(3), np.arange(3)) + self.assertAllClose({'a': np.arange(3)}, {'a': np.arange(3)}) + self.assertAllClose(np.linspace(0., 1.), np.linspace(0., 1.)) + self.assertAllClose( + np.linspace(0., 1.), np.linspace(0., 1.) + 0.1, atol=0.11) + self.assertAllClose( + np.linspace(1., 2.), 1 * 1.1 * np.linspace(1., 2.), rtol=0.11) + with self.assertRaisesRegex(AssertionError, 'not close'): + self.assertAllClose(np.arange(3), np.arange(4), msg='message') + self.assertAllClose(TestStruct(1, 2), TestStruct(1, 2)) + with self.assertRaisesRegex(AssertionError, 'not close'): + self.assertAllClose(TestStruct(1, 2), TestStruct(1, 3)) + + +if __name__ == '__main__': + test_util.main() diff --git a/discussion/robust_inverse_graphics/util/tree2.py b/discussion/robust_inverse_graphics/util/tree2.py new file mode 100644 index 0000000000..00e9774f17 --- /dev/null +++ b/discussion/robust_inverse_graphics/util/tree2.py @@ -0,0 +1,1104 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tree2 implementation. + +Example usage: + +```python +registry = tree2.Registry(allow_unknown_types=True) + +@registry.auto_register_type("MyClass") # Unique tag. +@dataclasses.dataclass +class MyClass: + a: int + b: float + +c = MyClass(1, 2.) + +registry.save_tree(c, '/tmp/c.tree2') + +c2 = registry.load_tree('/tmp/c.tree2') + +assert c.a == c2.a +assert c.b == c2.b +``` +""" + +import collections +from collections.abc import Mapping, Sequence +import dataclasses +import enum +import functools +import io +import json +from typing import Any, BinaryIO, Callable, Generic, Optional, Type, TypeVar, Union +import warnings + +from etils import epath +import immutabledict +import numpy as np + +_TREE2_TAG = 'TREE2' +_BLOCKS_TAG = 'BLOCKS' +_TREE2_TYPE_TAG = 'type' +_SAVE_VERSION = 1 +_JSON_SAFE_TYPES = frozenset({str, int, float, bool, type(None)}) + +UNKNOWN_SEQUENCE = 'unknown_sequence' +UNKNOWN_MAPPING = 'unknown_mapping' +UNKNOWN_NAMEDTUPLE = 'unknown_namedtuple' +UNKNOWN_DATACLASS = 'unknown_dataclass' +ARRAY = 'array' +SCALAR = 'scalar' + +Tree = TypeVar('Tree') +EncodedTree = TypeVar('EncodedTree') +InnerEncodeFn = Callable[[Tree], EncodedTree] +EncodeFn = Callable[[Tree, 'Context', InnerEncodeFn], EncodedTree] +DecodeFn = Callable[[EncodedTree, 'Context'], Tree] +DetectFn = Callable[[Any, 'Context'], Optional[str]] + +__all__ = [ + 'ARRAY', + 'Context', + 'DecodeFn', + 'DeferredNumpyArray', + 'DetectFn', + 'EncodedTree', + 'EncodeFn', + 'InnerEncodeFn', + 'Registry', + 'SCALAR', + 'Tree', + 'UNKNOWN_MAPPING', + 'UNKNOWN_NAMEDTUPLE', + 'UNKNOWN_SEQUENCE', +] + + +class DeferredNumpyArray: + """A numpy-array like class that defers disk IO until accessed.""" + + def __init__(self, filename: str, offset: int, shape: tuple[int, ...], + dtype: np.dtype): + """Creates a deferred numpy array. + + Args: + filename: Filename. + offset: Byte offset into the file. + shape: Shape of the stored array. + dtype: Dtype of the stored array. + """ + self._filename = filename + self._offset = offset + self._shape = shape + self._dtype = dtype + self._value = None + + @property + def shape(self) -> tuple[int, ...]: + return self._shape + + @property + def dtype(self) -> np.dtype: + return self._dtype + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(shape={self.shape}, dtype={self.dtype},' + f'numpy={self._value})') + + def __array__(self, dtype: Optional[np.dtype] = None) -> np.ndarray: + if self._value is None: + with epath.Path(self._filename).open('rb') as f: + f.seek(self._offset) + self._value = np.load(f, allow_pickle=False) + return self._value.astype(dtype) + + +class Context: + """Context used for tree serialization and deserialization.""" + + def __init__(self, version: int, filename: Optional[str], + options: Mapping[str, Any]): + """Creates the tree context. + + Args: + version: Version of the protocol. + filename: Name of the file we're writing to, used for deferred numpy + loading. + options: Serialization and desearialization. See `save_tree` for valid + options. + """ + combined_options = { + 'tree_format': 'json', + 'block_format': 'cat_npy', + 'defer_numpy': False, + } + combined_options.update(options) + self._arrays = {} + self._array_offsets = {} + self._version = version + self._options = combined_options + self._filename = filename + + def add_array(self, array: np.ndarray) -> str: + """Adds an array to be saved. Returns its name.""" + name = f'buf{len(self._arrays)}' + self._arrays[name] = array + return name + + def get_array(self, name: str, shape: tuple[int, ...], + dtype: np.dtype) -> Union[None, DeferredNumpyArray, np.ndarray]: + """Returns an array given its name, shape and dtype.""" + if self.options['defer_numpy']: + offset = self._array_offsets.get(name) + if offset is None: + return None + else: + if self._filename is None: + raise ValueError( + 'Cannot defer numpy loading if loading a tree from a file ' + 'object.') + return DeferredNumpyArray(self._filename, offset, shape, dtype) + else: + return self._arrays.get(name) + + def save_blocks(self, f: BinaryIO): + """Save blocks to a file.""" + if self._arrays: + block_format = self.options['block_format'] + if block_format == 'cat_npy': + f.write(f'{_BLOCKS_TAG},{block_format}\n'.encode('utf-8')) + + array_buf = io.BytesIO() + offsets = {} + for k, v in self._arrays.items(): + offsets[k] = array_buf.tell() + np.save(array_buf, v, allow_pickle=False) + f.write(json.dumps(offsets, sort_keys=True).encode('utf-8')) + f.write(b'\n') + f.write(array_buf.getvalue()) + else: + raise ValueError(f'Unknown block format: {block_format}') + + def load_blocks(self, f: BinaryIO, block_format: str): + """Load blocks from a file.""" + if block_format == 'cat_npy': + offsets = json.loads(f.readline().decode('utf-8')) + global_offset = f.tell() + defer_numpy = self.options['defer_numpy'] + for k, offset in offsets.items(): + if defer_numpy: + self._array_offsets[k] = global_offset + offset + else: + f.seek(global_offset + offset) + self._arrays[k] = np.load(f, allow_pickle=False) + else: + raise ValueError(f'Unknown block format: {block_format}') + + @property + def options(self) -> Mapping[str, Any]: + return self._options + + @property + def version(self) -> int: + return self._version + + +@dataclasses.dataclass +class _TreeSerialization(Generic[Tree, EncodedTree]): + encode_fn: EncodeFn + decode_fn: DecodeFn + + +def _get_first_tag(tag: Union[str, Sequence[str]]) -> str: + if isinstance(tag, str): + return tag + else: + return tag[0] + + +@dataclasses.dataclass +class Registry: + """Type registry for Tree2. + + Attributes: + interactive_mode: Whether to construct the registry in interactive mode, + where duplicate registration errors are turned to warnings. + allow_unknown_types: Whether to allow saving/loading unknown types. + save_version: Which format version to save. + """ + + interactive_mode: bool = False + allow_unknown_types: bool = False + save_version: int = _SAVE_VERSION + + _tag_to_serializer: dict[str, _TreeSerialization] = dataclasses.field( + default_factory=dict) + _tree_type_to_tag: dict[Type[Any], Union[str, Sequence[str]]] = ( + dataclasses.field(default_factory=dict)) + _detectors: collections.OrderedDict[Union[Type[Any], str], DetectFn] = ( + dataclasses.field(default_factory=collections.OrderedDict)) + + def __post_init__(self): + self.register_sequence_type('list', fallback=None)(list) + self.register_sequence_type('tuple', fallback=None)(tuple) + self.register_sequence_type('set', fallback=None)(set) + self.register_sequence_type('fronzenset', fallback=None)(frozenset) + if self.allow_unknown_types: + self.register_tags(UNKNOWN_SEQUENCE, _encode_unknown_sequence, + _decode_unknown_sequence) + self.register_detector(UNKNOWN_SEQUENCE, _detect_unknown_sequence) + + self.register_mapping_type('dict', fallback=None)(dict) + self.register_mapping_type( + 'immutabledict', fallback=None)( + immutabledict.immutabledict) + if self.allow_unknown_types: + self.register_tags(UNKNOWN_MAPPING, _encode_unknown_mapping, + _decode_unknown_mapping) + self.register_detector(UNKNOWN_MAPPING, _detect_unknown_mapping) + + self.register_tags(UNKNOWN_NAMEDTUPLE, _encode_unknown_namedtuple, + _decode_unknown_namedtuple) + self.register_detector(UNKNOWN_NAMEDTUPLE, _detect_unknown_namedtuple) + self.register_tags(UNKNOWN_DATACLASS, _encode_unknown_dataclass, + _decode_unknown_dataclass) + self.register_detector(UNKNOWN_DATACLASS, _detect_unknown_dataclass) + + self._register_numpy() + self._maybe_register_flax() + self._maybe_register_jax() + self._maybe_register_tensorflow() + self._maybe_register_tfp() + + def register_type(self, tag: Union[str, Sequence[str]], tree_type: Type[Tree], + encode_fn: EncodeFn, decode_fn: DecodeFn) -> Type[Tree]: + """Registers a type. + + Args: + tag: One or more tags for this type. The first tag, if more than one, is + used for serialization. + tree_type: The type of the tree. + encode_fn: Encoding function. + decode_fn: Decoding function. + + Returns: + Same value as `tree_type`. + """ + if isinstance(tag, str): + tags = [tag] + else: + tags = tag + + if tree_type is not None: + existing_tags = self._tree_type_to_tag.get(tree_type) + if existing_tags is not None: + msg = f'Type \'{tree_type}\' is already registered.' + if self.interactive_mode: + warnings.warn(msg) + for tag in existing_tags: + del self._tag_to_serializer[tag] + else: + raise TypeError(msg) + + self._tree_type_to_tag[tree_type] = tags + + self.register_tags(tags, encode_fn, decode_fn) + return tree_type + + def register_tags(self, tag: Union[str, Sequence[str]], encode_fn: EncodeFn, + decode_fn: DecodeFn): + """Registers encode and decode functions for tags. + + This is typically paired with `register_detector`. + + Args: + tag: One or more tags for this type. The first tag, if more than one, is + used for serialization. + encode_fn: Encoding function. + decode_fn: Decoding function. + """ + if isinstance(tag, str): + tags = [tag] + else: + tags = tag + + for tag in tags: + if tag in self._tag_to_serializer: + msg = f'Tag \'{tag}\' is already registered.' + if self.interactive_mode: + warnings.warn(msg) + else: + raise ValueError(msg) + self._tag_to_serializer[tag] = _TreeSerialization(encode_fn, decode_fn) + + def register_detector(self, type_hint_or_id: Union[Type[Any], str], + detector_fn: DetectFn): + """Registers a detector function. + + Given a tree, this detects which tag to assign it for serialization + purposes. The `type_hint_or_id` can either be used as a type hint, meaning + that if a tree has that type, the detector is called on it. If + `type_hint_or_id` is a string, this acts merely as re-registration id. + + Args: + type_hint_or_id: Idenfitier of this detector or a type hint. + detector_fn: The detector function. + """ + if type_hint_or_id in self._detectors: + msg = f'Detector \'{type_hint_or_id}\' is already registered.' + if self.interactive_mode: + warnings.warn(msg) + else: + raise TypeError(msg) + + self._detectors[type_hint_or_id] = detector_fn + + def _json_obj_hook(self, obj: dict[str, Any], ctx: Context) -> Any: + """Object loading hook for JSON.""" + tag = obj.get(_TREE2_TYPE_TAG) + if tag is None: + return obj + + serializer = self._tag_to_serializer.get(tag) + if serializer is None: + fallback = obj.get('fallback_type') + if fallback is None: + raise ValueError(f'Unknown tree type with no fallback: {obj}') + serializer = self._tag_to_serializer.get(fallback) + if serializer is None: + raise ValueError(f'Tree type with unknown fallback type: {obj}') + + return serializer.decode_fn(obj, ctx) + + def _encode_tree(self, tree: Any, ctx: Context) -> Any: + """Encode a tree into JSON-safe format.""" + tree_type = type(tree) + if tree_type in _JSON_SAFE_TYPES: + return tree + tags = self._tree_type_to_tag.get(tree_type) + if tags is None: + detector_fn = self._detectors.get(tree_type) + if detector_fn is None: + for detector_fn in reversed(list(self._detectors.values())): + tag = detector_fn(tree, ctx) + if tag is not None: + break + else: + raise TypeError(f'Unknown tree type: {tree}') + else: + tag = detector_fn(tree, ctx) + else: + tag = tags[0] + + encode_fn = self._tag_to_serializer[tag].encode_fn + return encode_fn(tree, ctx, functools.partial(self._encode_tree, ctx=ctx)) + + def save_tree(self, + tree: Any, + path: Union[str, BinaryIO], + options: Mapping[str, Any] = immutabledict.immutabledict({})): + """Saves a tree to a path or a file object. + + Args: + tree: A tree. + path: Either path to a file or a file object. + options: Options for serialization. See below for options. + Options: + tree_format: Format of the tree structure encoding. Must be 'json'. + block_format: Format of the blocks encoding. Must be 'cat_npy'. + """ + ctx = Context(self.save_version, None, options) + + if isinstance(path, str): + f = epath.Path(path).open('wb') + need_close = True + else: + f = path + need_close = False + + try: + tree_format = ctx.options['tree_format'] + + f.write(f'{_TREE2_TAG},{ctx.version},{tree_format}\n'.encode('utf-8')) + + if tree_format == 'json': + tree = self._encode_tree(tree, ctx) + f.write( + json.dumps( + { + 'tree': tree + }, + indent=None, + ensure_ascii=False, + sort_keys=True, + ).encode('utf-8')) + else: + raise ValueError(f'Unknown tree format: {tree_format}') + + f.write(b'\n') + + ctx.save_blocks(f) + finally: + if need_close: + f.close() + + def load_tree( + self, + path: Union[str, BinaryIO], + options: Mapping[str, Any] = immutabledict.immutabledict({}) + ) -> Any: + """Loads a tree from a path or a file object. + + Args: + path: Either path to a file or a file object. + options: Options for serialization. See below for options. + + Returns: + The loaded tree. + + Options: + defer_numpy: Whether to defer loading numpy arrays. Numpy arrays will be + replaced with instances of `DeferredNumpyArray`. Default: False + """ + if isinstance(path, str): + f = epath.Path(path).open('rb') + need_close = True + filename = path + else: + f = path + need_close = False + filename = None + + try: + header = f.readline().decode('utf-8') + header_parts = header.strip().split(',') + + if len(header_parts) != 3: + raise ValueError('Bad header') + if header_parts[0] != _TREE2_TAG: + raise ValueError('Bad magic constant') + version = int(header_parts[1]) + if version != _SAVE_VERSION: + raise ValueError(f'Unknown version: {header_parts[1]}') + tree_format = header_parts[2] + + ctx = Context(version, filename, options) + + block_header = None + tree_lines = [] + while True: + line = f.readline().decode('utf-8') + if not line: + break + if line.startswith(_BLOCKS_TAG): + block_header = line + break + else: + tree_lines.append(line) + + if not tree_lines: + raise ValueError('Empty tree?') + + if block_header is None: + block_format = None + else: + block_parts = line.strip().split(',') + if len(block_parts) != 2: + raise ValueError('Bad block header') + block_format = block_parts[1] + + if block_format is not None: + ctx.load_blocks(f, block_format) + + finally: + if need_close: + f.close() + + if tree_format == 'json': + tree = json.loads( + '\n'.join(tree_lines), + object_hook=functools.partial(self._json_obj_hook, ctx=ctx)) + else: + raise ValueError(f'Unknown tree format: {tree_format}') + + return tree['tree'] + + def auto_register_type( + self, tag: Union[str, + Sequence[str]]) -> Callable[[Type[Tree]], Type[Tree]]: + """Registers a type, with an automatic encoder/decoder. + + Only namedtuples, dataclasses, sequences, mappings and enums are supported. + + Args: + tag: Tags to register the type under. + + Returns: + Registration decorator. + """ + + def reg_fn(tree_type: Type[Tree]) -> Type[Tree]: + if issubclass(tree_type, tuple) and hasattr(tree_type, '_fields'): + return self.register_namedtuple_type(tag)(tree_type) + elif issubclass(tree_type, Sequence): + return self.register_sequence_type(tag)(tree_type) + elif issubclass(tree_type, Mapping): + return self.register_mapping_type(tag)(tree_type) + elif dataclasses.is_dataclass(tree_type): + return self.register_dataclass_type(tag)(tree_type) + elif issubclass(tree_type, enum.Enum): + return self.register_enum_type(tag)(tree_type) + else: + raise TypeError( + f'Cannot register \'{tree_type}\' automatically. Use ' + '`register_type` with manual encode/decode functions.') + + return reg_fn + + def register_sequence_type( + self, + tag: Union[None, str, Sequence[str]] = None, + fallback: Optional[str] = UNKNOWN_SEQUENCE + ) -> Callable[[Type[Tree]], Type[Tree]]: + """Registers a sequence type. + + Args: + tag: Tags to register the type under. + fallback: Fallback type to use for loading if this type is not registered + at loading time, typically `UNKNOWN_SEQUENCE`. Can be `None` if you want + that situation to raise an error. + + Returns: + Registration decorator. + """ + + def reg_fn(tree_type: Type[Tree]) -> Type[Tree]: + return self.register_type( + tag, tree_type, + functools.partial( + _encode_sequence, tag=_get_first_tag(tag), fallback=fallback), + functools.partial(_decode_sequence, tree_type=tree_type)) + + return reg_fn + + def register_mapping_type( + self, + tag: Union[None, str, Sequence[str]] = None, + fallback: Optional[str] = UNKNOWN_MAPPING + ) -> Callable[[Type[Tree]], Type[Tree]]: + """Registers a mapping type. + + Args: + tag: Tags to register the type under. + fallback: Fallback type to use for loading if this type is not registered + at loading time, typically `UNKNOWN_MAPPING`. Can be `None` if you want + that situation to raise an error. + + Returns: + Registration decorator. + """ + + def reg_fn(tree_type: Type[Tree]) -> Type[Tree]: + return self.register_type( + tag, tree_type, + functools.partial( + _encode_mapping, tag=_get_first_tag(tag), fallback=fallback), + functools.partial(_decode_mapping, tree_type=tree_type)) + + return reg_fn + + def register_namedtuple_type( + self, + tag: Union[None, str, Sequence[str]] = None, + fallback: Optional[str] = UNKNOWN_NAMEDTUPLE + ) -> Callable[[Type[Tree]], Type[Tree]]: + """Registers a namedtuple type. + + Args: + tag: Tags to register the type under. + fallback: Fallback type to use for loading if this type is not registered + at loading time, typically `UNKNOWN_NAMEDTUPLE`. Can be `None` if you + want that situation to raise an error. + + Returns: + Registration decorator. + """ + + def reg_fn(tree_type: Type[Tree]) -> Type[Tree]: + return self.register_type( + tag, tree_type, + functools.partial( + _encode_namedtuple, tag=_get_first_tag(tag), fallback=fallback), + functools.partial(_decode_namedtuple, tree_type=tree_type)) + + return reg_fn + + def register_dataclass_type( + self, + tag: Union[None, str, Sequence[str]] = None, + fallback: Optional[str] = UNKNOWN_DATACLASS + ) -> Callable[[Type[Tree]], Type[Tree]]: + """Registers a dataclass type. + + Args: + tag: Tags to register the type under. + fallback: Fallback type to use for loading if this type is not registered + at loading time, typically `UNKNOWN_DATACLASS`. Can be `None` if you + want that situation to raise an error. + + Returns: + Registration decorator. + """ + + def reg_fn(tree_type: Type[Tree]) -> Type[Tree]: + return self.register_type( + tag, tree_type, + functools.partial( + _encode_dataclass, tag=_get_first_tag(tag), fallback=fallback), + functools.partial(_decode_dataclass, tree_type=tree_type)) + + return reg_fn + + def register_enum_type( + self, + tag: Union[None, str, Sequence[str]] = None + ) -> Callable[[Type[Tree]], Type[Tree]]: + """Registers an enum type. + + Args: + tag: Tags to register the type under. + + Returns: + Registration decorator. + """ + + def reg_fn(tree_type: Type[Tree]) -> Type[Tree]: + return self.register_type( + tag, tree_type, + functools.partial(_encode_enum, tag=_get_first_tag(tag)), + functools.partial(_decode_enum, tree_type=tree_type)) + + return reg_fn + + def _maybe_register_flax(self): + """Registers Flax types if Flax is importable.""" + try: + # pytype: disable=import-error + import flax # pylint: disable=g-import-not-at-top + # pytype: enable=import-error + + self.register_mapping_type('flax_frozen_dict')( + flax.core.frozen_dict.FrozenDict) + except ImportError: + pass + + def _maybe_register_jax(self): + """Registers JAX types if JAX is importable.""" + + try: + # pytype: disable=import-error + import jax # pylint: disable=g-import-not-at-top + + # pytype: enable=import-error + + def detect_jax_array(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if isinstance(tree, jax.Array): + return ARRAY + else: + return None + + self.register_detector('jax_array', detect_jax_array) + except ImportError: + pass + + def _maybe_register_tensorflow(self): + """Registers TensorFlow types if TensorFlow is importable.""" + + try: + # pytype: disable=import-error + import tensorflow as tf # pylint: disable=g-import-not-at-top + + # pytype: enable=import-error + + def detect_tf_tensor(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if isinstance(tree, tf.Tensor): + return ARRAY + else: + return None + + self.register_detector('tensorflow_tensor', detect_tf_tensor) + except ImportError: + pass + + def _maybe_register_tfp(self): + """Registers TFP types if TFP is importable.""" + + structural_tuple = None + try: + # pytype: disable=import-error + from tensorflow_probability.python.internal import structural_tuple # pylint: disable=g-import-not-at-top + # pytype: enable=import-error + except ImportError: + pass + + if structural_tuple is None: + try: + # pytype: disable=import-error + import tensorflow_probability.substrates.jax as tfp # pylint: disable=g-import-not-at-top + structural_tuple = tfp.internal.structural_tuple + # pytype: enable=import-error + except ImportError: + pass + + if structural_tuple is None: + try: + # pytype: disable=import-error + import tensorflow_probability.substrates.numpy as tfp # pylint: disable=g-import-not-at-top + structural_tuple = tfp.internal.structural_tuple + # pytype: enable=import-error + except ImportError: + pass + + if structural_tuple is not None: + tfp_struct_tuple = 'tfp_struct_tuple' + + def detect_tfp_struct_tuple(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if (hasattr(tree, '_tfp_nest_expansion_force_args') and + type(tree).__name__ == 'StructTuple'): + return tfp_struct_tuple + else: + return None + + def encode_tfp_struct_tuple(tree: Type[Tree], ctx: Context, + encode_fn: InnerEncodeFn) -> EncodedTree: + """Encodes a StructTuple type.""" + del ctx + encoded = {} + encoded[_TREE2_TYPE_TAG] = tfp_struct_tuple + encoded['val'] = {k: encode_fn(v) for k, v in tree._asdict().items()} + encoded['fallback_type'] = UNKNOWN_NAMEDTUPLE + return encoded + + def decode_tfp_struct_tuple(encoded: Any, ctx: Context) -> Any: + del ctx + return structural_tuple.structtuple( + encoded['val'].keys())(**encoded['val']) + + self.register_detector(tfp_struct_tuple, detect_tfp_struct_tuple) + self.register_tags(tfp_struct_tuple, encode_tfp_struct_tuple, + decode_tfp_struct_tuple) + + def _register_numpy(self): + """Registers np.ndarray and np.generic handling.""" + + def encode_array_fn(tree: Tree, ctx: Context, + encode_fn: InnerEncodeFn) -> EncodedTree: + del encode_fn + tree = np.asarray(tree) + + encoded = {} + encoded[_TREE2_TYPE_TAG] = ARRAY + encoded['dtype'] = np.dtype(tree.dtype).name + encoded['shape'] = list(tree.shape) + if np.size(tree) < 64: + encoded['val'] = tree.tolist() + else: + encoded['head'] = tree.flatten()[:10].tolist() + encoded['tail'] = tree.flatten()[-10:].tolist() + encoded['block'] = ctx.add_array(tree) + + return encoded + + def decode_array_fn(encoded: Any, ctx: Context) -> Any: + val = encoded.get('val') + if val is None: + array = ctx.get_array(encoded['block'], encoded['shape'], + np.dtype(encoded['dtype'])) + else: + array = np.array(val).astype(encoded['dtype']) + return array + + self.register_type(ARRAY, np.ndarray, encode_array_fn, decode_array_fn) + + def encode_scalar_fn(tree: Tree, ctx: Context, + encode_fn: InnerEncodeFn) -> EncodedTree: + del ctx, encode_fn + tree = np.asarray(tree) + + encoded = {} + encoded[_TREE2_TYPE_TAG] = SCALAR + encoded['dtype'] = np.dtype(tree.dtype).name + encoded['val'] = tree.tolist() + + return encoded + + def decode_scalar_fn(encoded: Any, ctx: Context) -> Any: + del ctx + return np.dtype(encoded['dtype']).type(encoded['val']) + + def detect_scalar(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if isinstance(tree, np.generic): + return SCALAR + else: + return None + + self.register_tags(SCALAR, encode_scalar_fn, decode_scalar_fn) + self.register_detector(SCALAR, detect_scalar) + + +# +# Sequences +# + + +def _encode_sequence(tree: Type[Tree], + ctx: Context, + encode_fn: InnerEncodeFn, + tag: str, + fallback: Optional[str] = UNKNOWN_SEQUENCE) -> EncodedTree: + """Encodes a sequence type.""" + del ctx + encoded = {} + encoded[_TREE2_TYPE_TAG] = tag + if type(tree) is list: # pylint: disable=unidiomatic-typecheck + return [encode_fn(v) for v in tree] + else: + encoded['val'] = [encode_fn(v) for v in tree] + if fallback is not None: + encoded['fallback_type'] = fallback + return encoded + + +def _decode_sequence(encoded: Any, ctx: Context, tree_type: Type[Tree]) -> Tree: + del ctx + return tree_type(encoded['val']) + + +def _detect_unknown_sequence(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if isinstance(tree, Sequence): + return UNKNOWN_SEQUENCE + else: + return None + + +def _encode_unknown_sequence(tree: Type[Tree], ctx: Context, + encode_fn: InnerEncodeFn) -> EncodedTree: + warnings.warn(f'Encoding unknown sequence type: {type(tree).__name__}') + return _encode_sequence(tree, ctx, encode_fn, type(tree).__name__) + + +def _decode_unknown_sequence(encoded: Any, ctx: Context) -> list[Any]: + warnings.warn(f'Decoding unknown sequence type: {encoded[_TREE2_TYPE_TAG]}') + return _decode_sequence(encoded, ctx, list) + + +# +# Mappings +# + + +def _encode_mapping(tree: Type[Tree], + ctx: Context, + encode_fn: InnerEncodeFn, + tag: str, + fallback: Optional[str] = UNKNOWN_MAPPING) -> EncodedTree: + """Encodes a mapping type.""" + del ctx + encoded = {} + encoded[_TREE2_TYPE_TAG] = tag + # Fast path: all-string keys and no special tags inside the mapping lets us + # use a more efficient encoding. + if all(isinstance(x, str) for x in tree) and _TREE2_TYPE_TAG not in tree: + if type(tree) is dict: # pylint: disable=unidiomatic-typecheck + return {k: encode_fn(v) for k, v in tree.items()} + else: + tree = {k: encode_fn(v) for k, v in tree.items()} + else: + tree = [[encode_fn(k), encode_fn(v)] for k, v in tree.items()] + encoded['val'] = tree + if fallback is not None: + encoded['fallback_type'] = fallback + return encoded + + +def _decode_mapping(encoded: Any, ctx: Context, tree_type: Type[Tree]) -> Tree: + del ctx + return tree_type(encoded['val']) + + +def _detect_unknown_mapping(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if isinstance(tree, Mapping): + return UNKNOWN_MAPPING + else: + return None + + +def _encode_unknown_mapping(tree: Type[Tree], ctx: Context, + encode_fn: InnerEncodeFn) -> EncodedTree: + warnings.warn(f'Encoding unknown mapping type: {type(tree).__name__}') + return _encode_mapping(tree, ctx, encode_fn, type(tree).__name__) + + +def _decode_unknown_mapping(encoded: Any, ctx: Context) -> dict[Any, Any]: + warnings.warn(f'Decoding unknown mapping type: {encoded[_TREE2_TYPE_TAG]}') + return _decode_mapping(encoded, ctx, dict) + + +# +# NamedTuples +# + + +def _encode_namedtuple( + tree: Type[Tree], + ctx: Context, + encode_fn: InnerEncodeFn, + tag: str, + fallback: Optional[str] = UNKNOWN_NAMEDTUPLE) -> EncodedTree: + """Encodes a namedtuple.""" + del ctx + encoded = {} + encoded[_TREE2_TYPE_TAG] = tag + encoded['val'] = {k: encode_fn(v) for k, v in tree._asdict().items()} + if fallback is not None: + encoded['fallback_type'] = fallback + return encoded + + +def _decode_namedtuple(encoded: Any, ctx: Context, + tree_type: Type[Tree]) -> Tree: + """Decodes a namedtuple.""" + del ctx + fields = set(tree_type._fields) + sanitized_val = {} + for k, v in encoded['val'].items(): + if k in fields: + sanitized_val[k] = v + else: + warnings.warn(f'Saw unknown field \'{k}\' while decoding ' + f'\'{encoded[_TREE2_TYPE_TAG]}\'') + return tree_type(**sanitized_val) + + +def _detect_unknown_namedtuple(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if isinstance(tree, tuple) and hasattr(tree, '_fields'): + return UNKNOWN_NAMEDTUPLE + else: + return None + + +def _encode_unknown_namedtuple(tree: Type[Tree], ctx: Context, + encode_fn: InnerEncodeFn) -> EncodedTree: + warnings.warn(f'Encoding unknown namedtuple type: {type(tree).__name__}') + return _encode_namedtuple(tree, ctx, encode_fn, type(tree).__name__) + + +def _decode_unknown_namedtuple(encoded: Any, ctx: Context) -> Any: + del ctx + warnings.warn(f'Decoding unknown namedtuple type: {encoded[_TREE2_TYPE_TAG]}') + + tree_type = collections.namedtuple(encoded[_TREE2_TYPE_TAG], + list(encoded['val'].keys())) + return tree_type(**encoded['val']) + + +# +# Dataclasses +# + + +def _encode_dataclass( + tree: Type[Tree], + ctx: Context, + encode_fn: InnerEncodeFn, + tag: str, + fallback: Optional[str] = UNKNOWN_DATACLASS) -> EncodedTree: + """Encodes a dataclass.""" + del ctx + encoded = {} + encoded[_TREE2_TYPE_TAG] = tag + encoded['val'] = { + f.name: encode_fn(getattr(tree, f.name)) for f in dataclasses.fields(tree) + } + if fallback is not None: + encoded['fallback_type'] = fallback + return encoded + + +def _decode_dataclass(encoded: Any, ctx: Context, + tree_type: Type[Tree]) -> Tree: + """Decodes a dataclass.""" + del ctx + fields = set(f.name for f in dataclasses.fields(tree_type)) + sanitized_val = {} + for k, v in encoded['val'].items(): + if k in fields: + sanitized_val[k] = v + else: + warnings.warn(f'Saw unknown field \'{k}\' while decoding ' + f'\'{encoded[_TREE2_TYPE_TAG]}\'') + return tree_type(**sanitized_val) + + +def _detect_unknown_dataclass(tree: Any, ctx: Context) -> Optional[str]: + del ctx + if dataclasses.is_dataclass(tree): + return UNKNOWN_DATACLASS + else: + return None + + +def _encode_unknown_dataclass(tree: Type[Tree], ctx: Context, + encode_fn: InnerEncodeFn) -> EncodedTree: + warnings.warn(f'Encoding unknown dataclass type: {type(tree).__name__}') + return _encode_dataclass(tree, ctx, encode_fn, type(tree).__name__) + + +def _decode_unknown_dataclass(encoded: Any, ctx: Context) -> Any: + del ctx + warnings.warn(f'Decoding unknown dataclass type: {encoded[_TREE2_TYPE_TAG]}') + + tree_type = dataclasses.make_dataclass(encoded[_TREE2_TYPE_TAG], + list(encoded['val'].keys())) + return tree_type(**encoded['val']) + + +# +# Enums +# + + +def _encode_enum(tree: Type[Tree], ctx: Context, encode_fn: InnerEncodeFn, + tag: str) -> EncodedTree: + """Encodes an enum type.""" + del ctx, encode_fn + encoded = {} + encoded[_TREE2_TYPE_TAG] = tag + encoded['val'] = tree.name + return encoded + + +def _decode_enum(encoded: Any, ctx: Context, tree_type: Type[Tree]) -> Tree: + del ctx + return tree_type[encoded['val']] diff --git a/discussion/robust_inverse_graphics/util/tree2_test.py b/discussion/robust_inverse_graphics/util/tree2_test.py new file mode 100644 index 0000000000..ecac7cbeda --- /dev/null +++ b/discussion/robust_inverse_graphics/util/tree2_test.py @@ -0,0 +1,334 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for tree.""" + +from collections.abc import Mapping, Sequence +import dataclasses +import enum +import os +from typing import Any, NamedTuple + +from absl.testing import parameterized +import flax +import immutabledict +import jax +import jax.numpy as jnp +import numpy as np + +from discussion.robust_inverse_graphics.util import test_util +from discussion.robust_inverse_graphics.util import tree2 +import tensorflow_probability.substrates.jax as tfp + +global_registry = tree2.Registry(allow_unknown_types=True) + + +class UnregisteredNamedTuple(NamedTuple): + x: Any + y: Any + + +@global_registry.auto_register_type('test.RegisteredNamedTuple') +class RegisteredNamedTuple(NamedTuple): + x: Any + y: Any + + +@dataclasses.dataclass +class UnregisteredDataClass: + x: Any + y: Any + + +@global_registry.auto_register_type('test.RegisteredDataClass') +@dataclasses.dataclass +class RegisteredDataClass: + x: Any + y: Any + + +# IntEnum, so we can sort them. +@global_registry.auto_register_type('test.RegisteredEnum') +class RegisteredEnum(enum.IntEnum): + X = enum.auto() + Y = enum.auto() + + +class UnregisteredSequence(Sequence): + + def __init__(self, values): + self._values = values + + def __getitem__(self, idx): + return self._values[idx] + + def __len__(self): + return len(self._values) + + +class UnregisteredMapping(Mapping): + + def __init__(self, values): + self._values = values + + def __getitem__(self, idx): + return self._values[idx] + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values.keys()) + + +class NamedTupleV0(NamedTuple): + x: Any + y: Any + + +class NamedTupleV1(NamedTuple): + x: Any + + +@dataclasses.dataclass +class DataClassV0: + x: Any + y: Any + + +@dataclasses.dataclass +class DataClassV1: + x: Any + + +@dataclasses.dataclass +class DataClassMulti: + x: Any + + +def make_structtuple(): + + @tfp.distributions.JointDistributionCoroutine + def model(): + yield tfp.distributions.Normal(0., 1., name='x') + + return model.sample(seed=jax.random.PRNGKey(0)) + + +class TreeTest(test_util.TestCase): + + @parameterized.named_parameters( + ('scalar', 0), + ('string', 'abc'), + ('list', [1, 2]), + ('tuple', (1, 2)), + ('dict', { + 'a': 1, + 'b': 2 + }), + ('immutabledict', immutabledict.immutabledict({ + 'a': 1, + 'b': 2 + })), + ('flax_frozendict', flax.core.frozen_dict.FrozenDict({ + 'a': 1, + 'b': 2 + })), + ('dict_int_keys', { + 1: 2, + 3: 4, + }), + ('dict_enum_keys', { + RegisteredEnum.X: 1, + RegisteredEnum.Y: 2, + }), + ('dict_tuple_keys', { + (1, 2): 1, + (3, 4): 2, + }), + ('set', set([1, 2])), + ('frozenset', frozenset([1, 2])), + ('namedtuple', RegisteredNamedTuple(1, 2)), + ('dataclass', RegisteredDataClass(1, 2)), + ('numpy_array', np.arange(3)), + ('structtuple', make_structtuple, True, True), + ('jax_array', lambda: jnp.arange(3), True), + ('jax_array_b16', lambda: jnp.array([1., 2.], jnp.bfloat16), True), + ('nested', + RegisteredNamedTuple( + RegisteredNamedTuple([1, 2], {'a': 3}), np.zeros(100))), + ('nested_dataclass', + RegisteredDataClass(RegisteredDataClass([1, 2], {'a': 3}), 1)), + ( + 'enum', + RegisteredEnum.X, + ), + ) + def test_roundtrip(self, + tree, + need_numpy_lhs=False, + compare_type_names=False): + if callable(tree): + tree = tree() + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + + global_registry.save_tree(tree, tree_path) + out_tree = global_registry.load_tree(tree_path) + + if need_numpy_lhs: + tree = jax.tree.map(np.asarray, tree) + if compare_type_names: + self.assertEqual(type(tree).__name__, type(out_tree).__name__) + else: + self.assertIs(type(tree), type(out_tree)) + self.assertAllEqualNested(tree, out_tree) + + def test_deferred(self): + tree = { + 'a': np.arange(10), + 'b': np.arange(1000).reshape([20, 50]).astype(np.float32) + } + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + + global_registry.save_tree(tree, tree_path) + out_tree = global_registry.load_tree(tree_path, {'defer_numpy': True}) + self.assertIsInstance(out_tree['a'], np.ndarray) + self.assertIsInstance(out_tree['b'], tree2.DeferredNumpyArray) + self.assertEqual(out_tree['b'].dtype, np.float32) + self.assertEqual(out_tree['b'].shape, (20, 50)) + + array = np.array(out_tree['b']) + self.assertEqual(array[19, 49], 999) + + def test_unregistered_sequence(self): + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + tree = UnregisteredSequence([1, 2]) + + global_registry.save_tree(tree, tree_path) + out_tree = global_registry.load_tree(tree_path) + + self.assertIsInstance(out_tree, list) + self.assertEqual(out_tree, list(tree)) + + def test_unregistered_mapping(self): + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + tree = UnregisteredMapping({'x': 1, 'y': 2}) + + global_registry.save_tree(tree, tree_path) + out_tree = global_registry.load_tree(tree_path) + + self.assertIsInstance(out_tree, dict) + self.assertEqual(out_tree, dict(tree)) + + def test_unregistered_namedtuple(self): + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + tree = UnregisteredNamedTuple(x=1, y=2) + + global_registry.save_tree(tree, tree_path) + out_tree = global_registry.load_tree(tree_path) + + self.assertAllEqualNested(out_tree, tree) + + def test_unregistered_dataclass(self): + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + tree = UnregisteredDataClass(x=1, y=2) + + global_registry.save_tree(tree, tree_path) + out_tree = global_registry.load_tree(tree_path) + + self.assertAllEqualNested( + dataclasses.asdict(out_tree), dataclasses.asdict(tree)) + + def test_namedtuple_unknown_field(self): + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + + reg1 = tree2.Registry() + reg1.register_namedtuple_type('test.namedtuple')(NamedTupleV0) + + tree = NamedTupleV0(x=1, y=2) + reg1.save_tree(tree, tree_path) + + # Simulate changing the definition of the type and loading an old serialized + # copy. + reg2 = tree2.Registry() + reg2.register_namedtuple_type('test.namedtuple')(NamedTupleV1) + out_tree = reg2.load_tree(tree_path) + + self.assertAllEqual({'x': 1}, out_tree._asdict()) + + def test_dataclass_unknown_field(self): + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + + reg1 = tree2.Registry() + reg1.register_dataclass_type('test.dataclass')(DataClassV0) + + tree = DataClassV0(x=1, y=2) + reg1.save_tree(tree, tree_path) + + # Simulate changing the definition of the type and loading an old serialized + # copy. + reg2 = tree2.Registry() + reg2.register_dataclass_type('test.dataclass')(DataClassV1) + out_tree = reg2.load_tree(tree_path) + + self.assertAllEqual({'x': 1}, dataclasses.asdict(out_tree)) + + def test_multiple_tags(self): + path = self.create_tempdir() + tree_path = os.path.join(path, 'tree') + + reg1 = tree2.Registry() + reg1.register_dataclass_type('test.multi_old')(DataClassMulti) + + tree = DataClassMulti(x=1) + reg1.save_tree(tree, tree_path) + + # Simulate changing the tag, but keeping backwards loading compatibility. + reg2 = tree2.Registry() + reg2.register_dataclass_type(['test.multi_new', 'test.multi_old'])( + DataClassMulti) + out_tree = reg2.load_tree(tree_path) + reg2.save_tree(tree, tree_path) + + # Verify that the tree got saved with `multi_new` tag. + reg3 = tree2.Registry() + reg3.register_dataclass_type('test.multi_new')(DataClassMulti) + out_tree = reg3.load_tree(tree_path) + + self.assertAllEqual({'x': 1}, dataclasses.asdict(out_tree)) + + def test_interactive_mode(self): + reg1 = tree2.Registry() + reg1.auto_register_type('test.UnregisteredNamedTuple')( + UnregisteredNamedTuple + ) + with self.assertRaisesRegex(TypeError, 'already registered'): + reg1.auto_register_type('test.UnregisteredNamedTuple')( + UnregisteredNamedTuple + ) + reg1.interactive_mode = True + reg1.auto_register_type('test.UnregisteredNamedTuple')( + UnregisteredNamedTuple + ) + +if __name__ == '__main__': + test_util.main() diff --git a/discussion/robust_inverse_graphics/util/tree_util.py b/discussion/robust_inverse_graphics/util/tree_util.py new file mode 100644 index 0000000000..06497712bb --- /dev/null +++ b/discussion/robust_inverse_graphics/util/tree_util.py @@ -0,0 +1,144 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tree utilities.""" + +from collections.abc import Callable, Sequence +import dataclasses +from typing import Any, Generic, TypeVar + +import jax + +__all__ = [ + 'DataclassView', + 'get_element', + 'update_element', +] + +T = TypeVar('T') + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class DataclassView(Generic[T]): + """Allows selecting which fields of a dataclass are visible to jax.tree_util. + + ```python + e = Example(a=1, b=2) + v_e = DataclassView(e, lambda n: n == 'a') + v_e = jax.tree.map(lambda x: x + 1, v_e) + assert Example(a=2, b=2) == v_e.value) + ``` + + Attributes: + value: The dataclass. + field_selector_fn: A callback that determines which fields are selected. + """ + + value: T + field_selector_fn: Callable[[str], bool] + + def __post_init__(self): + if not dataclasses.is_dataclass(self.value): + raise TypeError(f'class_tree must be a dataclass: {self.value}.') + + # XXX(siege): This is very improper, since the contract for tree_util is that + # the aux_data is hashable. + def tree_flatten(self) -> tuple[list[Any], 'DataclassView[T]']: + selected_fields = [ + getattr(self.value, f.name) + for f in dataclasses.fields(self.value) + if self.field_selector_fn(f.name) + ] + + return selected_fields, self + + @classmethod + def tree_unflatten( + cls, aux_data: 'DataclassView[T]', children: list[Any] + ) -> 'DataclassView[T]': + selected_field_names = [ + f.name + for f in dataclasses.fields(aux_data.value) + if aux_data.field_selector_fn(f.name) + ] + selected_fields = dict(zip(selected_field_names, children)) + return cls( + aux_data.value.replace(**selected_fields), aux_data.field_selector_fn + ) + + +def _handle_element( + tree: Any, + path: Sequence[Any], + leaf_fn: Callable[[Any], Any], + subtree_fn: Callable[[Any, Any, Any], Any], +) -> Any: + """Implementation of get/update_element.""" + if not path: + return leaf_fn(tree) + + cur, *rest = path + + if isinstance(tree, list): + subtree = tree[cur] + elif isinstance(tree, tuple) and not hasattr(tree, '_fields'): + subtree = tree[cur] + elif isinstance(tree, dict): + subtree = tree[cur] + elif dataclasses.is_dataclass(tree): + subtree = getattr(tree, cur) + elif hasattr(tree, cur): + # Namedtuple + subtree = getattr(tree, cur) + else: + raise TypeError(f'Cannot handle type: {type(tree)}') + + res = _handle_element(subtree, rest, leaf_fn, subtree_fn) + return subtree_fn(tree, res, cur) + + +def _update_subtree(tree: Any, res: Any, cur: Any) -> Any: + """Helper for update_element.""" + if isinstance(tree, list): + new_tree = list(tree) + new_tree[cur] = res + return new_tree + elif isinstance(tree, tuple) and not hasattr(tree, '_fields'): + new_tree = list(tree) + new_tree[cur] = res + return tuple(new_tree) + elif isinstance(tree, dict): + new_tree = tree.copy() + new_tree[cur] = res + return new_tree + elif dataclasses.is_dataclass(tree): + return dataclasses.replace(tree, **{cur: res}) + elif hasattr(tree, cur): + # Namedtuple + return tree._replace(**{cur: res}) + else: + raise TypeError(f'Cannot handle type: {type(tree)}') + + +def get_element(tree: Any, path: Sequence[Any]) -> Any: + """Returns an element from a tree given by its path.""" + return _handle_element(tree, path, lambda x: x, lambda _tree, res, _cur: res) + + +def update_element( + tree: Any, path: Sequence[Any], update_fn: Callable[[Any], Any] +) -> Any: + """Updates an element from a tree given its path and returns a new tree.""" + return _handle_element(tree, path, update_fn, _update_subtree) diff --git a/discussion/robust_inverse_graphics/util/tree_util_test.py b/discussion/robust_inverse_graphics/util/tree_util_test.py new file mode 100644 index 0000000000..3daf31456e --- /dev/null +++ b/discussion/robust_inverse_graphics/util/tree_util_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import Any, NamedTuple + +from flax import struct +import jax +import jax.numpy as jnp + +from discussion.robust_inverse_graphics.util import test_util +from discussion.robust_inverse_graphics.util import tree_util + + +@struct.dataclass +class Example: + a: Any + b: Any + + +class Example2(NamedTuple): + x: Any + + +class TreeUtilTest(test_util.TestCase): + + def test_dataclass_view(self): + e = Example(a=1, b=2) + v_e = tree_util.DataclassView(e, lambda n: n == 'a') + v_e = jax.tree.map(lambda x: x + 1, v_e) + self.assertAllEqual(Example(a=2, b=2), v_e.value) + + def test_get_element(self): + tree = Example(a=[0, [1, 2]], b=Example2(x=(3,))) + + self.assertEqual(0, tree_util.get_element(tree, ['a', 0])) + self.assertEqual([1, 2], tree_util.get_element(tree, ['a', 1])) + self.assertEqual(Example2(x=(3,)), tree_util.get_element(tree, ['b'])) + self.assertEqual(3, tree_util.get_element(tree, ['b', 'x', 0])) + + def test_update_element(self): + tree = Example(a=[0, jnp.array([1, 2])], b=Example2(x=(3,))) + + tree2 = tree_util.update_element(tree, ['a', 1], lambda x: x + 1) + self.assertAllClose(jnp.array([2, 3]), tree2.a[1]) + + tree2 = tree_util.update_element(tree, ['b', 'x', 0], lambda x: x + 1) + self.assertEqual((4,), tree2.b.x) + + +if __name__ == '__main__': + test_util.main()